Skip to content

Commit

Permalink
Merge pull request #273 from kylebgorman/predict2
Browse files Browse the repository at this point in the history
Fixes prediction across all architectures
  • Loading branch information
kylebgorman authored Dec 2, 2024
2 parents 3e1c4cd + e32444f commit 0a91f56
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 428 deletions.
24 changes: 12 additions & 12 deletions yoyodyne/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def finalize_predictions(
Returns:
torch.Tensor: finalized predictions.
"""
return util.pad_tensor_after_eos(predictions)
return util.pad_tensor_after_end(predictions)

def finalize_golds(
self,
Expand Down Expand Up @@ -197,8 +197,8 @@ def _finalize_tensor(
) -> List[torch.Tensor]:
"""Finalizes each tensor.
Truncates at EOS for each prediction and returns a List of predictions.
This does basically the same as util.pad_after_eos, but does not
Truncates at END for each prediction and returns a List of predictions.
This does basically the same as util.pad_tensor_after_end, but does not
actually pad since we do not need to return a well-formed tensor.
Args:
Expand All @@ -212,21 +212,21 @@ def _finalize_tensor(
return [tensor]
out = []
for prediction in tensor:
# Gets first instance of EOS.
eos = (prediction == special.END_IDX).nonzero(as_tuple=False)
if len(eos) > 0 and eos[0].item() < len(prediction):
# If an EOS was decoded and it is not the last one in the
# Gets first instance of END.
end = (prediction == special.END_IDX).nonzero(as_tuple=False)
if len(end) > 0 and end[0].item() < len(prediction):
# If an END was decoded and it is not the last one in the
# sequence.
eos = eos[0]
end = end[0]
else:
# Leaves tensor[i] alone.
out.append(prediction)
continue
# Hack in case the first prediction is EOS. In this case
# Hack in case the first prediction is END. In this case
# torch.split will result in an error, so we change these 0's to
# 1's, which will make the entire sequence EOS as intended.
eos[eos == 0] = 1
symbols, *_ = torch.split(prediction, eos)
# 1's, which will make the entire sequence END as intended.
end[end == 0] = 1
symbols, *_ = torch.split(prediction, end)
out.append(symbols)
return out

Expand Down
38 changes: 5 additions & 33 deletions yoyodyne/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,31 +176,6 @@ def init_embeddings(num_embed: int, embed_size: int) -> nn.Embedding:
def get_decoder(self):
raise NotImplementedError

def beam_decode(
self,
encoder_out: torch.Tensor,
mask: torch.Tensor,
beam_width: int,
):
"""Method interface for beam search.
Args:
encoder_out (torch.Tensor): encoded inputs.
encoder_mask (torch.Tensor).
beam_width (int): size of the beam; also determines the number of
hypotheses to return.
Raises:
NotImplementedError: This method needs to be overridden.
Returns:
Tuple[torch.Tensor, torch.Tensor]: the predictions tensor and the
log-likelihood of each prediction.
"""
raise NotImplementedError(
f"Beam search not implemented for {self.name} model"
)

@property
def num_parameters(self) -> int:
return sum(part.numel() for part in self.parameters())
Expand Down Expand Up @@ -308,7 +283,7 @@ def predict_step(
self,
batch: data.PaddedBatch,
batch_idx: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Runs one predict step.
This is called by the PL Trainer.
Expand All @@ -318,19 +293,16 @@ def predict_step(
batch_idx (int).
Returns:
Tuple[torch.Tensor, torch.Tensor]: position 0 are the indices of
the argmax at each timestep. Position 1 are the scores for each
history in beam search. It will be None when using greedy.
Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if
using beam search, the predictions and scores as a tuple of
tensors; if using greedy search, the predictions as a tensor.
"""
predictions = self(batch)
if self.beam_width > 1:
predictions, scores = predictions
return predictions, scores
else:
# -> B x seq_len x 1.
greedy_predictions = self._get_predicted(predictions)
return greedy_predictions, None
return self._get_predicted(predictions)

def _get_predicted(self, predictions: torch.Tensor) -> torch.Tensor:
"""Picks the best index from the vocabulary.
Expand Down
2 changes: 1 addition & 1 deletion yoyodyne/models/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def _generate_data(
) -> Iterator[Tuple[List[int], List[int]]]:
"""Helper function to manage data encoding for SED."
We want encodings without BOS or EOS tokens. This
We want encodings without BOS or END tokens. This
encodes only raw source-target text for the Maxwell library.
Args:
Expand Down
100 changes: 56 additions & 44 deletions yoyodyne/models/hard_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,46 +89,6 @@ def init_decoding(
bos, decoder_hiddens, encoder_out, encoder_mask
)

def forward(
self,
batch: data.PaddedBatch,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Runs the encoder-decoder model.
Args:
batch (data.PaddedBatch).
Returns:
Tuple[torch.Tensor,torch.Tensor]: emission probabilities for
each transition state of shape tgt_len x batch_size x src_len
x vocab_size, and transition probabilities for each transition
state of shape batch_size x src_len x src_len.
"""
encoder_out = self.source_encoder(batch.source).output
if self.has_features_encoder:
encoder_features_out = self.features_encoder(batch.features).output
# Averages to flatten embedding.
encoder_features_out = encoder_features_out.sum(
dim=1, keepdim=True
)
# Sums to flatten embedding; this is done as an alternative to the
# linear projection used in the original paper.
encoder_features_out = encoder_features_out.expand(
-1, encoder_out.shape[1], -1
)
# Concatenates with the average.
encoder_out = torch.cat(
[encoder_out, encoder_features_out], dim=-1
)
if self.training:
return self.decode(
encoder_out,
batch.source.mask,
batch.target.padded,
)
else:
return self.greedy_decode(encoder_out, batch.source.mask)

def decode(
self,
encoder_out: torch.Tensor,
Expand All @@ -137,8 +97,8 @@ def decode(
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Decodes a sequence given the encoded input.
Decodes until all sequences in a batch have reached [EOS] up to
length of `target` args.
Decodes until all sequences in a batch have reached END up to length of
`target` args.
Args:
encoder_out (torch.Tensor): batch of encoded input symbols
Expand Down Expand Up @@ -173,15 +133,21 @@ def decode(
all_transition_probs.append(transition_probs)
return torch.stack(all_log_probs), torch.stack(all_transition_probs)

def beam_decode(self, *args, **kwargs):
"""Overrides incompatible implementation inherited from RNNModel."""
raise NotImplementedError(
f"Beam search not implemented for {self.name} model"
)

def greedy_decode(
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Decodes a sequence given the encoded input.
Decodes until all sequences in a batch have reached [EOS] up to
a specified length depending on the `target` args.
Decodes until all sequences in a batch have reached END up to a
specified length depending on the `target` args.
Args:
encoder_out (torch.Tensor): batch of encoded input symbols
Expand Down Expand Up @@ -307,6 +273,52 @@ def _apply_mono_mask(
)
return transition_prob

def forward(
self,
batch: data.PaddedBatch,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Runs the encoder-decoder model.
Args:
batch (data.PaddedBatch).
Returns:
Tuple[torch.Tensor,torch.Tensor]: emission probabilities for
each transition state of shape tgt_len x batch_size x src_len
x vocab_size, and transition probabilities for each transition
Raises:
NotImplementedError: beam search not implemented.
state of shape batch_size x src_len x src_len.
"""
encoder_out = self.source_encoder(batch.source).output
if self.has_features_encoder:
encoder_features_out = self.features_encoder(batch.features).output
# Averages to flatten embedding.
encoder_features_out = encoder_features_out.sum(
dim=1, keepdim=True
)
# Sums to flatten embedding; this is done as an alternative to the
# linear projection used in the original paper.
encoder_features_out = encoder_features_out.expand(
-1, encoder_out.shape[1], -1
)
# Concatenates with the average.
encoder_out = torch.cat(
[encoder_out, encoder_features_out], dim=-1
)
if self.training:
return self.decode(
encoder_out,
batch.source.mask,
batch.target.padded,
)
elif self.beam_width > 1:
# Will raise a NotImplementedError.
return self.beam_decode(encoder_out, batch.source.mask)
else:
return self.greedy_decode(encoder_out, batch.source.mask)

def training_step(
self, batch: data.PaddedBatch, batch_idx: int
) -> torch.Tensor:
Expand Down
Loading

0 comments on commit 0a91f56

Please sign in to comment.