Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal changes #819

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
class ModelTransformer(object):
"""Matches patterns to apply transforms in a tf.keras model graph."""

def __init__(
self, model, transforms, candidate_layers=None, layer_metadata=None):
def __init__(self,
model,
transforms,
candidate_layers=None,
layer_metadata=None):
"""Construct ModelTransformer.

Args:
Expand Down Expand Up @@ -68,7 +71,8 @@ def _is_functional_model(self, model):

def _inbound_node_generator(self, layer):
for inbound_node in layer['inbound_nodes']:
if len(inbound_node) > 0 and isinstance(inbound_node[0], str):
if (isinstance(inbound_node, list) and len(inbound_node) > 0 and
isinstance(inbound_node[0], str)):
# TODO(tfmot): The case for the SlicingOpLambda.
yield [inbound_node]
else:
Expand All @@ -78,12 +82,17 @@ def _get_inbound_layer_names(self, layer):
"""Return all the inbound connection layer names for the layer."""
inbound_layer_names = []
for inbound_node in self._inbound_node_generator(layer):
# TODO(b/197935452): temporary fix when the input is a dictionary of
# tensors. A comprehensive solution may be needed.
if isinstance(inbound_node, dict):
inbound_node = inbound_node.values()
for connection_info in inbound_node:
# input argument case.
inbound_layer_names.append(connection_info[0])
# **kwarg argument case.
inbound_layer_names += [
value[0] for value in connection_info[3].items()]
value[0] for value in connection_info[3].items()
]

return inbound_layer_names

Expand Down Expand Up @@ -212,10 +221,10 @@ def _match_layer_with_inputs(self, layer, pattern, is_head_node):

if len(pattern.inputs) == 0:
# Leaf layer in pattern.
return LayerNode(layer, self._get_layer_weights(layer['config']['name']),
[], self._get_layer_metadata(layer['config']['name']),
self._get_layer_names_and_weights(
layer['config']['name']))
return LayerNode(
layer, self._get_layer_weights(layer['config']['name']), [],
self._get_layer_metadata(layer['config']['name']),
self._get_layer_names_and_weights(layer['config']['name']))

# There is a possible edge case where a single layer may output multiple
# tensors and multiple tensors from that layer may be used by the
Expand Down Expand Up @@ -313,8 +322,8 @@ def _replace_functional(self, match_layer_node, replacement_layer_node):
match_name = match_layer_node.layer['config']['name']
replacement_name = replacement_layer_node.layer['config']['name']

def _replace_layer_name_for_connection_info(
connection_info, match_name, replacement_name):
def _replace_layer_name_for_connection_info(connection_info, match_name,
replacement_name):
if connection_info[0] == match_name:
connection_info[0] = replacement_name
for key in connection_info[3]:
Expand All @@ -323,9 +332,11 @@ def _replace_layer_name_for_connection_info(

for consumer in consuming_layers:
for inbound_node in self._inbound_node_generator(consumer):
if isinstance(inbound_node, dict):
inbound_node = inbound_node.values()
for connection_info in inbound_node:
_replace_layer_name_for_connection_info(
connection_info, match_name, replacement_name)
_replace_layer_name_for_connection_info(connection_info, match_name,
replacement_name)

output_consumers = self._get_output_consumers(match_layer_node.layer)
for output_consumer in output_consumers:
Expand Down Expand Up @@ -493,8 +504,7 @@ def _set_layer_weights(self, layer, weights_map):
for weight_tensor in layer.weights:
weight_name = self._weight_name(weight_tensor.name)
if weight_name in weights_map:
weight_value_tuples.append(
(weight_tensor, weights_map[weight_name]))
weight_value_tuples.append((weight_tensor, weights_map[weight_name]))

K.batch_set_value(weight_value_tuples)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,14 @@ def _get_layer(self, model, n_excluding_input, model_type):
return model.layers[n_excluding_input]

def _create_model_inputs(self, model):
return np.random.randn(*self._batch(model.input.get_shape().as_list(), 1))
if isinstance(model.input, dict):
inputs = {}
for key, input_layer in model.input.items():
inputs[key] = np.random.randn(
*self._batch(input_layer.get_shape().as_list(), 1))
return inputs
else:
return np.random.randn(*self._batch(model.input.get_shape().as_list(), 1))

def _simple_dense_model(self, model_type='functional'):
if model_type == 'functional':
Expand Down Expand Up @@ -90,9 +97,7 @@ def _nested_model(self, model_type='functional', submodel_type='functional'):
out = keras.layers.ReLU(6.0)(x)
return keras.Model(inp, out)
elif model_type == 'sequential':
return keras.Sequential(
[submodel,
keras.layers.ReLU(6.0)])
return keras.Sequential([submodel, keras.layers.ReLU(6.0)])

def _assert_config(self, expected_config, actual_config, exclude_keys=None):
"""Asserts that the two config dictionaries are equal.
Expand Down Expand Up @@ -216,6 +221,35 @@ def testReplaceSingleLayerWithSingleLayer_MultipleOccurrences(

self._assert_model_results_equal(model, transformed_model)

def testReplaceSingleLayerWithSingleLayer_DictInputOutput(self):
inp = {
'input1': keras.layers.Input((3,)),
'input2': keras.layers.Input((3,))
}
x1 = keras.layers.Dense(2)(inp['input1'])
x2 = keras.layers.Dense(2)(inp['input2'])
out1 = keras.layers.ReLU(6.0)(x1)
out2 = keras.layers.ReLU(6.0)(x2)
model = keras.Model(inp, {'output1': out1, 'output2': out2})

transformed_model, _ = ModelTransformer(
model, [self.ReplaceDenseLayer()]).transform()

# build_input_shape is a TensorShape object and the two objects are not
# considered the same even though the shapes are the same.
self._assert_config(model.get_config(), transformed_model.get_config(),
['class_name', 'build_input_shape'])

# There are two input layers in the input dict.
self.assertEqual(
'MyDense',
self._get_layer(transformed_model, 1, 'functional').__class__.__name__)
self.assertEqual(
'MyDense',
self._get_layer(transformed_model, 2, 'functional').__class__.__name__)

self._assert_model_results_equal(model, transformed_model)

@parameterized.parameters(['sequential', 'functional'])
def testReplaceSingleLayerWithSingleLayer_MatchParameters(self, model_type):

Expand All @@ -241,8 +275,8 @@ def replacement(self, match_layer):

model = self._simple_dense_model(model_type)

transformed_model, _ = ModelTransformer(
model, [RemoveBiasInDense()]).transform()
transformed_model, _ = ModelTransformer(model,
[RemoveBiasInDense()]).transform()

# build_input_shape is a TensorShape object and the two objects are not
# considered the same even though the shapes are the same.
Expand Down Expand Up @@ -312,8 +346,7 @@ def replacement(self, match_layer):
layer_config['name'] = activation_layer.name

activation_layer_node = LayerNode(
layer_config,
input_layers=[match_layer])
layer_config, input_layers=[match_layer])

return activation_layer_node

Expand Down Expand Up @@ -371,8 +404,8 @@ def replacement(self, match_layer):
keras.layers.ReLU()])
model.set_weights(model_fused.get_weights())

transformed_model, _ = ModelTransformer(
model, [FuseReLUIntoDense()]).transform()
transformed_model, _ = ModelTransformer(model,
[FuseReLUIntoDense()]).transform()

self._assert_config(
model_fused.get_config(),
Expand Down Expand Up @@ -430,6 +463,7 @@ def replacement(self, match_layer):
['build_input_shape'])

def testReplaceListOfLayers_Sequential(self):

class ReplaceConvBatchNorm(transforms.Transform):
"""Replaces a ConvBatchNorm pattern with the same set of layers.

Expand All @@ -438,8 +472,8 @@ class ReplaceConvBatchNorm(transforms.Transform):
"""

def pattern(self):
return LayerPattern('BatchNormalization',
inputs=[LayerPattern('Conv2D')])
return LayerPattern(
'BatchNormalization', inputs=[LayerPattern('Conv2D')])

def replacement(self, match_layer):
# Adds a modification so the transform happens. If the layers are
Expand All @@ -457,7 +491,8 @@ def replacement(self, match_layer):
transformed_model, _ = ModelTransformer(
model, [ReplaceConvBatchNorm()]).transform()
transformed_model_layer_names = [
layer.name for layer in transformed_model.layers]
layer.name for layer in transformed_model.layers
]

self.assertEqual(model_layer_names, transformed_model_layer_names)

Expand Down Expand Up @@ -495,10 +530,10 @@ def replacement(self, match_layer):

model = self._simple_dense_model(model_type)
transformed_model, _ = ModelTransformer(
model,
[ReplaceReLUWithSoftmax(), ReplaceSoftmaxWithELU()],
candidate_layers=set([layer.name for layer in model.layers])
).transform()
model, [ReplaceReLUWithSoftmax(),
ReplaceSoftmaxWithELU()],
candidate_layers=set([layer.name for layer in model.layers
])).transform()

self.assertEqual(transformed_model.layers[-1].__class__.__name__, 'ELU')

Expand All @@ -515,8 +550,8 @@ def replacement(self, match_layer):

model = self._simple_dense_model(model_type)

transformed_model, _ = ModelTransformer(
model, [ReplaceWithSelf()]).transform()
transformed_model, _ = ModelTransformer(model,
[ReplaceWithSelf()]).transform()

# build_input_shape is a TensorShape object and the two objects are not
# considered the same even though the shapes are the same.
Expand Down Expand Up @@ -689,8 +724,8 @@ def replacement(self, match_layer):
}
}

transformer = ModelTransformer(
model, [ReplaceLayerMetadata()], None, layer_metadata)
transformer = ModelTransformer(model, [ReplaceLayerMetadata()], None,
layer_metadata)
transformed_model, updated_metadata = transformer.transform()

self.assertEqual(expected_metadata, updated_metadata)
Expand All @@ -704,12 +739,12 @@ def replacement(self, match_layer):
('sequential', 'sequential'),
('sequential', 'functional'),
('functional', 'sequential'),
('functional', 'functional'),])
('functional', 'functional'),
])
def testNestedModelNoChange(self, model_type, submodel_type):
model = self._nested_model(model_type, submodel_type)

transformed_model, _ = ModelTransformer(
model, []).transform()
transformed_model, _ = ModelTransformer(model, []).transform()

# build_input_shape is a TensorShape object and the two objects are not
# considered the same even though the shapes are the same.
Expand All @@ -721,6 +756,7 @@ def testNestedModelNoChange(self, model_type, submodel_type):
# Validation Tests

def testRaisesErrorForSubclassModels(self):

class MyModel(keras.Model):
pass

Expand Down