Skip to content

Commit

Permalink
Port inference to TF2
Browse files Browse the repository at this point in the history
This patch ports the inference library and associated unit tests to TF2.

Pull Request: google#284
  • Loading branch information
boomanaiden154 committed Jan 6, 2025
1 parent 4c4930e commit ac2b26c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 102 deletions.
3 changes: 0 additions & 3 deletions gematria/model/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ gematria_py_test(
name = "inference_test",
size = "small",
srcs = ["inference_test.py"],
tags = [
"manual",
],
deps = [
":inference",
":model_base",
Expand Down
10 changes: 4 additions & 6 deletions gematria/model/python/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,19 @@ def _get_num_instructions_in_block_with_throughput_proto(

def predict_for_protos(
model: model_base.ModelBase,
sess: tf.Session,
basic_blocks: Iterable[throughput_pb2.BasicBlockWithThroughputProto],
max_blocks_in_batch: Optional[int] = None,
max_instructions_in_batch: Optional[int] = None,
) -> Iterable[throughput_pb2.BasicBlockWithThroughputProto]:
"""Predicts the inverse throughput using the model.
Assumes that sess has been initialized and that it contains the weights for
the model. The input sequence is iterated through only once, and the method
may safely be used with iterable objects that read the protos from a file or
Assumes that model has been initialized and that it contains the appropriate
weights. The input sequence is iterated through only once, and the method may
safely be used with iterable objects that read the protos from a file or
generate them on the fly.
Args:
model: The model used for inference.
sess: The TensorFlow session object in which the computation is done.
basic_blocks: The collection of basic blocks for which the inverse
throughput is predicted.
max_blocks_in_batch: The maximal number of basic blocks processed in a
Expand Down Expand Up @@ -82,7 +80,7 @@ def predict_for_protos(

# Blocks are already divided into batches according to the given criteria,
# no need to use max_blocks_in_batch and max_instructions_in_batch again.
predictions = iter(model.predict(sess, blocks))
predictions = iter(model.predict(blocks))

# Inject predictions into the input protos.
for proto, is_valid in zip(protos, block_is_valid):
Expand Down
164 changes: 71 additions & 93 deletions gematria/model/python/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gematria.proto import throughput_pb2
from gematria.testing.python import model_test
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf

_PrefixThroughputProto = (
throughput_pb2.ThroughputWithSourceProto.PrefixThroughputProto
Expand Down Expand Up @@ -50,22 +50,11 @@ def __init__(self, use_custom_output_names=False, **kwargs):
self.batch_sizes = []
self.use_custom_output_names = use_custom_output_names

# @Override
def _create_tf_graph(self):
def _forward(self, feed_dict):
if not self._use_deltas:
output_name = model_base.ModelBase.OUTPUT_TENSOR_NAME
if self.use_custom_output_names:
output_name = 'TestModel.output_tensor'
self._output_tensor = tf.placeholder(
self.dtype, (None, self.num_tasks), name=output_name
)
return {'output': feed_dict['output']}
else:
output_deltas_name = model_base.ModelBase.OUTPUT_TENSOR_DELTAS_NAME
if self.use_custom_output_names:
output_deltas_name = 'TestModel.output_tensor_deltas'
self._output_tensor_deltas = tf.placeholder(
self.dtype, (None, self.num_tasks), name=output_deltas_name
)
return {'output_deltas': feed_dict['output_deltas']}

# @Override
def _create_optimizer(self):
Expand Down Expand Up @@ -102,11 +91,9 @@ def _start_batch(self):

# @Override
def _make_batch_feed_dict(self):
output_tensor = (
self._output_tensor_deltas if self._use_deltas else self._output_tensor
)
output_name = 'output_deltas' if self._use_deltas else 'output'
return {
output_tensor: np.array(
output_name: np.array(
self._batch_collected_outputs, dtype=self.numpy_dtype
).reshape((-1, self.num_tasks)),
}
Expand All @@ -126,7 +113,6 @@ def _check_predict(
max_blocks_in_batch,
max_instructions_in_batch,
expected_batch_sizes,
source_name=None,
):
"""Checks the prediction of the test model with the given batch size.
Expand All @@ -141,44 +127,40 @@ def _check_predict(
passed to model.predict().
expected_batch_sizes: A collection of expected sizes of batches processed
by model.predict(), verified by this method.
source_name: A string template used with str.format() to create throughput
source names in the expected data.
"""
with self.session() as sess:
# inference.predict_for_protos() modifies the protos in-place. We need to
# make a copy to be able to compare them with the original protos.
input_protos = copy.deepcopy(self.block_protos)
output_protos = tuple(
inference.predict_for_protos(
model,
sess,
input_protos,
max_blocks_in_batch=max_blocks_in_batch,
max_instructions_in_batch=max_instructions_in_batch,
)
)
self.assertSequenceEqual(model.batch_sizes, expected_batch_sizes)
self.assertLen(output_protos, len(self.block_protos))
for index, (in_proto, out_proto) in enumerate(
zip(self.block_protos, output_protos)
):
# The prediction of the model is the number of calls to
# model._add_basic_block_to_batch(). There is one call per basic block,
# so we can get the expected value from the index of the basic block.
expected_inverse_throughputs = [*in_proto.inverse_throughputs]
for task_index in range(model.num_tasks):
expected_inverse_throughputs.append(
throughput_pb2.ThroughputWithSourceProto(
source=model.get_source_name(task_index),
inverse_throughput_cycles=(index + 1 + task_index,),
)
)
self.assertEqual(in_proto.basic_block, out_proto.basic_block)
# NOTE(ondrasej): assertSequenceEqual refuses to compare a repeated
# field of a proto with a native sequence type.
self.assertSequenceEqual(
tuple(out_proto.inverse_throughputs), expected_inverse_throughputs
# inference.predict_for_protos() modifies the protos in-place. We need to
# make a copy to be able to compare them with the original protos.
input_protos = copy.deepcopy(self.block_protos)
output_protos = tuple(
inference.predict_for_protos(
model,
input_protos,
max_blocks_in_batch=max_blocks_in_batch,
max_instructions_in_batch=max_instructions_in_batch,
)
)
self.assertSequenceEqual(model.batch_sizes, expected_batch_sizes)
self.assertLen(output_protos, len(self.block_protos))
for index, (in_proto, out_proto) in enumerate(
zip(self.block_protos, output_protos)
):
# The prediction of the model is the number of calls to
# model._add_basic_block_to_batch(). There is one call per basic block,
# so we can get the expected value from the index of the basic block.
expected_inverse_throughputs = [*in_proto.inverse_throughputs]
for task_index in range(model.num_tasks):
expected_inverse_throughputs.append(
throughput_pb2.ThroughputWithSourceProto(
source=model.get_source_name(task_index),
inverse_throughput_cycles=(index + 1 + task_index,),
)
)
self.assertEqual(in_proto.basic_block, out_proto.basic_block)
# NOTE(ondrasej): assertSequenceEqual refuses to compare a repeated
# field of a proto with a native sequence type.
self.assertSequenceEqual(
tuple(out_proto.inverse_throughputs), expected_inverse_throughputs
)

def test_predict_single_batch(self):
model = TestModel(dtype=tf.dtypes.float32)
Expand Down Expand Up @@ -230,44 +212,41 @@ def test_predict_multi_task(self):

def check_predict_deltas(self, model):
"""Checks the prediction of the model when predicting also deltas."""
with self.session() as sess:
input_protos = copy.deepcopy(self.block_protos)
output_protos = tuple(
inference.predict_for_protos(model, sess, input_protos)
)
self.assertLen(output_protos, len(self.block_protos))

for index, (in_proto, out_proto) in enumerate(
zip(self.block_protos, output_protos)
):
# Sum up delta predictions for the model with deltas.
# predictions.
num_instructions = len(in_proto.basic_block.canonicalized_instructions)

# Inverse throughput on one prefix.
expected_throughputs = [*in_proto.inverse_throughputs]
for task_index in range(model.num_tasks):
pref_inv_throughputs = _PrefixThroughputProto(
inverse_throughput_cycles=(index + 1 + task_index,)
)

expected_throughputs.append(
throughput_pb2.ThroughputWithSourceProto(
source=model.get_source_name(task_index),
inverse_throughput_cycles=(
num_instructions * (index + 1 + task_index),
),
prefix_inverse_throughputs=(
num_instructions * (pref_inv_throughputs,)
),
)
)

self.assertEqual(in_proto.basic_block, out_proto.basic_block)
self.assertSequenceEqual(
tuple(out_proto.inverse_throughputs), expected_throughputs
input_protos = copy.deepcopy(self.block_protos)
output_protos = tuple(inference.predict_for_protos(model, input_protos))
self.assertLen(output_protos, len(self.block_protos))

for index, (in_proto, out_proto) in enumerate(
zip(self.block_protos, output_protos)
):
# Sum up delta predictions for the model with deltas.
# predictions.
num_instructions = len(in_proto.basic_block.canonicalized_instructions)

# Inverse throughput on one prefix.
expected_throughputs = [*in_proto.inverse_throughputs]
for task_index in range(model.num_tasks):
pref_inv_throughputs = _PrefixThroughputProto(
inverse_throughput_cycles=(index + 1 + task_index,)
)

expected_throughputs.append(
throughput_pb2.ThroughputWithSourceProto(
source=model.get_source_name(task_index),
inverse_throughput_cycles=(
num_instructions * (index + 1 + task_index),
),
prefix_inverse_throughputs=(
num_instructions * (pref_inv_throughputs,)
),
)
)

self.assertEqual(in_proto.basic_block, out_proto.basic_block)
self.assertSequenceEqual(
tuple(out_proto.inverse_throughputs), expected_throughputs
)

def test_predict_deltas(self):
model = TestModel(dtype=tf.dtypes.float32, use_deltas=True)
model.initialize()
Expand All @@ -285,5 +264,4 @@ def test_predict_deltas_multi_task(self):


if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()

0 comments on commit ac2b26c

Please sign in to comment.