diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD
index 8914876c4..42ddb5a6f 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD
@@ -21,6 +21,7 @@ py_strict_test(
timeout = "long",
srcs = ["epr_test.py"],
python_version = "PY3",
+ shard_count = 4,
tags = ["requires-net:external"],
deps = [
":epr",
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py
index 2e6efec5e..c21f952bb 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr.py
@@ -18,45 +18,98 @@
> "Scalable Model Compression by Entropy Penalized Reparameterization"
> D. Oktay, J. Ballé, S. Singh, A. Shrivastava
> https://arxiv.org/abs/1906.06624
+
+The "fast" version of EPR is inspired by the entropy code used in the paper:
+> "Optimizing the Communication-Accuracy Trade-Off in Federated Learning with
+> Rate-Distortion Theory"
+> N. Mitchell, J. Ballé, Z. Charles, J. Konečný
+> https://arxiv.org/abs/2201.02664
"""
import functools
-from typing import List
+from typing import Callable, List, Tuple
import tensorflow as tf
import tensorflow_compression as tfc
from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
-class EPR(algorithm.WeightCompressor):
+@tf.custom_gradient
+def _to_complex_with_gradient(
+ value: tf.Tensor) -> Tuple[tf.Tensor, Callable[[tf.Tensor], tf.Tensor]]:
+ return tf.bitcast(value, tf.complex64), lambda g: tf.bitcast(g, tf.float32)
+
+
+def _transform_dense_weight(
+ weight: tf.Tensor,
+ log_step: tf.Tensor,
+ quantized: bool = True) -> tf.Tensor:
+ """Transforms from latent to dense kernel or bias."""
+ step = tf.exp(log_step)
+ if not quantized:
+ weight = tfc.round_st(weight / step)
+ return weight * step
+
+
+def _transform_conv_weight(
+ kernel_rdft: tf.Tensor,
+ kernel_shape: tf.Tensor,
+ log_step: tf.Tensor,
+ quantized: bool = True) -> tf.Tensor:
+ """Transforms from latent to convolution kernel."""
+ step = tf.exp(log_step)
+ if not quantized:
+ kernel_rdft = tfc.round_st(kernel_rdft / step)
+ kernel_rdft *= step * tf.sqrt(
+ tf.cast(tf.reduce_prod(kernel_shape[:-2]), kernel_rdft.dtype))
+ kernel_rdft = _to_complex_with_gradient(kernel_rdft)
+ if kernel_rdft.shape.rank == 3:
+ # 1D convolution.
+ kernel = tf.signal.irfft(kernel_rdft, fft_length=kernel_shape[:-2])
+ return tf.transpose(kernel, (2, 0, 1))
+ else:
+ # 2D convolution.
+ kernel = tf.signal.irfft2d(kernel_rdft, fft_length=kernel_shape[:-2])
+ return tf.transpose(kernel, (2, 3, 0, 1))
+
+
+class EPRBase(algorithm.WeightCompressor):
"""Defines how to apply the EPR algorithm."""
- def __init__(self, entropy_penalty):
- self.entropy_penalty = entropy_penalty
+ _compressible_classes = (
+ tf.keras.layers.Dense,
+ tf.keras.layers.Conv1D,
+ tf.keras.layers.Conv2D,
+ )
+
+ def __init__(self, regularization_weight: float):
+ super().__init__()
+ self.regularization_weight = regularization_weight
def get_compressible_weights(self, original_layer):
- if isinstance(
- original_layer,
- (tf.keras.layers.Dense, tf.keras.layers.Conv1D, tf.keras.layers.Conv2D),
- ):
+ if isinstance(original_layer, self._compressible_classes):
if original_layer.use_bias:
return [original_layer.kernel, original_layer.bias]
else:
return [original_layer.kernel]
return []
- def init_training_weights(self, pretrained_weight: tf.Tensor):
+ def _init_training_weights_reparam(
+ self,
+ pretrained_weight: tf.Tensor) -> Tuple[
+ tf.TensorShape, tf.dtypes.DType, str]:
+ """Initializes training weights needed for reparameterization."""
shape = pretrained_weight.shape
dtype = pretrained_weight.dtype
weight_name = "bias" if shape.rank == 1 else "kernel"
if 1 <= shape.rank <= 2:
# Bias or dense kernel.
- prior_shape = []
self.add_training_weight(
name=weight_name,
- shape=pretrained_weight.shape,
- dtype=pretrained_weight.dtype,
+ shape=shape,
+ dtype=dtype,
initializer=tf.keras.initializers.Constant(pretrained_weight))
+ prior_shape = tf.TensorShape(())
elif 3 <= shape.rank <= 4:
# Convolution kernel.
kernel_shape = tf.shape(pretrained_weight)
@@ -66,9 +119,7 @@ def init_training_weights(self, pretrained_weight: tf.Tensor):
else:
kernel_rdft = tf.signal.rfft2d(
tf.transpose(pretrained_weight, (2, 3, 0, 1)))
- kernel_rdft = tf.stack(
- [tf.math.real(kernel_rdft), tf.math.imag(kernel_rdft)], axis=-1)
- prior_shape = tf.shape(kernel_rdft)[2:]
+ kernel_rdft = tf.bitcast(kernel_rdft, tf.float32)
kernel_rdft /= tf.sqrt(tf.cast(tf.reduce_prod(kernel_shape[:-2]), dtype))
self.add_training_weight(
name="kernel_rdft",
@@ -83,10 +134,11 @@ def init_training_weights(self, pretrained_weight: tf.Tensor):
# If True, throws warnings that int tensors have no gradient.
# trainable=False,
initializer=tf.keras.initializers.Constant(kernel_shape))
+ prior_shape = kernel_rdft.shape[2:]
else:
raise ValueError(
f"Expected bias or kernel tensor with rank between 1 and 4, received "
- f"shape {self._shape}.")
+ f"shape {shape}.")
# Logarithm of quantization step size.
log_step = tf.fill(prior_shape, tf.constant(-4, dtype=dtype))
@@ -96,7 +148,54 @@ def init_training_weights(self, pretrained_weight: tf.Tensor):
dtype=log_step.dtype,
initializer=tf.keras.initializers.Constant(log_step))
- # Logarithm of scale of prior.
+ return prior_shape, dtype, weight_name
+
+ def get_training_model(self, model: tf.keras.Model) -> tf.keras.Model:
+ """Augments a model for training with EPR."""
+ if not (isinstance(model, tf.keras.Sequential) or model._is_graph_network): # pylint: disable=protected-access
+ raise ValueError("`model` must be either sequential or functional.")
+
+ training_model = tf.keras.models.clone_model(
+ model, clone_function=functools.partial(
+ algorithm.create_layer_for_training, algorithm=self))
+ training_model.build(model.input.shape)
+
+ # Divide regularization weight by number of original model parameters to
+ # bring it into a more standardized range.
+ weight = self.regularization_weight / float(model.count_params())
+
+ def regularization_loss(layer, name):
+ return weight * self.regularization_loss(*layer.training_weights[name])
+
+ for layer in training_model.layers:
+ if not hasattr(layer, "attr_name_map"): continue
+ for name in layer.attr_name_map.values():
+ layer.add_loss(functools.partial(regularization_loss, layer, name))
+
+ # TODO(jballe): It would be great to be able to track the entropy losses
+ # combined during training. How to do this?
+ # TODO(jballe): Some models might require training log_scale weights with a
+ # different optimizer/learning rate. How to do this?
+ return training_model
+
+ def compress_model(self, model: tf.keras.Model) -> tf.keras.Model:
+ """Compresses a model after training with EPR."""
+ if not (isinstance(model, tf.keras.Sequential) or model._is_graph_network): # pylint: disable=protected-access
+ raise ValueError("`model` must be either sequential or functional.")
+ return tf.keras.models.clone_model(
+ model, clone_function=functools.partial(
+ algorithm.create_layer_for_inference, algorithm=self))
+
+
+class EPR(EPRBase):
+ """Defines how to apply the EPR algorithm."""
+
+ def init_training_weights(self, pretrained_weight: tf.Tensor):
+ prior_shape, dtype, weight_name = self._init_training_weights_reparam(
+ pretrained_weight)
+
+ # In addition to reparameterization weights, this method also needs a
+ # variable for the probability model (logarithm of scale of prior).
log_scale = tf.fill(prior_shape, tf.constant(2.5, dtype=dtype))
self.add_training_weight(
name=f"{weight_name}_log_scale",
@@ -104,26 +203,13 @@ def init_training_weights(self, pretrained_weight: tf.Tensor):
dtype=log_scale.dtype,
initializer=tf.keras.initializers.Constant(log_scale))
- def project_training_weights(self, *training_weights) -> tf.Tensor:
+ def project_training_weights(self, *training_weights: tf.Tensor) -> tf.Tensor:
if len(training_weights) == 3:
# Bias or dense kernel.
- weight, log_step, _ = training_weights
- step = tf.exp(log_step)
- return tfc.round_st(weight / step) * step
+ return _transform_dense_weight(*training_weights[:-1], quantized=False)
else:
# Convolution kernel.
- kernel_rdft, kernel_shape, log_step, _ = training_weights
- step = tf.exp(log_step)
- kernel_rdft = tfc.round_st(kernel_rdft / step)
- kernel_rdft *= step * tf.sqrt(
- tf.cast(tf.reduce_prod(kernel_shape[:-2]), kernel_rdft.dtype))
- kernel_rdft = tf.dtypes.complex(*tf.unstack(kernel_rdft, axis=-1))
- if kernel_rdft.shape.rank == 3:
- kernel = tf.signal.irfft(kernel_rdft, fft_length=kernel_shape[:-2])
- return tf.transpose(kernel, (2, 0, 1))
- else:
- kernel = tf.signal.irfft2d(kernel_rdft, fft_length=kernel_shape[:-2])
- return tf.transpose(kernel, (2, 3, 0, 1))
+ return _transform_conv_weight(*training_weights[:-1], quantized=False)
def compress_training_weights(
self, *training_weights: tf.Tensor) -> List[tf.Tensor]:
@@ -140,18 +226,30 @@ def compress_training_weights(
compression=True, stateless=True, offset_heuristic=False)
string = em.compress(weight / tf.exp(log_step))
weight_shape = tf.cast(weight_shape, tf.uint16)
- return [string, weight_shape, log_step, em.cdf, em.cdf_offset]
+ log_step = tf.cast(log_step, tf.float16)
+ cdf = tf.cast(em.cdf, tf.int16)
+ cdf_offset = tf.cast(em.cdf_offset, tf.int16)
+ return [string, weight_shape, log_step, cdf, cdf_offset]
- def decompress_weights(self, string, weight_shape, log_step,
- cdf, cdf_offset) -> tf.Tensor:
+ def decompress_weights(
+ self,
+ string: tf.Tensor,
+ weight_shape: tf.Tensor,
+ log_step: tf.Tensor,
+ cdf: tf.Tensor,
+ cdf_offset: tf.Tensor) -> tf.Tensor:
weight_shape = tf.cast(weight_shape, tf.int32)
+ log_step = tf.cast(log_step, tf.float32)
+ cdf = tf.cast(cdf, tf.int32)
+ cdf_offset = tf.cast(cdf_offset, tf.int32)
if weight_shape.shape[0] <= 2:
# Bias or dense kernel.
em = tfc.ContinuousBatchedEntropyModel(
prior_shape=log_step.shape, cdf=cdf, cdf_offset=cdf_offset,
coding_rank=weight_shape.shape[0], compression=True, stateless=True,
offset_heuristic=False)
- return em.decompress(string, weight_shape) * tf.exp(log_step)
+ weight = em.decompress(string, weight_shape)
+ return _transform_dense_weight(weight, log_step, quantized=True)
else:
# Convolution kernel.
em = tfc.ContinuousBatchedEntropyModel(
@@ -159,17 +257,10 @@ def decompress_weights(self, string, weight_shape, log_step,
coding_rank=weight_shape.shape[0] + 1, compression=True,
stateless=True, offset_heuristic=False)
kernel_rdft = em.decompress(string, weight_shape[-2:])
- kernel_rdft *= tf.exp(log_step) * tf.sqrt(
- tf.cast(tf.reduce_prod(weight_shape[:-2]), kernel_rdft.dtype))
- kernel_rdft = tf.dtypes.complex(*tf.unstack(kernel_rdft, axis=-1))
- if weight_shape.shape[0] == 3:
- kernel = tf.signal.irfft(kernel_rdft, fft_length=weight_shape[:-2])
- return tf.transpose(kernel, (2, 0, 1))
- else:
- kernel = tf.signal.irfft2d(kernel_rdft, fft_length=weight_shape[:-2])
- return tf.transpose(kernel, (2, 3, 0, 1))
+ return _transform_conv_weight(
+ kernel_rdft, weight_shape, log_step, quantized=True)
- def compute_entropy(self, *training_weights) -> tf.Tensor:
+ def regularization_loss(self, *training_weights: tf.Tensor) -> tf.Tensor:
if len(training_weights) == 3:
# Bias or dense kernel.
weight, log_step, log_scale = training_weights
@@ -183,60 +274,72 @@ def compute_entropy(self, *training_weights) -> tf.Tensor:
_, bits = em(weight / tf.exp(log_step), training=True)
return bits
- def get_training_model(self, model: tf.keras.Model) -> tf.keras.Model:
- """Augments a model for training with EPR."""
- # pylint: disable=protected-access
- if (not isinstance(model, tf.keras.Sequential) and
- not model._is_graph_network):
- raise ValueError(
- "`compress_model` must be either a sequential or functional model.")
- # pylint: enable=protected-access
-
- entropies = []
-
- # Number of dimensions of original model weights. Used to bring
- # entropy_penalty into a more standardized range.
- weight_dims = tf.add_n([tf.size(w) for w in model.trainable_weights])
-
- def create_layer_for_training(layer):
- if not layer.built:
- raise ValueError(
- "Applying EPR currently requires passing in a built model.")
- train_layer = algorithm.create_layer_for_training(layer, algorithm=self)
- train_layer.build(layer.input_shape)
- for name in train_layer.attr_name_map.values():
- entropy = functools.partial(
- self.compute_entropy, *train_layer.training_weights[name])
- entropies.append(entropy)
- return train_layer
-
- def compute_entropy_loss():
- total_entropy = tf.add_n([e() for e in entropies])
- entropy_penalty = self.entropy_penalty / tf.cast(
- weight_dims, total_entropy.dtype)
- return total_entropy * entropy_penalty
- training_model = tf.keras.models.clone_model(
- model, clone_function=create_layer_for_training)
- training_model.add_loss(compute_entropy_loss)
+class FastEPR(EPRBase):
+ """Defines how to apply a faster version of the EPR algorithm."""
- # TODO(jballe): It would be great to be able to track the entropy losses
- # combined during training. How to do this?
- # TODO(jballe): Some models might require training log_scale weights with a
- # different optimizer/learning rate. How to do this?
- return training_model
+ def __init__(self, regularization_weight: float, alpha: float = 1e-2):
+ super().__init__(regularization_weight)
+ self.alpha = alpha
- def compress_model(self, model: tf.keras.Model) -> tf.keras.Model:
- """Compresses a model after training with EPR."""
- # pylint: disable=protected-access
- if (not isinstance(model, tf.keras.Sequential) and
- not model._is_graph_network):
- raise ValueError(
- "`compress_model` must be either a sequential or functional model.")
- # pylint: enable=protected-access
+ def init_training_weights(self, pretrained_weight: tf.Tensor):
+ # The probability model is fixed, so only need reparameterization weights.
+ self._init_training_weights_reparam(pretrained_weight)
- def create_layer_for_inference(layer):
- return algorithm.create_layer_for_inference(layer, algorithm=self)
+ def project_training_weights(self, *training_weights: tf.Tensor) -> tf.Tensor:
+ if len(training_weights) == 2:
+ # Bias or dense kernel.
+ return _transform_dense_weight(*training_weights, quantized=False)
+ else:
+ # Convolution kernel.
+ return _transform_conv_weight(*training_weights, quantized=False)
- return tf.keras.models.clone_model(
- model, clone_function=create_layer_for_inference)
+ def compress_training_weights(
+ self,
+ *training_weights: tf.Tensor) -> List[tf.Tensor]:
+ if len(training_weights) == 2:
+ # Bias or dense kernel.
+ weight, log_step = training_weights
+ weight_shape = tf.shape(weight)
+ else:
+ # Convolution kernel.
+ weight, weight_shape, log_step = training_weights
+ em = tfc.PowerLawEntropyModel(
+ coding_rank=weight.shape.rank, alpha=self.alpha)
+ string = em.compress(weight / tf.exp(log_step))
+ weight_shape = tf.cast(weight_shape, tf.uint16)
+ log_step = tf.cast(log_step, tf.float16)
+ return [string, weight_shape, log_step]
+
+ def decompress_weights(
+ self,
+ string: tf.Tensor,
+ weight_shape: tf.Tensor,
+ log_step: tf.Tensor) -> tf.Tensor:
+ weight_shape = tf.cast(weight_shape, tf.int32)
+ log_step = tf.cast(log_step, tf.float32)
+ if weight_shape.shape[0] <= 2:
+ # Bias or dense kernel.
+ em = tfc.PowerLawEntropyModel(
+ coding_rank=weight_shape.shape[0], alpha=self.alpha)
+ weight = em.decompress(string, weight_shape)
+ return _transform_dense_weight(weight, log_step, quantized=True)
+ else:
+ # Convolution kernel.
+ em = tfc.PowerLawEntropyModel(
+ coding_rank=weight_shape.shape[0] + 1, alpha=self.alpha)
+ kernel_rdft = em.decompress(
+ string, tf.concat([weight_shape[-2:], tf.shape(log_step)], 0))
+ return _transform_conv_weight(
+ kernel_rdft, weight_shape, log_step, quantized=True)
+
+ def regularization_loss(self, *training_weights: tf.Tensor) -> tf.Tensor:
+ if len(training_weights) == 2:
+ # Bias or dense kernel.
+ weight, log_step = training_weights
+ else:
+ # Convolution kernel.
+ weight, _, log_step = training_weights
+ em = tfc.PowerLawEntropyModel(
+ coding_rank=weight.shape.rank, alpha=self.alpha)
+ return em.penalty(weight / tf.exp(log_step))
diff --git a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py
index 0ccd04757..d97560dd7 100644
--- a/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py
+++ b/tensorflow_model_optimization/python/core/common/keras/compression/algorithms/epr_test.py
@@ -64,15 +64,6 @@ def evaluate_model(model):
return results["accuracy"]
-def train_and_compress_model():
- model = build_model()
- algorithm = epr.EPR(entropy_penalty=10.)
- training_model = algorithm.get_training_model(model)
- train_model(training_model)
- compressed_model = algorithm.compress_model(training_model)
- return model, training_model, compressed_model
-
-
def get_weight_size_in_bytes(weight):
if weight.dtype == tf.string:
return tf.reduce_sum(tf.strings.length(weight, unit="BYTE"))
@@ -86,17 +77,17 @@ def zip_directory(dir_name):
class EPRTest(parameterized.TestCase, tf.test.TestCase):
- def _save_models(self, model, compressed_model):
+ def get_algorithm(self, regularization_weight=1.):
+ return epr.EPR(regularization_weight=regularization_weight)
+
+ def save_model(self, model):
model_dir = self.create_tempdir().full_path
- original_model_dir = os.path.join(model_dir, "original")
- compressed_model_dir = os.path.join(model_dir, "compressed")
- model.save(original_model_dir)
- compressed_model.save(compressed_model_dir)
- return original_model_dir, compressed_model_dir
+ model.save(model_dir)
+ return model_dir
@parameterized.parameters([5], [2, 3], [3, 4, 2], [2, 3, 4, 1])
def test_project_training_weights_has_gradients(self, *shape):
- algorithm = epr.EPR(entropy_penalty=1.)
+ algorithm = self.get_algorithm()
init = tf.ones(shape, dtype=tf.float32)
algorithm.init_training_weights(init)
layer = tf.keras.layers.Layer()
@@ -105,29 +96,53 @@ def test_project_training_weights_has_gradients(self, *shape):
with tf.GradientTape() as tape:
weight = algorithm.project_training_weights(*layer.weights)
gradients = tape.gradient(weight, layer.weights)
- # Last weight is scale of prior. Should not have a gradient here.
self.assertAllEqual(
[g is not None for g in gradients],
- [w.dtype.is_floating for w in layer.weights[:-1]] + [False])
+ [w.dtype.is_floating and "log_scale" not in w.name
+ for w in layer.weights])
@parameterized.parameters([5], [2, 3], [3, 4, 2], [2, 3, 4, 1])
- def test_compute_entropy_has_gradients(self, *shape):
- algorithm = epr.EPR(entropy_penalty=1.)
+ def test_regularization_loss_has_gradients(self, *shape):
+ algorithm = self.get_algorithm()
init = tf.ones(shape, dtype=tf.float32)
algorithm.init_training_weights(init)
layer = tf.keras.layers.Layer()
for weight_repr in algorithm.weight_reprs:
layer.add_weight(*weight_repr.args, **weight_repr.kwargs)
with tf.GradientTape() as tape:
- loss = algorithm.compute_entropy(*layer.weights)
+ loss = algorithm.regularization_loss(*layer.weights)
gradients = tape.gradient(loss, layer.weights)
self.assertAllEqual(
[g is not None for g in gradients],
[w.dtype.is_floating for w in layer.weights])
+ @parameterized.parameters(
+ ((2, 3), tf.keras.layers.Dense, 5),
+ # TODO(jballe): This fails with: 'You called `set_weights(weights)` on
+ # layer "private__training_wrapper" with a weight list of length 0, but
+ # the layer was expecting 5 weights.' Find fix.
+ # ((3, 10, 2), tf.keras.layers.Conv1D, 5, 3),
+ ((1, 8, 9, 2), tf.keras.layers.Conv2D, 5, 3),
+ )
+ def test_model_has_gradients(self, input_shape, layer_cls, *args):
+ algorithm = self.get_algorithm()
+ model = tf.keras.Sequential([layer_cls(*args, use_bias=True)])
+ inputs = tf.random.normal(input_shape)
+ model(inputs)
+ training_model = algorithm.get_training_model(model)
+ with tf.GradientTape(persistent=True) as tape:
+ tape.watch(inputs)
+ outputs = training_model(inputs)
+ loss = tf.reduce_sum(abs(outputs)) + tf.reduce_sum(training_model.losses)
+ self.assertIsNotNone(tape.gradient(loss, inputs))
+ gradients = tape.gradient(loss, training_model.trainable_weights)
+ self.assertAllEqual(
+ [g is not None for g in gradients],
+ [w.dtype.is_floating for w in training_model.trainable_weights])
+
@parameterized.parameters([5], [2, 3], [3, 4, 2], [2, 3, 4, 1])
def test_train_and_test_weights_are_equal(self, *shape):
- algorithm = epr.EPR(entropy_penalty=1.)
+ algorithm = self.get_algorithm()
init = tf.random.uniform(shape, dtype=tf.float32)
algorithm.init_training_weights(init)
layer = tf.keras.layers.Layer()
@@ -138,10 +153,32 @@ def test_train_and_test_weights_are_equal(self, *shape):
test_weight = algorithm.decompress_weights(*compressed_weights)
self.assertAllEqual(train_weight, test_weight)
+ @parameterized.parameters([5], [2, 3], [3, 4, 2], [2, 3, 4, 1])
+ def test_initialized_value_is_close_enough(self, *shape):
+ algorithm = self.get_algorithm()
+ init = tf.random.uniform(shape, -10., 10., dtype=tf.float32)
+ algorithm.init_training_weights(init)
+ layer = tf.keras.layers.Layer()
+ for weight_repr in algorithm.weight_reprs:
+ layer.add_weight(*weight_repr.args, **weight_repr.kwargs)
+ weight = algorithm.project_training_weights(*layer.weights)
+ quantization_noise_std_dev = tf.exp(-4.) / tf.sqrt(12.)
+ self.assertLess(
+ tf.sqrt(tf.reduce_mean(tf.square(init - weight))),
+ 3. * quantization_noise_std_dev)
+
def test_reduces_model_size_at_reasonable_accuracy(self):
- model, _, compressed_model = train_and_compress_model()
- original_model_dir, compressed_model_dir = self._save_models(
- model, compressed_model)
+ algorithm = self.get_algorithm()
+ model = build_model()
+ training_model = algorithm.get_training_model(model)
+ train_model(training_model)
+ compressed_model = algorithm.compress_model(training_model)
+ original_model_dir = self.save_model(model)
+ compressed_model_dir = self.save_model(compressed_model)
+
+ with self.subTest("training_model_has_reasonable_accuracy"):
+ accuracy = evaluate_model(training_model)
+ self.assertGreater(accuracy, .9)
with self.subTest("compressed_weights_are_smaller"):
original_size = sum(
@@ -162,11 +199,25 @@ def test_reduces_model_size_at_reasonable_accuracy(self):
# rather than of each layer?
self.assertLess(compressed_size, 0.2 * original_size)
- with self.subTest("has_reasonable_accuracy"):
+ with self.subTest("compressed_model_has_reasonable_accuracy"):
compressed_model = tf.keras.models.load_model(compressed_model_dir)
accuracy = evaluate_model(compressed_model)
self.assertGreater(accuracy, .9)
+ def test_unregularized_training_model_has_reasonable_accuracy(self):
+ algorithm = self.get_algorithm(regularization_weight=0.)
+ model = build_model()
+ training_model = algorithm.get_training_model(model)
+ train_model(training_model)
+ accuracy = evaluate_model(training_model)
+ self.assertGreater(accuracy, .9)
+
+
+class FastEPRTest(EPRTest):
+
+ def get_algorithm(self, regularization_weight=1.):
+ return epr.FastEPR(regularization_weight=regularization_weight, alpha=1e-2)
+
if __name__ == "__main__":
tf.test.main()