Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port model_blocks to TF2 #281

Open
wants to merge 2 commits into
base: users/boomanaiden154/main.port-model_blocks-to-tf2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions gematria/model/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ gematria_py_test(
size = "small",
timeout = "moderate",
srcs = ["model_blocks_test.py"],
tags = [
"manual",
],
deps = [
":model_blocks",
],
Expand Down
35 changes: 3 additions & 32 deletions gematria/model/python/model_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class ResidualConnectionLayer(tf_keras.layers.Layer):
"""

# @Override
def build(
self, layer_input_shapes: tuple[tf.TensorShape, tf.TensorShape]
def __init__(
self, layer_input_shapes: tuple[tf.TensorShape, tf.TensorShape], **kwargs
) -> None:
super().__init__(**kwargs)
output_shape, residual_shape = layer_input_shapes
if output_shape.rank != 2:
# NOTE(ondrasej): For simplicity, we require that the output has shape
Expand Down Expand Up @@ -79,36 +80,6 @@ def call(self, layer_inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
return tf.math.add(output_part, residual_part, name=self.name)


def add_residual_connection(
output_part: tf.Tensor, residual_part: tf.Tensor, name: Optional[str] = None
) -> tf.Tensor:
"""Adds a residual connection to the output of a subnetwork.
When the shape of `output_part` and `residual_part` are the same, then they
are simply added, elementwise. If the shapes are different, the function
creates a learned linear transformation layer that transforms the residual
part to the right shape; in this case, the rank of the output part must be 2
and the first dimension must be the batch dimension.
Args:
output_part: The tensor that contains the output of the subnetwork.
residual_part: The input of the subnetwork.
name: The name of the residual connection. This name is used for the
tf.math.add operation that merges the two parts; if the linear
transformation is used, its name is f'{name}_transformation'.
Returns:
A tensor that contains the output of the network merged with the residual
connection.
Raises:
ValueError: If the rank of the output part is not two, including the batch
dimension.
"""
residual_layer = ResidualConnectionLayer(name=name)
return residual_layer((output_part, residual_part))


def cast(dtype: tf.dtypes.DType) -> snt.AbstractModule:
"""Creates a sonnet module that casts a tensor to the specified dtype."""
return snt.Module(build=functools.partial(tf.cast, dtype=dtype))
169 changes: 33 additions & 136 deletions gematria/model/python/model_blocks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,150 +13,56 @@
# limitations under the License.

from gematria.model.python import model_blocks
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf


class ResidualConnectionLayerTest(tf.test.TestCase):

def test_same_shapes(self):
shape = (2, 4)
shape = tf.TensorShape((2, 4))
dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=shape, dtype=dtype)
residual_tensor = tf.placeholder(shape=shape, dtype=dtype)
residual_layer = model_blocks.ResidualConnectionLayer(name='residual')
output_tensor = residual_layer((input_tensor, residual_tensor))
self.assertEqual(output_tensor.shape, shape)
residual_layer = model_blocks.ResidualConnectionLayer(
layer_input_shapes=(shape, shape), name='residual'
)

self.assertEmpty(residual_layer.weights)

with self.session() as sess:
input_array = np.array(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype
)
residual_array = np.array(
[[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype
)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertAllEqual(output_array, [[0, 3, 2, 5], [3, 8, 5, 10]])
input_array = tf.constant(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype
)
residual_array = tf.constant(
[[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype
)
output = residual_layer((input_array, residual_array))
self.assertEqual(output.shape, shape)
self.assertAllEqual(output, [[0, 3, 2, 5], [3, 8, 5, 10]])

def test_different_shapes(self):
input_shape = (None, 2)
residual_shape = (None, 3)
input_shape = tf.TensorShape((None, 2))
residual_shape = tf.TensorShape((None, 3))
dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=input_shape, dtype=dtype)
residual_tensor = tf.placeholder(shape=residual_shape, dtype=dtype)
residual_layer = model_blocks.ResidualConnectionLayer(name='residual')
output_tensor = residual_layer((input_tensor, residual_tensor))
output_tensor.shape.assert_is_compatible_with(input_shape)
residual_layer = model_blocks.ResidualConnectionLayer(
layer_input_shapes=(input_shape, residual_shape), name='residual'
)

residual_layer_weights = residual_layer.weights
self.assertLen(residual_layer_weights, 1)
self.assertTrue(residual_layer_weights[0].shape.is_compatible_with((3, 2)))

with self.session() as sess:
sess.run(tf.global_variables_initializer())
# In the first test, we use all zeros for the residual array. Any linear
# projection must map this to zeros, thus the output must be the same as
# input_array.
input_array = np.array([[1, 2], [3, 4]], dtype=dtype.as_numpy_dtype)
residual_array = np.zeros((2, 3), dtype=dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertAllEqual(output_array, input_array)

# In the second test, we use a non-zero array. We can't test for exact
# values because of the random initialization of the linear transformation
# but with probability one, the output is different from the input.
residual_array = np.ones((2, 3), dtype=dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertNotAllClose(output_array, input_array)


class AddResidualConnectioNTest(tf.test.TestCase):

def test_same_shapes(self):
shape = (2, 4)
dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=shape, dtype=dtype)
residual_tensor = tf.placeholder(shape=shape, dtype=dtype)
output_tensor = model_blocks.add_residual_connection(
input_tensor, residual_tensor, name='residual'
)
self.assertEqual(output_tensor.shape, shape)

with self.session() as sess:
input_array = np.array(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype
)
residual_array = np.array(
[[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype
)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertAllEqual(output_array, [[0, 3, 2, 5], [3, 8, 5, 10]])
# In the first test, we use all zeros for the residual array. Any linear
# projection must map this to zeros, thus the output must be the same as
# input_tensor.
input_tensor = tf.constant([[1, 2], [3, 4]], dtype=dtype)
residual_tensor = tf.zeros((2, 3), dtype=dtype)
output_tensor = residual_layer((input_tensor, residual_tensor))
self.assertAllEqual(output_tensor, input_tensor)

def test_different_shapes(self):
input_shape = (None, 2)
residual_shape = (None, 3)
dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=input_shape, dtype=dtype)
residual_tensor = tf.placeholder(shape=residual_shape, dtype=dtype)
output_tensor = model_blocks.add_residual_connection(
input_tensor, residual_tensor, name='residual'
)
output_tensor.shape.assert_is_compatible_with(input_shape)

with self.session() as sess:
sess.run(tf.global_variables_initializer())
# In the first test, we use all zeros for the residual array. Any linear
# projection must map this to zeros, thus the output must be the same as
# input_array.
input_array = np.array([[1, 2], [3, 4]], dtype=dtype.as_numpy_dtype)
residual_array = np.zeros((2, 3), dtype=dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertAllEqual(output_array, input_array)

# In the second test, we use a non-zero array. We can't test for exact
# values because of the random initialization of the linear transformation
# but with probability one, the output is different from the input.
residual_array = np.ones((2, 3), dtype=dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor,
feed_dict={
input_tensor: input_array,
residual_tensor: residual_array,
},
)
self.assertNotAllClose(output_array, input_array)
# In the second test, we use a non-zero array. We can't test for exact
# values because of the random initialization of the linear transformation
# but with probability one, the output is different from the input.
residual_tensor = tf.ones((2, 3), dtype=dtype.as_numpy_dtype)
output_tensor = residual_layer((input_tensor, residual_tensor))
self.assertNotAllClose(output_tensor, input_tensor)


class CastTest(tf.test.TestCase):
Expand All @@ -165,22 +71,13 @@ def test_int32_to_float_cast(self):
input_shape = (4, 24)
input_dtype = tf.dtypes.int32
output_dtype = tf.dtypes.float32
input_tensor = tf.placeholder(shape=input_shape, dtype=input_dtype)
input_tensor = tf.ones(input_shape, input_dtype)
cast = model_blocks.cast(output_dtype)
output_tensor = cast(input_tensor)

self.assertEqual(output_tensor.shape, input_shape)
self.assertEqual(output_tensor.dtype, output_dtype)

with self.session() as sess:
input_array = np.ones(input_shape, input_dtype.as_numpy_dtype)
output_array = sess.run(
output_tensor, feed_dict={input_tensor: input_array}
)
self.assertEqual(output_array.shape, input_shape)
self.assertEqual(output_array.dtype, output_dtype.as_numpy_dtype)


if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()
Loading