Skip to content

Commit

Permalink
Switches Keras object serialization to new logic and changes public A…
Browse files Browse the repository at this point in the history
…PI for deserialize_keras_object/serialize_keras_object to the new functions.

PiperOrigin-RevId: 480676373
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Oct 13, 2022
1 parent 0e08dea commit 31d0154
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ def testSerialization(self):

quantize_config_from_config = deserialize_keras_object(
serialized_quantize_config,
module_objects=globals(),
custom_objects=default_8bit_quantize_registry._types_dict())

self.assertEqual(quantize_config, quantize_config_from_config)
Expand Down Expand Up @@ -482,7 +481,6 @@ def testSerialization(self):

quantize_config_from_config = deserialize_keras_object(
serialized_quantize_config,
module_objects=globals(),
custom_objects=default_8bit_quantize_registry._types_dict())

self.assertEqual(self.quantize_config, quantize_config_from_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,7 @@ def testSerialization(self):
self.assertEqual(expected_config, serialized_quantize_config)

quantize_config_from_config = deserialize_keras_object(
serialized_quantize_config,
module_objects=globals(),
custom_objects=n_bit_registry._types_dict())
serialized_quantize_config, custom_objects=n_bit_registry._types_dict())

self.assertEqual(quantize_config, quantize_config_from_config)

Expand Down Expand Up @@ -491,9 +489,7 @@ def testSerialization(self):
self.assertEqual(expected_config, serialized_quantize_config)

quantize_config_from_config = deserialize_keras_object(
serialized_quantize_config,
module_objects=globals(),
custom_objects=n_bit_registry._types_dict())
serialized_quantize_config, custom_objects=n_bit_registry._types_dict())

self.assertEqual(self.quantize_config, quantize_config_from_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ def from_config(cls, config):
config = config.copy()

quantize_config = deserialize_keras_object(
config.pop('quantize_config'),
module_objects=globals(),
custom_objects=None)
config.pop('quantize_config'), custom_objects=None)

layer = tf.keras.layers.deserialize(config.pop('layer'))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def from_config(cls, config):

# Deserialization code should ensure Quantizer is in keras scope.
quantizer = deserialize_keras_object(
config.pop('quantizer'),
module_objects=globals(),
custom_objects=None)
config.pop('quantizer'), custom_objects=None)

return cls(quantizer=quantizer, **config)
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def from_config(cls, config):
# The deserialization code should ensure the QuantizeConfig is in keras
# serialization scope.
quantize_config = deserialize_keras_object(
config.pop('quantize_config'),
module_objects=globals(),
custom_objects=None)
config.pop('quantize_config'), custom_objects=None)

layer = tf.keras.layers.deserialize(config.pop('layer'))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def testSerialization(self, quantizer_type):
self.assertEqual(expected_config, serialized_quantizer)

quantizer_from_config = deserialize_keras_object(
serialized_quantizer,
module_objects=globals(),
custom_objects=quantizers._types_dict())
serialized_quantizer, custom_objects=quantizers._types_dict())

self.assertEqual(quantizer, quantizer_from_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,7 @@ def from_config(cls, config):
'PolynomialDecay': pruning_sched.PolynomialDecay
}
config['pruning_schedule'] = deserialize_keras_object(
pruning_schedule,
module_objects=globals(),
custom_objects=custom_objects)
pruning_schedule, custom_objects=custom_objects)

layer = keras.layers.deserialize(config.pop('layer'))
config['layer'] = layer
Expand Down

0 comments on commit 31d0154

Please sign in to comment.