From 31d01544badf06c55b6c3c50abec7daa76fb4346 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Wed, 12 Oct 2022 11:34:17 -0700 Subject: [PATCH] Switches Keras object serialization to new logic and changes public API for deserialize_keras_object/serialize_keras_object to the new functions. PiperOrigin-RevId: 480676373 --- .../default_8bit/default_8bit_quantize_registry_test.py | 2 -- .../default_n_bit/default_n_bit_quantize_registry_test.py | 8 ++------ .../python/core/quantization/keras/quantize_annotate.py | 4 +--- .../python/core/quantization/keras/quantize_layer.py | 4 +--- .../python/core/quantization/keras/quantize_wrapper.py | 4 +--- .../python/core/quantization/keras/quantizers_test.py | 4 +--- .../python/core/sparsity/keras/pruning_wrapper.py | 4 +--- 7 files changed, 7 insertions(+), 23 deletions(-) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py index 3efec2d6a..01c4d099d 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py @@ -366,7 +366,6 @@ def testSerialization(self): quantize_config_from_config = deserialize_keras_object( serialized_quantize_config, - module_objects=globals(), custom_objects=default_8bit_quantize_registry._types_dict()) self.assertEqual(quantize_config, quantize_config_from_config) @@ -482,7 +481,6 @@ def testSerialization(self): quantize_config_from_config = deserialize_keras_object( serialized_quantize_config, - module_objects=globals(), custom_objects=default_8bit_quantize_registry._types_dict()) self.assertEqual(self.quantize_config, quantize_config_from_config) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py index 75b9a31e7..4ab73dbdc 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py @@ -372,9 +372,7 @@ def testSerialization(self): self.assertEqual(expected_config, serialized_quantize_config) quantize_config_from_config = deserialize_keras_object( - serialized_quantize_config, - module_objects=globals(), - custom_objects=n_bit_registry._types_dict()) + serialized_quantize_config, custom_objects=n_bit_registry._types_dict()) self.assertEqual(quantize_config, quantize_config_from_config) @@ -491,9 +489,7 @@ def testSerialization(self): self.assertEqual(expected_config, serialized_quantize_config) quantize_config_from_config = deserialize_keras_object( - serialized_quantize_config, - module_objects=globals(), - custom_objects=n_bit_registry._types_dict()) + serialized_quantize_config, custom_objects=n_bit_registry._types_dict()) self.assertEqual(self.quantize_config, quantize_config_from_config) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py index e41686221..11842c58d 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py @@ -108,9 +108,7 @@ def from_config(cls, config): config = config.copy() quantize_config = deserialize_keras_object( - config.pop('quantize_config'), - module_objects=globals(), - custom_objects=None) + config.pop('quantize_config'), custom_objects=None) layer = tf.keras.layers.deserialize(config.pop('layer')) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py index be59458ca..70f9fb2bd 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py @@ -93,8 +93,6 @@ def from_config(cls, config): # Deserialization code should ensure Quantizer is in keras scope. quantizer = deserialize_keras_object( - config.pop('quantizer'), - module_objects=globals(), - custom_objects=None) + config.pop('quantizer'), custom_objects=None) return cls(quantizer=quantizer, **config) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py index 1e84dc01d..c68717fa7 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py @@ -197,9 +197,7 @@ def from_config(cls, config): # The deserialization code should ensure the QuantizeConfig is in keras # serialization scope. quantize_config = deserialize_keras_object( - config.pop('quantize_config'), - module_objects=globals(), - custom_objects=None) + config.pop('quantize_config'), custom_objects=None) layer = tf.keras.layers.deserialize(config.pop('layer')) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py index 7b3dcc3ed..ff091e2e3 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py @@ -97,9 +97,7 @@ def testSerialization(self, quantizer_type): self.assertEqual(expected_config, serialized_quantizer) quantizer_from_config = deserialize_keras_object( - serialized_quantizer, - module_objects=globals(), - custom_objects=quantizers._types_dict()) + serialized_quantizer, custom_objects=quantizers._types_dict()) self.assertEqual(quantizer, quantizer_from_config) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index 65d5a69a5..82e5b9125 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -324,9 +324,7 @@ def from_config(cls, config): 'PolynomialDecay': pruning_sched.PolynomialDecay } config['pruning_schedule'] = deserialize_keras_object( - pruning_schedule, - module_objects=globals(), - custom_objects=custom_objects) + pruning_schedule, custom_objects=custom_objects) layer = keras.layers.deserialize(config.pop('layer')) config['layer'] = layer