diff --git a/gematria/model/python/BUILD.bazel b/gematria/model/python/BUILD.bazel index f742ed02..61c7b4e9 100644 --- a/gematria/model/python/BUILD.bazel +++ b/gematria/model/python/BUILD.bazel @@ -140,9 +140,6 @@ gematria_py_test( size = "small", timeout = "moderate", srcs = ["model_blocks_test.py"], - tags = [ - "manual", - ], deps = [ ":model_blocks", ], diff --git a/gematria/model/python/model_blocks.py b/gematria/model/python/model_blocks.py index 0cdaca45..fefb7138 100644 --- a/gematria/model/python/model_blocks.py +++ b/gematria/model/python/model_blocks.py @@ -41,9 +41,10 @@ class ResidualConnectionLayer(tf_keras.layers.Layer): """ # @Override - def build( - self, layer_input_shapes: tuple[tf.TensorShape, tf.TensorShape] + def __init__( + self, layer_input_shapes: tuple[tf.TensorShape, tf.TensorShape], **kwargs ) -> None: + super().__init__(**kwargs) output_shape, residual_shape = layer_input_shapes if output_shape.rank != 2: # NOTE(ondrasej): For simplicity, we require that the output has shape @@ -79,36 +80,6 @@ def call(self, layer_inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor: return tf.math.add(output_part, residual_part, name=self.name) -def add_residual_connection( - output_part: tf.Tensor, residual_part: tf.Tensor, name: Optional[str] = None -) -> tf.Tensor: - """Adds a residual connection to the output of a subnetwork. - - When the shape of `output_part` and `residual_part` are the same, then they - are simply added, elementwise. If the shapes are different, the function - creates a learned linear transformation layer that transforms the residual - part to the right shape; in this case, the rank of the output part must be 2 - and the first dimension must be the batch dimension. - - Args: - output_part: The tensor that contains the output of the subnetwork. - residual_part: The input of the subnetwork. - name: The name of the residual connection. This name is used for the - tf.math.add operation that merges the two parts; if the linear - transformation is used, its name is f'{name}_transformation'. - - Returns: - A tensor that contains the output of the network merged with the residual - connection. - - Raises: - ValueError: If the rank of the output part is not two, including the batch - dimension. - """ - residual_layer = ResidualConnectionLayer(name=name) - return residual_layer((output_part, residual_part)) - - def cast(dtype: tf.dtypes.DType) -> snt.AbstractModule: """Creates a sonnet module that casts a tensor to the specified dtype.""" return snt.Module(build=functools.partial(tf.cast, dtype=dtype)) diff --git a/gematria/model/python/model_blocks_test.py b/gematria/model/python/model_blocks_test.py index 58e48e6a..0843249f 100644 --- a/gematria/model/python/model_blocks_test.py +++ b/gematria/model/python/model_blocks_test.py @@ -13,150 +13,56 @@ # limitations under the License. from gematria.model.python import model_blocks -import numpy as np -import tensorflow.compat.v1 as tf +import tensorflow as tf class ResidualConnectionLayerTest(tf.test.TestCase): def test_same_shapes(self): - shape = (2, 4) + shape = tf.TensorShape((2, 4)) dtype = tf.dtypes.float32 - input_tensor = tf.placeholder(shape=shape, dtype=dtype) - residual_tensor = tf.placeholder(shape=shape, dtype=dtype) - residual_layer = model_blocks.ResidualConnectionLayer(name='residual') - output_tensor = residual_layer((input_tensor, residual_tensor)) - self.assertEqual(output_tensor.shape, shape) + residual_layer = model_blocks.ResidualConnectionLayer( + layer_input_shapes=(shape, shape), name='residual' + ) self.assertEmpty(residual_layer.weights) - with self.session() as sess: - input_array = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype - ) - residual_array = np.array( - [[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype - ) - output_array = sess.run( - output_tensor, - feed_dict={ - input_tensor: input_array, - residual_tensor: residual_array, - }, - ) - self.assertAllEqual(output_array, [[0, 3, 2, 5], [3, 8, 5, 10]]) + input_array = tf.constant( + [[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype + ) + residual_array = tf.constant( + [[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype + ) + output = residual_layer((input_array, residual_array)) + self.assertEqual(output.shape, shape) + self.assertAllEqual(output, [[0, 3, 2, 5], [3, 8, 5, 10]]) def test_different_shapes(self): - input_shape = (None, 2) - residual_shape = (None, 3) + input_shape = tf.TensorShape((None, 2)) + residual_shape = tf.TensorShape((None, 3)) dtype = tf.dtypes.float32 - input_tensor = tf.placeholder(shape=input_shape, dtype=dtype) - residual_tensor = tf.placeholder(shape=residual_shape, dtype=dtype) - residual_layer = model_blocks.ResidualConnectionLayer(name='residual') - output_tensor = residual_layer((input_tensor, residual_tensor)) - output_tensor.shape.assert_is_compatible_with(input_shape) + residual_layer = model_blocks.ResidualConnectionLayer( + layer_input_shapes=(input_shape, residual_shape), name='residual' + ) residual_layer_weights = residual_layer.weights self.assertLen(residual_layer_weights, 1) self.assertTrue(residual_layer_weights[0].shape.is_compatible_with((3, 2))) - with self.session() as sess: - sess.run(tf.global_variables_initializer()) - # In the first test, we use all zeros for the residual array. Any linear - # projection must map this to zeros, thus the output must be the same as - # input_array. - input_array = np.array([[1, 2], [3, 4]], dtype=dtype.as_numpy_dtype) - residual_array = np.zeros((2, 3), dtype=dtype.as_numpy_dtype) - output_array = sess.run( - output_tensor, - feed_dict={ - input_tensor: input_array, - residual_tensor: residual_array, - }, - ) - self.assertAllEqual(output_array, input_array) - - # In the second test, we use a non-zero array. We can't test for exact - # values because of the random initialization of the linear transformation - # but with probability one, the output is different from the input. - residual_array = np.ones((2, 3), dtype=dtype.as_numpy_dtype) - output_array = sess.run( - output_tensor, - feed_dict={ - input_tensor: input_array, - residual_tensor: residual_array, - }, - ) - self.assertNotAllClose(output_array, input_array) - - -class AddResidualConnectioNTest(tf.test.TestCase): - - def test_same_shapes(self): - shape = (2, 4) - dtype = tf.dtypes.float32 - input_tensor = tf.placeholder(shape=shape, dtype=dtype) - residual_tensor = tf.placeholder(shape=shape, dtype=dtype) - output_tensor = model_blocks.add_residual_connection( - input_tensor, residual_tensor, name='residual' - ) - self.assertEqual(output_tensor.shape, shape) - - with self.session() as sess: - input_array = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype.as_numpy_dtype - ) - residual_array = np.array( - [[-1, 1, -1, 1], [-2, 2, -2, 2]], dtype=dtype.as_numpy_dtype - ) - output_array = sess.run( - output_tensor, - feed_dict={ - input_tensor: input_array, - residual_tensor: residual_array, - }, - ) - self.assertAllEqual(output_array, [[0, 3, 2, 5], [3, 8, 5, 10]]) + # In the first test, we use all zeros for the residual array. Any linear + # projection must map this to zeros, thus the output must be the same as + # input_tensor. + input_tensor = tf.constant([[1, 2], [3, 4]], dtype=dtype) + residual_tensor = tf.zeros((2, 3), dtype=dtype) + output_tensor = residual_layer((input_tensor, residual_tensor)) + self.assertAllEqual(output_tensor, input_tensor) - def test_different_shapes(self): - input_shape = (None, 2) - residual_shape = (None, 3) - dtype = tf.dtypes.float32 - input_tensor = tf.placeholder(shape=input_shape, dtype=dtype) - residual_tensor = tf.placeholder(shape=residual_shape, dtype=dtype) - output_tensor = model_blocks.add_residual_connection( - input_tensor, residual_tensor, name='residual' - ) - output_tensor.shape.assert_is_compatible_with(input_shape) - - with self.session() as sess: - sess.run(tf.global_variables_initializer()) - # In the first test, we use all zeros for the residual array. Any linear - # projection must map this to zeros, thus the output must be the same as - # input_array. - input_array = np.array([[1, 2], [3, 4]], dtype=dtype.as_numpy_dtype) - residual_array = np.zeros((2, 3), dtype=dtype.as_numpy_dtype) - output_array = sess.run( - output_tensor, - feed_dict={ - input_tensor: input_array, - residual_tensor: residual_array, - }, - ) - self.assertAllEqual(output_array, input_array) - - # In the second test, we use a non-zero array. We can't test for exact - # values because of the random initialization of the linear transformation - # but with probability one, the output is different from the input. - residual_array = np.ones((2, 3), dtype=dtype.as_numpy_dtype) - output_array = sess.run( - output_tensor, - feed_dict={ - input_tensor: input_array, - residual_tensor: residual_array, - }, - ) - self.assertNotAllClose(output_array, input_array) + # In the second test, we use a non-zero array. We can't test for exact + # values because of the random initialization of the linear transformation + # but with probability one, the output is different from the input. + residual_tensor = tf.ones((2, 3), dtype=dtype.as_numpy_dtype) + output_tensor = residual_layer((input_tensor, residual_tensor)) + self.assertNotAllClose(output_tensor, input_tensor) class CastTest(tf.test.TestCase): @@ -165,22 +71,13 @@ def test_int32_to_float_cast(self): input_shape = (4, 24) input_dtype = tf.dtypes.int32 output_dtype = tf.dtypes.float32 - input_tensor = tf.placeholder(shape=input_shape, dtype=input_dtype) + input_tensor = tf.ones(input_shape, input_dtype) cast = model_blocks.cast(output_dtype) output_tensor = cast(input_tensor) self.assertEqual(output_tensor.shape, input_shape) self.assertEqual(output_tensor.dtype, output_dtype) - with self.session() as sess: - input_array = np.ones(input_shape, input_dtype.as_numpy_dtype) - output_array = sess.run( - output_tensor, feed_dict={input_tensor: input_array} - ) - self.assertEqual(output_array.shape, input_shape) - self.assertEqual(output_array.dtype, output_dtype.as_numpy_dtype) - if __name__ == '__main__': - tf.disable_v2_behavior() tf.test.main()