forked from mattjj/pybasicbayes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathabstractions.py
146 lines (111 loc) · 3.7 KB
/
abstractions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import abc
import numpy as np
from util.stats import combinedata
# NOTE: data is always a (possibly masked) np.ndarray or list of (possibly
# masked) np.ndarrays.
################
# Base class #
################
class Distribution(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def rvs(self,size=[]):
'''
random variates (samples)
'''
pass
@abc.abstractmethod
def log_likelihood(self,x):
'''
log likelihood (either log probability mass function or log probability
density function)
'''
pass
#########################################################
# Algorithm interfaces for inference in distributions #
#########################################################
class GibbsSampling(Distribution):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def resample(self,data=[]):
pass
class MeanField(Distribution):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def expected_log_likelihood(self,x):
pass
@abc.abstractmethod
def meanfieldupdate(self,data,weights):
pass
def get_vlb(self):
raise NotImplementedError
class Collapsed(Distribution):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def log_marginal_likelihood(self,data):
pass
def log_predictive(self,newdata,olddata):
return self.log_marginal_likelihood(combinedata((newdata,olddata))) \
- self.log_marginal_likelihood(olddata)
def predictive(self,*args,**kwargs):
return np.exp(self.log_predictive(*args,**kwargs))
class MaxLikelihood(Distribution):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def max_likelihood(self,data,weights=None):
'''
sets the parameters set to their maximum likelihood values given the
(weighted) data
'''
pass
def max_likelihood_constructor(cls,data,weights=None):
'''
creates a new instance with the parameters set to their maximum
likelihood values and the hyperparameters set to something reasonable
along the lines of empirical Bayes
'''
raise NotImplementedError
def max_likelihood_withprior(self,data,weights=None):
'''
max_likelihood including prior statistics, for use with MAP EM
'''
raise NotImplementedError
############
# Models #
############
# what differentiates a "model" from a "distribution" in this code is latent
# state over data: a model attaches a latent variable (like a label or state
# sequence) to data, and so it 'holds onto' data. Hence the add_data method.
class Model(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def add_data(self,data):
pass
@abc.abstractmethod
def generate(self,keep=True,**kwargs):
'''
Like a distribution's rvs, but this also fills in latent state over
data and keeps references to the data.
'''
pass
def rvs(self,*args,**kwargs):
return self.generate(*args,keep=False,**kwargs)[0] # 0th component is data, not latent stuff
##################################################
# Algorithm interfaces for inference in models #
##################################################
class ModelGibbsSampling(Model):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def resample_model(self): # TODO niter?
pass
class ModelMeanField(Model):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def meanfield_coordinate_descent_step(self):
# returns variational lower bound after update
pass
class ModelEM(Model):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def EM_step(self):
pass