From e9fb0b017a2098cb31367734e85127a290e2cac5 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 14:15:53 -0500 Subject: [PATCH 1/8] Attempts to fix batch prediction support --- yoyodyne/models/base.py | 6 +- yoyodyne/models/hard_attention.py | 88 +++++++++++++++------------- yoyodyne/models/pointer_generator.py | 20 ++++--- yoyodyne/models/transformer.py | 10 ++-- 4 files changed, 65 insertions(+), 59 deletions(-) diff --git a/yoyodyne/models/base.py b/yoyodyne/models/base.py index 13cc4a80..03972daf 100644 --- a/yoyodyne/models/base.py +++ b/yoyodyne/models/base.py @@ -191,11 +191,7 @@ def beam_decode( hypotheses to return. Raises: - NotImplementedError: This method needs to be overridden. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: the predictions tensor and the - log-likelihood of each prediction. + NotImplementedError: beam search not implemented. """ raise NotImplementedError( f"Beam search not implemented for {self.name} model" diff --git a/yoyodyne/models/hard_attention.py b/yoyodyne/models/hard_attention.py index 4d87087f..4915a223 100644 --- a/yoyodyne/models/hard_attention.py +++ b/yoyodyne/models/hard_attention.py @@ -89,46 +89,6 @@ def init_decoding( bos, decoder_hiddens, encoder_out, encoder_mask ) - def forward( - self, - batch: data.PaddedBatch, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Runs the encoder-decoder model. - - Args: - batch (data.PaddedBatch). - - Returns: - Tuple[torch.Tensor,torch.Tensor]: emission probabilities for - each transition state of shape tgt_len x batch_size x src_len - x vocab_size, and transition probabilities for each transition - state of shape batch_size x src_len x src_len. - """ - encoder_out = self.source_encoder(batch.source).output - if self.has_features_encoder: - encoder_features_out = self.features_encoder(batch.features).output - # Averages to flatten embedding. - encoder_features_out = encoder_features_out.sum( - dim=1, keepdim=True - ) - # Sums to flatten embedding; this is done as an alternative to the - # linear projection used in the original paper. - encoder_features_out = encoder_features_out.expand( - -1, encoder_out.shape[1], -1 - ) - # Concatenates with the average. - encoder_out = torch.cat( - [encoder_out, encoder_features_out], dim=-1 - ) - if self.training: - return self.decode( - encoder_out, - batch.source.mask, - batch.target.padded, - ) - else: - return self.greedy_decode(encoder_out, batch.source.mask) - def decode( self, encoder_out: torch.Tensor, @@ -307,6 +267,54 @@ def _apply_mono_mask( ) return transition_prob + def forward( + self, + batch: data.PaddedBatch, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Runs the encoder-decoder model. + + Args: + batch (data.PaddedBatch). + + Returns: + Tuple[torch.Tensor,torch.Tensor]: emission probabilities for + each transition state of shape tgt_len x batch_size x src_len + x vocab_size, and transition probabilities for each transition + + Raises: + NotImplementedError: beam search not implemented. + state of shape batch_size x src_len x src_len. + """ + encoder_out = self.source_encoder(batch.source).output + if self.has_features_encoder: + encoder_features_out = self.features_encoder(batch.features).output + # Averages to flatten embedding. + encoder_features_out = encoder_features_out.sum( + dim=1, keepdim=True + ) + # Sums to flatten embedding; this is done as an alternative to the + # linear projection used in the original paper. + encoder_features_out = encoder_features_out.expand( + -1, encoder_out.shape[1], -1 + ) + # Concatenates with the average. + encoder_out = torch.cat( + [encoder_out, encoder_features_out], dim=-1 + ) + if self.training: + return self.decode( + encoder_out, + batch.source.mask, + batch.target.padded, + ) + elif self.beam_width > 1: + # Will raise a NotImplementedError. + output = self.beam_decode( + encoder_out, batch.source.mask, beam_width + ) + else: + return self.greedy_decode(encoder_out, batch.source.mask) + def training_step( self, batch: data.PaddedBatch, batch_idx: int ) -> torch.Tensor: diff --git a/yoyodyne/models/pointer_generator.py b/yoyodyne/models/pointer_generator.py index d62182f1..2944cb0b 100644 --- a/yoyodyne/models/pointer_generator.py +++ b/yoyodyne/models/pointer_generator.py @@ -166,7 +166,6 @@ def __init__(self, *args, **kwargs): "The number of encoder and decoder layers must match " f"({self.encoder_layers} != {self.decoder_layers})" ) - # We use the inherited defaults for the source embeddings/encoder. # Overrides classifier to take larger input. if not self.has_features_encoder: @@ -826,7 +825,7 @@ def decode_step( scaled_output_dist = output_dist * gen_probs return torch.log(scaled_output_dist + scaled_ptr_dist) - def _decode_greedy( + def greedy_decode( self, encoder_hidden: torch.Tensor, source_mask: torch.Tensor, @@ -907,6 +906,9 @@ def forward( Returns: torch.Tensor. + + Raises: + NotImplementedError: beam search not implemented. """ source_encoded = self.source_encoder(batch.source).output if self.training and self.teacher_forcing: @@ -930,9 +932,9 @@ def forward( if self.beam_width > 1: # Will raise a NotImplementedError. output = self.beam_decode( - encoder_out=source_encoded, - mask=batch.source.mask, - beam_width=self.beam_width, + source_encoded, + batch.source.mask, + self.beam_width, ) else: output = self.decode_step( @@ -952,13 +954,13 @@ def forward( if self.beam_width > 1: # Will raise a NotImplementedError. output = self.beam_decode( - encoder_out=source_encoded, - mask=batch.source.mask, - beam_width=self.beam_width, + source_encoded, + batch.source.mask, + self.beam_width, ) else: # -> B x seq_len x output_size. - output = self._decode_greedy( + output = self.greedy_decode( source_encoded, batch.source.mask, batch.source.padded, diff --git a/yoyodyne/models/transformer.py b/yoyodyne/models/transformer.py index 5c5f2e88..f595dd41 100644 --- a/yoyodyne/models/transformer.py +++ b/yoyodyne/models/transformer.py @@ -67,7 +67,7 @@ def get_decoder(self) -> modules.transformer.TransformerDecoder: source_attention_heads=self.source_attention_heads, ) - def _decode_greedy( + def greedy_decode( self, encoder_hidden: torch.Tensor, source_mask: torch.Tensor, @@ -172,13 +172,13 @@ def forward( if self.beam_width > 1: # Will raise a NotImplementedError. output = self.beam_decode( - encoder_out=encoder_output, - mask=batch.source.mask, - beam_width=self.beam_width, + encoder_output, + batch.source.mask, + self.beam_width, ) else: # -> B x seq_len x output_size. - output = self._decode_greedy( + output = self.greedy_decode( encoder_output, batch.source.mask, batch.target.padded if batch.target else None, From 8f057c0925cd538c57962b8d7b6d61c4479de2c5 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 14:19:26 -0500 Subject: [PATCH 2/8] Update indexes.py --- yoyodyne/data/indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index f55cc590..f0eeec48 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -89,7 +89,7 @@ def get_symbol(self, index: int) -> str: # Serialization support. @classmethod - def read(cls, model_dir: str, experiment: str) -> Index: + def read(cls, model_dir: str) -> Index: """Loads index. Args: From 33e4bdaf8888aa950b9725e58d065acc7982d180 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 14:19:51 -0500 Subject: [PATCH 3/8] Update README.md --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index ead5099b..6284a02c 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,6 @@ One must specify the following required arguments: - `--arch`: architecture, matching the one used for training - `--model_dir`: path for model metadata -- `--experiment`: name of experiment - `--checkpoint`: path to checkpoint - `--predict`: path to file containing data to be predicted - `--output`: path for predictions From 9689635246977f86a8c5b543f9d62c65d15a7d7e Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 14:42:10 -0500 Subject: [PATCH 4/8] Standardizes name: "END" not "EOS". --- yoyodyne/evaluators.py | 24 ++++++++++++------------ yoyodyne/models/base.py | 5 +---- yoyodyne/models/expert.py | 2 +- yoyodyne/models/hard_attention.py | 26 +++++++++++++++++++++----- yoyodyne/models/pointer_generator.py | 18 +++++++++--------- yoyodyne/models/rnn.py | 23 ++++++++++------------- yoyodyne/models/transducer.py | 4 ++-- yoyodyne/models/transformer.py | 12 ++++++------ yoyodyne/predict.py | 6 +++--- yoyodyne/util.py | 24 ++++++++++++------------ 10 files changed, 77 insertions(+), 67 deletions(-) diff --git a/yoyodyne/evaluators.py b/yoyodyne/evaluators.py index cc8059cc..4b7718e1 100644 --- a/yoyodyne/evaluators.py +++ b/yoyodyne/evaluators.py @@ -136,7 +136,7 @@ def finalize_predictions( Returns: torch.Tensor: finalized predictions. """ - return util.pad_tensor_after_eos(predictions) + return util.pad_tensor_after_end(predictions) def finalize_golds( self, @@ -197,8 +197,8 @@ def _finalize_tensor( ) -> List[torch.Tensor]: """Finalizes each tensor. - Truncates at EOS for each prediction and returns a List of predictions. - This does basically the same as util.pad_after_eos, but does not + Truncates at END for each prediction and returns a List of predictions. + This does basically the same as util.pad_after_end, but does not actually pad since we do not need to return a well-formed tensor. Args: @@ -212,21 +212,21 @@ def _finalize_tensor( return [tensor] out = [] for prediction in tensor: - # Gets first instance of EOS. - eos = (prediction == special.END_IDX).nonzero(as_tuple=False) - if len(eos) > 0 and eos[0].item() < len(prediction): - # If an EOS was decoded and it is not the last one in the + # Gets first instance of END. + end = (prediction == special.END_IDX).nonzero(as_tuple=False) + if len(end) > 0 and end[0].item() < len(prediction): + # If an END was decoded and it is not the last one in the # sequence. - eos = eos[0] + end = end[0] else: # Leaves tensor[i] alone. out.append(prediction) continue - # Hack in case the first prediction is EOS. In this case + # Hack in case the first prediction is END. In this case # torch.split will result in an error, so we change these 0's to - # 1's, which will make the entire sequence EOS as intended. - eos[eos == 0] = 1 - symbols, *_ = torch.split(prediction, eos) + # 1's, which will make the entire sequence END as intended. + end[end == 0] = 1 + symbols, *_ = torch.split(prediction, end) out.append(symbols) return out diff --git a/yoyodyne/models/base.py b/yoyodyne/models/base.py index 03972daf..a1f0c6bb 100644 --- a/yoyodyne/models/base.py +++ b/yoyodyne/models/base.py @@ -180,15 +180,12 @@ def beam_decode( self, encoder_out: torch.Tensor, mask: torch.Tensor, - beam_width: int, ): """Method interface for beam search. Args: - encoder_out (torch.Tensor): encoded inputs. + encoder_out (torch.Tensor). encoder_mask (torch.Tensor). - beam_width (int): size of the beam; also determines the number of - hypotheses to return. Raises: NotImplementedError: beam search not implemented. diff --git a/yoyodyne/models/expert.py b/yoyodyne/models/expert.py index 00662fd6..56af3af7 100644 --- a/yoyodyne/models/expert.py +++ b/yoyodyne/models/expert.py @@ -456,7 +456,7 @@ def _generate_data( ) -> Iterator[Tuple[List[int], List[int]]]: """Helper function to manage data encoding for SED." - We want encodings without BOS or EOS tokens. This + We want encodings without BOS or END tokens. This encodes only raw source-target text for the Maxwell library. Args: diff --git a/yoyodyne/models/hard_attention.py b/yoyodyne/models/hard_attention.py index 4915a223..735e0b14 100644 --- a/yoyodyne/models/hard_attention.py +++ b/yoyodyne/models/hard_attention.py @@ -97,7 +97,7 @@ def decode( ) -> Tuple[torch.Tensor, torch.Tensor]: """Decodes a sequence given the encoded input. - Decodes until all sequences in a batch have reached [EOS] up to + Decodes until all sequences in a batch have reached up to length of `target` args. Args: @@ -133,6 +133,24 @@ def decode( all_transition_probs.append(transition_probs) return torch.stack(all_log_probs), torch.stack(all_transition_probs) + def beam_decode( + self, + encoder_out: torch.Tensor, + mask: torch.Tensor, + ): + """Overrides incompatible implementation inherited from RNNModel. + + Args: + encoder_out (torch.Tensor). + encoder_mask (torch.Tensor). + + Raises: + NotImplementedError: beam search not implemented. + """ + raise NotImplementedError( + f"Beam search not implemented for {self.name} model" + ) + def greedy_decode( self, encoder_out: torch.Tensor, @@ -140,7 +158,7 @@ def greedy_decode( ) -> Tuple[torch.Tensor, torch.Tensor]: """Decodes a sequence given the encoded input. - Decodes until all sequences in a batch have reached [EOS] up to + Decodes until all sequences in a batch have reached up to a specified length depending on the `target` args. Args: @@ -309,9 +327,7 @@ def forward( ) elif self.beam_width > 1: # Will raise a NotImplementedError. - output = self.beam_decode( - encoder_out, batch.source.mask, beam_width - ) + return self.beam_decode(encoder_out, batch.source.mask) else: return self.greedy_decode(encoder_out, batch.source.mask) diff --git a/yoyodyne/models/pointer_generator.py b/yoyodyne/models/pointer_generator.py index 2944cb0b..1b2c56f8 100644 --- a/yoyodyne/models/pointer_generator.py +++ b/yoyodyne/models/pointer_generator.py @@ -285,7 +285,7 @@ def decode( ) -> torch.Tensor: """Decodes a sequence given the encoded input. - Decodes until all sequences in a batch have reached [EOS] up to + Decodes until all sequences in a batch have reached up to a specified length depending on the `target` args. Args: @@ -323,7 +323,7 @@ def decode( num_steps = ( target.size(1) if target is not None else self.max_target_length ) - # Tracks when each sequence has decoded an EOS. + # Tracks when each sequence has decoded an END. finished = torch.zeros(batch_size, device=self.device) for t in range(num_steps): # pred: B x 1 x target_vocab_size. @@ -345,11 +345,11 @@ def decode( # (i.e., student forcing, greedy decoding). else: decoder_input = self._get_predicted(output) - # Tracks which sequences have decoded an EOS. + # Tracks which sequences have decoded an END. finished = torch.logical_or( finished, (decoder_input == special.END_IDX) ) - # Breaks when all sequences have predicted an EOS symbol. If we + # Breaks when all sequences have predicted an END symbol. If we # have a target (and are thus computing loss), we only break # when we have decoded at least the the same number of steps as # the target length. @@ -842,7 +842,7 @@ def greedy_decode( source_indices (torch.Tensor): indices of the source symbols. targets (torch.Tensor, optional): the optional target tokens, which is only used for early stopping during validation - if the decoder has predicted [EOS] for every sequence in + if the decoder has predicted for every sequence in the batch. features_enc (Optional[torch.Tensor]): encoded features. features_mask (Optional[torch.Tensor]): mask for encoded features. @@ -860,7 +860,7 @@ def greedy_decode( device=self.device, ) ] - # Tracking when each sequence has decoded an EOS. + # Tracking when each sequence has decoded an END. finished = torch.zeros(batch_size, device=self.device) for _ in range(self.max_target_length): target_tensor = torch.stack(predictions, dim=1) @@ -881,11 +881,11 @@ def greedy_decode( # -> B x 1 x 1. _, pred = torch.max(last_output, dim=1) predictions.append(pred) - # Updates to track which sequences have decoded an EOS. + # Updates to track which sequences have decoded an END. finished = torch.logical_or( finished, (predictions[-1] == special.END_IDX) ) - # Breaks when all sequences have predicted an EOS symbol. If we + # Breaks when all sequences have predicted an END symbol. If we # have a target (and are thus computing loss), we only break when # we have decoded at least the the same number of steps as the # target length. @@ -945,7 +945,7 @@ def forward( target_mask, features_enc=features_encoded, ) - output = output[:, :-1, :] # Ignore EOS. + output = output[:, :-1, :] # Ignore END. else: features_encoded = None if self.has_features_encoder: diff --git a/yoyodyne/models/rnn.py b/yoyodyne/models/rnn.py index 786707f7..73c6ee1a 100644 --- a/yoyodyne/models/rnn.py +++ b/yoyodyne/models/rnn.py @@ -52,7 +52,7 @@ def decode( ) -> torch.Tensor: """Decodes a sequence given the encoded input. - Decodes until all sequences in a batch have reached [EOS] up to + Decodes until all sequences in a batch have reached up to a specified length depending on the `target` args. Args: @@ -83,7 +83,7 @@ def decode( num_steps = ( target.size(1) if target is not None else self.max_target_length ) - # Tracks when each sequence has decoded an EOS. + # Tracks when each sequence has decoded an END. finished = torch.zeros(batch_size, device=self.device) for t in range(num_steps): # pred: B x 1 x output_size. @@ -101,11 +101,11 @@ def decode( # (i.e., student forcing, greedy decoding). else: decoder_input = self._get_predicted(logits) - # Updates to track which sequences have decoded an EOS. + # Updates to track which sequences have decoded an END. finished = torch.logical_or( finished, (decoder_input == special.END_IDX) ) - # Breaks when all sequences have predicted an EOS symbol. If we + # Breaks when all sequences have predicted an END symbol. If we # have a target (and are thus computing loss), we only break # when we have decoded at least the the same number of steps as # the target length. @@ -121,7 +121,6 @@ def beam_decode( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, - beam_width: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Overrides `beam_decode` in `BaseEncoderDecoder`. @@ -150,7 +149,7 @@ def beam_decode( char_likelihoods, decoder_hiddens, ) in histories: - # Does not keep decoding a path that has hit EOS. + # Does not keep decoding a path that has hit END. if len(beam_idxs) > 1 and beam_idxs[-1] == special.END_IDX: fields = [ beam_likelihood, @@ -158,7 +157,6 @@ def beam_decode( char_likelihoods, decoder_hiddens, ] - # TODO: Beam search with beam_width. # TODO: Replace heapq with torch.max or similar? heapq.heappush(hypotheses, fields) continue @@ -197,7 +195,7 @@ def beam_decode( predictions = nn.functional.log_softmax(logits, dim=0).cpu() for j, logprob in enumerate(predictions): cl = char_loglikelihoods + [logprob] - if len(hypotheses) < beam_width: + if len(hypotheses) < self.beam_width: fields = [ beam_loglikelihood + logprob, beam_idxs + [j], @@ -249,9 +247,9 @@ def forward( Returns: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: beam search returns a tuple with a tensor of predictions of shape - beam_width x seq_len and tensor with the unnormalized sum of - symbol log-probabilities for each prediction. Greedy returns a - tensor of predictions of shape + beam_width x seq_len and tensor with the unnormalized sum + of symbol log-probabilities for each prediction. Greedy returns + a tensor of predictions of shape seq_len x batch_size x target_vocab_size. """ encoder_out = self.source_encoder(batch.source).output @@ -263,9 +261,8 @@ def forward( predictions, scores = self.beam_decode( encoder_out, batch.source.mask, - beam_width=self.beam_width, ) - # Reduce to beam_width x seq_len + # Reduces to beam_width x seq_len predictions = predictions.transpose(0, 2).squeeze(0) return predictions, scores else: diff --git a/yoyodyne/models/transducer.py b/yoyodyne/models/transducer.py index 08ea5363..b98f20f8 100644 --- a/yoyodyne/models/transducer.py +++ b/yoyodyne/models/transducer.py @@ -560,9 +560,9 @@ def convert_predictions( pred.extend(pad) predictions[i] = torch.tensor(pred, dtype=torch.int) predictions = torch.stack(predictions) - # This turns all symbols after the first EOS into PAD so prediction + # This turns all symbols after the first END into PAD so prediction # tensors match gold tensors. - return util.pad_tensor_after_eos(predictions) + return util.pad_tensor_after_end(predictions) def on_train_epoch_start(self) -> None: """Scheduler for oracle.""" diff --git a/yoyodyne/models/transformer.py b/yoyodyne/models/transformer.py index f595dd41..2b2899e6 100644 --- a/yoyodyne/models/transformer.py +++ b/yoyodyne/models/transformer.py @@ -80,7 +80,7 @@ def greedy_decode( source_mask (torch.Tensor): mask for the encoded source tokens. targets (torch.Tensor, optional): the optional target tokens, which is only used for early stopping during validation - if the decoder has predicted [EOS] for every sequence in + if the decoder has predicted for every sequence in the batch. Returns: @@ -96,7 +96,7 @@ def greedy_decode( device=self.device, ) ] - # Tracking when each sequence has decoded an EOS. + # Tracking when each sequence has decoded an END. finished = torch.zeros(batch_size, device=self.device) for _ in range(self.max_target_length): target_tensor = torch.stack(predictions, dim=1) @@ -110,16 +110,16 @@ def greedy_decode( target_mask, ).output logits = self.classifier(decoder_output) - last_output = logits[:, -1, :] # Ignores EOS. + last_output = logits[:, -1, :] # Ignores END. outputs.append(last_output) # -> B x 1 x 1 _, pred = torch.max(last_output, dim=1) predictions.append(pred) - # Updates to track which sequences have decoded an EOS. + # Updates to track which sequences have decoded an END. finished = torch.logical_or( finished, (predictions[-1] == special.END_IDX) ) - # Breaks when all sequences have predicted an EOS symbol. If we + # Breaks when all sequences have predicted an END symbol. If we # have a target (and are thus computing loss), we only break when # we have decoded at least the the same number of steps as the # target length. @@ -166,7 +166,7 @@ def forward( target_mask, ).output logits = self.classifier(decoder_output) - output = logits[:, :-1, :] # Ignore EOS. + output = logits[:, :-1, :] # Ignore END. else: encoder_output = self.source_encoder(batch.source).output if self.beam_width > 1: diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index e4154864..b8390ac3 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -113,7 +113,7 @@ def predict( # Beam search. tsv_writer = csv.writer(sink, delimiter="\t") for predictions, scores in trainer.predict(model, loader): - predictions = util.pad_tensor_after_eos(predictions) + predictions = util.pad_tensor_after_end(predictions) decoded_predictions = loader.dataset.decode_target(predictions) row = itertools.chain( *zip(decoded_predictions, scores.tolist()) @@ -121,8 +121,8 @@ def predict( tsv_writer.writerow(row) else: # Greedy search. - for predictions, _ in trainer.predict(model, loader): - predictions = util.pad_tensor_after_eos(predictions) + for predictions, *_ in trainer.predict(model, loader): + predictions = util.pad_tensor_after_end(predictions) for prediction in loader.dataset.decode_target(predictions): print(prediction, file=sink) diff --git a/yoyodyne/util.py b/yoyodyne/util.py index 5dc70d83..62400855 100644 --- a/yoyodyne/util.py +++ b/yoyodyne/util.py @@ -27,12 +27,12 @@ def __call__( # Tensor manipulation -def pad_tensor_after_eos( +def pad_tensor_after_end( predictions: torch.Tensor, ) -> torch.Tensor: - """Replaces everything after an EOS token with PADs. + """Replaces everything after an END token with PADs. - Cuts off tensors at the first END_IDX, and replaces the rest of the + Cuts off tensors at the first END, and replaces the rest of the predictions with PAD_IDX, as these can be erroneously decoded while the rest of the batch is finishing decoding. @@ -46,20 +46,20 @@ def pad_tensor_after_eos( if predictions.size(0) == 1: return predictions for i, prediction in enumerate(predictions): - # Gets first instance of EOS. - eos = (prediction == special.END_IDX).nonzero(as_tuple=False) - if len(eos) > 0 and eos[0].item() < len(prediction): - # If an EOS was decoded and it is not the last one in the + # Gets first instance of END. + end = (prediction == special.END_IDX).nonzero(as_tuple=False) + if len(end) > 0 and end[0].item() < len(prediction): + # If an END was decoded and it is not the last one in the # sequence. - eos = eos[0] + end = end[0] else: # Leaves predictions[i] alone. continue - # Hack in case the first prediction is EOS. In this case + # Hack in case the first prediction is END. In this case # torch.split will result in an error, so we change these 0's to - # 1's, which will make the entire sequence EOS as intended. - eos[eos == 0] = 1 - symbols, *_ = torch.split(prediction, eos) + # 1's, which will make the entire sequence END as intended. + end[end == 0] = 1 + symbols, *_ = torch.split(prediction, end) # Replaces everything after with PAD, to replace erroneous decoding # While waiting on the entire batch to finish. pads = ( From 8e794b3860d9c8fcce3cf2f87e5862900dfc33e1 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 16:26:39 -0500 Subject: [PATCH 5/8] Bugfixes. --- pyproject.toml | 2 +- yoyodyne/models/base.py | 18 -- yoyodyne/models/hard_attention.py | 20 +- yoyodyne/models/pointer_generator.py | 274 +++++++++++++++------------ yoyodyne/models/rnn.py | 150 +++++++-------- yoyodyne/models/transducer.py | 13 +- yoyodyne/models/transformer.py | 12 +- 7 files changed, 253 insertions(+), 236 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 77d5f949..997df497 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ exclude = ["examples*"] [project] name = "yoyodyne" -version = "0.2.15" +version = "0.2.16" description = "Small-vocabulary neural sequence-to-sequence models" readme = "README.md" requires-python = ">= 3.9" diff --git a/yoyodyne/models/base.py b/yoyodyne/models/base.py index a1f0c6bb..45579e93 100644 --- a/yoyodyne/models/base.py +++ b/yoyodyne/models/base.py @@ -176,24 +176,6 @@ def init_embeddings(num_embed: int, embed_size: int) -> nn.Embedding: def get_decoder(self): raise NotImplementedError - def beam_decode( - self, - encoder_out: torch.Tensor, - mask: torch.Tensor, - ): - """Method interface for beam search. - - Args: - encoder_out (torch.Tensor). - encoder_mask (torch.Tensor). - - Raises: - NotImplementedError: beam search not implemented. - """ - raise NotImplementedError( - f"Beam search not implemented for {self.name} model" - ) - @property def num_parameters(self) -> int: return sum(part.numel() for part in self.parameters()) diff --git a/yoyodyne/models/hard_attention.py b/yoyodyne/models/hard_attention.py index 735e0b14..184fb9f6 100644 --- a/yoyodyne/models/hard_attention.py +++ b/yoyodyne/models/hard_attention.py @@ -97,8 +97,8 @@ def decode( ) -> Tuple[torch.Tensor, torch.Tensor]: """Decodes a sequence given the encoded input. - Decodes until all sequences in a batch have reached up to - length of `target` args. + Decodes until all sequences in a batch have reached END up to length of + `target` args. Args: encoder_out (torch.Tensor): batch of encoded input symbols @@ -133,20 +133,8 @@ def decode( all_transition_probs.append(transition_probs) return torch.stack(all_log_probs), torch.stack(all_transition_probs) - def beam_decode( - self, - encoder_out: torch.Tensor, - mask: torch.Tensor, - ): - """Overrides incompatible implementation inherited from RNNModel. - - Args: - encoder_out (torch.Tensor). - encoder_mask (torch.Tensor). - - Raises: - NotImplementedError: beam search not implemented. - """ + def beam_decode(self, *args, **kwargs): + """Overrides incompatible implementation inherited from RNNModel.""" raise NotImplementedError( f"Beam search not implemented for {self.name} model" ) diff --git a/yoyodyne/models/pointer_generator.py b/yoyodyne/models/pointer_generator.py index 1b2c56f8..360e99ac 100644 --- a/yoyodyne/models/pointer_generator.py +++ b/yoyodyne/models/pointer_generator.py @@ -208,71 +208,13 @@ def _check_layer_sizes(self) -> None: f"({self.encoder_layers} != {self.decoder_layers})" ) - def decode_step( - self, - symbol: torch.Tensor, - last_hiddens: Tuple[torch.Tensor, torch.Tensor], - source_indices: torch.Tensor, - source_enc: torch.Tensor, - source_mask: torch.Tensor, - features_enc: Optional[torch.Tensor] = None, - features_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Runs a single step of the decoder. - - This predicts a distribution for one symbol. - - Args: - symbol (torch.Tensor). - last_hiddens (Tuple[torch.Tensor, torch.Tensor]). - source_indices (torch.Tensor). - source_enc (torch.Tensor). - source_mask (torch.Tensor). - features_enc (Optional[torch.Tensor]). - features_mask (Optional[torch.Tensor]). - - Returns: - Tuple[torch.Tensor, torch.Tensor]. - """ - embedded = self.decoder.embed(symbol) - last_h0, last_c0 = last_hiddens - source_context, attention_weights = self.decoder.attention( - last_h0.transpose(0, 1), source_enc, source_mask - ) - if self.has_features_encoder: - features_context, _ = self.features_attention( - last_h0.transpose(0, 1), features_enc, features_mask - ) - # -> B x 1 x 4*hidden_size. - context = torch.cat([source_context, features_context], dim=2) - else: - context = source_context - _, (h, c) = self.decoder.module( - torch.cat((embedded, context), 2), (last_h0, last_c0) + def beam_decode(self, *args, **kwargs): + """Overrides incompatible implementation inherited from RNNModel.""" + raise NotImplementedError( + f"Beam search not implemented for {self.name} model" ) - # -> B x 1 x hidden_size - hidden = h[-1, :, :].unsqueeze(1) - output_dist = self.classifier(torch.cat([hidden, context], dim=2)) - output_dist = nn.functional.softmax(output_dist, dim=2) - # -> B x 1 x target_vocab_size. - ptr_dist = torch.zeros( - symbol.size(0), - self.target_vocab_size, - device=self.device, - dtype=attention_weights.dtype, - ).unsqueeze(1) - # Gets the attentions to the source in terms of the output generations. - # These are the "pointer" distribution. - ptr_dist.scatter_add_( - 2, source_indices.unsqueeze(1), attention_weights - ) - # Probability of generating (from output_dist). - gen_probs = self.generation_probability(context, hidden, embedded) - scaled_ptr_dist = ptr_dist * (1 - gen_probs) - scaled_output_dist = output_dist * gen_probs - return torch.log(scaled_output_dist + scaled_ptr_dist), (h, c) - def decode( + def greedy_decode( self, source_enc: torch.Tensor, source_mask: torch.Tensor, @@ -285,7 +227,7 @@ def decode( ) -> torch.Tensor: """Decodes a sequence given the encoded input. - Decodes until all sequences in a batch have reached up to + Decodes until all sequences in a batch have reached END up to a specified length depending on the `target` args. Args: @@ -361,6 +303,70 @@ def decode( predictions = torch.stack(predictions).transpose(0, 1) return predictions + def decode_step( + self, + symbol: torch.Tensor, + last_hiddens: Tuple[torch.Tensor, torch.Tensor], + source_indices: torch.Tensor, + source_enc: torch.Tensor, + source_mask: torch.Tensor, + features_enc: Optional[torch.Tensor] = None, + features_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Runs a single step of the decoder. + + This predicts a distribution for one symbol. + + Args: + symbol (torch.Tensor). + last_hiddens (Tuple[torch.Tensor, torch.Tensor]). + source_indices (torch.Tensor). + source_enc (torch.Tensor). + source_mask (torch.Tensor). + features_enc (Optional[torch.Tensor]). + features_mask (Optional[torch.Tensor]). + + Returns: + Tuple[torch.Tensor, torch.Tensor]. + """ + embedded = self.decoder.embed(symbol) + last_h0, last_c0 = last_hiddens + source_context, attention_weights = self.decoder.attention( + last_h0.transpose(0, 1), source_enc, source_mask + ) + if self.has_features_encoder: + features_context, _ = self.features_attention( + last_h0.transpose(0, 1), features_enc, features_mask + ) + # -> B x 1 x 4*hidden_size. + context = torch.cat([source_context, features_context], dim=2) + else: + context = source_context + _, (h, c) = self.decoder.module( + torch.cat((embedded, context), 2), (last_h0, last_c0) + ) + # -> B x 1 x hidden_size + hidden = h[-1, :, :].unsqueeze(1) + output_dist = self.classifier(torch.cat([hidden, context], dim=2)) + output_dist = nn.functional.softmax(output_dist, dim=2) + # -> B x 1 x target_vocab_size. + ptr_dist = torch.zeros( + symbol.size(0), + self.target_vocab_size, + device=self.device, + dtype=attention_weights.dtype, + ).unsqueeze(1) + # Gets the attentions to the source in terms of the output generations. + # These are the "pointer" distribution. + ptr_dist.scatter_add_( + 2, source_indices.unsqueeze(1), attention_weights + ) + # Probability of generating (from output_dist). + gen_probs = self.generation_probability(context, hidden, embedded) + scaled_ptr_dist = ptr_dist * (1 - gen_probs) + scaled_output_dist = output_dist * gen_probs + return torch.log(scaled_output_dist + scaled_ptr_dist), (h, c) + class PointerGeneratorGRUModel(PointerGeneratorRNNModel, rnn.GRUModel): """Pointer-generator model with an GRU backend.""" @@ -459,11 +465,11 @@ def forward( Returns: torch.Tensor. """ - encoder_output = self.source_encoder(batch.source) - source_encoded = encoder_output.output - if encoder_output.has_hiddens: + encoder_out = self.source_encoder(batch.source) + source_encoded = encoder_out.output + if encoder_out.has_hiddens: last_hiddens = self._reshape_hiddens( - encoder_output.hiddens, + encoder_out.hiddens, self.source_encoder.layers, self.source_encoder.num_directions, ) @@ -472,23 +478,31 @@ def forward( len(batch), self.source_encoder.layers ) if not self.has_features_encoder: - if self.beam_width is not None and self.beam_width > 1: - raise NotImplementedError + if self.beam_width > 1: + # Will raise a NotImplementedError. + return self.beam_decode( + source_encoded, + batch.source.mask, + batch.source.padded, + last_hiddens, + self.teacher_forcing if self.training else False, + target=batch.target.padded if batch.target else None, + ) else: - predictions = self.decode( + return self.greedy_decode( source_encoded, batch.source.mask, batch.source.padded, last_hiddens, self.teacher_forcing if self.training else False, - target=(batch.target.padded if batch.target else None), + target=batch.target.padded if batch.target else None, ) else: - features_encoder_output = self.features_encoder(batch.features) - features_encoded = features_encoder_output.output - if features_encoder_output.has_hiddens: + features_encoder_out = self.features_encoder(batch.features) + features_encoded = features_encoder_out.output + if features_encoder_out.has_hiddens: last_hiddens = self._reshape_hiddens( - features_encoder_output.hiddens, + features_encoder_out.hiddens, self.features_encoder.layers, self.features_encoder.num_directions, ) @@ -496,17 +510,29 @@ def forward( last_hiddens = self.init_hiddens( len(batch), self.source_encoder.layers ) - predictions = self.decode( - source_encoded, - batch.source.mask, - batch.source.padded, - last_hiddens, - self.teacher_forcing if self.training else False, - features_enc=features_encoded, - features_mask=batch.features.mask, - target=batch.target.padded if batch.target else None, - ) - return predictions + if self.beam_width > 1: + # Will raise a NotImplementedError. + return self.beam_decode( + source_encoded, + batch.source.mask, + batch.source.padded, + last_hiddens, + self.teacher_forcing if self.training else False, + features_enc=features_encoded, + features_mask=batch.features.mask, + target=batch.target.padded if batch.target else None, + ) + else: + return self.greedy_decode( + source_encoded, + batch.source.mask, + batch.source.padded, + last_hiddens, + self.teacher_forcing if self.training else False, + features_enc=features_encoded, + features_mask=batch.features.mask, + target=batch.target.padded if batch.target else None, + ) @staticmethod def _reshape_hiddens( @@ -624,10 +650,10 @@ def forward( Returns: torch.Tensor. """ - encoder_output = self.source_encoder(batch.source) - source_encoded = encoder_output.output - if encoder_output.has_hiddens: - h_source, c_source = encoder_output.hiddens + encoder_out = self.source_encoder(batch.source) + source_encoded = encoder_out.output + if encoder_out.has_hiddens: + h_source, c_source = encoder_out.hiddens last_hiddens = self._reshape_hiddens( h_source, c_source, @@ -638,23 +664,29 @@ def forward( last_hiddens = self.init_hiddens(len(batch)) if not self.has_features_encoder: if self.beam_width > 1: - raise NotImplementedError( - f"Beam search not implemented for {self.name} model" + # Will raise a NotImplementedError. + return self.beam_decode( + source_encoded, + batch.source.mask, + batch.source.padded, + last_hiddens, + self.teacher_forcing if self.training else False, + target=batch.target.padded if batch.target else None, ) else: - predictions = self.decode( + return self.greedy_decode( source_encoded, batch.source.mask, batch.source.padded, last_hiddens, self.teacher_forcing if self.training else False, - target=(batch.target.padded if batch.target else None), + target=batch.target.padded if batch.target else None, ) else: - features_encoder_output = self.features_encoder(batch.features) - features_encoded = features_encoder_output.output - if features_encoder_output.has_hiddens: - h_features, c_features = features_encoder_output.hiddens + features_encoder_out = self.features_encoder(batch.features) + features_encoded = features_encoder_out.output + if features_encoder_out.has_hiddens: + h_features, c_features = features_encoder_out.hiddens h_features, c_features = self._reshape_hiddens( h_features, c_features, @@ -664,12 +696,19 @@ def forward( else: h_features, c_features = self.init_hiddens(len(batch)) if self.beam_width > 1: - # LSTM beam search does not work with pointer generator LSTM. - raise NotImplementedError( - f"Beam search not implemented for {self.name} model" + # Will raise a NotImplementedError. + return self.beam_decode( + source_encoded, + batch.source.mask, + batch.source.padded, + last_hiddens, + self.teacher_forcing if self.training else False, + features_enc=features_encoded, + features_mask=batch.features.mask, + target=batch.target.padded if batch.target else None, ) else: - predictions = self.decode( + return self.greedy_decode( source_encoded, batch.source.mask, batch.source.padded, @@ -679,7 +718,6 @@ def forward( features_mask=batch.features.mask, target=batch.target.padded if batch.target else None, ) - return predictions @staticmethod def _reshape_hiddens( @@ -748,7 +786,7 @@ def get_decoder(self) -> modules.transformer.TransformerPointerDecoder: def decode_step( self, - encoder_outputs: torch.Tensor, + encoder_outs: torch.Tensor, source_mask: torch.Tensor, source_indices: torch.Tensor, target_tensor: torch.Tensor, @@ -765,7 +803,7 @@ def decode_step( parallel with a diagonal mask. Args: - encoder_outputs (torch.Tensor): encoded output representations. + encoder_outs (torch.Tensor): encoded output representations. source_mask (torch.Tensor): mask for the encoded source tokens. source_indices (torch.Tensor): source token vocabulary ids. target_tensor (torch.Tensor): target token vocabulary ids. @@ -779,7 +817,7 @@ def decode_step( target_seq_len is inferred form the target_tensor. """ decoder_output = self.decoder( - encoder_outputs, + encoder_outs, source_mask, target_tensor, target_mask, @@ -816,7 +854,7 @@ def decode_step( ptr_dist.scatter_add_(2, repeated_source_indices, mha_outputs) # A matrix of context vectors from applying attention to the encoder # representations w.r.t. each decoder step. - context = torch.bmm(mha_outputs, encoder_outputs) + context = torch.bmm(mha_outputs, encoder_outs) # Probability of generating (from output_dist). gen_probs = self.generation_probability( context, decoder_output, target_embeddings @@ -842,7 +880,7 @@ def greedy_decode( source_indices (torch.Tensor): indices of the source symbols. targets (torch.Tensor, optional): the optional target tokens, which is only used for early stopping during validation - if the decoder has predicted for every sequence in + if the decoder has predicted END for every sequence in the batch. features_enc (Optional[torch.Tensor]): encoded features. features_mask (Optional[torch.Tensor]): mask for encoded features. @@ -927,11 +965,11 @@ def forward( ) features_encoded = None if self.has_features_encoder: - features_encoder_output = self.features_encoder(batch.features) - features_encoded = features_encoder_output.output + features_encoder_out = self.features_encoder(batch.features) + features_encoded = features_encoder_out.output if self.beam_width > 1: # Will raise a NotImplementedError. - output = self.beam_decode( + return self.beam_decode( source_encoded, batch.source.mask, self.beam_width, @@ -945,29 +983,29 @@ def forward( target_mask, features_enc=features_encoded, ) - output = output[:, :-1, :] # Ignore END. + output = output[:, :-1, :] # Ignores END. + return output else: features_encoded = None if self.has_features_encoder: - features_encoder_output = self.features_encoder(batch.features) - features_encoded = features_encoder_output.output + features_encoder_out = self.features_encoder(batch.features) + features_encoded = features_encoder_out.output if self.beam_width > 1: # Will raise a NotImplementedError. - output = self.beam_decode( + return self.beam_decode( source_encoded, batch.source.mask, self.beam_width, ) else: # -> B x seq_len x output_size. - output = self.greedy_decode( + return self.greedy_decode( source_encoded, batch.source.mask, batch.source.padded, batch.target.padded if batch.target else None, features_enc=features_encoded, ) - return output @property def name(self) -> str: diff --git a/yoyodyne/models/rnn.py b/yoyodyne/models/rnn.py index 73c6ee1a..4af7790b 100644 --- a/yoyodyne/models/rnn.py +++ b/yoyodyne/models/rnn.py @@ -43,80 +43,6 @@ def init_embeddings( """ return embeddings.normal_embedding(num_embeddings, embedding_size) - def decode( - self, - encoder_out: torch.Tensor, - encoder_mask: torch.Tensor, - teacher_forcing: bool, - target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Decodes a sequence given the encoded input. - - Decodes until all sequences in a batch have reached up to - a specified length depending on the `target` args. - - Args: - encoder_out (torch.Tensor): batch of encoded input symbols. - encoder_mask (torch.Tensor): mask for the batch of encoded - input symbols. - teacher_forcing (bool): Whether or not to decode - with teacher forcing. - target (torch.Tensor, optional): target symbols; we - decode up to `len(target)` symbols. If None, we decode up to - `self.max_target_length` symbols. - - Returns: - torch.Tensor: tensor of predictions of shape seq_len x - batch_size x target_vocab_size. - """ - batch_size = encoder_mask.shape[0] - # Initializes hidden states for decoder LSTM. - decoder_hiddens = self.init_hiddens(batch_size) - # Feed in the first decoder input, as a start tag. - # -> B x 1. - decoder_input = ( - torch.tensor([special.START_IDX], device=self.device) - .repeat(batch_size) - .unsqueeze(1) - ) - predictions = [] - num_steps = ( - target.size(1) if target is not None else self.max_target_length - ) - # Tracks when each sequence has decoded an END. - finished = torch.zeros(batch_size, device=self.device) - for t in range(num_steps): - # pred: B x 1 x output_size. - decoded = self.decoder( - decoder_input, decoder_hiddens, encoder_out, encoder_mask - ) - decoder_output, decoder_hiddens = decoded.output, decoded.hiddens - logits = self.classifier(decoder_output) - predictions.append(logits.squeeze(1)) - # In teacher forcing mode the next input is the gold symbol - # for this step. - if teacher_forcing: - decoder_input = target[:, t].unsqueeze(1) - # Otherwise we pass the top pred to the next timestep - # (i.e., student forcing, greedy decoding). - else: - decoder_input = self._get_predicted(logits) - # Updates to track which sequences have decoded an END. - finished = torch.logical_or( - finished, (decoder_input == special.END_IDX) - ) - # Breaks when all sequences have predicted an END symbol. If we - # have a target (and are thus computing loss), we only break - # when we have decoded at least the the same number of steps as - # the target length. - if finished.all(): - if target is None or decoder_input.size(-1) >= target.size( - -1 - ): - break - predictions = torch.stack(predictions) - return predictions - def beam_decode( self, encoder_out: torch.Tensor, @@ -235,6 +161,80 @@ def beam_decode( # Beam search returns the likelihoods of each history. return predictions, torch.tensor([h[0] for h in histories]) + def greedy_decode( + self, + encoder_out: torch.Tensor, + encoder_mask: torch.Tensor, + teacher_forcing: bool, + target: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Decodes a sequence given the encoded input. + + Decodes until all sequences in a batch have reached up to + a specified length depending on the `target` args. + + Args: + encoder_out (torch.Tensor): batch of encoded input symbols. + encoder_mask (torch.Tensor): mask for the batch of encoded + input symbols. + teacher_forcing (bool): Whether or not to decode + with teacher forcing. + target (torch.Tensor, optional): target symbols; we + decode up to `len(target)` symbols. If None, we decode up to + `self.max_target_length` symbols. + + Returns: + torch.Tensor: tensor of predictions of shape seq_len x + batch_size x target_vocab_size. + """ + batch_size = encoder_mask.shape[0] + # Initializes hidden states for decoder LSTM. + decoder_hiddens = self.init_hiddens(batch_size) + # Feed in the first decoder input, as a start tag. + # -> B x 1. + decoder_input = ( + torch.tensor([special.START_IDX], device=self.device) + .repeat(batch_size) + .unsqueeze(1) + ) + predictions = [] + num_steps = ( + target.size(1) if target is not None else self.max_target_length + ) + # Tracks when each sequence has decoded an END. + finished = torch.zeros(batch_size, device=self.device) + for t in range(num_steps): + # pred: B x 1 x output_size. + decoded = self.decoder( + decoder_input, decoder_hiddens, encoder_out, encoder_mask + ) + decoder_output, decoder_hiddens = decoded.output, decoded.hiddens + logits = self.classifier(decoder_output) + predictions.append(logits.squeeze(1)) + # In teacher forcing mode the next input is the gold symbol + # for this step. + if teacher_forcing: + decoder_input = target[:, t].unsqueeze(1) + # Otherwise we pass the top pred to the next timestep + # (i.e., student forcing, greedy decoding). + else: + decoder_input = self._get_predicted(logits) + # Updates to track which sequences have decoded an END. + finished = torch.logical_or( + finished, (decoder_input == special.END_IDX) + ) + # Breaks when all sequences have predicted an END symbol. If we + # have a target (and are thus computing loss), we only break + # when we have decoded at least the the same number of steps as + # the target length. + if finished.all(): + if target is None or decoder_input.size(-1) >= target.size( + -1 + ): + break + predictions = torch.stack(predictions) + return predictions + def forward( self, batch: data.PaddedBatch, @@ -266,7 +266,7 @@ def forward( predictions = predictions.transpose(0, 2).squeeze(0) return predictions, scores else: - predictions = self.decode( + predictions = self.greedy_decode( encoder_out, batch.source.mask, self.teacher_forcing if self.training else False, diff --git a/yoyodyne/models/transducer.py b/yoyodyne/models/transducer.py index b98f20f8..0f964857 100644 --- a/yoyodyne/models/transducer.py +++ b/yoyodyne/models/transducer.py @@ -96,13 +96,13 @@ def forward( last_hiddens = self.init_hiddens(source_mask.shape[0]) if self.beam_width > 1: # Will raise a NotImplementedError. - prediction = self.beam_decode( + return self.beam_decode( encoder_out=encoded, mask=batch.source.mask, beam_width=self.beam_width, ) else: - prediction, loss = self.decode( + return self.greedy_decode( encoded, last_hiddens, source_padded, @@ -113,9 +113,14 @@ def forward( target=batch.target.padded if batch.target else None, target_mask=batch.target.mask if batch.target else None, ) - return prediction, loss - def decode( + def beam_decode(self, *args, **kwargs): + """Overrides incompatible implementation inherited from RNNModel.""" + raise NotImplementedError( + f"Beam search not implemented for {self.name} model" + ) + + def greedy_decode( self, encoder_out: torch.Tensor, last_hiddens: Tuple[torch.Tensor, torch.Tensor], diff --git a/yoyodyne/models/transformer.py b/yoyodyne/models/transformer.py index 2b2899e6..8c7bddfb 100644 --- a/yoyodyne/models/transformer.py +++ b/yoyodyne/models/transformer.py @@ -67,6 +67,11 @@ def get_decoder(self) -> modules.transformer.TransformerDecoder: source_attention_heads=self.source_attention_heads, ) + def beam_decode(self, *args, **kwargs): + raise NotImplementedError( + f"Beam search not implemented for {self.name} model" + ) + def greedy_decode( self, encoder_hidden: torch.Tensor, @@ -166,24 +171,23 @@ def forward( target_mask, ).output logits = self.classifier(decoder_output) - output = logits[:, :-1, :] # Ignore END. + return logits[:, :-1, :] # Ignores END. else: encoder_output = self.source_encoder(batch.source).output if self.beam_width > 1: # Will raise a NotImplementedError. - output = self.beam_decode( + return self.beam_decode( encoder_output, batch.source.mask, self.beam_width, ) else: # -> B x seq_len x output_size. - output = self.greedy_decode( + return self.greedy_decode( encoder_output, batch.source.mask, batch.target.padded if batch.target else None, ) - return output @property def name(self) -> str: From 4135e041631ac36e47620bd92604d105d1642f0e Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 16:40:30 -0500 Subject: [PATCH 6/8] More bugfixes --- yoyodyne/models/hard_attention.py | 4 ++-- yoyodyne/models/rnn.py | 4 ++-- yoyodyne/models/transformer.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/yoyodyne/models/hard_attention.py b/yoyodyne/models/hard_attention.py index 184fb9f6..5cf362fa 100644 --- a/yoyodyne/models/hard_attention.py +++ b/yoyodyne/models/hard_attention.py @@ -146,8 +146,8 @@ def greedy_decode( ) -> Tuple[torch.Tensor, torch.Tensor]: """Decodes a sequence given the encoded input. - Decodes until all sequences in a batch have reached up to - a specified length depending on the `target` args. + Decodes until all sequences in a batch have reached END up to a + specified length depending on the `target` args. Args: encoder_out (torch.Tensor): batch of encoded input symbols diff --git a/yoyodyne/models/rnn.py b/yoyodyne/models/rnn.py index 4af7790b..80a12a77 100644 --- a/yoyodyne/models/rnn.py +++ b/yoyodyne/models/rnn.py @@ -170,8 +170,8 @@ def greedy_decode( ) -> torch.Tensor: """Decodes a sequence given the encoded input. - Decodes until all sequences in a batch have reached up to - a specified length depending on the `target` args. + Decodes until all sequences in a batch have reached END up to a + specified length depending on the `target` args. Args: encoder_out (torch.Tensor): batch of encoded input symbols. diff --git a/yoyodyne/models/transformer.py b/yoyodyne/models/transformer.py index 8c7bddfb..c60b7d1a 100644 --- a/yoyodyne/models/transformer.py +++ b/yoyodyne/models/transformer.py @@ -85,7 +85,7 @@ def greedy_decode( source_mask (torch.Tensor): mask for the encoded source tokens. targets (torch.Tensor, optional): the optional target tokens, which is only used for early stopping during validation - if the decoder has predicted for every sequence in + if the decoder has predicted END for every sequence in the batch. Returns: From 0b094fd9c932066d9691338d13a698abb7863b3a Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 21:16:56 -0500 Subject: [PATCH 7/8] Late-breaking improvements --- yoyodyne/models/base.py | 13 ++- yoyodyne/models/rnn.py | 2 +- yoyodyne/models/transducer.py | 144 +++++++++++++--------------------- yoyodyne/predict.py | 2 +- 4 files changed, 61 insertions(+), 100 deletions(-) diff --git a/yoyodyne/models/base.py b/yoyodyne/models/base.py index 45579e93..7c07967d 100644 --- a/yoyodyne/models/base.py +++ b/yoyodyne/models/base.py @@ -283,7 +283,7 @@ def predict_step( self, batch: data.PaddedBatch, batch_idx: int, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Runs one predict step. This is called by the PL Trainer. @@ -293,19 +293,16 @@ def predict_step( batch_idx (int). Returns: - Tuple[torch.Tensor, torch.Tensor]: position 0 are the indices of - the argmax at each timestep. Position 1 are the scores for each - history in beam search. It will be None when using greedy. - + Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if + using beam search, the predictions and scores as a tuple of + tensors; if using greedy search, the predictions as a tensor. """ predictions = self(batch) if self.beam_width > 1: predictions, scores = predictions return predictions, scores else: - # -> B x seq_len x 1. - greedy_predictions = self._get_predicted(predictions) - return greedy_predictions, None + return self._get_predicted(predictions) def _get_predicted(self, predictions: torch.Tensor) -> torch.Tensor: """Picks the best index from the vocabulary. diff --git a/yoyodyne/models/rnn.py b/yoyodyne/models/rnn.py index 80a12a77..3fd2cf52 100644 --- a/yoyodyne/models/rnn.py +++ b/yoyodyne/models/rnn.py @@ -278,7 +278,7 @@ def forward( @staticmethod def add_argparse_args(parser: argparse.ArgumentParser) -> None: - """Adds LSTM configuration options to the argument parser. + """Adds RNN configuration options to the argument parser. Args: parser (argparse.ArgumentParser). diff --git a/yoyodyne/models/transducer.py b/yoyodyne/models/transducer.py index 0f964857..cf621256 100644 --- a/yoyodyne/models/transducer.py +++ b/yoyodyne/models/transducer.py @@ -51,69 +51,6 @@ def __init__( self.substitutions = self.actions.substitutions self.insertions = self.actions.insertions - def forward( - self, - batch: data.PaddedBatch, - ) -> Tuple[List[List[int]], torch.Tensor]: - """Runs the encoder-decoder model. - - Args: - batch (data.PaddedBatch). - - Returns: - Tuple[List[List[int]], torch.Tensor]: encoded prediction values - and loss tensor; due to transducer setup, prediction is - performed during training, so these are returned. - """ - encoder_out = self.source_encoder(batch.source) - encoded = encoder_out.output[:, 1:, :] # Ignores start symbol. - source_padded = batch.source.padded[:, 1:] - source_mask = batch.source.mask[:, 1:] - # Start of decoding. - if self.has_features_encoder: - features_encoder_out = self.features_encoder(batch.features) - features_encoded = features_encoder_out.output - if features_encoder_out.has_hiddens: - h_features, c_features = features_encoder_out.hiddens - h_features = h_features.mean(dim=0, keepdim=True).expand( - self.decoder_layers, -1, -1 - ) - c_features = c_features.mean(dim=0, keepdim=True).expand( - self.decoder_layers, -1, -1 - ) - last_hiddens = h_features, c_features - else: - last_hiddens = self.init_hiddens(source_mask.shape[0]) - features_encoded = features_encoded.mean(dim=1, keepdim=True) - encoded = torch.cat( - ( - encoded, - features_encoded.expand(-1, encoded.shape[1], -1), - ), - dim=2, - ) - else: - last_hiddens = self.init_hiddens(source_mask.shape[0]) - if self.beam_width > 1: - # Will raise a NotImplementedError. - return self.beam_decode( - encoder_out=encoded, - mask=batch.source.mask, - beam_width=self.beam_width, - ) - else: - return self.greedy_decode( - encoded, - last_hiddens, - source_padded, - source_mask, - teacher_forcing=( - self.teacher_forcing if self.training else False - ), - target=batch.target.padded if batch.target else None, - target_mask=batch.target.mask if batch.target else None, - ) - def beam_decode(self, *args, **kwargs): """Overrides incompatible implementation inherited from RNNModel.""" raise NotImplementedError( @@ -546,10 +483,10 @@ def validation_step(self, batch: data.PaddedBatch, batch_idx: int) -> Dict: val_eval_items_dict.update({"val_loss": loss}) return val_eval_items_dict - def predict_step(self, batch: Tuple[torch.tensor], batch_idx: int) -> Dict: - predictions, _ = self.forward( - batch, - ) + def predict_step( + self, batch: data.PaddedBatch, batch_idx: int + ) -> torch.Tensor: + predictions, _ = self(batch) # Evaluation requires prediction tensor. return self.convert_predictions(predictions) @@ -592,9 +529,6 @@ def name(self) -> str: raise NotImplementedError -# TODO: Implement beam decoding. - - class TransducerGRUModel(TransducerRNNModel, rnn.GRUModel): """Transducer with GRU backend.""" @@ -644,16 +578,31 @@ def forward( ) else: last_hiddens = self.init_hiddens(source_mask.shape[0]) - prediction, loss = self.decode( - encoded, - last_hiddens, - source_padded, - source_mask, - teacher_forcing=(self.teacher_forcing if self.training else False), - target=batch.target.padded if batch.target else None, - target_mask=batch.target.mask if batch.target else None, - ) - return prediction, loss + if self.beam_width > 1: + # Will raise a NotImplementedError. + return self.beam_decode( + encoded, + last_hiddens, + source_padded, + source_mask, + teacher_forcing=( + self.teacher_forcing if self.training else False + ), + target=batch.target.padded if batch.target else None, + target_mask=batch.target.mask if batch.target else None, + ) + else: + return self.greedy_decode( + encoded, + last_hiddens, + source_padded, + source_mask, + teacher_forcing=( + self.teacher_forcing if self.training else False + ), + target=batch.target.padded if batch.target else None, + target_mask=batch.target.mask if batch.target else None, + ) def get_decoder(self) -> modules.GRUDecoder: return modules.GRUDecoder( @@ -731,16 +680,31 @@ def forward( ) else: last_hiddens = self.init_hiddens(source_mask.shape[0]) - prediction, loss = self.decode( - encoded, - last_hiddens, - source_padded, - source_mask, - teacher_forcing=(self.teacher_forcing if self.training else False), - target=batch.target.padded if batch.target else None, - target_mask=batch.target.mask if batch.target else None, - ) - return prediction, loss + if self.beam_width > 1: + # Will raise a NotImplementedError. + return self.beam_decode( + encoded, + last_hiddens, + source_padded, + source_mask, + teacher_forcing=( + self.teacher_forcing if self.training else False + ), + target=batch.target.padded if batch.target else None, + target_mask=batch.target.mask if batch.target else None, + ) + else: + return self.greedy_decode( + encoded, + last_hiddens, + source_padded, + source_mask, + teacher_forcing=( + self.teacher_forcing if self.training else False + ), + target=batch.target.padded if batch.target else None, + target_mask=batch.target.mask if batch.target else None, + ) def init_hiddens( self, batch_size: int diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index b8390ac3..638ea5a2 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -121,7 +121,7 @@ def predict( tsv_writer.writerow(row) else: # Greedy search. - for predictions, *_ in trainer.predict(model, loader): + for predictions in trainer.predict(model, loader): predictions = util.pad_tensor_after_end(predictions) for prediction in loader.dataset.decode_target(predictions): print(prediction, file=sink) From e32444f0d45a77e58933e01e0208db6498864402 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 11:07:41 -0500 Subject: [PATCH 8/8] fix 1 --- yoyodyne/evaluators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yoyodyne/evaluators.py b/yoyodyne/evaluators.py index 4b7718e1..8302f5cf 100644 --- a/yoyodyne/evaluators.py +++ b/yoyodyne/evaluators.py @@ -198,7 +198,7 @@ def _finalize_tensor( """Finalizes each tensor. Truncates at END for each prediction and returns a List of predictions. - This does basically the same as util.pad_after_end, but does not + This does basically the same as util.pad_tensor_after_end, but does not actually pad since we do not need to return a well-formed tensor. Args: