diff --git a/gematria/model/python/BUILD.bazel b/gematria/model/python/BUILD.bazel index de18cba4..f742ed02 100644 --- a/gematria/model/python/BUILD.bazel +++ b/gematria/model/python/BUILD.bazel @@ -118,9 +118,6 @@ gematria_py_test( size = "small", timeout = "moderate", srcs = ["model_base_test.py"], - tags = [ - "manual", - ], deps = [ ":model_base", ":options", diff --git a/gematria/model/python/model_base.py b/gematria/model/python/model_base.py index 77781318..6518752c 100644 --- a/gematria/model/python/model_base.py +++ b/gematria/model/python/model_base.py @@ -50,7 +50,7 @@ # simpler than what is actually accepted by TensorFlow, but the typing should be # sufficient for our use. Moreover, since TensorFlow and NumPy do not provide # type annotations, both the key and the value are reduced to typing.Any. -FeedDict = MutableMapping[tf.Tensor, np.ndarray] +FeedDict = MutableMapping[str, Union[np.ndarray, tf.Tensor]] # A throughput value used as a placeholder in the expected output tensors for # masked expected outputs. @@ -120,7 +120,7 @@ def after_run(self, run_context: ..., run_values: ...): ) -class ModelBase(metaclass=abc.ABCMeta): +class ModelBase(tf.Module, metaclass=abc.ABCMeta): """Base class for Gematria basic block processing models. Provides infrastructure for building basic-block-oriented models on top of @@ -321,16 +321,6 @@ def __init__( self._decayed_learning_rate = None self._loss: Optional[loss_utils.LossComputation] = None self._delta_loss: Optional[loss_utils.LossComputation] = None - self._output_tensor: Optional[tf.Tensor] = None - self._output_tensor_deltas: Optional[tf.Tensor] = None - # The mask applied to outputs/expected outputs when a basic block has - # expected outputs only in certain tasks. Tasks where the expected output is - # not available are ignored in the loss computation. - self._output_mask: Optional[tf.Tensor] = None - self._output_mask_deltas: Optional[tf.Tensor] = None - self._expected_outputs: Optional[tf.Tensor] = None - self._expected_outputs_deltas: Optional[tf.Tensor] = None - self._loss_tensor: Optional[tf.Tensor] = None self._train_step: Optional[tf.Operation] = None self._optimizer: Union[ tf.train.Optimizer, tf.train.SyncReplicasOptimizer @@ -377,43 +367,6 @@ def __init__( def initialize(self) -> None: """Initializes the model. Must be called before any other method.""" - self._expected_outputs = tf.placeholder( - dtype=self.dtype, shape=(None, self.num_tasks), name='expected_outputs' - ) - if self._create_delta_block_index: - self._delta_block_index_tensor = tf.placeholder( - _BASIC_BLOCK_INDEX_TF_DTYPE, - shape=(None,), - name='ModelBase.delta_block_index_tensor', - ) - self._create_tf_graph() - # Check that the required tensors were created by _create_tf_graph(). - expected_shape = (None, self.num_tasks) - if not self._use_deltas: - assert ( - self._output_tensor is not None - ), 'self._output_tensor was not created by self._create_tf_graph()' - assert ( - self._output_tensor_deltas is None - ), 'self._output_tensor_deltas was created with self._use_deltas == False' - assert self._output_tensor.shape.is_compatible_with(expected_shape), ( - f'Expected shape {expected_shape}, got ' - f'{self._output_tensor.shape.as_list()}' - ) - else: - assert ( - self._output_tensor_deltas is not None - ), 'self._output_tensor_deltas was not created by self._create_tf_graph()' - assert ( - self._output_tensor is None - ), 'self._output_tensor was created with self._use_deltas == True' - assert self._output_tensor_deltas.shape.is_compatible_with( - expected_shape - ), ( - f'Expected shape {expected_shape}, got ' - f'{self._output_tensor_deltas.shape.as_list()}' - ) - self._create_output_and_loss_tensors() self._create_optimizer() tf.summary.scalar('learning_rate', self._decayed_learning_rate) @@ -445,23 +398,6 @@ def task_list(self) -> Sequence[str]: """Returns the names of the tasks in the model.""" return self._task_list - @property - def output_tensor(self) -> tf.Tensor: - """Returns the tensor that contains the per-basic block outputs.""" - return self._output_tensor - - @property - def output_tensor_deltas(self) -> tf.Tensor: - """Returns the tensor that contains the output deltas. - - This property is available only when self.use_deltas is True. - """ - if not self._use_deltas: - raise AttributeError( - ModelBase._USE_DELTAS_ATTRIBUTE_ERROR_MESSAGE % 'output_tensor_deltas' - ) - return self._output_tensor_deltas - @property def loss_type(self) -> options.LossType: """Returns the type of loss used by the model.""" @@ -472,38 +408,6 @@ def loss_normalization(self) -> options.ErrorNormalization: """Returns the error normalization used when computing the loss.""" return self._loss_normalization - @property - def loss_tensor(self) -> tf.Tensor: - """Returns the loss tensor.""" - return self._loss_tensor - - @property - def absolute_mse_tensor(self) -> tf.Tensor: - """Returns the absolute loss tensor.""" - if self._loss is None: - raise ValueError( - 'self._loss is None while returning absolute mse loss tensor' - ) - return self._loss.mean_squared_error - - @property - def relative_mae_tensor(self) -> tf.Tensor: - """Returns the relative MAE (L1 loss) tensor.""" - if self._loss is None: - raise ValueError( - 'self._loss is None while returning relative mae (L1 loss) tensor' - ) - return self._loss.mean_absolute_percentage_error - - @property - def relative_mse_tensor(self) -> tf.Tensor: - """Returns the relative loss tensor.""" - if self._loss is None: - raise ValueError( - 'self._loss is None while returning relative mse loss tensor' - ) - return self._loss.mean_squared_percentage_error - @property def collected_percentile_ranks(self) -> Sequence[int]: """Returns the list of collected error percentile ranks.""" @@ -536,6 +440,44 @@ def output_tensor_names(self) -> Sequence[str]: return (ModelBase.OUTPUT_TENSOR_NAME, ModelBase.OUTPUT_TENSOR_DELTAS_NAME) return (ModelBase.OUTPUT_TENSOR_NAME,) + @abc.abstractmethod + def _forward(self, feed_dict: FeedDict) -> dict[str, tf.Tensor]: + """Implements the forward pass of the model. + + This function should be implemented in downstream models and calculate the + outputs of the model given the inputs specified in feed_dict. + """ + pass + + def __call__(self, feed_dict, train=False): + """Implements the non-model specific part of the forward pass. + + This function wraps the _forward method and does relevant calculations + when working with models that use deltas. + """ + if self._use_deltas: + output_dict = {} + + if train: + output_dict['output_mask_deltas'] = tf.nn.embedding_lookup( + feed_dict['output_mask'], + feed_dict['delta_block_index'], + name='ModelBase.output_mask_deltas', + ) + + output = self._forward(feed_dict) + + output_dict['output'] = tf.math.segment_sum( + output['output_deltas'], + feed_dict['delta_block_index'], + name=ModelBase.OUTPUT_TENSOR_NAME, + ) + output_dict['output_deltas'] = output['output_deltas'] + + return output_dict + else: + return self._forward(feed_dict) + @abc.abstractmethod def _make_model_name(self) -> str: """Returns a model name based on its class and parameters.""" @@ -547,17 +489,6 @@ def get_source_name(self, task_index: int) -> str: task_name = self.task_list[task_index] return f'{self.model_name}, task={task_name}' - @abc.abstractmethod - def _create_tf_graph(self) -> None: - """Creates the TensorFlow ops necessary to run the model. - - This method must set up self._output_tensor when self._use_deltas is True, - or self._output_tensor_deltas when self._use_deltas is False. - """ - # NOTE(ondrasej): Even though the method is marked as abstract, it may still - # appear in the super() call chain of some classes. We thus make it a no-op - # rather than raise NotImplementedError or other exception. - def _add_percentile_summaries( self, error_name: str, @@ -600,168 +531,6 @@ def _add_error_summaries(self, error_name: str, error_tensor: tf.Tensor): summary_name = f'{error_name}_{task_name}' tf.summary.scalar(summary_name, error_tensor[task_idx]) - def _create_output_and_loss_tensors(self) -> None: - """Creates the output, expected output and loss tensors. - - This method is called after calling _create_tf_graph(). - """ - self._expected_outputs = tf.placeholder( - self.dtype, - shape=(None, self.num_tasks), - name='ModelBase.expected_outputs', - ) - self._output_mask = tf.placeholder( - tf.dtypes.bool, - shape=(None, self.num_tasks), - name='ModelBase.output_mask', - ) - if self._use_deltas: - assert self._delta_block_index_tensor is not None - self._expected_outputs_deltas = tf.placeholder( - self.dtype, - shape=(None, self.num_tasks), - name='ModelBase.expected_outputs_deltas', - ) - # The mask for all instructions of a block is the same as the mask for the - # whole block. Instead of composing it in Python, we use embedding_lookup - # to stretch the per-block mask to the right shape. - self._output_mask_deltas = tf.nn.embedding_lookup( - self._output_mask, - self._delta_block_index_tensor, - name='ModelBase.output_mask_deltas', - ) - - # Tensor names are in the format "op_name:index", where "op_name" is the - # name of the op that produced the tensor, and "index" is the index of the - # output of the op. Most ops have just a single output (and terminate with - # ":0"), so we add an additional warning to the log in case the index is - # not "0" as this might cause errors from omission. - if self._output_tensor_deltas is None: - raise ValueError( - 'ModelBase._output_tensor_deltas is None while creating output,' - ' expected output and loss tensors.' - ) - output_deltas_name, output_deltas_index = ( - self._output_tensor_deltas.name.split(':') - ) - if output_deltas_name != ModelBase.OUTPUT_TENSOR_DELTAS_NAME: - logging.warning( - ( - 'ModelBase._output_tensor_deltas has invalid name.' - ' Expected %s, found %s.' - ), - ModelBase.OUTPUT_TENSOR_DELTAS_NAME, - output_deltas_name, - ) - self._output_tensor_deltas = tf.identity( - self._output_tensor_deltas, ModelBase.OUTPUT_TENSOR_DELTAS_NAME - ) - elif output_deltas_index != '0': - # NOTE(ondrasej): We can't rename the output tensor automatically, - # because the desired name is already taken, and the output index 0 is - # used by another tensor. - logging.warning( - ( - 'ModelBase._output_tensor_deltas has an unexpected index: ' - 'expected 0, found %s.' - ), - output_deltas_index, - ) - self._output_tensor = tf.math.segment_sum( - self._output_tensor_deltas, - self._delta_block_index_tensor, - name=ModelBase.OUTPUT_TENSOR_NAME, - ) - else: - if self._output_tensor is None: - raise ValueError( - 'ModelBase._output_tensor is None while creating output,' - ' expected output and loss tensors.' - ) - output_name, output_index = self._output_tensor.name.split(':') - if output_name != ModelBase.OUTPUT_TENSOR_NAME: - logging.warning( - 'ModelBase._output_tensor has invalid name. Expected %s, found %s.', - ModelBase.OUTPUT_TENSOR_NAME, - self._output_tensor.name, - ) - self._output_tensor = tf.identity( - self._output_tensor, ModelBase.OUTPUT_TENSOR_NAME - ) - elif output_index != '0': - logging.warning( - ( - 'ModelBase._output_tensor has an unexpected index: expected 0,' - ' found %s.' - ), - output_index, - ) - - self._loss = loss_utils.LossComputation( - self._output_tensor, - self._expected_outputs, - self._output_mask, - percentile_ranks=self._collected_percentile_ranks, - dtype=self.dtype, - ) - self._add_error_summaries('absolute_mse', self._loss.mean_squared_error) - self._add_error_summaries( - 'relative_mae', self._loss.mean_absolute_percentage_error - ) - self._add_error_summaries( - 'relative_mse', self._loss.mean_squared_percentage_error - ) - self._add_percentile_summaries( - 'absolute_error', - self._collected_percentile_ranks, - self._loss.absolute_error_percentiles, - ) - self._add_percentile_summaries( - 'absolute_percentage_error', - self._collected_percentile_ranks, - self._loss.absolute_percentage_error_percentiles, - ) - - loss = self._loss - if self._use_deltas: - self._delta_loss = loss_utils.LossComputation( - self._output_tensor_deltas, - self._expected_outputs_deltas, - self._output_mask_deltas, - percentile_ranks=self._collected_percentile_ranks, - dtype=self.dtype, - ) - - self._add_error_summaries( - 'absolute_mse_deltas', self._delta_loss.mean_squared_error - ) - self._add_error_summaries( - 'absolute_mae_deltas', self._delta_loss.mean_absolute_error - ) - self._add_percentile_summaries( - 'absolute_error_deltas', - self._collected_percentile_ranks, - self._delta_loss.absolute_error_percentiles, - ) - - if self._use_delta_loss: - loss = self._delta_loss - - spearman_correlations = self._make_spearman_correlations( - self._expected_outputs, self._output_tensor - ) - self._add_error_summaries('spearman', spearman_correlations) - - self._loss_tensor_per_task = loss.loss_tensor( - self._loss_normalization, self._loss_type - ) - self._loss_tensor = loss.loss_tensor( - self._loss_normalization, self._loss_type - ) - self._add_error_summaries('loss', self._loss_tensor_per_task) - self._loss_tensor = tf.reduce_mean(self._loss_tensor_per_task) - tf.summary.scalar('overall_loss', self._loss_tensor) - def _make_spearman_correlations( self, expected_outputs: tf.Tensor, output_tensor: tf.Tensor ) -> tf.Tensor: @@ -867,12 +636,6 @@ def _create_optimizer(self) -> None: ) self._decayed_learning_rate = self._learning_rate - # The list of variables to optimize. By default, the list is empty which - # means optimize all trainable variables. - variables = set() - for variable_group in self._trained_variable_groups: - variables.update(self._variable_groups.get(variable_group)) - if self._optimizer_type == options.OptimizerType.ADAM: self._optimizer = tf.train.AdamOptimizer( learning_rate=self._decayed_learning_rate @@ -902,21 +665,6 @@ def _create_optimizer(self) -> None: logging.warning( 'ModelBase._synchronous_training is True with a single worker.' ) - grads_and_vars = self._optimizer.compute_gradients( - self._loss_tensor, var_list=list(variables) if variables else None - ) - if self._grad_clip_norm: - if self._grad_clip_norm <= 0.0: - logging.warning( - 'The gradients are clipped to zero. Please revise if this is not ' - 'intended.' - ) - grads_and_vars = [ - (self._clip_if_not_none(g), v) for g, v in grads_and_vars - ] - self._train_step = self._optimizer.apply_gradients( - grads_and_vars, global_step=self.global_step - ) def get_monitored_training_session_hooks( self, @@ -997,17 +745,17 @@ def _finalize_batch(self, include_expected_outputs: bool) -> FeedDict: """ schedule = self._make_batch_feed_dict() if self._create_delta_block_index: - schedule[self._delta_block_index_tensor] = np.array( + schedule['delta_block_index'] = np.array( self._batch_delta_block_index, dtype=_BASIC_BLOCK_INDEX_NUMPY_DTYPE ) if include_expected_outputs: - schedule[self._expected_outputs] = np.reshape( + schedule['expected_outputs'] = np.reshape( np.array(self._batch_expected_outputs, dtype=self.numpy_dtype), [-1, self.num_tasks], ) - schedule[self._output_mask] = np.array(self._batch_mask, dtype=bool) + schedule['output_mask'] = np.array(self._batch_mask, dtype=bool) if self._use_deltas: - schedule[self._expected_outputs_deltas] = np.reshape( + schedule['expected_outputs_deltas'] = np.reshape( np.array( self._batch_expected_outputs_deltas, dtype=self.numpy_dtype ), @@ -1340,10 +1088,6 @@ def run_continuous_evaluation( randomize_expected_outputs=False, ) - if self._loss is None: - raise ValueError( - 'ModelBase._loss is None while running continuous evaluation.' - ) hooks = [ tf_slim.evaluation.StopAfterNEvalsHook(1), tf_slim.evaluation.SummaryAtEndHook(summary_dir, feed_dict=schedule), @@ -1373,19 +1117,16 @@ def run_continuous_evaluation( def predict( self, - sess: tf.Session, basic_blocks: Iterable[basic_block.BasicBlock], max_blocks_in_batch: Optional[int] = None, max_instructions_in_batch: Optional[int] = None, ) -> Iterable[throughput.BasicBlockWithThroughput]: """Predicts the inverse throughput using the model. - Assumes that sess has been initialized and that it contains the weights for - the model. The input sequence is iterated through only once, and the method - may be used with basic block sources such as tf.io.tf_record_iterator. + The input sequence is iterated through only once, and the method may be + used with basic block sources such as tf.io.tf_record_iterator. Args: - sess: The TensorFlow session object in which the computation is done. basic_blocks: The collection of basic blocks for which the inverse throughput is predicted. max_blocks_in_batch: The maximal number of basic blocks processed in a @@ -1411,10 +1152,9 @@ def predict( with timer.scoped('ModelBase.predict - one batch'): schedule = self.schedule_batch(batch) if self._use_deltas: - output, output_deltas = sess.run( - (self._output_tensor, self._output_tensor_deltas), - feed_dict=schedule, - ) + output_dict = self(schedule) + output = output_dict['output'] + output_deltas = output_dict['output_deltas'] output_index = 0 for block_index, block in enumerate(batch): block_len = len(block.instructions) @@ -1442,7 +1182,7 @@ def predict( ) ) else: - output = sess.run(self._output_tensor, feed_dict=schedule) + output = self(schedule)['output'] for block_index, block in enumerate(batch): throughputs = [] for task_index in range(self.num_tasks): @@ -1464,7 +1204,6 @@ def predict( def train( self, - monitored_session: tf.train.MonitoredSession, basic_block_list: Sequence[throughput.BasicBlockWithThroughput], num_epochs: int, max_blocks_in_batch: Optional[int], @@ -1475,7 +1214,6 @@ def train( """Runs training of the model on the given training data. Args: - monitored_session: The monitored training session to run the training in. basic_block_list: The collection of input basic blocks. num_epochs: The number of training steps. This value is used only for profiling and logging; the method uses monitored_session.should_stop() @@ -1502,7 +1240,6 @@ def train( def run_one_epoch(): return self.train_mini_batch( - monitored_session, basic_block_list, max_blocks_in_batch=max_blocks_in_batch, max_instructions_in_batch=max_instructions_in_batch, @@ -1531,74 +1268,116 @@ def run_one_epoch(): schedule = self.schedule_batch( batch, randomize_expected_outputs=randomize_expected_outputs ) - return self.train_batch(monitored_session, schedule) + return self.train_batch(schedule) - with timer.scoped('ModelBase.train - one batch', num_iterations=num_epochs): - stats = None - while not monitored_session.should_stop(): - stats = run_one_epoch() - logging.info('Training: %s', stats) + for _ in range(0, num_epochs): + stats = run_one_epoch() + logging.info('Training: %s', stats) return stats + def compute_loss(self, schedule: FeedDict): + output = self(schedule, train=True) + loss = loss_utils.LossComputation( + output['output'], + tf.constant(schedule['expected_outputs']), + tf.constant(schedule['output_mask']), + percentile_ranks=self._collected_percentile_ranks, + dtype=self.dtype, + ) + + if self._use_deltas: + self._delta_loss = loss_utils.LossComputation( + output['output_deltas'], + tf.constant(schedule['expected_outputs_deltas']), + output['output_mask_deltas'], + percentile_ranks=self._collected_percentile_ranks, + dtype=self.dtype, + ) + + if self._use_delta_loss: + loss = self._delta_loss + + return loss + + def compute_loss_tensor(self, schedule: FeedDict): + return tf.reduce_mean( + self.compute_loss(schedule).loss_tensor( + self._loss_normalization, self._loss_type + ) + ) + def train_batch( self, - sess: Union[tf.Session, tf.train.MonitoredSession], schedule: FeedDict, ) -> training.TrainingEpochStats: """Trains a batch based on the given schedule. Args: - sess: The TensorFlow session the training is running in. schedule: A feed_dict that describes the current batch. Returns: The loss on the training set before the training step. """ with timer.scoped('ModelBase.train_batch'): - # The keys in the are names of keyword arguments of the constructor of - # TraningEpochStats. sess.run() returns a dict that has the same keys and - # the values of these tensors in the training step. This dict can then be - # unpacked to TrainingEpochStats.__init__() as keyword arguments. - # NOTE(ondrasej): The loss tensors are created lazily, when they are first - # referenced. To be used here, the tensors need to be created (referenced) - # during graph creation, e.g. by adding them as TensorFlow summaries in - # self._create_output_and_loss_tensors(). - stats_ops = {} - stats_ops['epoch'] = self.global_step - stats_ops['loss'] = self._loss_tensor - if self._loss is None: - raise ValueError('ModelBase._loss is None inside train_batch function') - stats_ops['absolute_mse'] = self._loss.mean_squared_error - stats_ops['relative_mae'] = self._loss.mean_absolute_percentage_error - stats_ops['relative_mse'] = self._loss.mean_squared_percentage_error - stats_ops['absolute_error_percentiles'] = ( - self._loss.absolute_error_percentiles - ) - stats_ops['relative_error_percentiles'] = ( - self._loss.absolute_percentage_error_percentiles - ) - if self._use_deltas: - if self._delta_loss is None: - raise ValueError( - 'ModelBase._delta_loss is None inside train_batch function while' - ' using deltas' + # The keys of stats are the names of keyword arguments of the constructor + # of TraningEpochStats. This dict can then be unpacked to + # TrainingEpochStats.__init__() as keyword arguments. + with tf.GradientTape() as tape: + stats = {} + loss = self.compute_loss(schedule) + loss_tensor_per_task = loss.loss_tensor( + self._loss_normalization, self._loss_type + ) + loss_tensor = tf.reduce_mean(loss_tensor_per_task) + + # The list of variables to optimize. By default, the list is empty which + # means optimize all trainable variables. + variables = set() + for variable_group in self._trained_variable_groups: + variables.update( + [ + variable.ref() + for variable in self._variable_groups.get(variable_group) + ] ) - stats_ops['absolute_delta_mse'] = self._delta_loss.mean_squared_error - stats_ops['absolute_delta_mae'] = self._delta_loss.mean_absolute_error - stats_ops['absolute_delta_error_percentiles'] = ( - self._delta_loss.absolute_error_percentiles + + variables = ( + [variable.deref() for variable in variables] + if variables + else self.trainable_variables ) - # NOTE(ondrasej): We do not compute relative errors for deltas. Deltas - # are very often zero or very small, which makes the percentage error - # ill-defined. - (_, stats) = sess.run((self._train_step, stats_ops), feed_dict=schedule) + + grads = tape.gradient(loss_tensor, variables) + grads_and_vars = zip(grads, variables) + stats['loss'] = loss_tensor + stats['epoch'] = self.global_step + stats['absolute_mse'] = loss.mean_squared_error + stats['relative_mae'] = loss.mean_absolute_percentage_error + stats['relative_mse'] = loss.mean_squared_percentage_error + stats['absolute_error_percentiles'] = loss.absolute_error_percentiles + stats['relative_error_percentiles'] = ( + loss.absolute_percentage_error_percentiles + ) + + if self._grad_clip_norm: + if self._grad_clip_norm <= 0.0: + logging.warning( + 'The gradients are clipped to zero. Please revise if this is not ' + 'intended.' + ) + grads_and_vars = [ + (self._clip_if_not_none(g), v) for g, v in grads_and_vars + ] + self._train_step = self._optimizer.apply_gradients( + grads_and_vars, global_step=self.global_step + ) + return training.TrainingEpochStats( percentile_ranks=self._collected_percentile_ranks, **stats ) def train_mini_batch( self, - sess: Union[tf.Session, tf.train.MonitoredSession], basic_blocks: Sequence[throughput.BasicBlockWithThroughput], max_blocks_in_batch: int, max_instructions_in_batch: Optional[int] = None, @@ -1629,4 +1408,4 @@ def train_mini_batch( randomize_batch=True, randomize_expected_outputs=randomize_expected_outputs, ) - return self.train_batch(sess, train_schedule) + return self.train_batch(train_schedule) diff --git a/gematria/model/python/model_base_test.py b/gematria/model/python/model_base_test.py index 2f1dd28b..5b2e57fb 100644 --- a/gematria/model/python/model_base_test.py +++ b/gematria/model/python/model_base_test.py @@ -18,8 +18,7 @@ from gematria.model.python import options from gematria.testing.python import model_test import numpy as np -import tensorflow.compat.v1 as tf -import tf_keras +import tensorflow as tf # The tolerance used in tests with heavier use of float32 arithmetics. _TOLERANCE = 1e-6 @@ -50,21 +49,11 @@ def __init__(self, use_custom_output_names=False, **kwargs): self.use_custom_output_names = use_custom_output_names # @Override - def _create_tf_graph(self): + def _forward(self, feed_dict): if not self._use_deltas: - output_name = model_base.ModelBase.OUTPUT_TENSOR_NAME - if self.use_custom_output_names: - output_name = 'TestModel.output_tensor' - self._output_tensor = tf.placeholder( - self.dtype, (None, self.num_tasks), name=output_name - ) + return {'output': feed_dict['output']} else: - output_deltas_name = model_base.ModelBase.OUTPUT_TENSOR_DELTAS_NAME - if self.use_custom_output_names: - output_deltas_name = 'TestModel.output_tensor_deltas' - self._output_tensor_deltas = tf.placeholder( - self.dtype, (None, self.num_tasks), name=output_deltas_name - ) + return {'output_deltas': feed_dict['output_deltas']} # @Override def _create_optimizer(self): @@ -101,11 +90,9 @@ def _start_batch(self): # @Override def _make_batch_feed_dict(self): - output_tensor = ( - self._output_tensor_deltas if self._use_deltas else self._output_tensor - ) + output_name = 'output_deltas' if self._use_deltas else 'output' return { - output_tensor: np.array( + output_name: np.array( self._batch_collected_outputs, dtype=self.numpy_dtype ).reshape((-1, self.num_tasks)), } @@ -128,33 +115,34 @@ class TestModelWithVarGroups(model_base.ModelBase): WEIGHTS = 'weights' BIAS = 'bias' - def _create_tf_graph(self): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self._weights = {} + self._biases = {} + + for task in self.task_list: + self._weights[task] = tf.Variable([0.5], dtype=self.dtype) + self._variable_groups[TestModelWithVarGroups.WEIGHTS].append( + self._weights[task] + ) + self._biases[task] = tf.Variable([-0.5], dtype=self.dtype) + self._variable_groups[TestModelWithVarGroups.BIAS].append( + self._biases[task] + ) + + def _forward(self, feed_dict): assert not self._use_deltas, 'This model does not support seq2seq.' - self._input_tensor = tf.placeholder( - self.dtype, (None, 1), name='TestModelWithVarGroups._input_tensor' - ) output_parts = [] # NOTE(ondrasej): The weights are initialized to 0.5, and the biases are # initialized to -0.5. These initial values are intentionally chosen because # they do not provide good predictions, and are likely to be changed by the # optimizer. for task in self.task_list: - weight = tf.get_variable( - name=f'weight_{task}', - shape=(1,), - dtype=self.dtype, - initializer=tf_keras.initializers.constant(0.5), - ) - self._variable_groups[TestModelWithVarGroups.WEIGHTS].append(weight) - bias = tf.get_variable( - name=f'bias_{task}', - shape=(1,), - dtype=self.dtype, - initializer=tf_keras.initializers.constant(-0.5), + output_parts.append( + self._weights[task] * feed_dict['input'] + self._biases[task] ) - self._variable_groups[TestModelWithVarGroups.BIAS].append(bias) - output_parts.append(weight * self._input_tensor + bias) - self._output_tensor = tf.concat(output_parts, axis=1) + return {'output': tf.concat(output_parts, axis=1)} def _make_model_name(self): return 'TestModelWithVarGroups' @@ -168,7 +156,7 @@ def _add_basic_block_to_batch(self, block): def _make_batch_feed_dict(self): return { - self._input_tensor: np.array( + 'input': np.array( self._batch_block_sizes, dtype=self.numpy_dtype ).reshape((-1, 1)), } @@ -189,46 +177,6 @@ def test_initialize_model_base(self): model.initialize() - def test_output_tensor_names(self): - for use_custom_names in [True, False]: - with tf.Graph().as_default(): - model = TestModel( - dtype=tf.dtypes.float32, use_custom_output_names=use_custom_names - ) - model.initialize() - # NOTE(ondrasej): TensorFlow adds a ":\d+" to all tensor names. The - # number is the index of the tensor in the list of outputs of the op - # that produced it. In case of this model, the output tensor is the - # first output of the identity op (used for renaming the tensor). - self.assertEqual( - model.output_tensor.name, - model_base.ModelBase.OUTPUT_TENSOR_NAME + ':0', - ) - - with self.assertRaisesRegex( - AttributeError, 'output_tensor_deltas is available only' - ): - _ = model.output_tensor_deltas - - def test_output_tensor_names_seq2seq(self): - for use_custom_names in [True, False]: - with tf.Graph().as_default(): - model_seq2seq = TestModel( - use_deltas=True, - dtype=tf.dtypes.float32, - use_custom_output_names=use_custom_names, - ) - model_seq2seq.initialize() - - self.assertEqual( - model_seq2seq.output_tensor.name, - model_base.ModelBase.OUTPUT_TENSOR_NAME + ':0', - ) - self.assertEqual( - model_seq2seq.output_tensor_deltas.name, - model_base.ModelBase.OUTPUT_TENSOR_DELTAS_NAME + ':0', - ) - def test_schedule_batch_with_throughputs(self): model = TestModel(dtype=tf.dtypes.float32) model.initialize() @@ -236,7 +184,7 @@ def test_schedule_batch_with_throughputs(self): # Schedule a batch with no limits. full_schedule = model.schedule_batch(self.blocks_with_throughput) self.assertLen(self.blocks_with_throughput, model.num_visited_blocks) - expected_outputs = full_schedule[model._expected_outputs] + expected_outputs = full_schedule['expected_outputs'] self.assertEqual( expected_outputs.shape, (len(self.blocks_with_throughput), 1) ) @@ -247,17 +195,16 @@ def test_schedule_batch_with_throughputs(self): self.blocks_with_throughput, max_blocks_in_batch=batch_size ) self.assertEqual(model.num_visited_blocks, batch_size) - expected_outputs = block_batch_schedule[model._expected_outputs] + expected_outputs = block_batch_schedule['expected_outputs'] self.assertEqual(expected_outputs.shape, (batch_size, 1)) - with self.session() as sess: - output = sess.run(model.output_tensor, feed_dict=full_schedule) - self.assertAllEqual( - output, [[x + 1] for x in range(len(self.blocks_with_throughput))] - ) + output = model(full_schedule)['output'] + self.assertAllEqual( + output, [[x + 1] for x in range(len(self.blocks_with_throughput))] + ) - output = sess.run(model.output_tensor, feed_dict=block_batch_schedule) - self.assertAllEqual(output, [[x + 1] for x in range(batch_size)]) + output = model(block_batch_schedule)['output'] + self.assertAllEqual(output, [[x + 1] for x in range(batch_size)]) def test_schedule_batch_with_throughputs_with_deltas(self): model = TestModel(dtype=tf.dtypes.float32, use_deltas=True) @@ -266,12 +213,12 @@ def test_schedule_batch_with_throughputs_with_deltas(self): # Schedule a batch with no limits. full_schedule = model.schedule_batch(self.blocks_with_throughput) self.assertLen(self.blocks_with_throughput, model.num_visited_blocks) - expected_outputs = full_schedule[model._expected_outputs] + expected_outputs = full_schedule['expected_outputs'] self.assertEqual( expected_outputs.shape, (len(self.blocks_with_throughput), 1) ) - expected_outputs_prefixes = full_schedule[model._expected_outputs_deltas] + expected_outputs_prefixes = full_schedule['expected_outputs_deltas'] expected_num_prefixes = sum( len(block.instructions) for block in self.blocks ) @@ -285,12 +232,10 @@ def test_schedule_batch_with_throughputs_with_deltas(self): self.blocks_with_throughput, max_blocks_in_batch=batch_size ) self.assertEqual(model.num_visited_blocks, batch_size) - expected_outputs = block_batch_schedule[model._expected_outputs] + expected_outputs = block_batch_schedule['expected_outputs'] self.assertEqual(expected_outputs.shape, (batch_size, 1)) - expected_outputs_prefixes = block_batch_schedule[ - model._expected_outputs_deltas - ] + expected_outputs_prefixes = block_batch_schedule['expected_outputs_deltas'] expected_len_prefixes = sum( len(block.instructions) for block in self.blocks[:batch_size] ) @@ -298,39 +243,36 @@ def test_schedule_batch_with_throughputs_with_deltas(self): expected_outputs_prefixes.shape, (expected_len_prefixes, 1) ) - with self.session() as sess: - output_blocks, output_deltas = sess.run( - [model.output_tensor, model.output_tensor_deltas], - feed_dict=full_schedule, - ) + output = model(full_schedule) + output_blocks = output['output'] + output_deltas = output['output_deltas'] - expected_output_blocks = [] - expected_output_deltas = [] - for i, block in enumerate(self.blocks_with_throughput): - expected_output_blocks.append((i + 1) * len(block.block.instructions)) - for _ in block.block.instructions: - expected_output_deltas.append([i + 1]) + expected_output_blocks = [] + expected_output_deltas = [] + for i, block in enumerate(self.blocks_with_throughput): + expected_output_blocks.append((i + 1) * len(block.block.instructions)) + for _ in block.block.instructions: + expected_output_deltas.append([i + 1]) - output_blocks = np.reshape(output_blocks, 10) + output_blocks = np.reshape(output_blocks, 10) - self.assertAllEqual(output_blocks, expected_output_blocks) - self.assertAllEqual(output_deltas, expected_output_deltas) + self.assertAllEqual(output_blocks, expected_output_blocks) + self.assertAllEqual(output_deltas, expected_output_deltas) - output_blocks, output_deltas = sess.run( - [model.output_tensor, model.output_tensor_deltas], - feed_dict=block_batch_schedule, - ) - output_blocks = np.reshape(output_blocks, 3) + output = model(block_batch_schedule) + output_blocks = output['output'] + output_deltas = output['output_deltas'] + output_blocks = np.reshape(output_blocks, 3) - expected_output_blocks = [] - expected_output_deltas = [] - for i, block in enumerate(self.blocks_with_throughput[:batch_size]): - expected_output_blocks.append((i + 1) * len(block.block.instructions)) - for _ in block.block.instructions: - expected_output_deltas.append([i + 1]) + expected_output_blocks = [] + expected_output_deltas = [] + for i, block in enumerate(self.blocks_with_throughput[:batch_size]): + expected_output_blocks.append((i + 1) * len(block.block.instructions)) + for _ in block.block.instructions: + expected_output_deltas.append([i + 1]) - self.assertAllEqual(output_blocks, expected_output_blocks) - self.assertAllEqual(output_deltas, expected_output_deltas) + self.assertAllEqual(output_blocks, expected_output_blocks) + self.assertAllEqual(output_deltas, expected_output_deltas) def test_schedule_batch_and_train_with_masked_outputs(self): task_list = ('task_1', 'task_2') @@ -352,17 +294,15 @@ def test_schedule_batch_and_train_with_masked_outputs(self): feed_dict = model.schedule_batch(blocks) self.assertAllEqual( - feed_dict[model._output_mask], ((True, False), (False, True)) + feed_dict['output_mask'], ((True, False), (False, True)) ) - with self.session() as sess: - self.check_training_model( - model, - num_epochs=30, - blocks=blocks, - session=sess, - print_output_to_log=True, - ) + self.check_training_model( + model, + num_epochs=30, + blocks=blocks, + print_output_to_log=True, + ) def test_expected_outputs_delta(self): model = TestModel(dtype=tf.dtypes.float32, use_deltas=True) @@ -370,8 +310,8 @@ def test_expected_outputs_delta(self): for block in self.blocks_with_throughput: schedule = model.schedule_batch([block], randomize_batch=False) - expected_outputs = schedule[model._expected_outputs] - expected_output_deltas = schedule[model._expected_outputs_deltas] + expected_outputs = schedule['expected_outputs'] + expected_output_deltas = schedule['expected_outputs_deltas'] self.assertEqual(expected_outputs.shape, (1, 1)) self.assertEqual( @@ -396,8 +336,8 @@ def test_randomized_expected_outputs_delta(self): for block in self.blocks_with_throughput: schedule = model.schedule_batch([block], randomize_expected_outputs=True) - expected_outputs = schedule[model._expected_outputs] - expected_output_deltas = schedule[model._expected_outputs_deltas] + expected_outputs = schedule['expected_outputs'] + expected_output_deltas = schedule['expected_outputs_deltas'] self.assertEqual(expected_outputs.shape, (1, 1)) self.assertEqual( @@ -451,13 +391,12 @@ def test_schedule_batch_with_instruction_limit(self): ) model.initialize() - with self.session() as sess: - # Use the second block from testdata/basic_blocks_with_throughput.pbtxt. - # This basic block has overall inverse throughput equal to 2.0, i.e. the - # loss with MSE and absolute error must be 1.0. - schedule = model.schedule_batch([self.blocks_with_throughput[1]]) - loss = sess.run(model.loss_tensor, feed_dict=schedule) - self.assertEqual(loss, 1.0) + # Use the second block from testdata/basic_blocks_with_throughput.pbtxt. + # This basic block has overall inverse throughput equal to 2.0, i.e. the + # loss with MSE and absolute error must be 1.0. + schedule = model.schedule_batch([self.blocks_with_throughput[1]]) + loss = model.compute_loss_tensor(schedule) + self.assertEqual(loss, 1.0) def test_seq2seq_delta_loss(self): model = TestModel( @@ -468,15 +407,14 @@ def test_seq2seq_delta_loss(self): ) model.initialize() - with self.session() as sess: - # Use the second block from testdata/basic_blocks_with_throughput.pbtxt. - # This basic block has 5 prefixes with inverse throughputs - # [1, 1, 1, 1, 2] and deltas [1, 0, 0, 0, 1]. The model predicts - # deltas [1, 1, 1, 1, 1] and the delta-based loss is thus - # (0 + 1 + 1 + 1 + 0)/5. - schedule = model.schedule_batch([self.blocks_with_throughput[1]]) - loss = sess.run(model.loss_tensor, feed_dict=schedule) - self.assertNear(loss, (0 + 1 + 1 + 1 + 0) / 5, 1e-6) + # Use the second block from testdata/basic_blocks_with_throughput.pbtxt. + # This basic block has 5 prefixes with inverse throughputs + # [1, 1, 1, 1, 2] and deltas [1, 0, 0, 0, 1]. The model predicts + # deltas [1, 1, 1, 1, 1] and the delta-based loss is thus + # (0 + 1 + 1 + 1 + 0)/5. + schedule = model.schedule_batch([self.blocks_with_throughput[1]]) + loss = model.compute_loss_tensor(schedule) + self.assertNear(loss, (0 + 1 + 1 + 1 + 0) / 5, 1e-6) def test_seq2seq_no_delta_loss(self): model = TestModel( @@ -487,14 +425,13 @@ def test_seq2seq_no_delta_loss(self): ) model.initialize() - with self.session() as sess: - # Use the second block from testdata/basic_blocks_with_throughput.pbtxt. - # This basic block has 5 prefixes with inverse throughputs - # [1, 1, 1, 4/3, 2] and deltas [1, 0, 0, 1/3, 2/3]. The model predicts - # deltas [1, 1, 1, 1, 1] and the per-basic block loss is thus (5-2)^2 = 9. - schedule = model.schedule_batch([self.blocks_with_throughput[1]]) - loss = sess.run(model.loss_tensor, feed_dict=schedule) - self.assertNear(loss, 9, 1e-6) + # Use the second block from testdata/basic_blocks_with_throughput.pbtxt. + # This basic block has 5 prefixes with inverse throughputs + # [1, 1, 1, 4/3, 2] and deltas [1, 0, 0, 1/3, 2/3]. The model predicts + # deltas [1, 1, 1, 1, 1] and the per-basic block loss is thus (5-2)^2 = 9. + schedule = model.schedule_batch([self.blocks_with_throughput[1]]) + loss = model.compute_loss_tensor(schedule) + self.assertNear(loss, 9, 1e-6) def check_predict( self, @@ -517,35 +454,33 @@ def check_predict( expected_batch_sizes: A collection of expected sizes of batches processed by model.predict(), verified by this method. """ - with self.session() as sess: - output_blocks = tuple( - model.predict( - sess, - self.blocks, - max_blocks_in_batch=max_blocks_in_batch, - max_instructions_in_batch=max_instructions_in_batch, - ) - ) - self.assertEqual(model.batch_sizes, expected_batch_sizes) - self.assertLen(output_blocks, len(self.blocks_with_throughput)) - for index, (in_block, out_block) in enumerate( - zip(self.blocks, output_blocks) - ): - # The prediction of the model is the number of calls to - # model._add_basic_block_to_batch(). There is one call per basic block, - # so we can get the expected value from the index of the basic block. - expected_inverse_throughputs = [] - for task_index in range(model.num_tasks): - expected_inverse_throughputs.append((index + 1 + task_index,)) - self.assertEqual(in_block, out_block.block) - self.assertLen(out_block.throughputs, model.num_tasks) - predicted_throughputs = [ - throughput.inverse_throughput_cycles - for throughput in out_block.throughputs - ] - self.assertSequenceEqual( - predicted_throughputs, expected_inverse_throughputs + output_blocks = tuple( + model.predict( + self.blocks, + max_blocks_in_batch=max_blocks_in_batch, + max_instructions_in_batch=max_instructions_in_batch, ) + ) + self.assertEqual(model.batch_sizes, expected_batch_sizes) + self.assertLen(output_blocks, len(self.blocks_with_throughput)) + for index, (in_block, out_block) in enumerate( + zip(self.blocks, output_blocks) + ): + # The prediction of the model is the number of calls to + # model._add_basic_block_to_batch(). There is one call per basic block, + # so we can get the expected value from the index of the basic block. + expected_inverse_throughputs = [] + for task_index in range(model.num_tasks): + expected_inverse_throughputs.append((index + 1 + task_index,)) + self.assertEqual(in_block, out_block.block) + self.assertLen(out_block.throughputs, model.num_tasks) + predicted_throughputs = [ + throughput.inverse_throughput_cycles + for throughput in out_block.throughputs + ] + self.assertSequenceEqual( + predicted_throughputs, expected_inverse_throughputs + ) def test_predict_single_batch(self): model = TestModel(dtype=tf.dtypes.float32) @@ -590,35 +525,34 @@ def test_predict_with_both_limits(self): ) def check_predict_deltas(self, model): - with self.session() as sess: - output_blocks = tuple(model.predict(sess, self.blocks)) - self.assertLen(output_blocks, len(self.blocks)) + output_blocks = tuple(model.predict(self.blocks)) + self.assertLen(output_blocks, len(self.blocks)) + + for index, (in_block, out_block) in enumerate( + zip(self.blocks, output_blocks) + ): + # Sum up delta predictions for the model with deltas. + # predictions. + num_instructions = len(in_block.instructions) + + # Inverse throughput on one prefix. + expected_throughputs = [] + for task_index in range(model.num_tasks): + pref_inv_throughputs = (index + 1 + task_index,) + + expected_throughputs.append( + throughput.BasicBlockThroughput( + inverse_throughput_cycles=( + num_instructions * (index + 1 + task_index), + ), + prefix_inverse_throughput_cycles=(pref_inv_throughputs,) + * num_instructions, + ) + ) - for index, (in_block, out_block) in enumerate( - zip(self.blocks, output_blocks) - ): - # Sum up delta predictions for the model with deltas. - # predictions. - num_instructions = len(in_block.instructions) - - # Inverse throughput on one prefix. - expected_throughputs = [] - for task_index in range(model.num_tasks): - pref_inv_throughputs = (index + 1 + task_index,) - - expected_throughputs.append( - throughput.BasicBlockThroughput( - inverse_throughput_cycles=( - num_instructions * (index + 1 + task_index), - ), - prefix_inverse_throughput_cycles=(pref_inv_throughputs,) - * num_instructions, - ) - ) - - self.assertEqual(in_block, out_block.block) - self.assertLen(out_block.throughputs, model.num_tasks) - self.assertEqual(out_block.throughputs, expected_throughputs) + self.assertEqual(in_block, out_block.block) + self.assertLen(out_block.throughputs, model.num_tasks) + self.assertEqual(out_block.throughputs, expected_throughputs) def test_predict_deltas(self): model = TestModel(dtype=tf.dtypes.float32, use_deltas=True) @@ -652,7 +586,7 @@ def test_schedule_batch_with_throughputs_multi_task(self): # Schedule a batch with no limits. full_schedule = model.schedule_batch(self.blocks_with_throughput) self.assertLen(self.blocks_with_throughput, model.num_visited_blocks) - expected_outputs = full_schedule[model._expected_outputs] + expected_outputs = full_schedule['expected_outputs'] self.assertEqual( expected_outputs.shape, (len(self.blocks_with_throughput), num_tasks) ) @@ -663,20 +597,21 @@ def test_schedule_batch_with_throughputs_multi_task(self): self.blocks_with_throughput, max_blocks_in_batch=batch_size ) self.assertEqual(model.num_visited_blocks, batch_size) - expected_outputs = block_batch_schedule[model._expected_outputs] + expected_outputs = block_batch_schedule['expected_outputs'] self.assertEqual(expected_outputs.shape, (batch_size, num_tasks)) - with self.session() as sess: - output = sess.run(model.output_tensor, feed_dict=full_schedule) - self.assertAllEqual( - output, - [[x + 1, x + 2] for x in range(len(self.blocks_with_throughput))], - ) + output = model(full_schedule)['output'] + self.assertAllEqual( + output, + [[x + 1, x + 2] for x in range(len(self.blocks_with_throughput))], + ) - output = sess.run(model.output_tensor, feed_dict=block_batch_schedule) - self.assertAllEqual(output, [[x + 1, x + 2] for x in range(batch_size)]) + output = model(block_batch_schedule)['output'] + self.assertAllEqual(output, [[x + 1, x + 2] for x in range(batch_size)]) - def test_schedule_batch_with_throughputs_multi_task_with_deltas(self): + def test_schedule_batch_with_throughputs_multi_task_with_deltas( + self, + ): tasks = ['llvm', 'test'] num_tasks = len(tasks) @@ -689,10 +624,10 @@ def test_schedule_batch_with_throughputs_multi_task_with_deltas(self): # Schedule a batch with no limits. full_schedule = model.schedule_batch(blocks) self.assertLen(blocks, model.num_visited_blocks) - expected_outputs = full_schedule[model._expected_outputs] + expected_outputs = full_schedule['expected_outputs'] self.assertEqual(expected_outputs.shape, (len(blocks), num_tasks)) - expected_outputs_prefixes = full_schedule[model._expected_outputs_deltas] + expected_outputs_prefixes = full_schedule['expected_outputs_deltas'] expected_num_prefixes = sum( len(block.block.instructions) for block in blocks ) @@ -700,28 +635,26 @@ def test_schedule_batch_with_throughputs_multi_task_with_deltas(self): expected_outputs_prefixes.shape, (expected_num_prefixes, num_tasks) ) - with self.session() as sess: - output_blocks, output_deltas = sess.run( - (model.output_tensor, model.output_tensor_deltas), - feed_dict=full_schedule, - ) + output = model(full_schedule) + output_blocks = output['output'] + output_deltas = output['output_deltas'] + + expected_output_blocks = [] + expected_output_deltas = [] + for i, block in enumerate(blocks): + block_expected_output = [] + for task in range(num_tasks): + block_expected_output.append( + (i + 1 + task) * len(block.block.instructions) + ) + expected_output_blocks.append(block_expected_output) + for _ in block.block.instructions: + expected_output_deltas.append( + [i + 1 + task for task in range(num_tasks)] + ) - expected_output_blocks = [] - expected_output_deltas = [] - for i, block in enumerate(blocks): - block_expected_output = [] - for task in range(num_tasks): - block_expected_output.append( - (i + 1 + task) * len(block.block.instructions) - ) - expected_output_blocks.append(block_expected_output) - for _ in block.block.instructions: - expected_output_deltas.append( - [i + 1 + task for task in range(num_tasks)] - ) - - self.assertAllEqual(output_blocks, expected_output_blocks) - self.assertAllEqual(output_deltas, expected_output_deltas) + self.assertAllEqual(output_blocks, expected_output_blocks) + self.assertAllEqual(output_deltas, expected_output_deltas) def test_training_with_full_variable_list(self): task_list = ['foo', 'bar'] @@ -736,19 +669,17 @@ def test_training_with_full_variable_list(self): ), ) model.initialize() - with self.session() as sess: - self.check_training_model( - model, - num_epochs=40, - blocks=self.blocks_with_throughput[0:2], - session=sess, - ) - biases = sess.run(model._variable_groups[TestModelWithVarGroups.BIAS]) - for bias in biases: - self.assertNotAlmostEqual(float(bias), -0.5) - weights = sess.run(model._variable_groups[TestModelWithVarGroups.WEIGHTS]) - for weight in weights: - self.assertNotAlmostEqual(float(weight), 0.5) + self.check_training_model( + model, + num_epochs=40, + blocks=self.blocks_with_throughput[0:2], + ) + biases = model._variable_groups[TestModelWithVarGroups.BIAS] + for bias in biases: + self.assertNotAlmostEqual(float(bias), -0.5) + weights = model._variable_groups[TestModelWithVarGroups.WEIGHTS] + for weight in weights: + self.assertNotAlmostEqual(float(weight), 0.5) def test_training_bias_only(self): task_list = ['foo', 'bar'] @@ -760,19 +691,17 @@ def test_training_bias_only(self): trained_variable_groups=(TestModelWithVarGroups.BIAS,), ) model.initialize() - with self.session() as sess: - self.check_training_model( - model, - num_epochs=40, - blocks=self.blocks_with_throughput[0:1], - session=sess, - ) - biases = sess.run(model._variable_groups[TestModelWithVarGroups.BIAS]) - for bias in biases: - self.assertNotAlmostEqual(float(bias), -0.5) - weights = sess.run(model._variable_groups[TestModelWithVarGroups.WEIGHTS]) - for weight in weights: - self.assertAlmostEqual(float(weight), 0.5) + self.check_training_model( + model, + num_epochs=40, + blocks=self.blocks_with_throughput[0:1], + ) + biases = model._variable_groups[TestModelWithVarGroups.BIAS] + for bias in biases: + self.assertNotAlmostEqual(float(bias), -0.5) + weights = model._variable_groups[TestModelWithVarGroups.WEIGHTS] + for weight in weights: + self.assertAlmostEqual(float(weight), 0.5) def test_grad_clipping(self): task_list = ['foo', 'bar'] @@ -785,13 +714,11 @@ def test_grad_clipping(self): trained_variable_groups=(TestModelWithVarGroups.BIAS,), ) model.initialize() - with self.session() as sess: - self.check_training_model( - model, - num_epochs=40, - blocks=self.blocks_with_throughput[0:1], - session=sess, - ) + self.check_training_model( + model, + num_epochs=40, + blocks=self.blocks_with_throughput[0:1], + ) def test_training_weight_only(self): task_list = ['foo', 'bar'] @@ -803,21 +730,18 @@ def test_training_weight_only(self): trained_variable_groups=(TestModelWithVarGroups.WEIGHTS,), ) model.initialize() - with self.session() as sess: - self.check_training_model( - model, - num_epochs=40, - blocks=self.blocks_with_throughput[0:1], - session=sess, - ) - biases = sess.run(model._variable_groups[TestModelWithVarGroups.BIAS]) - for bias in biases: - self.assertAlmostEqual(float(bias), -0.5) - weights = sess.run(model._variable_groups[TestModelWithVarGroups.WEIGHTS]) - for weight in weights: - self.assertNotAlmostEqual(float(weight), 0.5) + self.check_training_model( + model, + num_epochs=40, + blocks=self.blocks_with_throughput[0:1], + ) + biases = model._variable_groups[TestModelWithVarGroups.BIAS] + for bias in biases: + self.assertAlmostEqual(float(bias), -0.5) + weights = model._variable_groups[TestModelWithVarGroups.WEIGHTS] + for weight in weights: + self.assertNotAlmostEqual(float(weight), 0.5) if __name__ == '__main__': - tf.disable_v2_behavior() tf.test.main() diff --git a/gematria/testing/python/model_test.py b/gematria/testing/python/model_test.py index 4d78fe4a..f003c714 100644 --- a/gematria/testing/python/model_test.py +++ b/gematria/testing/python/model_test.py @@ -118,7 +118,6 @@ def check_training_model( max_expected_min_loss=0.2, log_directory=None, print_output_to_log=False, - session=None, ): """Tests training the given model. @@ -138,17 +137,12 @@ def check_training_model( are not stored. print_output_to_log: When True, the contents of the output tensor is printed to the log at each step. - session: An optional session to run the training in. If `session` is not - None, the method will run the training in it, but it will not release - the session at the end. If `session` is None, the function will create a - session just for the training, and it will release it at the end. """ blocks = blocks or self.blocks_with_throughput - def _check_training(sess): + def _check_training(): if log_directory is not None: - tf.summary.FileWriter(logdir=log_directory, graph=sess.graph) - sess.run(tf.global_variables_initializer()) + tf.summary.FileWriter(logdir=log_directory) schedule = model.schedule_batch(blocks) # The loss at the end of the training may increase temporarily, and it is @@ -157,15 +151,15 @@ def _check_training(sess): min_mse = [math.inf] * model.num_tasks min_relative_mse = [math.inf] * model.num_tasks for epoch in range(num_epochs): - stats = model.train_batch(sess, schedule) + stats = model.train_batch(schedule) if print_output_to_log: - output = sess.run(model.output_tensor, schedule) + output = model(schedule) # The output is a 2D tensor of shape (batch_size, num_tasks). When # num_tasks == 1, the output is a 2D column tensor. We reshape it to # (num_tasks,), so that it prints on a single line. When num_tasks > 1 # we leave the shape as is. - if output.shape[1] == 1: - output = output.reshape((-1,)) + if output['output'].shape[1] == 1: + output['output'] = output['output'].reshape((-1,)) logging.info('Output: %r', output) # Check basic properties. self.assertEqual(stats.epoch, epoch + 1) @@ -193,12 +187,4 @@ def _check_training(sess): ) self.assertAllLess(min_relative_mse, max_expected_min_loss) - if session: - # If an external session was provided, just run the training in the - # session and assume that the owner will take care of releasing it - # afterwards. - _check_training(session) - else: - # Otherwise, create a session for this call and release it at the end. - with self.session() as session: - _check_training(session) + _check_training()