From b6557b622dfd6666cac7431d369c1ec82ab348bc Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Sat, 4 Jan 2025 01:35:41 +0000 Subject: [PATCH 1/2] update comments Created using spr 1.3.4 --- gematria/model/python/model_base.py | 32 +++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/gematria/model/python/model_base.py b/gematria/model/python/model_base.py index a8d29d85..103bc069 100644 --- a/gematria/model/python/model_base.py +++ b/gematria/model/python/model_base.py @@ -50,7 +50,7 @@ # simpler than what is actually accepted by TensorFlow, but the typing should be # sufficient for our use. Moreover, since TensorFlow and NumPy do not provide # type annotations, both the key and the value are reduced to typing.Any. -FeedDict = MutableMapping[tf.Tensor, np.ndarray] +FeedDict = MutableMapping[str, Union[np.ndarray, tf.Tensor]] # A throughput value used as a placeholder in the expected output tensors for # masked expected outputs. @@ -441,10 +441,20 @@ def output_tensor_names(self) -> Sequence[str]: return (ModelBase.OUTPUT_TENSOR_NAME,) @abc.abstractmethod - def _forward(self, feed_dict): + def _forward(self, feed_dict: FeedDict) -> dict[str, tf.Tensor]: + """Implements the forward pass of the model. + + This function should be implemented in downstream models and calculate the + outputs of the model given the inputs specified in feed_dict. + """ pass def __call__(self, feed_dict, train=False): + """Implements the non-model specific part of the forward pass. + + This function wraps the _forward method and does relevant calculations + when working with models that use deltas. + """ if self._use_deltas: output_dict = {} @@ -1113,12 +1123,10 @@ def predict( ) -> Iterable[throughput.BasicBlockWithThroughput]: """Predicts the inverse throughput using the model. - Assumes that sess has been initialized and that it contains the weights for - the model. The input sequence is iterated through only once, and the method - may be used with basic block sources such as tf.io.tf_record_iterator. + The input sequence is iterated through only once, and the method may be + used with basic block sources such as tf.io.tf_record_iterator. Args: - sess: The TensorFlow session object in which the computation is done. basic_blocks: The collection of basic blocks for which the inverse throughput is predicted. max_blocks_in_batch: The maximal number of basic blocks processed in a @@ -1305,21 +1313,15 @@ def train_batch( """Trains a batch based on the given schedule. Args: - sess: The TensorFlow session the training is running in. schedule: A feed_dict that describes the current batch. Returns: The loss on the training set before the training step. """ with timer.scoped('ModelBase.train_batch'): - # The keys in the are names of keyword arguments of the constructor of - # TraningEpochStats. sess.run() returns a dict that has the same keys and - # the values of these tensors in the training step. This dict can then be - # unpacked to TrainingEpochStats.__init__() as keyword arguments. - # NOTE(ondrasej): The loss tensors are created lazily, when they are first - # referenced. To be used here, the tensors need to be created (referenced) - # during graph creation, e.g. by adding them as TensorFlow summaries in - # self._create_output_and_loss_tensors(). + # The keys of stats are the names of keyword arguments of the constructor + # of TraningEpochStats. This dict can then be unpacked to + # TrainingEpochStats.__init__() as keyword arguments. with tf.GradientTape() as tape: stats = {} loss = self.compute_loss(schedule) From 4e3ef265f913e93bf7f1942e961958121b1812a3 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Mon, 6 Jan 2025 07:15:09 +0000 Subject: [PATCH 2/2] Fix bug Created using spr 1.3.4 --- gematria/model/python/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gematria/model/python/model_base.py b/gematria/model/python/model_base.py index 103bc069..6518752c 100644 --- a/gematria/model/python/model_base.py +++ b/gematria/model/python/model_base.py @@ -472,7 +472,7 @@ def __call__(self, feed_dict, train=False): feed_dict['delta_block_index'], name=ModelBase.OUTPUT_TENSOR_NAME, ) - output_dict['output_deltas'] = feed_dict['output_deltas'] + output_dict['output_deltas'] = output['output_deltas'] return output_dict else: