Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] More robustly check inputs to inference algorithms #777

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions edward/inferences/bigan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ class BiGANInference(GANInference):
```
"""
def __init__(self, latent_vars, data, discriminator):
if len(key) != 1:
raise TypeError("latent_vars must have exactly one key.")
if len([key for key in six.iterkeys(data)
if not isinstance(key, tf.Tensor) or (isinstance(key,
tf.Tensor) and not "Placeholder" in key.op.type)]) != 1:
raise TypeError("data must have exactly one key that is not a "
"`tf.placeholder`.")
if not callable(discriminator):
raise TypeError("discriminator must be a callable function.")

Expand All @@ -49,6 +56,7 @@ def __init__(self, latent_vars, data, discriminator):
super(GANInference, self).__init__(latent_vars, data)

def build_loss_and_gradients(self, var_list):
# TODO
x_true = list(six.itervalues(self.data))[0]
x_fake = list(six.iterkeys(self.data))[0]

Expand Down
8 changes: 7 additions & 1 deletion edward/inferences/gan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ def __init__(self, data, discriminator):
data: dict.
Data dictionary which binds observed variables (of type
`RandomVariable` or `tf.Tensor`) to their realizations (of
type `tf.Tensor`). It can also bind placeholders (of type
type `tf.Tensor`). It can also bind placeholders (of type
`tf.Tensor`) used in the model to their realizations.
discriminator: function.
Function (with parameters) to discriminate samples. It should
output logit probabilities (real-valued) and not probabilities
in $[0, 1]$.
"""
if len([key for key in six.iterkeys(data)
if not isinstance(key, tf.Tensor) or (isinstance(key,
tf.Tensor) and not "Placeholder" in key.op.type)]) != 1:
raise TypeError("data must have exactly one key that is not a "
"`tf.placeholder`.")
if not callable(discriminator):
raise TypeError("discriminator must be a callable function.")

Expand Down Expand Up @@ -111,6 +116,7 @@ def initialize(self, optimizer=None, optimizer_d=None,
self.summarize = tf.summary.merge_all(key=self._summary_key)

def build_loss_and_gradients(self, var_list):
# TODO
x_true = list(six.itervalues(self.data))[0]
x_fake = list(six.iterkeys(self.data))[0]
with tf.variable_scope("Disc"):
Expand Down
1 change: 1 addition & 0 deletions edward/inferences/gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, latent_vars, proposal_vars=None, data=None):

self.proposal_vars = proposal_vars
super(Gibbs, self).__init__(latent_vars, data)
# TODO what to need here?

def initialize(self, scan_order='random', *args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
Expand Down
3 changes: 3 additions & 0 deletions edward/inferences/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class HMC(MonteCarlo):
"""
def __init__(self, *args, **kwargs):
super(HMC, self).__init__(*args, **kwargs)
# TODO
check_latent_vars_densities(latent_vars)
check_data_densities(data)

def initialize(self, step_size=0.25, n_steps=2, *args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
Expand Down
2 changes: 2 additions & 0 deletions edward/inferences/implicit_klqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self, latent_vars, data=None, discriminator=None,
self.global_vars = global_vars
# call grandparent's method; avoid parent (GANInference)
super(GANInference, self).__init__(latent_vars, data)
# TODO
check_latent_vars_densities(global_vars)

def initialize(self, ratio_loss='log', *args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
Expand Down
3 changes: 3 additions & 0 deletions edward/inferences/klpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class KLpq(VariationalInference):
"""
def __init__(self, *args, **kwargs):
super(KLpq, self).__init__(*args, **kwargs)
# TODO
check_latent_vars_densities(latent_vars)
check_data_densities(data)

def initialize(self, n_samples=1, *args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
Expand Down
4 changes: 4 additions & 0 deletions edward/inferences/klqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class KLqp(VariationalInference):
"""
def __init__(self, *args, **kwargs):
super(KLqp, self).__init__(*args, **kwargs)
# TODO
check_latent_vars_densities(latent_vars)
check_data_densities(data)

def initialize(self, n_samples=1, kl_scaling=None, *args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
Expand Down Expand Up @@ -137,6 +140,7 @@ class ReparameterizationKLqp(VariationalInference):
"""
def __init__(self, *args, **kwargs):
super(ReparameterizationKLqp, self).__init__(*args, **kwargs)
# TODO

def initialize(self, n_samples=1, *args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
Expand Down
3 changes: 3 additions & 0 deletions edward/inferences/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def __init__(self, latent_vars=None, data=None):
"PointMass random variables.")

super(MAP, self).__init__(latent_vars, data)
# TODO
check_latent_vars_densities(latent_vars)
check_data_densities(data)

def build_loss_and_gradients(self, var_list):
"""Build loss function. Its automatic differentiation
Expand Down
30 changes: 30 additions & 0 deletions edward/util/random_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def check_data(data):
raise TypeError("Data key has an invalid type: {}".format(type(key)))


def check_data_densities(latent_vars):
"""Check that dictionary is collection of ed.RandomVariable
key-value pairs with `_log_prob` implemented."""
for key, value in six.iteritems(dictionary):
valid_key = (isinstance(key, tf.Tensor) and "Placeholder" in key.op.type) or \
(isinstance(key, RandomVariable) and hasattr(key, '_log_prob'))
if not valid_key:
raise TypeError("Dictionary key must be a ed.RandomVariable with "
"the `_log_prob` method implemented")
elif not isinstance(value, RandomVariable) or not hasattr(value, '_log_prob'):
raise TypeError("Dictionary value must be a ed.RandomVariable with "
"the `_log_prob` method implemented")
for key in six.iterkeys(data):
if isinstance(key, tf.Tensor) and not "Placeholder" in key.op.type:
raise TypeError("Data key must be a ed.RandomVariable object or "
"tf.placeholder object.")


def check_latent_vars(latent_vars):
"""Check that the latent variable dictionary passed during inference and
criticism is valid.
Expand All @@ -81,6 +99,18 @@ def check_latent_vars(latent_vars):
"dtype: {}, {}".format(key.dtype, value.dtype))


def check_latent_vars_densities(latent_vars):
"""Check that dictionary is collection of ed.RandomVariable
key-value pairs with `_log_prob` implemented."""
for key, value in six.iteritems(dictionary):
if not isinstance(key, RandomVariable) or not hasattr(key, '_log_prob'):
raise TypeError("Dictionary key must be a ed.RandomVariable with "
"the `_log_prob` method implemented")
elif not isinstance(value, RandomVariable) or not hasattr(value, '_log_prob'):
raise TypeError("Dictionary value must be a ed.RandomVariable with "
"the `_log_prob` method implemented")


def _copy_default(x, *args, **kwargs):
if isinstance(x, (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)):
x = copy(x, *args, **kwargs)
Expand Down