Skip to content

Commit

Permalink
Port model_blocks to TF2
Browse files Browse the repository at this point in the history
This patch ports model_blocks and model_blocks_test to TF2. The major
change is the removal of the add_residual_connection_function. This is
not needed anymore given that everything will be constructed explicitly
eagerly inside the forward pass of models building off of the residual
connection layer.

Pull Request: google#281
  • Loading branch information
boomanaiden154 committed Jan 6, 2025
1 parent 36a0c65 commit 8ac7b8d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 171 deletions.
3 changes: 0 additions & 3 deletions gematria/model/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ gematria_py_test(
size = "small",
timeout = "moderate",
srcs = ["model_blocks_test.py"],
tags = [
"manual",
],
deps = [
":model_blocks",
],
Expand Down
35 changes: 3 additions & 32 deletions gematria/model/python/model_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class ResidualConnectionLayer(tf_keras.layers.Layer):
"""

# @Override
def build(
self, layer_input_shapes: tuple[tf.TensorShape, tf.TensorShape]
def __init__(
self, layer_input_shapes: tuple[tf.TensorShape, tf.TensorShape], **kwargs
) -> None:
super().__init__(**kwargs)
output_shape, residual_shape = layer_input_shapes
if output_shape.rank != 2:
# NOTE(ondrasej): For simplicity, we require that the output has shape
Expand Down Expand Up @@ -79,36 +80,6 @@ def call(self, layer_inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
return tf.math.add(output_part, residual_part, name=self.name)


def add_residual_connection(
output_part: tf.Tensor, residual_part: tf.Tensor, name: Optional[str] = None
) -> tf.Tensor:
"""Adds a residual connection to the output of a subnetwork.
When the shape of `output_part` and `residual_part` are the same, then they
are simply added, elementwise. If the shapes are different, the function
creates a learned linear transformation layer that transforms the residual
part to the right shape; in this case, the rank of the output part must be 2
and the first dimension must be the batch dimension.
Args:
output_part: The tensor that contains the output of the subnetwork.
residual_part: The input of the subnetwork.
name: The name of the residual connection. This name is used for the
tf.math.add operation that merges the two parts; if the linear
transformation is used, its name is f'{name}_transformation'.
Returns:
A tensor that contains the output of the network merged with the residual
connection.
Raises:
ValueError: If the rank of the output part is not two, including the batch
dimension.
"""
residual_layer = ResidualConnectionLayer(name=name)
return residual_layer((output_part, residual_part))


def cast(dtype: tf.dtypes.DType) -> snt.AbstractModule:
"""Creates a sonnet module that casts a tensor to the specified dtype."""
return snt.Module(build=functools.partial(tf.cast, dtype=dtype))
169 changes: 33 additions & 136 deletions gematria/model/python/model_blocks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,150 +13,56 @@
# limitations under the License.

from gematria.model.python import model_blocks
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf


class ResidualConnectionLayerTest(tf.test.TestCase):

def test_same_shapes(self):
shape = (2, 4)
shape = tf.TensorShape((2, 4))
dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=shape, dtype=dtype)
residual_tensor = tf.placeholder(shape=shape, dtype=dtype)
residual_layer = model_blocks.ResidualConnectionLayer(name='residual')
output_tensor = residual_layer((input_tensor, residual_tensor))
self.assertEqual(output_tensor.shape, shape)
residual_layer = model_blocks.ResidualConnectionLayer(
layer_input_shapes=(shape, shape), name='residual'
)

self.assertEmpty(residual_layer.weights)

with self.session() as sess:
input_array = np.array(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype
)
residual_array = np.array(
[[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype
)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertAllEqual(output_array, [[0, 3, 2, 5], [3, 8, 5, 10]])
input_array = tf.constant(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype
)
residual_array = tf.constant(
[[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype
)
output = residual_layer((input_array, residual_array))
self.assertEqual(output.shape, shape)
self.assertAllEqual(output, [[0, 3, 2, 5], [3, 8, 5, 10]])

def test_different_shapes(self):
input_shape = (None, 2)
residual_shape = (None, 3)
input_shape = tf.TensorShape((None, 2))
residual_shape = tf.TensorShape((None, 3))
dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=input_shape, dtype=dtype)
residual_tensor = tf.placeholder(shape=residual_shape, dtype=dtype)
residual_layer = model_blocks.ResidualConnectionLayer(name='residual')
output_tensor = residual_layer((input_tensor, residual_tensor))
output_tensor.shape.assert_is_compatible_with(input_shape)
residual_layer = model_blocks.ResidualConnectionLayer(
layer_input_shapes=(input_shape, residual_shape), name='residual'
)

residual_layer_weights = residual_layer.weights
self.assertLen(residual_layer_weights, 1)
self.assertTrue(residual_layer_weights[0].shape.is_compatible_with((3, 2)))

with self.session() as sess:
sess.run(tf.global_variables_initializer())
# In the first test, we use all zeros for the residual array. Any linear
# projection must map this to zeros, thus the output must be the same as
# input_array.
input_array = np.array([[1, 2], [3, 4]], dtype=dtype.as_numpy_dtype)
residual_array = np.zeros((2, 3), dtype=dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertAllEqual(output_array, input_array)

# In the second test, we use a non-zero array. We can't test for exact
# values because of the random initialization of the linear transformation
# but with probability one, the output is different from the input.
residual_array = np.ones((2, 3), dtype=dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertNotAllClose(output_array, input_array)


class AddResidualConnectioNTest(tf.test.TestCase):

def test_same_shapes(self):
shape = (2, 4)
dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=shape, dtype=dtype)
residual_tensor = tf.placeholder(shape=shape, dtype=dtype)
output_tensor = model_blocks.add_residual_connection(
input_tensor, residual_tensor, name='residual'
)
self.assertEqual(output_tensor.shape, shape)

with self.session() as sess:
input_array = np.array(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype
)
residual_array = np.array(
[[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype
)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertAllEqual(output_array, [[0, 3, 2, 5], [3, 8, 5, 10]])
# In the first test, we use all zeros for the residual array. Any linear
# projection must map this to zeros, thus the output must be the same as
# input_tensor.
input_tensor = tf.constant([[1, 2], [3, 4]], dtype=dtype)
residual_tensor = tf.zeros((2, 3), dtype=dtype)
output_tensor = residual_layer((input_tensor, residual_tensor))
self.assertAllEqual(output_tensor, input_tensor)

def test_different_shapes(self):
input_shape = (None, 2)
residual_shape = (None, 3)
dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=input_shape, dtype=dtype)
residual_tensor = tf.placeholder(shape=residual_shape, dtype=dtype)
output_tensor = model_blocks.add_residual_connection(
input_tensor, residual_tensor, name='residual'
)
output_tensor.shape.assert_is_compatible_with(input_shape)

with self.session() as sess:
sess.run(tf.global_variables_initializer())
# In the first test, we use all zeros for the residual array. Any linear
# projection must map this to zeros, thus the output must be the same as
# input_array.
input_array = np.array([[1, 2], [3, 4]], dtype=dtype.as_numpy_dtype)
residual_array = np.zeros((2, 3), dtype=dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertAllEqual(output_array, input_array)

# In the second test, we use a non-zero array. We can't test for exact
# values because of the random initialization of the linear transformation
# but with probability one, the output is different from the input.
residual_array = np.ones((2, 3), dtype=dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertNotAllClose(output_array, input_array)
# In the second test, we use a non-zero array. We can't test for exact
# values because of the random initialization of the linear transformation
# but with probability one, the output is different from the input.
residual_tensor = tf.ones((2, 3), dtype=dtype.as_numpy_dtype)
output_tensor = residual_layer((input_tensor, residual_tensor))
self.assertNotAllClose(output_tensor, input_tensor)


class CastTest(tf.test.TestCase):
Expand All @@ -165,22 +71,13 @@ def test_int32_to_float_cast(self):
input_shape = (4, 24)
input_dtype = tf.dtypes.int32
output_dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=input_shape, dtype=input_dtype)
input_tensor = tf.ones(input_shape, input_dtype)
cast = model_blocks.cast(output_dtype)
output_tensor = cast(input_tensor)

self.assertEqual(output_tensor.shape, input_shape)
self.assertEqual(output_tensor.dtype, output_dtype)

with self.session() as sess:
input_array = np.ones(input_shape, input_dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor, feed_dict={input_tensor: input_array}
)
self.assertEqual(output_array.shape, input_shape)
self.assertEqual(output_array.dtype, output_dtype.as_numpy_dtype)


if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()

0 comments on commit 8ac7b8d

Please sign in to comment.