From 7d2cbb6429d681ea23279f9051b8eb49f56dce23 Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Sun, 1 Oct 2017 10:42:51 -0700 Subject: [PATCH 1/2] robustify gan inference input checks --- edward/inferences/bigan_inference.py | 8 ++++++++ edward/inferences/gan_inference.py | 8 +++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/edward/inferences/bigan_inference.py b/edward/inferences/bigan_inference.py index 033c1d123..db54b5df6 100644 --- a/edward/inferences/bigan_inference.py +++ b/edward/inferences/bigan_inference.py @@ -41,6 +41,13 @@ class BiGANInference(GANInference): ``` """ def __init__(self, latent_vars, data, discriminator): + if len(key) != 1: + raise TypeError("latent_vars must have exactly one key.") + if len([key for key in six.iterkeys(data) + if not isinstance(key, tf.Tensor) or (isinstance(key, + tf.Tensor) and not "Placeholder" in key.op.type)]) != 1: + raise TypeError("data must have exactly one key that is not a " + "`tf.placeholder`.") if not callable(discriminator): raise TypeError("discriminator must be a callable function.") @@ -49,6 +56,7 @@ def __init__(self, latent_vars, data, discriminator): super(GANInference, self).__init__(latent_vars, data) def build_loss_and_gradients(self, var_list): + # TODO x_true = list(six.itervalues(self.data))[0] x_fake = list(six.iterkeys(self.data))[0] diff --git a/edward/inferences/gan_inference.py b/edward/inferences/gan_inference.py index f0dbbf069..0889f79c3 100644 --- a/edward/inferences/gan_inference.py +++ b/edward/inferences/gan_inference.py @@ -46,13 +46,18 @@ def __init__(self, data, discriminator): data: dict. Data dictionary which binds observed variables (of type `RandomVariable` or `tf.Tensor`) to their realizations (of - type `tf.Tensor`). It can also bind placeholders (of type + type `tf.Tensor`). It can also bind placeholders (of type `tf.Tensor`) used in the model to their realizations. discriminator: function. Function (with parameters) to discriminate samples. It should output logit probabilities (real-valued) and not probabilities in $[0, 1]$. """ + if len([key for key in six.iterkeys(data) + if not isinstance(key, tf.Tensor) or (isinstance(key, + tf.Tensor) and not "Placeholder" in key.op.type)]) != 1: + raise TypeError("data must have exactly one key that is not a " + "`tf.placeholder`.") if not callable(discriminator): raise TypeError("discriminator must be a callable function.") @@ -111,6 +116,7 @@ def initialize(self, optimizer=None, optimizer_d=None, self.summarize = tf.summary.merge_all(key=self._summary_key) def build_loss_and_gradients(self, var_list): + # TODO x_true = list(six.itervalues(self.data))[0] x_fake = list(six.iterkeys(self.data))[0] with tf.variable_scope("Disc"): From d0abba2006b8f21d118616222261d985b8589d5e Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Sun, 1 Oct 2017 10:46:30 -0700 Subject: [PATCH 2/2] check log_prob's exist for all explicit methods --- edward/inferences/gibbs.py | 1 + edward/inferences/hmc.py | 3 +++ edward/inferences/implicit_klqp.py | 2 ++ edward/inferences/klpq.py | 3 +++ edward/inferences/klqp.py | 4 ++++ edward/inferences/map.py | 3 +++ edward/util/random_variables.py | 30 ++++++++++++++++++++++++++++++ 7 files changed, 46 insertions(+) diff --git a/edward/inferences/gibbs.py b/edward/inferences/gibbs.py index 32e1586b7..980ce3cbe 100644 --- a/edward/inferences/gibbs.py +++ b/edward/inferences/gibbs.py @@ -49,6 +49,7 @@ def __init__(self, latent_vars, proposal_vars=None, data=None): self.proposal_vars = proposal_vars super(Gibbs, self).__init__(latent_vars, data) + # TODO what to need here? def initialize(self, scan_order='random', *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters diff --git a/edward/inferences/hmc.py b/edward/inferences/hmc.py index 45ed38363..93a26e8de 100644 --- a/edward/inferences/hmc.py +++ b/edward/inferences/hmc.py @@ -46,6 +46,9 @@ class HMC(MonteCarlo): """ def __init__(self, *args, **kwargs): super(HMC, self).__init__(*args, **kwargs) + # TODO + check_latent_vars_densities(latent_vars) + check_data_densities(data) def initialize(self, step_size=0.25, n_steps=2, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters diff --git a/edward/inferences/implicit_klqp.py b/edward/inferences/implicit_klqp.py index c9e4dd3d9..33eb0fe6b 100644 --- a/edward/inferences/implicit_klqp.py +++ b/edward/inferences/implicit_klqp.py @@ -79,6 +79,8 @@ def __init__(self, latent_vars, data=None, discriminator=None, self.global_vars = global_vars # call grandparent's method; avoid parent (GANInference) super(GANInference, self).__init__(latent_vars, data) + # TODO + check_latent_vars_densities(global_vars) def initialize(self, ratio_loss='log', *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters diff --git a/edward/inferences/klpq.py b/edward/inferences/klpq.py index a362d73e9..8961dbb46 100644 --- a/edward/inferences/klpq.py +++ b/edward/inferences/klpq.py @@ -43,6 +43,9 @@ class KLpq(VariationalInference): """ def __init__(self, *args, **kwargs): super(KLpq, self).__init__(*args, **kwargs) + # TODO + check_latent_vars_densities(latent_vars) + check_data_densities(data) def initialize(self, n_samples=1, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters diff --git a/edward/inferences/klqp.py b/edward/inferences/klqp.py index 491620d13..6ad0275bf 100644 --- a/edward/inferences/klqp.py +++ b/edward/inferences/klqp.py @@ -49,6 +49,9 @@ class KLqp(VariationalInference): """ def __init__(self, *args, **kwargs): super(KLqp, self).__init__(*args, **kwargs) + # TODO + check_latent_vars_densities(latent_vars) + check_data_densities(data) def initialize(self, n_samples=1, kl_scaling=None, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters @@ -137,6 +140,7 @@ class ReparameterizationKLqp(VariationalInference): """ def __init__(self, *args, **kwargs): super(ReparameterizationKLqp, self).__init__(*args, **kwargs) + # TODO def initialize(self, n_samples=1, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters diff --git a/edward/inferences/map.py b/edward/inferences/map.py index 68268c5a4..fc1391ad3 100644 --- a/edward/inferences/map.py +++ b/edward/inferences/map.py @@ -108,6 +108,9 @@ def __init__(self, latent_vars=None, data=None): "PointMass random variables.") super(MAP, self).__init__(latent_vars, data) + # TODO + check_latent_vars_densities(latent_vars) + check_data_densities(data) def build_loss_and_gradients(self, var_list): """Build loss function. Its automatic differentiation diff --git a/edward/util/random_variables.py b/edward/util/random_variables.py index 5b5f3d137..ede241a78 100644 --- a/edward/util/random_variables.py +++ b/edward/util/random_variables.py @@ -59,6 +59,24 @@ def check_data(data): raise TypeError("Data key has an invalid type: {}".format(type(key))) +def check_data_densities(latent_vars): + """Check that dictionary is collection of ed.RandomVariable + key-value pairs with `_log_prob` implemented.""" + for key, value in six.iteritems(dictionary): + valid_key = (isinstance(key, tf.Tensor) and "Placeholder" in key.op.type) or \ + (isinstance(key, RandomVariable) and hasattr(key, '_log_prob')) + if not valid_key: + raise TypeError("Dictionary key must be a ed.RandomVariable with " + "the `_log_prob` method implemented") + elif not isinstance(value, RandomVariable) or not hasattr(value, '_log_prob'): + raise TypeError("Dictionary value must be a ed.RandomVariable with " + "the `_log_prob` method implemented") + for key in six.iterkeys(data): + if isinstance(key, tf.Tensor) and not "Placeholder" in key.op.type: + raise TypeError("Data key must be a ed.RandomVariable object or " + "tf.placeholder object.") + + def check_latent_vars(latent_vars): """Check that the latent variable dictionary passed during inference and criticism is valid. @@ -81,6 +99,18 @@ def check_latent_vars(latent_vars): "dtype: {}, {}".format(key.dtype, value.dtype)) +def check_latent_vars_densities(latent_vars): + """Check that dictionary is collection of ed.RandomVariable + key-value pairs with `_log_prob` implemented.""" + for key, value in six.iteritems(dictionary): + if not isinstance(key, RandomVariable) or not hasattr(key, '_log_prob'): + raise TypeError("Dictionary key must be a ed.RandomVariable with " + "the `_log_prob` method implemented") + elif not isinstance(value, RandomVariable) or not hasattr(value, '_log_prob'): + raise TypeError("Dictionary value must be a ed.RandomVariable with " + "the `_log_prob` method implemented") + + def _copy_default(x, *args, **kwargs): if isinstance(x, (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)): x = copy(x, *args, **kwargs)