Skip to content

Commit

Permalink
update comments
Browse files Browse the repository at this point in the history
Created using spr 1.3.4
  • Loading branch information
boomanaiden154 committed Jan 4, 2025
1 parent 9c205dc commit b6557b6
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions gematria/model/python/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# simpler than what is actually accepted by TensorFlow, but the typing should be
# sufficient for our use. Moreover, since TensorFlow and NumPy do not provide
# type annotations, both the key and the value are reduced to typing.Any.
FeedDict = MutableMapping[tf.Tensor, np.ndarray]
FeedDict = MutableMapping[str, Union[np.ndarray, tf.Tensor]]

# A throughput value used as a placeholder in the expected output tensors for
# masked expected outputs.
Expand Down Expand Up @@ -441,10 +441,20 @@ def output_tensor_names(self) -> Sequence[str]:
return (ModelBase.OUTPUT_TENSOR_NAME,)

@abc.abstractmethod
def _forward(self, feed_dict):
def _forward(self, feed_dict: FeedDict) -> dict[str, tf.Tensor]:
"""Implements the forward pass of the model.
This function should be implemented in downstream models and calculate the
outputs of the model given the inputs specified in feed_dict.
"""
pass

def __call__(self, feed_dict, train=False):
"""Implements the non-model specific part of the forward pass.
This function wraps the _forward method and does relevant calculations
when working with models that use deltas.
"""
if self._use_deltas:
output_dict = {}

Expand Down Expand Up @@ -1113,12 +1123,10 @@ def predict(
) -> Iterable[throughput.BasicBlockWithThroughput]:
"""Predicts the inverse throughput using the model.
Assumes that sess has been initialized and that it contains the weights for
the model. The input sequence is iterated through only once, and the method
may be used with basic block sources such as tf.io.tf_record_iterator.
The input sequence is iterated through only once, and the method may be
used with basic block sources such as tf.io.tf_record_iterator.
Args:
sess: The TensorFlow session object in which the computation is done.
basic_blocks: The collection of basic blocks for which the inverse
throughput is predicted.
max_blocks_in_batch: The maximal number of basic blocks processed in a
Expand Down Expand Up @@ -1305,21 +1313,15 @@ def train_batch(
"""Trains a batch based on the given schedule.
Args:
sess: The TensorFlow session the training is running in.
schedule: A feed_dict that describes the current batch.
Returns:
The loss on the training set before the training step.
"""
with timer.scoped('ModelBase.train_batch'):
# The keys in the are names of keyword arguments of the constructor of
# TraningEpochStats. sess.run() returns a dict that has the same keys and
# the values of these tensors in the training step. This dict can then be
# unpacked to TrainingEpochStats.__init__() as keyword arguments.
# NOTE(ondrasej): The loss tensors are created lazily, when they are first
# referenced. To be used here, the tensors need to be created (referenced)
# during graph creation, e.g. by adding them as TensorFlow summaries in
# self._create_output_and_loss_tensors().
# The keys of stats are the names of keyword arguments of the constructor
# of TraningEpochStats. This dict can then be unpacked to
# TrainingEpochStats.__init__() as keyword arguments.
with tf.GradientTape() as tape:
stats = {}
loss = self.compute_loss(schedule)
Expand Down

0 comments on commit b6557b6

Please sign in to comment.