diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD index 8914876c4..42ddb5a6f 100644 --- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD +++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD @@ -21,6 +21,7 @@ py_strict_test( timeout = "long", srcs = ["epr_test.py"], python_version = "PY3", + shard_count = 4, tags = ["requires-net:external"], deps = [ ":epr", diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py index 2e6efec5e..c21f952bb 100644 --- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py +++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py @@ -18,45 +18,98 @@ > "Scalable Model Compression by Entropy Penalized Reparameterization"
> D. Oktay, J. Ballé, S. Singh, A. Shrivastava
> https://arxiv.org/abs/1906.06624 + +The "fast" version of EPR is inspired by the entropy code used in the paper: +> "Optimizing the Communication-Accuracy Trade-Off in Federated Learning with +> Rate-Distortion Theory"
+> N. Mitchell, J. Ballé, Z. Charles, J. Konečný
+> https://arxiv.org/abs/2201.02664 """ import functools -from typing import List +from typing import Callable, List, Tuple import tensorflow as tf import tensorflow_compression as tfc from tensorflow_model_optimization.python.core.common.keras.compression import algorithm -class EPR(algorithm.WeightCompressor): +@tf.custom_gradient +def _to_complex_with_gradient( + value: tf.Tensor) -> Tuple[tf.Tensor, Callable[[tf.Tensor], tf.Tensor]]: + return tf.bitcast(value, tf.complex64), lambda g: tf.bitcast(g, tf.float32) + + +def _transform_dense_weight( + weight: tf.Tensor, + log_step: tf.Tensor, + quantized: bool = True) -> tf.Tensor: + """Transforms from latent to dense kernel or bias.""" + step = tf.exp(log_step) + if not quantized: + weight = tfc.round_st(weight / step) + return weight * step + + +def _transform_conv_weight( + kernel_rdft: tf.Tensor, + kernel_shape: tf.Tensor, + log_step: tf.Tensor, + quantized: bool = True) -> tf.Tensor: + """Transforms from latent to convolution kernel.""" + step = tf.exp(log_step) + if not quantized: + kernel_rdft = tfc.round_st(kernel_rdft / step) + kernel_rdft *= step * tf.sqrt( + tf.cast(tf.reduce_prod(kernel_shape[:-2]), kernel_rdft.dtype)) + kernel_rdft = _to_complex_with_gradient(kernel_rdft) + if kernel_rdft.shape.rank == 3: + # 1D convolution. + kernel = tf.signal.irfft(kernel_rdft, fft_length=kernel_shape[:-2]) + return tf.transpose(kernel, (2, 0, 1)) + else: + # 2D convolution. + kernel = tf.signal.irfft2d(kernel_rdft, fft_length=kernel_shape[:-2]) + return tf.transpose(kernel, (2, 3, 0, 1)) + + +class EPRBase(algorithm.WeightCompressor): """Defines how to apply the EPR algorithm.""" - def __init__(self, entropy_penalty): - self.entropy_penalty = entropy_penalty + _compressible_classes = ( + tf.keras.layers.Dense, + tf.keras.layers.Conv1D, + tf.keras.layers.Conv2D, + ) + + def __init__(self, regularization_weight: float): + super().__init__() + self.regularization_weight = regularization_weight def get_compressible_weights(self, original_layer): - if isinstance( - original_layer, - (tf.keras.layers.Dense, tf.keras.layers.Conv1D, tf.keras.layers.Conv2D), - ): + if isinstance(original_layer, self._compressible_classes): if original_layer.use_bias: return [original_layer.kernel, original_layer.bias] else: return [original_layer.kernel] return [] - def init_training_weights(self, pretrained_weight: tf.Tensor): + def _init_training_weights_reparam( + self, + pretrained_weight: tf.Tensor) -> Tuple[ + tf.TensorShape, tf.dtypes.DType, str]: + """Initializes training weights needed for reparameterization.""" shape = pretrained_weight.shape dtype = pretrained_weight.dtype weight_name = "bias" if shape.rank == 1 else "kernel" if 1 <= shape.rank <= 2: # Bias or dense kernel. - prior_shape = [] self.add_training_weight( name=weight_name, - shape=pretrained_weight.shape, - dtype=pretrained_weight.dtype, + shape=shape, + dtype=dtype, initializer=tf.keras.initializers.Constant(pretrained_weight)) + prior_shape = tf.TensorShape(()) elif 3 <= shape.rank <= 4: # Convolution kernel. kernel_shape = tf.shape(pretrained_weight) @@ -66,9 +119,7 @@ def init_training_weights(self, pretrained_weight: tf.Tensor): else: kernel_rdft = tf.signal.rfft2d( tf.transpose(pretrained_weight, (2, 3, 0, 1))) - kernel_rdft = tf.stack( - [tf.math.real(kernel_rdft), tf.math.imag(kernel_rdft)], axis=-1) - prior_shape = tf.shape(kernel_rdft)[2:] + kernel_rdft = tf.bitcast(kernel_rdft, tf.float32) kernel_rdft /= tf.sqrt(tf.cast(tf.reduce_prod(kernel_shape[:-2]), dtype)) self.add_training_weight( name="kernel_rdft", @@ -83,10 +134,11 @@ def init_training_weights(self, pretrained_weight: tf.Tensor): # If True, throws warnings that int tensors have no gradient. # trainable=False, initializer=tf.keras.initializers.Constant(kernel_shape)) + prior_shape = kernel_rdft.shape[2:] else: raise ValueError( f"Expected bias or kernel tensor with rank between 1 and 4, received " - f"shape {self._shape}.") + f"shape {shape}.") # Logarithm of quantization step size. log_step = tf.fill(prior_shape, tf.constant(-4, dtype=dtype)) @@ -96,7 +148,54 @@ def init_training_weights(self, pretrained_weight: tf.Tensor): dtype=log_step.dtype, initializer=tf.keras.initializers.Constant(log_step)) - # Logarithm of scale of prior. + return prior_shape, dtype, weight_name + + def get_training_model(self, model: tf.keras.Model) -> tf.keras.Model: + """Augments a model for training with EPR.""" + if not (isinstance(model, tf.keras.Sequential) or model._is_graph_network): # pylint: disable=protected-access + raise ValueError("`model` must be either sequential or functional.") + + training_model = tf.keras.models.clone_model( + model, clone_function=functools.partial( + algorithm.create_layer_for_training, algorithm=self)) + training_model.build(model.input.shape) + + # Divide regularization weight by number of original model parameters to + # bring it into a more standardized range. + weight = self.regularization_weight / float(model.count_params()) + + def regularization_loss(layer, name): + return weight * self.regularization_loss(*layer.training_weights[name]) + + for layer in training_model.layers: + if not hasattr(layer, "attr_name_map"): continue + for name in layer.attr_name_map.values(): + layer.add_loss(functools.partial(regularization_loss, layer, name)) + + # TODO(jballe): It would be great to be able to track the entropy losses + # combined during training. How to do this? + # TODO(jballe): Some models might require training log_scale weights with a + # different optimizer/learning rate. How to do this? + return training_model + + def compress_model(self, model: tf.keras.Model) -> tf.keras.Model: + """Compresses a model after training with EPR.""" + if not (isinstance(model, tf.keras.Sequential) or model._is_graph_network): # pylint: disable=protected-access + raise ValueError("`model` must be either sequential or functional.") + return tf.keras.models.clone_model( + model, clone_function=functools.partial( + algorithm.create_layer_for_inference, algorithm=self)) + + +class EPR(EPRBase): + """Defines how to apply the EPR algorithm.""" + + def init_training_weights(self, pretrained_weight: tf.Tensor): + prior_shape, dtype, weight_name = self._init_training_weights_reparam( + pretrained_weight) + + # In addition to reparameterization weights, this method also needs a + # variable for the probability model (logarithm of scale of prior). log_scale = tf.fill(prior_shape, tf.constant(2.5, dtype=dtype)) self.add_training_weight( name=f"{weight_name}_log_scale", @@ -104,26 +203,13 @@ def init_training_weights(self, pretrained_weight: tf.Tensor): dtype=log_scale.dtype, initializer=tf.keras.initializers.Constant(log_scale)) - def project_training_weights(self, *training_weights) -> tf.Tensor: + def project_training_weights(self, *training_weights: tf.Tensor) -> tf.Tensor: if len(training_weights) == 3: # Bias or dense kernel. - weight, log_step, _ = training_weights - step = tf.exp(log_step) - return tfc.round_st(weight / step) * step + return _transform_dense_weight(*training_weights[:-1], quantized=False) else: # Convolution kernel. - kernel_rdft, kernel_shape, log_step, _ = training_weights - step = tf.exp(log_step) - kernel_rdft = tfc.round_st(kernel_rdft / step) - kernel_rdft *= step * tf.sqrt( - tf.cast(tf.reduce_prod(kernel_shape[:-2]), kernel_rdft.dtype)) - kernel_rdft = tf.dtypes.complex(*tf.unstack(kernel_rdft, axis=-1)) - if kernel_rdft.shape.rank == 3: - kernel = tf.signal.irfft(kernel_rdft, fft_length=kernel_shape[:-2]) - return tf.transpose(kernel, (2, 0, 1)) - else: - kernel = tf.signal.irfft2d(kernel_rdft, fft_length=kernel_shape[:-2]) - return tf.transpose(kernel, (2, 3, 0, 1)) + return _transform_conv_weight(*training_weights[:-1], quantized=False) def compress_training_weights( self, *training_weights: tf.Tensor) -> List[tf.Tensor]: @@ -140,18 +226,30 @@ def compress_training_weights( compression=True, stateless=True, offset_heuristic=False) string = em.compress(weight / tf.exp(log_step)) weight_shape = tf.cast(weight_shape, tf.uint16) - return [string, weight_shape, log_step, em.cdf, em.cdf_offset] + log_step = tf.cast(log_step, tf.float16) + cdf = tf.cast(em.cdf, tf.int16) + cdf_offset = tf.cast(em.cdf_offset, tf.int16) + return [string, weight_shape, log_step, cdf, cdf_offset] - def decompress_weights(self, string, weight_shape, log_step, - cdf, cdf_offset) -> tf.Tensor: + def decompress_weights( + self, + string: tf.Tensor, + weight_shape: tf.Tensor, + log_step: tf.Tensor, + cdf: tf.Tensor, + cdf_offset: tf.Tensor) -> tf.Tensor: weight_shape = tf.cast(weight_shape, tf.int32) + log_step = tf.cast(log_step, tf.float32) + cdf = tf.cast(cdf, tf.int32) + cdf_offset = tf.cast(cdf_offset, tf.int32) if weight_shape.shape[0] <= 2: # Bias or dense kernel. em = tfc.ContinuousBatchedEntropyModel( prior_shape=log_step.shape, cdf=cdf, cdf_offset=cdf_offset, coding_rank=weight_shape.shape[0], compression=True, stateless=True, offset_heuristic=False) - return em.decompress(string, weight_shape) * tf.exp(log_step) + weight = em.decompress(string, weight_shape) + return _transform_dense_weight(weight, log_step, quantized=True) else: # Convolution kernel. em = tfc.ContinuousBatchedEntropyModel( @@ -159,17 +257,10 @@ def decompress_weights(self, string, weight_shape, log_step, coding_rank=weight_shape.shape[0] + 1, compression=True, stateless=True, offset_heuristic=False) kernel_rdft = em.decompress(string, weight_shape[-2:]) - kernel_rdft *= tf.exp(log_step) * tf.sqrt( - tf.cast(tf.reduce_prod(weight_shape[:-2]), kernel_rdft.dtype)) - kernel_rdft = tf.dtypes.complex(*tf.unstack(kernel_rdft, axis=-1)) - if weight_shape.shape[0] == 3: - kernel = tf.signal.irfft(kernel_rdft, fft_length=weight_shape[:-2]) - return tf.transpose(kernel, (2, 0, 1)) - else: - kernel = tf.signal.irfft2d(kernel_rdft, fft_length=weight_shape[:-2]) - return tf.transpose(kernel, (2, 3, 0, 1)) + return _transform_conv_weight( + kernel_rdft, weight_shape, log_step, quantized=True) - def compute_entropy(self, *training_weights) -> tf.Tensor: + def regularization_loss(self, *training_weights: tf.Tensor) -> tf.Tensor: if len(training_weights) == 3: # Bias or dense kernel. weight, log_step, log_scale = training_weights @@ -183,60 +274,72 @@ def compute_entropy(self, *training_weights) -> tf.Tensor: _, bits = em(weight / tf.exp(log_step), training=True) return bits - def get_training_model(self, model: tf.keras.Model) -> tf.keras.Model: - """Augments a model for training with EPR.""" - # pylint: disable=protected-access - if (not isinstance(model, tf.keras.Sequential) and - not model._is_graph_network): - raise ValueError( - "`compress_model` must be either a sequential or functional model.") - # pylint: enable=protected-access - - entropies = [] - - # Number of dimensions of original model weights. Used to bring - # entropy_penalty into a more standardized range. - weight_dims = tf.add_n([tf.size(w) for w in model.trainable_weights]) - - def create_layer_for_training(layer): - if not layer.built: - raise ValueError( - "Applying EPR currently requires passing in a built model.") - train_layer = algorithm.create_layer_for_training(layer, algorithm=self) - train_layer.build(layer.input_shape) - for name in train_layer.attr_name_map.values(): - entropy = functools.partial( - self.compute_entropy, *train_layer.training_weights[name]) - entropies.append(entropy) - return train_layer - - def compute_entropy_loss(): - total_entropy = tf.add_n([e() for e in entropies]) - entropy_penalty = self.entropy_penalty / tf.cast( - weight_dims, total_entropy.dtype) - return total_entropy * entropy_penalty - training_model = tf.keras.models.clone_model( - model, clone_function=create_layer_for_training) - training_model.add_loss(compute_entropy_loss) +class FastEPR(EPRBase): + """Defines how to apply a faster version of the EPR algorithm.""" - # TODO(jballe): It would be great to be able to track the entropy losses - # combined during training. How to do this? - # TODO(jballe): Some models might require training log_scale weights with a - # different optimizer/learning rate. How to do this? - return training_model + def __init__(self, regularization_weight: float, alpha: float = 1e-2): + super().__init__(regularization_weight) + self.alpha = alpha - def compress_model(self, model: tf.keras.Model) -> tf.keras.Model: - """Compresses a model after training with EPR.""" - # pylint: disable=protected-access - if (not isinstance(model, tf.keras.Sequential) and - not model._is_graph_network): - raise ValueError( - "`compress_model` must be either a sequential or functional model.") - # pylint: enable=protected-access + def init_training_weights(self, pretrained_weight: tf.Tensor): + # The probability model is fixed, so only need reparameterization weights. + self._init_training_weights_reparam(pretrained_weight) - def create_layer_for_inference(layer): - return algorithm.create_layer_for_inference(layer, algorithm=self) + def project_training_weights(self, *training_weights: tf.Tensor) -> tf.Tensor: + if len(training_weights) == 2: + # Bias or dense kernel. + return _transform_dense_weight(*training_weights, quantized=False) + else: + # Convolution kernel. + return _transform_conv_weight(*training_weights, quantized=False) - return tf.keras.models.clone_model( - model, clone_function=create_layer_for_inference) + def compress_training_weights( + self, + *training_weights: tf.Tensor) -> List[tf.Tensor]: + if len(training_weights) == 2: + # Bias or dense kernel. + weight, log_step = training_weights + weight_shape = tf.shape(weight) + else: + # Convolution kernel. + weight, weight_shape, log_step = training_weights + em = tfc.PowerLawEntropyModel( + coding_rank=weight.shape.rank, alpha=self.alpha) + string = em.compress(weight / tf.exp(log_step)) + weight_shape = tf.cast(weight_shape, tf.uint16) + log_step = tf.cast(log_step, tf.float16) + return [string, weight_shape, log_step] + + def decompress_weights( + self, + string: tf.Tensor, + weight_shape: tf.Tensor, + log_step: tf.Tensor) -> tf.Tensor: + weight_shape = tf.cast(weight_shape, tf.int32) + log_step = tf.cast(log_step, tf.float32) + if weight_shape.shape[0] <= 2: + # Bias or dense kernel. + em = tfc.PowerLawEntropyModel( + coding_rank=weight_shape.shape[0], alpha=self.alpha) + weight = em.decompress(string, weight_shape) + return _transform_dense_weight(weight, log_step, quantized=True) + else: + # Convolution kernel. + em = tfc.PowerLawEntropyModel( + coding_rank=weight_shape.shape[0] + 1, alpha=self.alpha) + kernel_rdft = em.decompress( + string, tf.concat([weight_shape[-2:], tf.shape(log_step)], 0)) + return _transform_conv_weight( + kernel_rdft, weight_shape, log_step, quantized=True) + + def regularization_loss(self, *training_weights: tf.Tensor) -> tf.Tensor: + if len(training_weights) == 2: + # Bias or dense kernel. + weight, log_step = training_weights + else: + # Convolution kernel. + weight, _, log_step = training_weights + em = tfc.PowerLawEntropyModel( + coding_rank=weight.shape.rank, alpha=self.alpha) + return em.penalty(weight / tf.exp(log_step)) diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py index 0ccd04757..d97560dd7 100644 --- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py +++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py @@ -64,15 +64,6 @@ def evaluate_model(model): return results["accuracy"] -def train_and_compress_model(): - model = build_model() - algorithm = epr.EPR(entropy_penalty=10.) - training_model = algorithm.get_training_model(model) - train_model(training_model) - compressed_model = algorithm.compress_model(training_model) - return model, training_model, compressed_model - - def get_weight_size_in_bytes(weight): if weight.dtype == tf.string: return tf.reduce_sum(tf.strings.length(weight, unit="BYTE")) @@ -86,17 +77,17 @@ def zip_directory(dir_name): class EPRTest(parameterized.TestCase, tf.test.TestCase): - def _save_models(self, model, compressed_model): + def get_algorithm(self, regularization_weight=1.): + return epr.EPR(regularization_weight=regularization_weight) + + def save_model(self, model): model_dir = self.create_tempdir().full_path - original_model_dir = os.path.join(model_dir, "original") - compressed_model_dir = os.path.join(model_dir, "compressed") - model.save(original_model_dir) - compressed_model.save(compressed_model_dir) - return original_model_dir, compressed_model_dir + model.save(model_dir) + return model_dir @parameterized.parameters([5], [2, 3], [3, 4, 2], [2, 3, 4, 1]) def test_project_training_weights_has_gradients(self, *shape): - algorithm = epr.EPR(entropy_penalty=1.) + algorithm = self.get_algorithm() init = tf.ones(shape, dtype=tf.float32) algorithm.init_training_weights(init) layer = tf.keras.layers.Layer() @@ -105,29 +96,53 @@ def test_project_training_weights_has_gradients(self, *shape): with tf.GradientTape() as tape: weight = algorithm.project_training_weights(*layer.weights) gradients = tape.gradient(weight, layer.weights) - # Last weight is scale of prior. Should not have a gradient here. self.assertAllEqual( [g is not None for g in gradients], - [w.dtype.is_floating for w in layer.weights[:-1]] + [False]) + [w.dtype.is_floating and "log_scale" not in w.name + for w in layer.weights]) @parameterized.parameters([5], [2, 3], [3, 4, 2], [2, 3, 4, 1]) - def test_compute_entropy_has_gradients(self, *shape): - algorithm = epr.EPR(entropy_penalty=1.) + def test_regularization_loss_has_gradients(self, *shape): + algorithm = self.get_algorithm() init = tf.ones(shape, dtype=tf.float32) algorithm.init_training_weights(init) layer = tf.keras.layers.Layer() for weight_repr in algorithm.weight_reprs: layer.add_weight(*weight_repr.args, **weight_repr.kwargs) with tf.GradientTape() as tape: - loss = algorithm.compute_entropy(*layer.weights) + loss = algorithm.regularization_loss(*layer.weights) gradients = tape.gradient(loss, layer.weights) self.assertAllEqual( [g is not None for g in gradients], [w.dtype.is_floating for w in layer.weights]) + @parameterized.parameters( + ((2, 3), tf.keras.layers.Dense, 5), + # TODO(jballe): This fails with: 'You called `set_weights(weights)` on + # layer "private__training_wrapper" with a weight list of length 0, but + # the layer was expecting 5 weights.' Find fix. + # ((3, 10, 2), tf.keras.layers.Conv1D, 5, 3), + ((1, 8, 9, 2), tf.keras.layers.Conv2D, 5, 3), + ) + def test_model_has_gradients(self, input_shape, layer_cls, *args): + algorithm = self.get_algorithm() + model = tf.keras.Sequential([layer_cls(*args, use_bias=True)]) + inputs = tf.random.normal(input_shape) + model(inputs) + training_model = algorithm.get_training_model(model) + with tf.GradientTape(persistent=True) as tape: + tape.watch(inputs) + outputs = training_model(inputs) + loss = tf.reduce_sum(abs(outputs)) + tf.reduce_sum(training_model.losses) + self.assertIsNotNone(tape.gradient(loss, inputs)) + gradients = tape.gradient(loss, training_model.trainable_weights) + self.assertAllEqual( + [g is not None for g in gradients], + [w.dtype.is_floating for w in training_model.trainable_weights]) + @parameterized.parameters([5], [2, 3], [3, 4, 2], [2, 3, 4, 1]) def test_train_and_test_weights_are_equal(self, *shape): - algorithm = epr.EPR(entropy_penalty=1.) + algorithm = self.get_algorithm() init = tf.random.uniform(shape, dtype=tf.float32) algorithm.init_training_weights(init) layer = tf.keras.layers.Layer() @@ -138,10 +153,32 @@ def test_train_and_test_weights_are_equal(self, *shape): test_weight = algorithm.decompress_weights(*compressed_weights) self.assertAllEqual(train_weight, test_weight) + @parameterized.parameters([5], [2, 3], [3, 4, 2], [2, 3, 4, 1]) + def test_initialized_value_is_close_enough(self, *shape): + algorithm = self.get_algorithm() + init = tf.random.uniform(shape, -10., 10., dtype=tf.float32) + algorithm.init_training_weights(init) + layer = tf.keras.layers.Layer() + for weight_repr in algorithm.weight_reprs: + layer.add_weight(*weight_repr.args, **weight_repr.kwargs) + weight = algorithm.project_training_weights(*layer.weights) + quantization_noise_std_dev = tf.exp(-4.) / tf.sqrt(12.) + self.assertLess( + tf.sqrt(tf.reduce_mean(tf.square(init - weight))), + 3. * quantization_noise_std_dev) + def test_reduces_model_size_at_reasonable_accuracy(self): - model, _, compressed_model = train_and_compress_model() - original_model_dir, compressed_model_dir = self._save_models( - model, compressed_model) + algorithm = self.get_algorithm() + model = build_model() + training_model = algorithm.get_training_model(model) + train_model(training_model) + compressed_model = algorithm.compress_model(training_model) + original_model_dir = self.save_model(model) + compressed_model_dir = self.save_model(compressed_model) + + with self.subTest("training_model_has_reasonable_accuracy"): + accuracy = evaluate_model(training_model) + self.assertGreater(accuracy, .9) with self.subTest("compressed_weights_are_smaller"): original_size = sum( @@ -162,11 +199,25 @@ def test_reduces_model_size_at_reasonable_accuracy(self): # rather than of each layer? self.assertLess(compressed_size, 0.2 * original_size) - with self.subTest("has_reasonable_accuracy"): + with self.subTest("compressed_model_has_reasonable_accuracy"): compressed_model = tf.keras.models.load_model(compressed_model_dir) accuracy = evaluate_model(compressed_model) self.assertGreater(accuracy, .9) + def test_unregularized_training_model_has_reasonable_accuracy(self): + algorithm = self.get_algorithm(regularization_weight=0.) + model = build_model() + training_model = algorithm.get_training_model(model) + train_model(training_model) + accuracy = evaluate_model(training_model) + self.assertGreater(accuracy, .9) + + +class FastEPRTest(EPRTest): + + def get_algorithm(self, regularization_weight=1.): + return epr.FastEPR(regularization_weight=regularization_weight, alpha=1e-2) + if __name__ == "__main__": tf.test.main()