diff --git a/config.py b/config.py new file mode 100644 index 0000000..0365855 --- /dev/null +++ b/config.py @@ -0,0 +1,125 @@ +# Create the config +from pathlib import Path + +data_path = Path('output/posses') +experiment_dir = Path('models/posses') + +config = """ +name: "posses" +joeynmt_version: 2.0.0 + +data: + task: "S2T" # "S2T" for speech-to-text, "MT" for (text) translation + train: "{data_dir}/train" + dev: "{data_dir}/dev" + test: "{data_dir}/test" + dataset_type: "speech" # SpeechDataset takes tsv as input + src: + lang: "en_ng" + num_freq: 534 # number of frequencies of audio inputs + max_length: 3000 # much longer than text sequence! + min_length: 10 # have to be specified so that 1d-conv works! + level: "frame" # Here we specify we're working on BPEs. + tokenizer_type: "speech" + tokenizer_cfg: + specaugment: + freq_mask_n: 1 + freq_mask_f: 5 + time_mask_n: 1 + time_mask_t: 10 + time_mask_p: 1.0 + cmvn: + norm_means: True + norm_vars: True + before: True + trg: + lang: "en_ng" + max_length: 100 + lowercase: False + level: "bpe" # Here we specify we're working on BPEs. + voc_file: "{data_dir}/spm_bpe40.vocab" + tokenizer_type: "sentencepiece" + tokenizer_cfg: + model_file: "{data_dir}/spm_bpe40.model" + pretokenize: "none" + +testing: + n_best: 1 + beam_size: 5 + beam_alpha: 1.0 + batch_size: 4 + batch_type: "sentence" + max_output_length: 100 # Don't generate translations longer than this. + eval_metrics: ["wer"] # Use "wer" for ASR task, "bleu" for ST task + sacrebleu_cfg: # sacrebleu options + tokenize: "intl" # `tokenize` option in sacrebleu.corpus_bleu() function (options include: "none" (use for already tokenized test data), "13a" (default minimal tokenizer), "intl" which mostly does punctuation and unicode, etc) + +training: + #load_model: "{experiment_dir}/1.ckpt" # if uncommented, load a pre-trained model from this checkpoint + random_seed: 42 + optimizer: "adam" + normalization: "tokens" + adam_betas: [0.9, 0.98] + scheduling: "plateau" + patience: 5 + learning_rate: 0.0002 + learning_rate_min: 0.00000001 + weight_decay: 0.0 + label_smoothing: 0.1 + loss: "crossentropy-ctc" # use CrossEntropyLoss + CTCLoss + ctc_weight: 0.3 # ctc weight in interpolation + batch_size: 4 # much bigger than text! your "tokens" are "frames" now. + batch_type: "sentence" + batch_multiplier: 1 + early_stopping_metric: "wer" + epochs: 10 # Decrease for when playing around and checking of working. + validation_freq: 1000 # Set to at least once per epoch. + logging_freq: 100 + model_dir: "{experiment_dir}" + overwrite: True + shuffle: True + use_cuda: True + print_valid_sents: [0, 1, 2, 3] + keep_best_ckpts: 2 + +model: + initializer: "xavier_uniform" + bias_initializer: "zeros" + init_gain: 1.0 + embed_initializer: "xavier_uniform" + embed_init_gain: 1.0 + tied_embeddings: False # DIsable embeddings sharing between enc(audio) and dec(text) + tied_softmax: False + encoder: + type: "transformer" + num_layers: 12 # Common to use doubly bigger encoder than decoder in S2T. + num_heads: 4 + embeddings: + embedding_dim: 534 # Must be same as the frequency of the filterbank features! + # typically ff_size = 4 x hidden_size + hidden_size: 256 + ff_size: 1024 + dropout: 0.1 + layer_norm: "pre" + # new for S2T: + subsample: True # enable 1d conv module + conv_kernel_sizes: [5, 5] # convolution kernel sizes (window width) + conv_channels: 512 # convolution channels + in_channels: 534 # Must be same as the embedding_dim + decoder: + type: "transformer" + num_layers: 6 + num_heads: 4 + embeddings: + embedding_dim: 256 + scale: True + dropout: 0.0 + # typically ff_size = 4 x hidden_size + hidden_size: 256 + ff_size: 1024 + dropout: 0.1 + layer_norm: "pre" +""".format(data_dir=data_path.as_posix(), + experiment_dir=experiment_dir.as_posix()) + +(data_path / 'config.yaml').write_text(config) \ No newline at end of file diff --git a/data_preprocessing.py b/data_preprocessing.py new file mode 100644 index 0000000..b3638a9 --- /dev/null +++ b/data_preprocessing.py @@ -0,0 +1,32 @@ +import argparse +from pathlib import Path + +from pose_format import Pose +from pose_format.utils.generic import pose_normalization_info, correct_wrists, reduce_holistic + + +def preprocess(srcDir, trgDir): + srcDir = Path(srcDir) + trgDir = Path(trgDir) + trgDir.mkdir(parents=True, exist_ok=True) + for path in srcDir.iterdir(): + if path.is_file() and path.suffix == ".pose": + with open(srcDir / path.name, 'rb') as pose_file: + pose = Pose.read(pose_file.read()) + pose = reduce_holistic(pose) + correct_wrists(pose) + pose = pose.normalize(pose_normalization_info(pose.header)) + with open(trgDir / path.name, 'w+b') as pose_file: + pose.write(pose_file) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--srcDir", required=True, type=str) + parser.add_argument("--trgDir", required=True, type=str) + args = parser.parse_args() + preprocess(args.srcDir, args.trgDir) + + +if __name__ == "__main__": + main() diff --git a/datasets_pose.py b/datasets_pose.py new file mode 100644 index 0000000..af8f57a --- /dev/null +++ b/datasets_pose.py @@ -0,0 +1,47 @@ +import numpy as np +from pose_format import Pose +import pandas as pd +from swu_representation import swu2data + +FrameRate = 29.97003 + + +def ms2frame(ms) -> int: + return int(ms / 1000 * FrameRate) + + +def pose_to_matrix(file_path, start_ms, end_ms): + with open(file_path, "rb") as f: + pose = Pose.read(f.read()) + pose = pose.body.data + pose = pose.reshape(pose.shape[0], pose.shape[2] * pose.shape[3]) + pose = pose[ms2frame(start_ms):ms2frame(end_ms)] + return pose + + +def load_dataset(folder_name): + + target = pd.read_csv(f'{folder_name}/target.csv') + dataset = [] + for line in target.values: + pose = pose_to_matrix(f'{folder_name}/{line[0]}', line[2], line[3]) + pose = pose.filled(fill_value=0) + utt_id = line[0].split('.')[0] + utt_id = f'{utt_id}({line[2]})' + dataset.append((utt_id, pose, swu2data(line[4]))) + return dataset + + +def extract_to_fbank(pose_data, output_path, overwrite: bool = False): + if output_path is not None and output_path.is_file() and not overwrite: + return np.load(output_path.as_posix()) + if output_path is not None: + np.save(output_path.as_posix(), pose_data) + assert output_path.is_file(), output_path + return pose_data + + +if __name__ == "__main__": + dataSet = load_dataset("Dataset") + + print(dataSet) diff --git a/joeynmt/__init__.py b/joeynmt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/joeynmt/__main__.py b/joeynmt/__main__.py new file mode 100644 index 0000000..c2ad503 --- /dev/null +++ b/joeynmt/__main__.py @@ -0,0 +1,64 @@ +import argparse + +from joeynmt.prediction import test, translate +from joeynmt.training import train + + +def main(): + ap = argparse.ArgumentParser("Joey NMT") + + ap.add_argument( + "mode", + choices=["train", "test", "translate"], + help="train a model or test or translate", + ) + + ap.add_argument("config_path", type=str, help="path to YAML config file") + + ap.add_argument("-c", "--ckpt", type=str, help="checkpoint for prediction") + + ap.add_argument("-o", + "--output_path", + type=str, + help="path for saving translation output") + + ap.add_argument( + "-a", + "--save_attention", + action="store_true", + help="save attention visualizations", + ) + + ap.add_argument("-s", "--save_scores", action="store_true", help="save scores") + + ap.add_argument( + "-t", + "--skip_test", + action="store_true", + help="Skip test after training", + ) + + args = ap.parse_args() + + if args.mode == "train": + train(cfg_file=args.config_path, skip_test=args.skip_test) + elif args.mode == "test": + test( + cfg_file=args.config_path, + ckpt=args.ckpt, + output_path=args.output_path, + save_attention=args.save_attention, + save_scores=args.save_scores, + ) + elif args.mode == "translate": + translate( + cfg_file=args.config_path, + ckpt=args.ckpt, + output_path=args.output_path, + ) + else: + raise ValueError("Unknown mode") + + +if __name__ == "__main__": + main() diff --git a/joeynmt/attention.py b/joeynmt/attention.py new file mode 100644 index 0000000..39f3701 --- /dev/null +++ b/joeynmt/attention.py @@ -0,0 +1,220 @@ +# coding: utf-8 +""" +Attention modules +""" +from typing import Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class AttentionMechanism(nn.Module): + """ + Base attention class + """ + + def forward(self, *inputs): + raise NotImplementedError("Implement this.") + + +class BahdanauAttention(AttentionMechanism): + """ + Implements Bahdanau (MLP) attention + + Section A.1.2 in https://arxiv.org/abs/1409.0473. + """ + + def __init__(self, hidden_size: int = 1, key_size: int = 1, query_size: int = 1): + """ + Creates attention mechanism. + + :param hidden_size: size of the projection for query and key + :param key_size: size of the attention input keys + :param query_size: size of the query + """ + + super().__init__() + + self.key_layer = nn.Linear(key_size, hidden_size, bias=False) + self.query_layer = nn.Linear(query_size, hidden_size, bias=False) + self.energy_layer = nn.Linear(hidden_size, 1, bias=False) + + self.proj_keys = None # to store projected keys + self.proj_query = None # projected query + + def forward(self, query: Tensor, mask: Tensor, + values: Tensor) -> Tuple[Tensor, Tensor]: + """ + Bahdanau MLP attention forward pass. + + :param query: the item (decoder state) to compare with the keys/memory, + shape (batch_size, 1, decoder.hidden_size) + :param mask: mask out keys position (0 in invalid positions, 1 else), + shape (batch_size, 1, src_length) + :param values: values (encoder states), + shape (batch_size, src_length, encoder.hidden_size) + :return: + - context vector of shape (batch_size, 1, value_size), + - attention probabilities of shape (batch_size, 1, src_length) + """ + # pylint: disable=arguments-differ + self._check_input_shapes_forward(query=query, mask=mask, values=values) + + assert mask is not None, "mask is required" + assert self.proj_keys is not None, "projection keys have to get pre-computed" + + # We first project the query (the decoder state). + # The projected keys (the encoder states) were already pre-computed. + self.compute_proj_query(query) + + # Calculate scores. + # proj_keys: batch x src_len x hidden_size + # proj_query: batch x 1 x hidden_size + scores = self.energy_layer(torch.tanh(self.proj_query + self.proj_keys)) + # scores: batch x src_len x 1 + + scores = scores.squeeze(2).unsqueeze(1) + # scores: batch x 1 x time + + # mask out invalid positions by filling the masked out parts with -inf + scores = torch.where(mask > 0, scores, scores.new_full([1], -np.inf)) + + # turn scores to probabilities + alphas = F.softmax(scores, dim=-1) # batch x 1 x time + + # the context vector is the weighted sum of the values + context = alphas @ values # batch x 1 x value_size + + return context, alphas + + def compute_proj_keys(self, keys: Tensor) -> None: + """ + Compute the projection of the keys. + Is efficient if pre-computed before receiving individual queries. + + :param keys: + :return: + """ + self.proj_keys = self.key_layer(keys) + + def compute_proj_query(self, query: Tensor): + """ + Compute the projection of the query. + + :param query: + :return: + """ + self.proj_query = self.query_layer(query) + + def _check_input_shapes_forward(self, query: Tensor, mask: Tensor, + values: Tensor) -> None: + """ + Make sure that inputs to `self.forward` are of correct shape. + Same input semantics as for `self.forward`. + + :param query: + :param mask: + :param values: + :return: + """ + assert query.shape[0] == values.shape[0] == mask.shape[0] + assert query.shape[1] == 1 == mask.shape[1] + assert query.shape[2] == self.query_layer.in_features + assert values.shape[2] == self.key_layer.in_features + assert mask.shape[2] == values.shape[1] + + def __repr__(self): + return "BahdanauAttention" + + +class LuongAttention(AttentionMechanism): + """ + Implements Luong (bilinear / multiplicative) attention. + + Eq. 8 ("general") in http://aclweb.org/anthology/D15-1166. + """ + + def __init__(self, hidden_size: int = 1, key_size: int = 1): + """ + Creates attention mechanism. + + :param hidden_size: size of the key projection layer, has to be equal + to decoder hidden size + :param key_size: size of the attention input keys + """ + + super().__init__() + self.key_layer = nn.Linear(in_features=key_size, + out_features=hidden_size, + bias=False) + self.proj_keys = None # projected keys + + def forward(self, query: Tensor, mask: Tensor, + values: Tensor) -> Tuple[Tensor, Tensor]: + """ + Luong (multiplicative / bilinear) attention forward pass. + Computes context vectors and attention scores for a given query and + all masked values and returns them. + + :param query: the item (decoder state) to compare with the keys/memory, + shape (batch_size, 1, decoder.hidden_size) + :param mask: mask out keys position (0 in invalid positions, 1 else), + shape (batch_size, 1, src_length) + :param values: values (encoder states), + shape (batch_size, src_length, encoder.hidden_size) + :return: + - context vector of shape (batch_size, 1, value_size), + - attention probabilities of shape (batch_size, 1, src_length) + """ + # pylint: disable=arguments-differ + self._check_input_shapes_forward(query=query, mask=mask, values=values) + + assert self.proj_keys is not None, "projection keys have to get pre-computed" + assert mask is not None, "mask is required" + + # scores: batch_size x 1 x src_length + scores = query @ self.proj_keys.transpose(1, 2) + + # mask out invalid positions by filling the masked out parts with -inf + scores = torch.where(mask > 0, scores, scores.new_full([1], -np.inf)) + + # turn scores to probabilities + alphas = F.softmax(scores, dim=-1) # batch x 1 x src_len + + # the context vector is the weighted sum of the values + context = alphas @ values # batch x 1 x values_size + + return context, alphas + + def compute_proj_keys(self, keys: Tensor) -> None: + """ + Compute the projection of the keys and assign them to `self.proj_keys`. + This pre-computation is efficiently done for all keys + before receiving individual queries. + + :param keys: shape (batch_size, src_length, encoder.hidden_size) + """ + # proj_keys: batch x src_len x hidden_size + self.proj_keys = self.key_layer(keys) + + def _check_input_shapes_forward(self, query: Tensor, mask: Tensor, + values: Tensor) -> None: + """ + Make sure that inputs to `self.forward` are of correct shape. + Same input semantics as for `self.forward`. + + :param query: + :param mask: + :param values: + :return: + """ + assert query.shape[0] == values.shape[0] == mask.shape[0] + assert query.shape[1] == 1 == mask.shape[1] + assert query.shape[2] == self.key_layer.out_features + assert values.shape[2] == self.key_layer.in_features + assert mask.shape[2] == values.shape[1] + + def __repr__(self): + return "LuongAttention" diff --git a/joeynmt/batch.py b/joeynmt/batch.py new file mode 100644 index 0000000..2fe04d0 --- /dev/null +++ b/joeynmt/batch.py @@ -0,0 +1,186 @@ +# coding: utf-8 +""" +Implementation of a mini-batch. +""" +import logging +from typing import List, Optional + +import numpy as np +import torch +from torch import Tensor + +from joeynmt.constants import PAD_ID + +logger = logging.getLogger(__name__) + + +class Batch: + """ + Object for holding a batch of data with mask during training. + Input is yielded from `collate_fn()` called by torch.data.utils.DataLoader. + """ + + # pylint: disable=too-many-instance-attributes + + def __init__( + self, + src: Tensor, + src_length: Tensor, + trg: Optional[Tensor], + trg_length: Optional[Tensor], + device: torch.device, + pad_index: int = PAD_ID, + has_trg: bool = True, + is_train: bool = True, + task: str = "MT", + ): + """ + Creates a new joey batch. This batch supports attributes with src and trg + length, masks, number of non-padded tokens in trg. Furthermore, it can be + sorted by src length. + + :param src: + :param src_length: + :param trg: + :param trg_length: + :param device: + :param pad_index: *must be the same for both src and trg + :param is_train: *can be used for online data augmentation, subsampling etc. + :param task: task + """ + self.src: Tensor = src + self.src_length: Tensor = src_length + self.src_mask: Optional[Tensor] = None + + self.trg_input: Optional[Tensor] = None + self.trg: Optional[Tensor] = None + self.trg_mask: Optional[Tensor] = None + self.trg_length: Optional[Tensor] = None + + self.nseqs: int = self.src.size(0) + self.ntokens: Optional[int] = None + self.has_trg: bool = has_trg + self.is_train: bool = is_train + if self.is_train: + assert self.has_trg + + if self.has_trg: + assert trg is not None and trg_length is not None + # trg_input is used for teacher forcing, last one is cut off + self.trg_input: Tensor = trg[:, :-1] # shape (batch_size, seq_length) + self.trg_length: Tensor = trg_length - 1 + # trg is used for loss computation, shifted by one since BOS + self.trg: Tensor = trg[:, 1:] # shape (batch_size, seq_length) + # we exclude the padded areas (and blank areas) from the loss computation + self.trg_mask: Tensor = (self.trg != pad_index).unsqueeze(1) + self.ntokens: int = (self.trg != pad_index).data.sum().item() + + if device.type == "cuda": + self._make_cuda(device) + + self.task: str = task + if self.task == "MT": + self.src_mask: Tensor = (self.src != pad_index).unsqueeze(1) + elif self.task == "S2T": + # Note: src_mask will be re-constructed in TransformerEncoder + self.src_max_len: int = self.src.size(1) + # if multi-gpu, re-pad src so that all seqs in parallel gpus + # have the same length! + self.repad: bool = torch.cuda.device_count() > 1 + + # a batch has to contain more than one src sentence + assert self.nseqs > 0, self.nseqs + + def _make_cuda(self, device: torch.device) -> None: + """Move the batch to GPU""" + self.src = self.src.to(device) + self.src_length = self.src_length.to(device) + if self.src_mask is not None: # if self.task == "MT": + self.src_mask = self.src_mask.to(device) + + if self.has_trg: + self.trg_input = self.trg_input.to(device) + self.trg = self.trg.to(device) + self.trg_length = self.trg_length.to(device) + self.trg_mask = self.trg_mask.to(device) + + def normalize( + self, + tensor: Tensor, + normalization: str = "none", + n_gpu: int = 1, + n_accumulation: int = 1, + ) -> Tensor: + """ + Normalizes batch tensor (i.e. loss). Takes sum over multiple gpus, divides by + nseqs or ntokens, divide by n_gpu, then divide by n_accumulation. + + :param tensor: (Tensor) tensor to normalize, i.e. batch loss + :param normalization: (str) one of {`batch`, `tokens`, `none`} + :param n_gpu: (int) the number of gpus + :param n_accumulation: (int) the number of gradient accumulation + :return: normalized tensor + """ + if n_gpu > 1: + tensor = tensor.sum() + + if normalization == "sum": # pylint: disable=no-else-return + return tensor + elif normalization == "batch": + normalizer = self.nseqs + elif normalization == "tokens": + normalizer = self.ntokens + elif normalization == "none": + normalizer = 1 + + norm_tensor = tensor / normalizer + + if n_gpu > 1: + norm_tensor = norm_tensor / n_gpu + + if n_accumulation > 1: + norm_tensor = norm_tensor / n_accumulation + return norm_tensor + + def sort_by_src_length(self) -> List[int]: + """ + Sort by src length (descending) and return index to revert sort + + :return: list of indices + """ + _, perm_index = self.src_length.sort(0, descending=True) + rev_index = [0] * perm_index.size(0) + for new_pos, old_pos in enumerate(perm_index.cpu().numpy()): + rev_index[old_pos] = new_pos + + self.src = self.src[perm_index] + self.src_length = self.src_length[perm_index] + + if self.src_mask is not None: # if task != "S2T" + self.src_mask = self.src_mask[perm_index] + + if self.has_trg: + self.trg_input = self.trg_input[perm_index] + self.trg_mask = self.trg_mask[perm_index] + self.trg_length = self.trg_length[perm_index] + self.trg = self.trg[perm_index] + + assert max(rev_index) < len(rev_index), rev_index + return rev_index + + def score(self, log_probs: Tensor) -> np.ndarray: + """Look up the score of the trg token (ground truth) in the batch""" + scores = [] + for i in range(self.nseqs): + scores.append( + np.array([ + log_probs[i, j, ind].item() for j, ind in enumerate(self.trg[i]) + if ind != PAD_ID + ])) + # Note: each element in `scores` list can have different lengths. + return np.array(scores, dtype=object) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(nseqs={self.nseqs}, ntokens={self.ntokens}, " + f"has_trg={self.has_trg}, is_train={self.is_train}, task={self.task})") diff --git a/joeynmt/builders.py b/joeynmt/builders.py new file mode 100644 index 0000000..8ec20d1 --- /dev/null +++ b/joeynmt/builders.py @@ -0,0 +1,445 @@ +# coding: utf-8 +""" +Collection of builder functions +""" +import logging +from functools import partial +from typing import Callable, Generator, Optional + +import torch +from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import ( + ExponentialLR, + ReduceLROnPlateau, + StepLR, + _LRScheduler, +) + +from joeynmt.helpers import ConfigurationError + +logger = logging.getLogger(__name__) + + +def build_gradient_clipper(config: dict) -> Optional[Callable]: + """ + Define the function for gradient clipping as specified in configuration. + If not specified, returns None. + + Current options: + - "clip_grad_val": clip the gradients if they exceed this value, + see `torch.nn.utils.clip_grad_value_` + - "clip_grad_norm": clip the gradients if their norm exceeds this value, + see `torch.nn.utils.clip_grad_norm_` + + :param config: dictionary with training configurations + :return: clipping function (in-place) or None if no gradient clipping + """ + if "clip_grad_val" in config.keys() and "clip_grad_norm" in config.keys(): + raise ConfigurationError( + "You can only specify either clip_grad_val or clip_grad_norm.") + + clip_grad_fun = None + if "clip_grad_val" in config.keys(): + clip_grad_fun = partial(nn.utils.clip_grad_value_, + clip_value=config["clip_grad_val"]) + elif "clip_grad_norm" in config.keys(): + clip_grad_fun = partial(nn.utils.clip_grad_norm_, + max_norm=config["clip_grad_norm"]) + return clip_grad_fun + + +def build_optimizer(config: dict, parameters: Generator) -> Optimizer: + """ + Create an optimizer for the given parameters as specified in config. + + Except for the weight decay and initial learning rate, + default optimizer settings are used. + + Currently supported configuration settings for "optimizer": + - "sgd" (default): see `torch.optim.SGD` + - "adam": see `torch.optim.adam` + - "adagrad": see `torch.optim.adagrad` + - "adadelta": see `torch.optim.adadelta` + - "rmsprop": see `torch.optim.RMSprop` + + The initial learning rate is set according to "learning_rate" in the config. + The weight decay is set according to "weight_decay" in the config. + If they are not specified, the initial learning rate is set to 3.0e-4, the + weight decay to 0. + + Note that the scheduler state is saved in the checkpoint, so if you load + a model for further training you have to use the same type of scheduler. + + :param config: configuration dictionary + :param parameters: + :return: optimizer + """ + optimizer_name = config.get("optimizer", "sgd").lower() + + kwargs = { + "lr": config.get("learning_rate", 3.0e-4), + "weight_decay": config.get("weight_decay", 0), + } + + if optimizer_name == "adam": + kwargs["betas"] = config.get("adam_betas", (0.9, 0.999)) + optimizer = torch.optim.Adam(parameters, **kwargs) + elif optimizer_name == "adagrad": + optimizer = torch.optim.Adagrad(parameters, **kwargs) + elif optimizer_name == "adadelta": + optimizer = torch.optim.Adadelta(parameters, **kwargs) + elif optimizer_name == "rmsprop": + optimizer = torch.optim.RMSprop(parameters, **kwargs) + elif optimizer_name == "sgd": + # default + kwargs["momentum"] = config.get("momentum", 0.0) + optimizer = torch.optim.SGD(parameters, **kwargs) + else: + raise ConfigurationError("Invalid optimizer. Valid options: 'adam', " + "'adagrad', 'adadelta', 'rmsprop', 'sgd'.") + + logger.info( + "%s(%s)", + optimizer.__class__.__name__, + ", ".join([f"{k}={v}" for k, v in kwargs.items()]), + ) + return optimizer + + +def build_scheduler( + config: dict, + optimizer: Optimizer, + scheduler_mode: str, + hidden_size: int = 0, +) -> (Optional[_LRScheduler], Optional[str]): + """ + Create a learning rate scheduler if specified in config and determine when a + scheduler step should be executed. + + Current options: + - "plateau": see `torch.optim.lr_scheduler.ReduceLROnPlateau` + - "decaying": see `torch.optim.lr_scheduler.StepLR` + - "exponential": see `torch.optim.lr_scheduler.ExponentialLR` + - "noam": see `joeynmt.builders.NoamScheduler` + - "warmupexponentialdecay": see + `joeynmt.builders.WarmupExponentialDecayScheduler` + - "warmupinversesquareroot": see + `joeynmt.builders.WarmupInverseSquareRootScheduler` + + If no scheduler is specified, returns (None, None) which will result in a constant + learning rate. + + :param config: training configuration + :param optimizer: optimizer for the scheduler, determines the set of parameters + which the scheduler sets the learning rate for + :param scheduler_mode: "min" or "max", depending on whether the validation score + should be minimized or maximized. Only relevant for "plateau". + :param hidden_size: encoder hidden size (required for NoamScheduler) + :return: + - scheduler: scheduler object, + - scheduler_step_at: either "validation", "epoch", "step" or "none" + """ + scheduler, scheduler_step_at = None, None + if "scheduling" in config.keys() and config["scheduling"]: + scheduler_name = config["scheduling"].lower() + kwargs = {} + if scheduler_name == "plateau": + # learning rate scheduler + kwargs = { + "mode": scheduler_mode, + "verbose": False, + "threshold_mode": "abs", + "eps": 0.0, + "factor": config.get("decrease_factor", 0.1), + "patience": config.get("patience", 10), + } + scheduler = ReduceLROnPlateau(optimizer=optimizer, **kwargs) + # scheduler step is executed after every validation + scheduler_step_at = "validation" + elif scheduler_name == "decaying": + kwargs = {"step_size": config.get("decaying_step_size", 1)} + scheduler = StepLR(optimizer=optimizer, **kwargs) + # scheduler step is executed after every epoch + scheduler_step_at = "epoch" + elif scheduler_name == "exponential": + kwargs = {"gamma": config.get("decrease_factor", 0.99)} + scheduler = ExponentialLR(optimizer=optimizer, **kwargs) + # scheduler step is executed after every epoch + scheduler_step_at = "epoch" + elif scheduler_name == "noam": + scheduler = NoamScheduler( + optimizer=optimizer, + hidden_size=hidden_size, + factor=config.get("learning_rate_factor", 1), + warmup=config.get("learning_rate_warmup", 4000), + ) + scheduler_step_at = "step" + elif scheduler_name == "warmupexponentialdecay": + scheduler = WarmupExponentialDecayScheduler( + min_rate=config.get("learning_rate_min", 1.0e-5), + decay_rate=config.get("learning_rate_decay", 0.1), + warmup=config.get("learning_rate_warmup", 4000), + peak_rate=config.get("learning_rate_peak", 1.0e-3), + decay_length=config.get("learning_rate_decay_length", 10000), + ) + scheduler_step_at = "step" + elif scheduler_name == "warmupinversesquareroot": + lr = config.get("learning_rate", 1.0e-3) + peak_rate = config.get("learning_rate_peak", lr) + scheduler = WarmupInverseSquareRootScheduler( + optimizer=optimizer, + peak_rate=peak_rate, + min_rate=config.get("learning_rate_min", 1.0e-5), + warmup=config.get("learning_rate_warmup", 10000), + ) + scheduler_step_at = "step" + + if scheduler is None: + scheduler_step_at = "none" + else: + assert scheduler_step_at in {"validation", "epoch", "step", "none"} + + # print log + if scheduler_name in [ + "noam", + "warmupexponentialdecay", + "warmupinversesquareroot", + ]: + logger.info(scheduler) + else: + logger.info( + "%s(%s)", + scheduler.__class__.__name__, + ", ".join([f"{k}={v}" for k, v in kwargs.items()]), + ) + return scheduler, scheduler_step_at + + +class BaseScheduler: + """Base LR Scheduler + decay at "step" + """ + + def __init__(self, optimizer: torch.optim.Optimizer): + """ + :param optimizer: + """ + self.optimizer = optimizer + self._step = 0 + self._rate = 0 + self._state_dict = {"step": self._step, "rate": self._rate} + + def state_dict(self): + """Returns dictionary of values necessary to reconstruct scheduler""" + self._state_dict["step"] = self._step + self._state_dict["rate"] = self._rate + return self._state_dict + + def load_state_dict(self, state_dict): + """Given a state_dict, this function loads scheduler's state""" + self._step = state_dict["step"] + self._rate = state_dict["rate"] + + def step(self, step): + """Update parameters and rate""" + self._step = step + 1 # sync with trainer.stats.steps + rate = self._compute_rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + + def _compute_rate(self): + raise NotImplementedError + + +class NoamScheduler(BaseScheduler): + """ + The Noam learning rate scheduler used in "Attention is all you need" + See Eq. 3 in https://arxiv.org/abs/1706.03762 + """ + + def __init__( + self, + hidden_size: int, + optimizer: torch.optim.Optimizer, + factor: float = 1.0, + warmup: int = 4000, + ): + """ + Warm-up, followed by learning rate decay. + + :param hidden_size: + :param optimizer: + :param factor: decay factor + :param warmup: number of warmup steps + """ + super().__init__(optimizer) + self.warmup = warmup + self.factor = factor + self.hidden_size = hidden_size + + def _compute_rate(self): + """Implement `lrate` above""" + step = self._step + upper_bound = min(step**(-0.5), step * self.warmup**(-1.5)) + return self.factor * (self.hidden_size**(-0.5) * upper_bound) + + def state_dict(self): + """Returns dictionary of values necessary to reconstruct scheduler""" + super().state_dict() + self._state_dict["warmup"] = self.warmup + self._state_dict["factor"] = self.factor + self._state_dict["hidden_size"] = self.hidden_size + return self._state_dict + + def load_state_dict(self, state_dict): + """Given a state_dict, this function loads scheduler's state""" + super().load_state_dict(state_dict) + self.warmup = state_dict["warmup"] + self.factor = state_dict["factor"] + self.hidden_size = state_dict["hidden_size"] + + def __repr__(self): + return (f"{self.__class__.__name__}(warmup={self.warmup}, " + f"factor={self.factor}, hidden_size={self.hidden_size})") + + +class WarmupExponentialDecayScheduler(BaseScheduler): + """ + A learning rate scheduler similar to Noam, but modified: + Keep the warm up period but make it so that the decay rate can be tuneable. + The decay is exponential up to a given minimum rate. + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + peak_rate: float = 1.0e-3, + decay_length: int = 10000, + warmup: int = 4000, + decay_rate: float = 0.5, + min_rate: float = 1.0e-5, + ): + """ + Warm-up, followed by exponential learning rate decay. + + :param peak_rate: maximum learning rate at peak after warmup + :param optimizer: + :param decay_length: decay length after warmup + :param decay_rate: decay rate after warmup + :param warmup: number of warmup steps + :param min_rate: minimum learning rate + """ + super().__init__(optimizer) + self.warmup = warmup + self.decay_length = decay_length + self.peak_rate = peak_rate + self.decay_rate = decay_rate + self.min_rate = min_rate + + def _compute_rate(self): + """Implement `lrate` above""" + step = self._step + warmup = self.warmup + + if step < warmup: + rate = step * self.peak_rate / warmup + else: + exponent = (step - warmup) / self.decay_length + rate = self.peak_rate * (self.decay_rate**exponent) + return max(rate, self.min_rate) + + def state_dict(self): + """Returns dictionary of values necessary to reconstruct scheduler""" + super().state_dict() + self._state_dict["warmup"] = self.warmup + self._state_dict["decay_length"] = self.decay_length + self._state_dict["peak_rate"] = self.peak_rate + self._state_dict["decay_rate"] = self.decay_rate + self._state_dict["min_rate"] = self.min_rate + return self._state_dict + + def load_state_dict(self, state_dict): + """Given a state_dict, this function loads scheduler's state""" + super().load_state_dict(state_dict) + self.warmup = state_dict["warmup"] + self.decay_length = state_dict["decay_length"] + self.peak_rate = state_dict["peak_rate"] + self.decay_rate = state_dict["decay_rate"] + self.min_rate = state_dict["min_rate"] + + def __repr__(self): + return (f"{self.__class__.__name__}(warmup={self.warmup}, " + f"decay_length={self.decay_length}, " + f"decay_rate={self.decay_rate}, " + f"peak_rate={self.peak_rate}, " + f"min_rate={self.min_rate})") + + +class WarmupInverseSquareRootScheduler(BaseScheduler): + """ + Decay the LR based on the inverse square root of the update number. + In the warmup phase, we linearly increase the learning rate. + After warmup, we decrease the learning rate as follows: + ``` + decay_factor = peak_rate * sqrt(warmup) # constant value + lr = decay_factor / sqrt(step) + ``` + cf.) https://github.com/pytorch/fairseq/blob/main/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py + """ # noqa + + def __init__( + self, + optimizer: torch.optim.Optimizer, + peak_rate: float = 1.0e-3, + warmup: int = 10000, + min_rate: float = 1.0e-5, + ): + """ + Warm-up, followed by inverse square root learning rate decay. + :param optimizer: + :param peak_rate: maximum learning rate at peak after warmup + :param warmup: number of warmup steps + :param min_rate: minimum learning rate + """ + super().__init__(optimizer) + self.warmup = warmup + self.min_rate = min_rate + self.peak_rate = peak_rate + self.decay_rate = peak_rate * (warmup**0.5) # constant value + + def _compute_rate(self): + """Implement `lrate` above""" + step = self._step + warmup = self.warmup + + if step < warmup: + # linear warmup + rate = step * self.peak_rate / warmup + else: + # decay prop. to the inverse square root of the update number + rate = self.decay_rate * (step**-0.5) + return max(rate, self.min_rate) + + def state_dict(self): + """Returns dictionary of values necessary to reconstruct scheduler""" + super().state_dict() + self._state_dict["warmup"] = self.warmup + self._state_dict["peak_rate"] = self.peak_rate + self._state_dict["decay_rate"] = self.decay_rate + self._state_dict["min_rate"] = self.min_rate + return self._state_dict + + def load_state_dict(self, state_dict): + """Given a state_dict, this function loads scheduler's state""" + super().load_state_dict(state_dict) + self.warmup = state_dict["warmup"] + self.decay_rate = state_dict["decay_rate"] + self.peak_rate = state_dict["peak_rate"] + self.min_rate = state_dict["min_rate"] + + def __repr__(self): + return (f"{self.__class__.__name__}(warmup={self.warmup}, " + f"decay_rate={self.decay_rate:.6f}, peak_rate={self.peak_rate}, " + f"min_rate={self.min_rate})") diff --git a/joeynmt/constants.py b/joeynmt/constants.py new file mode 100644 index 0000000..150671f --- /dev/null +++ b/joeynmt/constants.py @@ -0,0 +1,9 @@ +# coding: utf-8 +""" +Defining global constants +""" + +UNK_TOKEN, UNK_ID = "", 0 +PAD_TOKEN, PAD_ID = "", 1 +BOS_TOKEN, BOS_ID = "", 2 +EOS_TOKEN, EOS_ID = "", 3 diff --git a/joeynmt/data.py b/joeynmt/data.py new file mode 100644 index 0000000..83ba832 --- /dev/null +++ b/joeynmt/data.py @@ -0,0 +1,179 @@ +# coding: utf-8 +""" +Data module +""" +import logging +from functools import partial +from typing import Optional, Tuple + +from joeynmt.datasets import BaseDataset, build_dataset +from joeynmt.helpers_for_pose import pad_features +from joeynmt.tokenizers import build_tokenizer +from joeynmt.vocabulary import Vocabulary, build_vocab + +logger = logging.getLogger(__name__) + + +def load_data( + data_cfg: dict, + datasets: list = None +) -> Tuple[Vocabulary, Vocabulary, Optional[BaseDataset], Optional[BaseDataset], + Optional[BaseDataset]]: + """ + Load train, dev and optionally test data as specified in configuration. + Vocabularies are created from the training set with a limit of `voc_limit` tokens + and a minimum token frequency of `voc_min_freq` (specified in the configuration + dictionary). + + The training data is filtered to include sentences up to `max_sent_length` on source + and target side. + + If you set `random_{train|dev}_subset`, a random selection of this size is used + from the {train|development} set instead of the full {train|development} set. + + :param data_cfg: configuration dictionary for data ("data" part of config file) + :param datasets: list of dataset names to load + :returns: + - src_vocab: source vocabulary + - trg_vocab: target vocabulary + - train_data: training dataset + - dev_data: development dataset + - test_data: test dataset if given, otherwise None + """ + if datasets is None: + datasets = ["train", "dev", "test"] + + task = data_cfg.get("task", "MT").upper() + assert task in ["MT", "S2T"] + + src_cfg = data_cfg["src"] + trg_cfg = data_cfg["trg"] + + # load data from files + src_lang = src_cfg["lang"] + trg_lang = trg_cfg["lang"] + train_path = data_cfg.get("train", None) + dev_path = data_cfg.get("dev", None) + test_path = data_cfg.get("test", None) + + if train_path is None and dev_path is None and test_path is None: + raise ValueError("Please specify at least one data source path.") + + # build tokenizer + logger.info("Building tokenizer...") + tokenizer = build_tokenizer(data_cfg) + + dataset_type = data_cfg.get("dataset_type", "plain").lower() + if task == "S2T": + assert dataset_type == "speech" + dataset_cfg = data_cfg.get("dataset_cfg", {}) + + # train data + train_data = None + if "train" in datasets and train_path is not None: + train_subset = data_cfg.get("sample_train_subset", -1) + if "random_train_subset" in data_cfg: + logger.warning("`random_train_subset` option is obsolete. " + "Please use `sample_train_subset` instead.") + train_subset = data_cfg.get("random_train_subset", train_subset) + logger.info("Loading train set...") + train_data = build_dataset( + dataset_type=dataset_type, + path=train_path, + src_lang=src_lang, + trg_lang=trg_lang, + split="train", + tokenizer=tokenizer, + random_subset=train_subset, + task=task, + **dataset_cfg, + ) + + # build vocab + logger.info("Building vocabulary...") + src_vocab, trg_vocab = build_vocab(data_cfg, dataset=train_data) + + # set vocab to tokenizer + # pylint: disable=protected-access + if task == "MT": + tokenizer[src_lang].set_vocab(src_vocab._itos) + tokenizer[trg_lang].set_vocab(trg_vocab._itos) + elif task == "S2T": + tokenizer["trg"].set_vocab(trg_vocab._itos) + # pylint: enable=protected-access + + # encoding func + if task == "MT": + sequence_encoder = { + src_lang: partial(src_vocab.sentences_to_ids, bos=False, eos=True), + trg_lang: partial(trg_vocab.sentences_to_ids, bos=True, eos=True), + } + elif task == "S2T": + sequence_encoder = { + "src": partial(pad_features, embed_size=tokenizer["src"].num_freq), + "trg": partial(trg_vocab.sentences_to_ids, bos=True, eos=True), + } + if train_data is not None: + train_data.sequence_encoder = sequence_encoder + + # dev data + dev_data = None + if "dev" in datasets and dev_path is not None: + dev_subset = data_cfg.get("sample_dev_subset", -1) + if "random_dev_subset" in data_cfg: + logger.warning("`random_dev_subset` option is obsolete. " + "Please use `sample_dev_subset` instead.") + dev_subset = data_cfg.get("random_dev_subset", dev_subset) + logger.info("Loading dev set...") + dev_data = build_dataset( + dataset_type=dataset_type, + path=dev_path, + src_lang=src_lang, + trg_lang=trg_lang, + split="dev", + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=dev_subset, + task=task, + **dataset_cfg, + ) + + # test data + test_data = None + if "test" in datasets and test_path is not None: + logger.info("Loading test set...") + test_data = build_dataset( + dataset_type=dataset_type, + path=test_path, + src_lang=src_lang, + trg_lang=trg_lang, + split="test", + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=-1, # no subsampling for test + task=task, + **dataset_cfg, + ) + logger.info("Data loaded.") + + # Log statistics of data and vocabulary + logger.info("Train dataset: %s", train_data) + logger.info("Valid dataset: %s", dev_data) + logger.info(" Test dataset: %s", test_data) + + if train_data: + src = "" if src_vocab is None else "\n\t[SRC] " + " ".join( + train_data.get_item(idx=0, lang=train_data.src_lang, is_train=False)) + trg = "\n\t[TRG] " + " ".join( + train_data.get_item(idx=0, lang=train_data.trg_lang, is_train=False)) + logger.info("First training example:%s%s", src, trg) + + if src_vocab is not None: + logger.info("First 10 Src tokens: %s", src_vocab.log_vocab(10)) + logger.info("First 10 Trg tokens: %s", trg_vocab.log_vocab(10)) + + if src_vocab is not None: + logger.info("Number of unique Src tokens (vocab_size): %d", len(src_vocab)) + logger.info("Number of unique Trg tokens (vocab_size): %d", len(trg_vocab)) + + return src_vocab, trg_vocab, train_data, dev_data, test_data diff --git a/joeynmt/data_augmentation.py b/joeynmt/data_augmentation.py new file mode 100644 index 0000000..bb67940 --- /dev/null +++ b/joeynmt/data_augmentation.py @@ -0,0 +1,105 @@ +# coding: utf-8 +""" +Data Augmentation +""" +import math +from typing import Optional + +import numpy as np + + +class SpecAugment: + """ + SpecAugment (https://arxiv.org/abs/1904.08779) + cf.) https://github.com/pytorch/fairseq/blob/main/fairseq/data/audio/feature_transforms/specaugment.py + """ # noqa + + def __init__(self, + freq_mask_n: int = 2, + freq_mask_f: int = 27, + time_mask_n: int = 2, + time_mask_t: int = 40, + time_mask_p: float = 1.0, + mask_value: Optional[float] = None): + + self.freq_mask_n = freq_mask_n + self.freq_mask_f = freq_mask_f + self.time_mask_n = time_mask_n + self.time_mask_t = time_mask_t + self.time_mask_p = time_mask_p + self.mask_value = mask_value + + def __call__(self, spectrogram: np.ndarray) -> np.ndarray: + assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." + + distorted = spectrogram.copy() # make a copy of input spectrogram. + num_frames, num_freqs = spectrogram.shape + mask_value = self.mask_value + + if mask_value is None: # if no value was specified, use local mean. + mask_value = spectrogram.mean() + + if num_frames == 0: + return spectrogram + + if num_freqs < self.freq_mask_f: + return spectrogram + + for _i in range(self.freq_mask_n): + f = np.random.randint(0, self.freq_mask_f) + f0 = np.random.randint(0, num_freqs - f) + if f != 0: + distorted[:, f0:f0 + f] = mask_value + + max_time_mask_t = min(self.time_mask_t, + math.floor(num_frames * self.time_mask_p)) + if max_time_mask_t < 1: + return distorted + + for _i in range(self.time_mask_n): + t = np.random.randint(0, max_time_mask_t) + t0 = np.random.randint(0, num_frames - t) + if t != 0: + distorted[t0:t0 + t, :] = mask_value + + assert distorted.shape == spectrogram.shape + return distorted + + def __repr__(self): + return (f"{self.__class__.__name__}(freq_mask_n={self.freq_mask_n}, " + f"freq_mask_f={self.freq_mask_f}, time_mask_n={self.time_mask_n}, " + f"time_mask_t={self.time_mask_t}, time_mask_p={self.time_mask_p})") + + +class CMVN: + """ + CMVN: Cepstral Mean and Variance Normalization (Utterance-level) + cf.) https://github.com/pytorch/fairseq/blob/main/fairseq/data/audio/feature_transforms/utterance_cmvn.py + """ # noqa + + def __init__(self, + norm_means: bool = True, + norm_vars: bool = True, + before: bool = True): + self.norm_means = norm_means + self.norm_vars = norm_vars + self.before = before + + def __call__(self, x: np.ndarray) -> np.ndarray: + orig_shape = x.shape + mean = x.mean(axis=0) + square_sums = (x**2).sum(axis=0) + + if self.norm_means: + x = np.subtract(x, mean) + if self.norm_vars: + var = square_sums / x.shape[0] - mean**2 + std = np.sqrt(np.maximum(var, 1e-10)) + x = np.divide(x, std) + + assert orig_shape == x.shape + return x + + def __repr__(self): + return (f"{self.__class__.__name__}(norm_means={self.norm_means}, " + f"norm_vars={self.norm_vars}, before={self.before})") diff --git a/joeynmt/datasets.py b/joeynmt/datasets.py new file mode 100644 index 0000000..ed3fe6b --- /dev/null +++ b/joeynmt/datasets.py @@ -0,0 +1,1049 @@ +# coding: utf-8 +""" +Dataset module +""" +import logging +import random +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple, Union + +import pandas as pd +import torch +from torch.utils.data import ( + BatchSampler, + DataLoader, + Dataset, + RandomSampler, + Sampler, + SequentialSampler, +) + +from joeynmt.batch import Batch +from joeynmt.constants import PAD_ID +from joeynmt.helpers import ConfigurationError, read_list_from_file +from joeynmt.tokenizers import BasicTokenizer, SpeechProcessor + +logger = logging.getLogger(__name__) +CPU_DEVICE = torch.device("cpu") + + +class BaseDataset(Dataset): + """ + BaseDataset which loads and looks up data. + - holds pointer to tokenizers, encoding functions. + + :param path: path to data directory + :param src_lang: source language code, i.e. `en` + :param trg_lang: target language code, i.e. `de` + :param has_trg: bool indicator if trg exists + :param split: bool indicator for train set or not + :param tokenizer: tokenizer objects + :param sequence_encoder: encoding functions + """ + + def __init__( + self, + path: str, + src_lang: str, + trg_lang: str, + split: int = "train", + has_trg: bool = True, + tokenizer: Dict[str, BasicTokenizer] = None, + sequence_encoder: Dict[str, Callable] = None, + random_subset: int = -1, + task: str = "MT", + ): + self.path = path + self.src_lang = src_lang + self.trg_lang = trg_lang + self.has_trg = has_trg + self.split = split + if self.split == "train": + assert self.has_trg + + _place_holder = {self.src_lang: None, self.trg_lang: None} + self.tokenizer = _place_holder if tokenizer is None else tokenizer + self.sequence_encoder = (_place_holder + if sequence_encoder is None else sequence_encoder) + + # for ransom subsampling + self.random_subset = random_subset + + self.task = task + + def sample_random_subset(self, seed: int = 42) -> None: + # pylint: disable=unused-argument + assert ( + self.split != "test" and self.__len__() > self.random_subset > 0 + ), f"Can only subsample from train or dev set larger than {self.random_subset}." + + def reset_random_subset(self): + raise NotImplementedError + + def load_data(self, path: Path, **kwargs) -> Any: + """ + load data + - preprocessing (lowercasing etc) is applied here. + """ + raise NotImplementedError + + def get_item(self, idx: int, lang: str) -> List[str]: + """ + seek one src/trg item of given index. + - tokenization is applied here. + - length-filtering, bpe-dropout etc also triggered if self.split == "train" + """ + raise NotImplementedError + + def __getitem__(self, idx: Union[int, str]) -> Tuple[List[str], List[str]]: + """lookup one item pair of given index.""" + src, trg = None, None + src = self.get_item(idx=idx, lang=self.src_lang) + if self.has_trg: + trg = self.get_item(idx=idx, lang=self.trg_lang) + if trg is None: + src = None + return src, trg + + def get_list(self, + lang: str, + postproccessed: bool = False, + tokenized: bool = False) -> Union[List[str], List[List[str]]]: + """get data column-wise.""" + raise NotImplementedError + + @property + def src(self) -> List[str]: + """get detokenized preprocessed data in src language.""" + return self.get_list(self.src_lang, postproccessed=False, tokenized=False) + + @property + def trg(self) -> List[str]: + """get detokenized preprocessed data in trg language.""" + return self.get_list(self.trg_lang, postproccessed=False, tokenized=False) \ + if self.has_trg else [] + + def collate_fn( + self, + batch: List[Tuple], + pad_index: int = PAD_ID, + device: torch.device = CPU_DEVICE, + ) -> Batch: + """ + Custom collate function. + See https://pytorch.org/docs/stable/data.html#dataloader-collate-fn for details. + Please override the batch class here. (not in TrainManager) + + :param batch: + :param pad_index: + :param device: + :return: joeynmt batch object + """ + + def _is_valid(s, t, has_trg): + # pylint: disable=no-else-return + if has_trg: + return s is not None and t is not None + else: + return s is not None + + batch = [(s, t) for s, t in batch if _is_valid(s, t, self.has_trg)] + src_list, trg_list = zip(*batch) + assert len(batch) == len(src_list), (len(batch), len(src_list)) + assert all(s is not None for s in src_list), src_list + src, src_length = self.sequence_encoder[self.src_lang](src_list) + + if self.has_trg: + assert all(t is not None for t in trg_list), trg_list + trg, trg_length = self.sequence_encoder[self.trg_lang](trg_list) + else: + assert all(t is None for t in trg_list) + trg, trg_length = None, None + + return Batch( + src=(torch.tensor(src).long() + if self.task == "MT" else torch.tensor(src).float()), + src_length=torch.tensor(src_length).long(), + trg=torch.tensor(trg).long() if trg else None, + trg_length=torch.tensor(trg_length).long() if trg_length else None, + device=device, + pad_index=pad_index, + has_trg=self.has_trg, + is_train=self.split == "train", + task=self.task, + ) + + def make_iter( + self, + batch_size: int, + batch_type: str = "sentence", + seed: int = 42, + shuffle: bool = False, + num_workers: int = 0, + pad_index: int = PAD_ID, + device: torch.device = CPU_DEVICE, + ) -> DataLoader: + """ + Returns a torch DataLoader for a torch Dataset. (no bucketing) + + :param batch_size: size of the batches the iterator prepares + :param batch_type: measure batch size by sentence count or by token count + :param seed: random seed for shuffling + :param shuffle: whether to shuffle the data before each epoch + (for testing, no effect even if set to True) + :param num_workers: number of cpus for multiprocessing + :param pad_index: + :param device: + :return: torch DataLoader + """ + # sampler + sampler: Sampler[int] # (type annotation) + if shuffle and self.split == "train": + generator = torch.Generator() + generator.manual_seed(seed) + sampler = RandomSampler(self, generator=generator) + else: + sampler = SequentialSampler(self) + + # batch generator + if batch_type == "sentence": + batch_sampler = SentenceBatchSampler(sampler, + batch_size=batch_size, + drop_last=False) + elif batch_type == "token": + batch_sampler = TokenBatchSampler(sampler, + batch_size=batch_size, + drop_last=False) + else: + raise ConfigurationError(f"{batch_type}: Unknown batch type") + + assert self.sequence_encoder[self.src_lang] is not None + if self.has_trg: + assert self.sequence_encoder[self.trg_lang] is not None + + # data iterator + return DataLoader( + dataset=self, + batch_sampler=batch_sampler, + collate_fn=partial(self.collate_fn, pad_index=pad_index, device=device), + num_workers=num_workers, + ) + + def __len__(self) -> int: + raise NotImplementedError + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(split={self.split}, len={self.__len__()}, " + f"src_lang={self.src_lang}, trg_lang={self.trg_lang}, " + f"has_trg={self.has_trg}, random_subset={self.random_subset})") + + +class PlaintextDataset(BaseDataset): + """ + PlaintextDataset which stores plain text pairs. + - used for text file data in the format of one sentence per line. + """ + + def __init__( + self, + path: str, + src_lang: str, + trg_lang: str, + split: int = "train", + has_trg: bool = True, + tokenizer: Dict[str, BasicTokenizer] = None, + sequence_encoder: Dict[str, Callable] = None, + random_subset: int = -1, + task: str = "MT", + **kwargs, + ): + super().__init__( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split=split, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + ) + + # load data + self.data = self.load_data(path, **kwargs) + self._initial_len = len(self.data[self.src_lang]) + + # for random subsampling + self.idx_map = [] + + def load_data(self, path: str, **kwargs) -> Any: + + def _pre_process(seq, lang): + if self.tokenizer[lang] is not None: + seq = [self.tokenizer[lang].pre_process(s) for s in seq if len(s) > 0] + return seq + + path = Path(path) + src_file = path.with_suffix(f"{path.suffix}.{self.src_lang}") + assert src_file.is_file(), f"{src_file} not found. Abort." + + src_list = read_list_from_file(src_file) + data = {self.src_lang: _pre_process(src_list, self.src_lang)} + + if self.has_trg: + trg_file = path.with_suffix(f"{path.suffix}.{self.trg_lang}") + assert trg_file.is_file(), f"{trg_file} not found. Abort." + + trg_list = read_list_from_file(trg_file) + data[self.trg_lang] = _pre_process(trg_list, self.trg_lang) + assert len(src_list) == len(trg_list) + return data + + def sample_random_subset(self, seed: int = 42) -> None: + super().sample_random_subset(seed) # check validity + + random.seed(seed) # resample every epoch: seed += epoch_no + self.idx_map = list(random.sample(range(self._initial_len), self.random_subset)) + + def reset_random_subset(self): + self.idx_map = [] + + def get_item(self, idx: int, lang: str, is_train: bool = None) -> List[str]: + line = self._look_up_item(idx, lang) + is_train = self.split == "train" if is_train is None else is_train + item = self.tokenizer[lang](line, is_train=is_train) + + if item is None: + logger.debug("Skip %d-th instance (%s): {%s}", idx, lang, line) + return item + + def _look_up_item(self, idx: int, lang: str) -> str: + try: + if len(self.idx_map) > 0: + idx = self.idx_map[idx] + line = self.data[lang][idx] + return line + except Exception as e: + logger.error(idx, self._initial_len) + raise Exception from e + + def get_list(self, + lang: str, + postproccessed: bool = False, + tokenized: bool = False) -> Union[List[str], List[List[str]]]: + """ + Return list of preprocessed sentences in the given language. + (not length-filtered, no bpe-dropout) + """ + item_list = [] + for idx in range(self.__len__()): + item = self._look_up_item(idx, lang) + if postproccessed: + item = self.tokenizer[lang].post_process(item) + elif tokenized: + item = self.tokenizer[lang](self._look_up_item(idx, lang), + is_train=False) + item_list.append(item) + return item_list + + def __len__(self) -> int: + if len(self.idx_map) > 0: + return len(self.idx_map) + return self._initial_len + + +class TsvDataset(BaseDataset): + """ + TsvDataset which handles data in tsv format. + - file_name should be specified without extention `.tsv` + - needs src_lang and trg_lang (i.e. `en`, `de`) in header. + """ + + def __init__( + self, + path: str, + src_lang: str, + trg_lang: str, + split: int = "train", + has_trg: bool = True, + tokenizer: Dict[str, BasicTokenizer] = None, + sequence_encoder: Dict[str, Callable] = None, + random_subset: int = -1, + task: str = "MT", + **kwargs, + ): + super().__init__( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split=split, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + ) + + # load tsv file + self.df = self.load_data(path, **kwargs) + + # for random subsampling + self._initial_df = None + + def load_data(self, path: str, **kwargs) -> Any: + path = Path(path) + file_path = path.with_suffix(f"{path.suffix}.tsv") + assert file_path.is_file(), f"{file_path} not found. Abort." + + # read tsv data + try: + # pylint: disable=import-outside-toplevel,redefined-outer-name,reimported + import pandas as pd + + df = pd.read_csv( + file_path.as_posix(), + sep="\t", + header=0, + encoding="utf-8", + escapechar="\\", + quoting=3, + na_filter=False, + ) + # TODO: use `chunksize` for online data loading. + assert self.src_lang in df.columns + df[self.src_lang] = df[self.src_lang].apply( + self.tokenizer[self.src_lang].pre_process) + + if self.trg_lang not in df.columns: + self.has_trg = False + assert self.split == "test" + if self.has_trg: + df[self.trg_lang] = df[self.trg_lang].apply( + self.tokenizer[self.trg_lang].pre_process) + return df + + except ImportError as e: + logger.error(e) + raise ImportError from e + + def sample_random_subset(self, seed: int = 42) -> None: + super().sample_random_subset(seed) # check validity + + if self._initial_df is None: + self._initial_df = self.df.copy(deep=True) + + self.df = self._initial_df.sample( + n=self.random_subset, + replace=False, + random_state=seed, # resample every epoch: seed += epoch_no + ).reset_index() + + def reset_random_subset(self): + if self._initial_df is not None: + self.df = self._initial_df + self._initial_df = None + + def get_item(self, idx: int, lang: str, is_train: bool = None) -> List[str]: + line = self.df.iloc[idx][lang] + is_train = self.split == "train" if is_train is None else is_train + item = self.tokenizer[lang](line, is_train=is_train) + return item + + def get_list(self, + lang: str, + postproccessed: bool = False, + tokenized: bool = False) -> Union[List[str], List[List[str]]]: + if postproccessed: + sents = self.df[lang].apply(self.tokenizer[lang].post_process).to_list() + elif tokenized: + sents = self.df[lang].apply(self.tokenizer[lang]).to_list() + else: + sents = self.df[lang].to_list() + return sents + + def __len__(self) -> int: + return len(self.df) + + +class SpeechDataset(TsvDataset): + """ + Speech Dataset + """ + + def __init__( + self, + path: str, + src_lang: str = "src", + trg_lang: str = "trg", + split: int = "train", + has_trg: bool = True, + tokenizer: Dict[str, Union[BasicTokenizer, SpeechProcessor]] = None, + sequence_encoder: Dict[str, Callable] = None, + random_subset: int = -1, + task: str = "S2T", + **kwargs, + ): + super().__init__( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split=split, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + ) + + # load tsv file + self.df = self.load_data(path, **kwargs) + + assert isinstance(self.tokenizer["src"], SpeechProcessor) + self.tokenizer["src"].root_path = Path(path).parent + + # for random subsampling + self._initial_df = None + + def load_data(self, path: str, **kwargs) -> Any: + path = Path(path) + file_path = path.with_suffix(f"{path.suffix}.tsv") + assert file_path.is_file(), f"{file_path} not found. Abort." + + # read tsv data + try: + # pylint: disable=import-outside-toplevel,redefined-outer-name,reimported + import pandas as pd + + # TODO: use `chunksize` for online data loading. + dtype = {'id': str, 'src': str, 'trg': str, 'n_frames': int} + df = pd.read_csv( + file_path.as_posix(), + sep="\t", + header=0, + encoding="utf-8", + escapechar="\\", + quoting=3, + na_filter=False, + dtype=dtype, + ) + + # WARNING: instances shorter than the kernel size cannot be convolved. + min_length = int(self.tokenizer["src"].min_length) + df['n_frames'] = df[df['n_frames'] > min_length]['n_frames'] + + # drop invalid rows + df = df.replace(r'^\s*$', float("nan"), regex=True) + df = df.dropna() + + assert "src" in df.columns + + if "trg" not in df.columns: + self.has_trg = False + assert self.split == "test" + + if self.has_trg: + df["trg"] = df["trg"].apply(self.tokenizer["trg"].pre_process) + return df + + except ImportError as e: + logger.error(e) + raise ImportError from e + + @property + def src(self) -> List[str]: + return self.df["src"] + + +class StreamDataset(BaseDataset): + """ + StreamDataset which interacts with stream inputs. + - called by `translate()` func in `prediction.py`. + """ + + def __init__( + self, + path: str, + src_lang: str, + trg_lang: str, + split: int = "test", + has_trg: bool = False, + tokenizer: Dict[str, BasicTokenizer] = None, + sequence_encoder: Dict[str, Callable] = None, + random_subset: int = -1, + task: str = "MT", + **kwargs, + ): + # pylint: disable=unused-argument + super().__init__( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split=split, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + ) + # place holder + self.cache = {} + + def set_item(self, src_line: str, trg_line: str = None) -> None: + """ + Set input text to the cache. + + :param src_line: (str) + :param trg_line: (str) + """ + assert isinstance(src_line, str) and src_line.strip() != "", \ + "The input sentence is empty! Please make sure " \ + "that you are feeding a valid input." + + idx = len(self.cache) + src_line = self.tokenizer[self.src_lang].pre_process(src_line) + + if self.has_trg: + trg_line = self.tokenizer[self.trg_lang].pre_process(trg_line) + self.cache[idx] = (src_line, trg_line) + + def get_item(self, idx: int, lang: str, is_train: bool = None) -> List[str]: + # pylint: disable=unused-argument + assert idx in self.cache, (idx, self.cache) + assert lang in [self.src_lang, self.trg_lang] + if lang == self.trg_lang: + assert self.has_trg + + line = {} + src_line, trg_line = self.cache[idx] + line[self.src_lang] = src_line + line[self.trg_lang] = trg_line + + item = self.tokenizer[lang](line[lang], is_train=False) + return item + + def reset_cache(self) -> None: + self.cache = {} + + def __len__(self) -> int: + return len(self.cache) + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(split={self.split}, len={len(self.cache)}, " + f"src_lang={self.src_lang}, trg_lang={self.trg_lang}, " + f"has_trg={self.has_trg}, random_subset={self.random_subset})") + + +class SpeechStreamDataset(SpeechDataset): + """ + SpeechStreamDataset which interacts with audio file inputs. + - called by `translate()` func in `prediction.py`. + """ + + def __init__( + self, + path: str, + src_lang: str = "src", + trg_lang: str = "trg", + split: int = "test", + has_trg: bool = False, + tokenizer: Dict[str, BasicTokenizer] = None, + sequence_encoder: Dict[str, Callable] = None, + random_subset: int = -1, + task: str = "S2T", + **kwargs, + ): + # pylint: disable=unused-argument + super(TsvDataset, self).__init__( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split=split, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + ) + + # place holder (empty dataframe) + try: + # pylint: disable=import-outside-toplevel,redefined-outer-name,reimported + import pandas as pd + self.df = pd.DataFrame({'id': [], 'src': [], 'n_frames': []}) + + assert isinstance(self.tokenizer["src"], SpeechProcessor) + self.tokenizer["src"].root_path = Path("") + + except ImportError as e: + logger.error(e) + raise ImportError from e + + def set_item(self, src_line: str, trg_line: str = None) -> None: + """ + Set input text to the cache. + + :param src_line: (str) absolute path to an audio file + :param trg_line: (str) + """ + assert isinstance(src_line, str) and src_line.strip() != "", \ + "The input sentence is empty! Please make sure " \ + "that you are feeding a valid input." + + assert (Path(self.tokenizer["src"].root_path) / src_line).is_file(), \ + f"{src_line} not found. Please provide the abosolute path to a file!" + + min_length = int(self.tokenizer["src"].min_length) # minimum length + row = {"id": str(len(self.df)), "src": src_line, "n_frames": min_length} + + if trg_line: + row["trg"] = self.tokenizer[self.trg_lang].pre_process(trg_line) + + row_df = pd.DataFrame.from_records([row]) + self.df = pd.concat([self.df, row_df], ignore_index=True) + + def reset_cache(self) -> None: + self.df = pd.DataFrame({'id': [], 'src': [], 'n_frames': []}) + + +class BaseHuggingfaceDataset(BaseDataset): + """ + Wrapper for Huggingface's dataset object + cf.) https://huggingface.co/docs/datasets + """ + + def __init__( + self, + path: str, + src_lang: str, + trg_lang: str, + has_trg: bool = True, + tokenizer: Dict[str, BasicTokenizer] = None, + sequence_encoder: Dict[str, Callable] = None, + random_subset: int = -1, + task: str = "MT", + **kwargs, + ): + super().__init__( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split=kwargs["split"], + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + ) + # load data + self.dataset = self.load_data(path, **kwargs) + self._kwargs = kwargs # should contain arguments passed to `load_dataset()` + + def load_data(self, path: str, **kwargs) -> Any: + # pylint: disable=import-outside-toplevel + try: + from datasets import config, load_dataset, load_from_disk + if Path(path, config.DATASET_STATE_JSON_FILENAME).exists(): + return load_from_disk(path) + return load_dataset(path, **kwargs) + + except ImportError as e: + logger.error(e) + raise ImportError from e + + def sample_random_subset(self, seed: int = 42) -> None: + super().sample_random_subset(seed) # check validity + + # resample every epoch: seed += epoch_no + self.dataset = self.dataset.shuffle(seed=seed).select(range(self.random_subset)) + + def reset_random_subset(self) -> None: + # reload from cache + self.dataset = self.load_data(self.path, **self._kwargs) + + def get_item(self, idx: int, lang: str, is_train: bool = None) -> List[str]: + # lookup + line = self.dataset[idx] + assert lang in line, (line, lang) + + # tokenize + is_train = self.split == "train" if is_train is None else is_train + item = self.tokenizer[lang](line[lang], is_train=is_train) + return item + + def get_list(self, + lang: str, + postproccessed: bool = False, + tokenized: bool = False) -> List[str]: + if postproccessed: # pylint: disable=no-else-return + if f"post_{lang}" not in self.dataset: + self.dataset = self.dataset.map( + lambda item: + {f"post_{lang}": self.tokenizer[lang].post_process(item[lang])}, + desc=f"Postprocessing {lang}...") + return self.dataset[f"post_{lang}"] + elif tokenized: + if f"tok_{lang}" not in self.dataset: + self.dataset = self.dataset.map( + lambda item: + {f"tok_{lang}": self.tokenizer[lang](item[lang], is_train=False)}, + desc=f"Tokenizing {lang}...") + return self.dataset[f"tok_{lang}"] + else: + return self.dataset[lang] + + def __len__(self) -> int: + return self.dataset.num_rows + + def __repr__(self) -> str: + ret = (f"{self.__class__.__name__}(len={self.__len__()}, " + f"src_lang={self.src_lang}, trg_lang={self.trg_lang}, " + f"has_trg={self.has_trg}, random_subset={self.random_subset}") + for k, v in self._kwargs.items(): + ret += f", {k}={v}" + ret += ")" + return ret + + +class HuggingfaceDataset(BaseHuggingfaceDataset): + """ + Wrapper for Huggingface's `datasets.features.Translation` class + cf.) https://github.com/huggingface/datasets/blob/master/src/datasets/features/translation.py + """ # noqa + + def load_data(self, path: str, **kwargs) -> Any: + dataset = super().load_data(path=path, **kwargs) + + # rename columns + if "translation" in dataset.features: + # check language pair + lang_pair = dataset.features["translation"].languages + assert self.src_lang in lang_pair, (self.src_lang, lang_pair) + + # rename columns + columns = {f"translation.{self.src_lang}": self.src_lang} + if self.has_trg: + assert self.trg_lang in lang_pair, (self.trg_lang, lang_pair) + columns[f"translation.{self.trg_lang}"] = self.trg_lang + + # flatten + dataset = dataset.flatten() + + elif f"{self.src_lang}_sentence" in dataset.features: + # rename columns + columns = {f"{self.src_lang}_sentence": self.src_lang} + if self.has_trg: + assert f"{self.trg_lang}_sentence" in dataset.features + columns[f"{self.trg_lang}_sentence"] = self.trg_lang + + else: + pass + # TODO: support other field names + dataset = dataset.rename_columns(columns) + + # preprocess (lowercase, pretokenize, etc.) + def _pre_process(item): + sl = self.src_lang + tl = self.trg_lang + ret = {sl: self.tokenizer[sl].pre_process(item[sl])} + if self.has_trg: + ret[tl] = self.tokenizer[tl].pre_process(item[tl]) + return ret + + def _drop_nan(item): + sl = self.src_lang + tl = self.trg_lang + is_src_valid = item[sl] is not None and len(item[sl]) > 0 + if self.has_trg: + is_trg_valid = item[tl] is not None and len(item[tl]) > 0 + return is_src_valid and is_trg_valid + return is_src_valid + + dataset = dataset.filter(_drop_nan, desc="Dropping NaN...") + return dataset.map(_pre_process, desc="Preprocessing...") + + +def build_dataset( + dataset_type: str, + path: str, + src_lang: str, + trg_lang: str, + split: str, + tokenizer: Dict = None, + sequence_encoder: Dict = None, + random_subset: int = -1, + task: str = "MT", + **kwargs, +): + """ + Builds a dataset. + + :param dataset_type: (str) one of {`plain`, `tsv`, `stream`, `huggingface`} + :param path: (str) either a local file name or + dataset name to download from remote + :param src_lang: (str) language code for source + :param trg_lang: (str) language code for target + :param split: (str) one of {`train`, `dev`, `test`} + :param tokenizer: tokenizer objects for both source and target + :param sequence_encoder: encoding functions for both source and target + :param random_subset: (int) number of random subset; -1 means no subsampling + :param task: (str) + :return: loaded Dataset + """ + dataset = None + has_trg = True # by default, we expect src-trg pairs + + if dataset_type == "plain": + if not Path(path).with_suffix(f"{Path(path).suffix}.{trg_lang}").is_file(): + # no target is given -> create dataset from src only + has_trg = False + dataset = PlaintextDataset( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split=split, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + **kwargs, + ) + elif dataset_type == "tsv": + dataset = TsvDataset( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split=split, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + **kwargs, + ) + elif dataset_type == "speech": + dataset = SpeechDataset( + path=path, + src_lang="src", + trg_lang="trg", + split=split, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + **kwargs, + ) + elif dataset_type == "stream": + dataset = StreamDataset( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + split="test", + has_trg=False, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=-1, + task=task, + **kwargs, + ) + elif dataset_type == "speech_stream": + dataset = SpeechStreamDataset( + path=None, + src_lang="src", + trg_lang="trg", + split="test", + has_trg=False, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=-1, + task=task, + **kwargs, + ) + elif dataset_type == "huggingface": + # "split" should be specified in kwargs + if "split" not in kwargs: + kwargs["split"] = "validation" if split == "dev" else split + dataset = HuggingfaceDataset( + path=path, + src_lang=src_lang, + trg_lang=trg_lang, + has_trg=has_trg, + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + random_subset=random_subset, + task=task, + **kwargs, + ) + else: + ConfigurationError(f"{dataset_type}: Unknown dataset type.") + return dataset + + +class SentenceBatchSampler(BatchSampler): + """ + Wraps another sampler to yield a mini-batch of indices based on num of instances. + An instance longer than dataset.max_len will be filtered out. + + :param sampler: Base sampler. Can be any iterable object + :param batch_size: Size of mini-batch. + :param drop_last: If `True`, the sampler will drop the last batch if its size + would be less than `batch_size` + """ + + def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool): + super().__init__(sampler, batch_size, drop_last) + + def __iter__(self): + batch = [] + d = self.sampler.data_source + for idx in self.sampler: + src, trg = d[idx] # pylint: disable=unused-variable + if src is not None: # otherwise drop instance + batch.append(idx) + if len(batch) >= self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + +class TokenBatchSampler(BatchSampler): + """ + Wraps another sampler to yield a mini-batch of indices based on num of tokens + (incl. padding). An instance longer than dataset.max_len or shorter than + dataset.min_len will be filtered out. + * no bucketing implemented + + :param sampler: Base sampler. Can be any iterable object + :param batch_size: Size of mini-batch. + :param drop_last: If `True`, the sampler will drop the last batch if + its size would be less than `batch_size` + """ + + def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool): + super().__init__(sampler, batch_size, drop_last) + + def __iter__(self): + batch = [] + max_tokens = 0 + d = self.sampler.data_source + for idx in self.sampler: + src, trg = d[idx] # call __getitem__() + if src is not None: # otherwise drop instance + src_len = 0 if src is None else len(src) + trg_len = 0 if trg is None else len(trg) + n_tokens = 0 if src_len == 0 else max(src_len + 1, trg_len + 2) + batch.append(idx) + if n_tokens > max_tokens: + max_tokens = n_tokens + if max_tokens * len(batch) >= self.batch_size: + yield batch + batch = [] + max_tokens = 0 + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + raise NotImplementedError diff --git a/joeynmt/decoders.py b/joeynmt/decoders.py new file mode 100644 index 0000000..36d18c0 --- /dev/null +++ b/joeynmt/decoders.py @@ -0,0 +1,608 @@ +# coding: utf-8 +""" +Various decoders +""" +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn + +from joeynmt.attention import BahdanauAttention, LuongAttention +from joeynmt.encoders import Encoder +from joeynmt.helpers import ConfigurationError, freeze_params, subsequent_mask +from joeynmt.transformer_layers import PositionalEncoding, TransformerDecoderLayer + + +class Decoder(nn.Module): + """ + Base decoder class + """ + + # pylint: disable=abstract-method + + @property + def output_size(self): + """ + Return the output size (size of the target vocabulary) + + :return: + """ + return self._output_size + + +class RecurrentDecoder(Decoder): + """A conditional RNN decoder with attention.""" + + # pylint: disable=too-many-arguments,unused-argument + # pylint: disable=too-many-instance-attributes + + def __init__( + self, + rnn_type: str = "gru", + emb_size: int = 0, + hidden_size: int = 0, + encoder: Encoder = None, + attention: str = "bahdanau", + num_layers: int = 1, + vocab_size: int = 0, + dropout: float = 0.0, + emb_dropout: float = 0.0, + hidden_dropout: float = 0.0, + init_hidden: str = "bridge", + input_feeding: bool = True, + freeze: bool = False, + **kwargs, + ) -> None: + """ + Create a recurrent decoder with attention. + + :param rnn_type: rnn type, valid options: "lstm", "gru" + :param emb_size: target embedding size + :param hidden_size: size of the RNN + :param encoder: encoder connected to this decoder + :param attention: type of attention, valid options: "bahdanau", "luong" + :param num_layers: number of recurrent layers + :param vocab_size: target vocabulary size + :param hidden_dropout: Is applied to the input to the attentional layer. + :param dropout: Is applied between RNN layers. + :param emb_dropout: Is applied to the RNN input (word embeddings). + :param init_hidden: If "bridge" (default), the decoder hidden states are + initialized from a projection of the last encoder state, + if "zeros" they are initialized with zeros, + if "last" they are identical to the last encoder state + (only if they have the same size) + :param input_feeding: Use Luong's input feeding. + :param freeze: Freeze the parameters of the decoder during training. + :param kwargs: + """ + + super().__init__() + + self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False) + self.type = rnn_type + self.hidden_dropout = torch.nn.Dropout(p=hidden_dropout, inplace=False) + self.hidden_size = hidden_size + self.emb_size = emb_size + + rnn = nn.GRU if rnn_type == "gru" else nn.LSTM + + self.input_feeding = input_feeding + if self.input_feeding: # Luong-style + # combine embedded prev word +attention vector before feeding to rnn + self.rnn_input_size = emb_size + hidden_size + else: + # just feed prev word embedding + self.rnn_input_size = emb_size + + # the decoder RNN + self.rnn = rnn( + self.rnn_input_size, + hidden_size, + num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0.0, + ) + + # combine output with context vector before output layer (Luong-style) + self.att_vector_layer = nn.Linear(hidden_size + encoder.output_size, + hidden_size, + bias=True) + + self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False) + self._output_size = vocab_size + + if attention == "bahdanau": + self.attention = BahdanauAttention( + hidden_size=hidden_size, + key_size=encoder.output_size, + query_size=hidden_size, + ) + elif attention == "luong": + self.attention = LuongAttention(hidden_size=hidden_size, + key_size=encoder.output_size) + else: + raise ConfigurationError( + f"Unknown attention mechanism: " + f"{attention}. Valid options: 'bahdanau', 'luong'.") + + self.num_layers = num_layers + self.hidden_size = hidden_size + + # to initialize from the final encoder state of last layer + self.init_hidden_option = init_hidden + if self.init_hidden_option == "bridge": + self.bridge_layer = nn.Linear(encoder.output_size, hidden_size, bias=True) + elif self.init_hidden_option == "last": + if encoder.output_size != self.hidden_size: + if encoder.output_size != 2 * self.hidden_size: # bidirectional + raise ConfigurationError( + f"For initializing the decoder state with the " + f"last encoder state, their sizes have to match " + f"(encoder: {encoder.output_size} " + f"vs. decoder: {self.hidden_size})") + if freeze: + freeze_params(self) + + self.ctc_output_layer = None # not supported + + def _check_shapes_input_forward_step( + self, + prev_embed: Tensor, + prev_att_vector: Tensor, + encoder_output: Tensor, + src_mask: Tensor, + hidden: Tensor, + ) -> None: + """ + Make sure the input shapes to `self._forward_step` are correct. + Same inputs as `self._forward_step`. + + :param prev_embed: + :param prev_att_vector: + :param encoder_output: + :param src_mask: + :param hidden: + """ + assert prev_embed.shape[1:] == torch.Size([1, self.emb_size]) + assert prev_att_vector.shape[1:] == torch.Size([1, self.hidden_size]) + assert prev_att_vector.shape[0] == prev_embed.shape[0] + assert encoder_output.shape[0] == prev_embed.shape[0] + assert len(encoder_output.shape) == 3 + assert src_mask.shape[0] == prev_embed.shape[0] + assert src_mask.shape[1] == 1 + assert src_mask.shape[2] == encoder_output.shape[1] + if isinstance(hidden, tuple): # for lstm + hidden = hidden[0] + assert hidden.shape[0] == self.num_layers + assert hidden.shape[1] == prev_embed.shape[0] + assert hidden.shape[2] == self.hidden_size + + def _check_shapes_input_forward( + self, + trg_embed: Tensor, + encoder_output: Tensor, + encoder_hidden: Tensor, + src_mask: Tensor, + hidden: Tensor = None, + prev_att_vector: Tensor = None, + ) -> None: + """ + Make sure that inputs to `self.forward` are of correct shape. + Same input semantics as for `self.forward`. + + :param trg_embed: + :param encoder_output: + :param encoder_hidden: + :param src_mask: + :param hidden: + :param prev_att_vector: + """ + assert len(encoder_output.shape) == 3 + if encoder_hidden is not None: + assert len(encoder_hidden.shape) == 2 + assert encoder_hidden.shape[-1] == encoder_output.shape[-1] + assert src_mask.shape[1] == 1 + assert src_mask.shape[0] == encoder_output.shape[0] + assert src_mask.shape[2] == encoder_output.shape[1] + assert trg_embed.shape[0] == encoder_output.shape[0] + assert trg_embed.shape[2] == self.emb_size + if hidden is not None: + if isinstance(hidden, tuple): # for lstm + hidden = hidden[0] + assert hidden.shape[1] == encoder_output.shape[0] + assert hidden.shape[2] == self.hidden_size + if prev_att_vector is not None: + assert prev_att_vector.shape[0] == encoder_output.shape[0] + assert prev_att_vector.shape[2] == self.hidden_size + assert prev_att_vector.shape[1] == 1 + + def _forward_step( + self, + prev_embed: Tensor, + prev_att_vector: Tensor, # context or att vector + encoder_output: Tensor, + src_mask: Tensor, + hidden: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Perform a single decoder step (1 token). + + 1. `rnn_input`: concat(prev_embed, prev_att_vector [possibly empty]) + 2. update RNN with `rnn_input` + 3. calculate attention and context/attention vector + + :param prev_embed: embedded previous token, + shape (batch_size, 1, embed_size) + :param prev_att_vector: previous attention vector, + shape (batch_size, 1, hidden_size) + :param encoder_output: encoder hidden states for attention context, + shape (batch_size, src_length, encoder.output_size) + :param src_mask: src mask, 1s for area before , 0s elsewhere + shape (batch_size, 1, src_length) + :param hidden: previous hidden state, + shape (num_layers, batch_size, hidden_size) + :return: + - att_vector: new attention vector (batch_size, 1, hidden_size), + - hidden: new hidden state with shape (batch_size, 1, hidden_size), + - att_probs: attention probabilities (batch_size, 1, src_len) + """ + + # shape checks + self._check_shapes_input_forward_step( + prev_embed=prev_embed, + prev_att_vector=prev_att_vector, + encoder_output=encoder_output, + src_mask=src_mask, + hidden=hidden, + ) + + if self.input_feeding: + # concatenate the input with the previous attention vector + rnn_input = torch.cat([prev_embed, prev_att_vector], dim=2) + else: + rnn_input = prev_embed + + rnn_input = self.emb_dropout(rnn_input) + + # rnn_input: batch x 1 x emb+2*enc_size + _, hidden = self.rnn(rnn_input, hidden) + + # use new (top) decoder layer as attention query + if isinstance(hidden, tuple): + query = hidden[0][-1].unsqueeze(1) + else: + query = hidden[-1].unsqueeze(1) # [#layers, B, D] -> [B, 1, D] + + # compute context vector using attention mechanism + # only use last layer for attention mechanism + # key projections are pre-computed + context, att_probs = self.attention(query=query, + values=encoder_output, + mask=src_mask) + + # return attention vector (Luong) + # combine context with decoder hidden state before prediction + att_vector_input = torch.cat([query, context], dim=2) + # batch x 1 x 2*enc_size+hidden_size + att_vector_input = self.hidden_dropout(att_vector_input) + + att_vector = torch.tanh(self.att_vector_layer(att_vector_input)) + + # output: batch x 1 x hidden_size + return att_vector, hidden, att_probs + + def forward( + self, + trg_embed: Tensor, + encoder_output: Tensor, + encoder_hidden: Tensor, + src_mask: Tensor, + unroll_steps: int, + hidden: Tensor = None, + prev_att_vector: Tensor = None, + **kwargs, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Unroll the decoder one step at a time for `unroll_steps` steps. For every step, + the `_forward_step` function is called internally. + + During training, the target inputs (`trg_embed') are already known for the full + sequence, so the full unrol is done. In this case, `hidden` and + `prev_att_vector` are None. + + For inference, this function is called with one step at a time since embedded + targets are the predictions from the previous time step. In this case, `hidden` + and `prev_att_vector` are fed from the output of the previous call of this + function (from the 2nd step on). + + `src_mask` is needed to mask out the areas of the encoder states that should not + receive any attention, which is everything after the first . + + The `encoder_output` are the hidden states from the encoder and are used as + context for the attention. + + The `encoder_hidden` is the last encoder hidden state that is used to initialize + the first hidden decoder state (when `self.init_hidden_option` is "bridge" or + "last"). + + :param trg_embed: embedded target inputs, + shape (batch_size, trg_length, embed_size) + :param encoder_output: hidden states from the encoder, + shape (batch_size, src_length, encoder.output_size) + :param encoder_hidden: last state from the encoder, + shape (batch_size, encoder.output_size) + :param src_mask: mask for src states: 0s for padded areas, + 1s for the rest, shape (batch_size, 1, src_length) + :param unroll_steps: number of steps to unroll the decoder RNN + :param hidden: previous decoder hidden state, + if not given it's initialized as in `self.init_hidden`, + shape (batch_size, num_layers, hidden_size) + :param prev_att_vector: previous attentional vector, + if not given it's initialized with zeros, + shape (batch_size, 1, hidden_size) + :return: + - outputs: shape (batch_size, unroll_steps, vocab_size), + - hidden: last hidden state (batch_size, num_layers, hidden_size), + - att_probs: attention probabilities + with shape (batch_size, unroll_steps, src_length), + - att_vectors: attentional vectors + with shape (batch_size, unroll_steps, hidden_size) + - None + """ + # initialize decoder hidden state from final encoder hidden state + if hidden is None and encoder_hidden is not None: + hidden = self._init_hidden(encoder_hidden) + else: + # DataParallel splits batch along the 0th dim. + # Place back the batch_size to the 1st dim here. + if isinstance(hidden, tuple): + h, c = hidden + hidden = ( + h.permute(1, 0, 2).contiguous(), + c.permute(1, 0, 2).contiguous(), + ) + else: + hidden = hidden.permute(1, 0, 2).contiguous() + # shape (num_layers, batch_size, hidden_size) + + # shape checks + self._check_shapes_input_forward( + trg_embed=trg_embed, + encoder_output=encoder_output, + encoder_hidden=encoder_hidden, + src_mask=src_mask, + hidden=hidden, + prev_att_vector=prev_att_vector, + ) + + # pre-compute projected encoder outputs + # (the "keys" for the attention mechanism) + # this is only done for efficiency + if hasattr(self.attention, "compute_proj_keys"): + self.attention.compute_proj_keys(keys=encoder_output) + + # here we store all intermediate attention vectors (used for prediction) + att_vectors = [] + att_probs = [] + + batch_size = encoder_output.size(0) + + if prev_att_vector is None: + with torch.no_grad(): + prev_att_vector = encoder_output.new_zeros( + [batch_size, 1, self.hidden_size]) + + # unroll the decoder RNN for `unroll_steps` steps + for i in range(unroll_steps): + prev_embed = trg_embed[:, i].unsqueeze(1) # batch, 1, emb + prev_att_vector, hidden, att_prob = self._forward_step( + prev_embed=prev_embed, + prev_att_vector=prev_att_vector, + encoder_output=encoder_output, + src_mask=src_mask, + hidden=hidden, + ) + att_vectors.append(prev_att_vector) + att_probs.append(att_prob) + + att_vectors = torch.cat(att_vectors, dim=1) + # att_vectors: batch, unroll_steps, hidden_size + att_probs = torch.cat(att_probs, dim=1) + # att_probs: batch, unroll_steps, src_length + outputs = self.output_layer(att_vectors) + # outputs: batch, unroll_steps, vocab_size + + # DataParallel gathers batches along the 0th dim. + # Put batch_size dim to the 0th position. + if isinstance(hidden, tuple): + h, c = hidden + hidden = ( + h.permute(1, 0, 2).contiguous(), + c.permute(1, 0, 2).contiguous(), + ) + assert hidden[0].size(0) == batch_size + else: + hidden = hidden.permute(1, 0, 2).contiguous() + assert hidden.size(0) == batch_size + # shape (batch_size, num_layers, hidden_size) + + return outputs, hidden, att_probs, att_vectors, None + + def _init_hidden(self, + encoder_final: Tensor = None) -> Tuple[Tensor, Optional[Tensor]]: + """ + Returns the initial decoder state, conditioned on the final encoder state of the + last encoder layer. + + In case of `self.init_hidden_option == "bridge"` and a given `encoder_final`, + this is a projection of the encoder state. + + In case of `self.init_hidden_option == "last"` and a size-matching + `encoder_final`, this is set to the encoder state. If the encoder is twice as + large as the decoder state (e.g. when bi-directional), just use the forward + hidden state. + + In case of `self.init_hidden_option == "zero"`, it is initialized with zeros. + + For LSTMs we initialize both the hidden state and the memory cell with the same + projection/copy of the encoder hidden state. + + All decoder layers are initialized with the same initial values. + + :param encoder_final: final state from the last layer of the encoder, + shape (batch_size, encoder_hidden_size) + :return: hidden state if GRU, (hidden state, memory cell) if LSTM, + shape (batch_size, hidden_size) + """ + batch_size = encoder_final.size(0) + + # for multiple layers: is the same for all layers + if self.init_hidden_option == "bridge" and encoder_final is not None: + # num_layers x batch_size x hidden_size + hidden = (torch.tanh(self.bridge_layer(encoder_final)).unsqueeze(0).repeat( + self.num_layers, 1, 1)) + elif self.init_hidden_option == "last" and encoder_final is not None: + # special case: encoder is bidirectional: use only forward state + if encoder_final.shape[1] == 2 * self.hidden_size: # bidirectional + encoder_final = encoder_final[:, :self.hidden_size] + hidden = encoder_final.unsqueeze(0).repeat(self.num_layers, 1, 1) + else: # initialize with zeros + with torch.no_grad(): + hidden = encoder_final.new_zeros(self.num_layers, batch_size, + self.hidden_size) + + return (hidden, hidden) if isinstance(self.rnn, nn.LSTM) else hidden + + def __repr__(self): + return (f"{self.__class__.__name__}(rnn={self.rnn}, " + f"attention={self.attention})") + + +class TransformerDecoder(Decoder): + """ + A transformer decoder with N masked layers. + Decoder layers are masked so that an attention head cannot see the future. + """ + + # pylint: disable=unused-argument + def __init__( + self, + num_layers: int = 4, + num_heads: int = 8, + hidden_size: int = 512, + ff_size: int = 2048, + dropout: float = 0.1, + emb_dropout: float = 0.1, + vocab_size: int = 1, + freeze: bool = False, + **kwargs, + ): + """ + Initialize a Transformer decoder. + + :param num_layers: number of Transformer layers + :param num_heads: number of heads for each layer + :param hidden_size: hidden size + :param ff_size: position-wise feed-forward size + :param dropout: dropout probability (1-keep) + :param emb_dropout: dropout probability for embeddings + :param vocab_size: size of the output vocabulary + :param freeze: set to True keep all decoder parameters fixed + :param kwargs: + """ + super().__init__() + + self._hidden_size = hidden_size + self._output_size = vocab_size + + # create num_layers decoder layers and put them in a list + self.layers = nn.ModuleList([ + TransformerDecoderLayer( + size=hidden_size, + ff_size=ff_size, + num_heads=num_heads, + dropout=dropout, + alpha=kwargs.get("alpha", 1.0), + layer_norm=kwargs.get("layer_norm", "post"), + ) for _ in range(num_layers) + ]) + + self.pe = PositionalEncoding(hidden_size) + self.layer_norm = (nn.LayerNorm(hidden_size, eps=1e-6) if kwargs.get( + "layer_norm", "post") == "pre" else None) + + self.emb_dropout = nn.Dropout(p=emb_dropout) + self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False) + + if freeze: + freeze_params(self) + + self.ctc_output_layer = None + encoder_output_size = kwargs.get("encoder_output_size_for_ctc", None) + if encoder_output_size is not None: + self.ctc_output_layer = nn.Linear(encoder_output_size, + vocab_size, + bias=False) + + def forward( + self, + trg_embed: Tensor, + encoder_output: Tensor, + encoder_hidden: Tensor, + src_mask: Tensor, + unroll_steps: int, + hidden: Tensor, + trg_mask: Tensor, + **kwargs, + ): + """ + Transformer decoder forward pass. + + :param trg_embed: embedded targets + :param encoder_output: source representations + :param encoder_hidden: unused + :param src_mask: + :param unroll_steps: unused + :param hidden: unused + :param trg_mask: to mask out target paddings + Note that a subsequent mask is applied here. + :param kwargs: + :return: + - decoder_output: shape (batch_size, seq_len, vocab_size) + - decoder_hidden: shape (batch_size, seq_len, emb_size) + - att_probs: shape (batch_size, trg_length, src_length), + - None + - ctc_output + """ + assert trg_mask is not None, "trg_mask required for Transformer" + + x = self.pe(trg_embed) # add position encoding to word embedding + x = self.emb_dropout(x) + + trg_mask = trg_mask & subsequent_mask(trg_embed.size(1)).type_as(trg_mask) + + last_layer = len(self.layers) - 1 + return_attention = kwargs.get("return_attention", False) + for i, layer in enumerate(self.layers): + x, att = layer(x=x, + memory=encoder_output, + src_mask=src_mask, + trg_mask=trg_mask, + return_attention=(return_attention and i == last_layer)) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + out = self.output_layer(x) + + ctc_output = None if self.ctc_output_layer is None \ + else self.ctc_output_layer(encoder_output) + + return out, x, att, None, ctc_output + + def __repr__(self): + return (f"{self.__class__.__name__}(num_layers={len(self.layers)}, " + f"num_heads={self.layers[0].trg_trg_att.num_heads}, " + f"alpha={self.layers[0].alpha}, " + f'layer_norm="{self.layers[0]._layer_norm_position}", ' + f'ctc_layer={self.ctc_output_layer is not None})') diff --git a/joeynmt/embeddings.py b/joeynmt/embeddings.py new file mode 100644 index 0000000..4ea4d25 --- /dev/null +++ b/joeynmt/embeddings.py @@ -0,0 +1,126 @@ +# coding: utf-8 +""" +Embedding module +""" + +import logging +import math +from pathlib import Path +from typing import Dict + +import torch +from torch import Tensor, nn + +from joeynmt.helpers import freeze_params +from joeynmt.vocabulary import Vocabulary + +logger = logging.getLogger(__name__) + + +class Embeddings(nn.Module): + """ + Simple embeddings class + """ + + def __init__( + self, + embedding_dim: int = 64, + scale: bool = False, + vocab_size: int = 0, + padding_idx: int = 1, + freeze: bool = False, + **kwargs, + ): + """ + Create new embeddings for the vocabulary. + Use scaling for the Transformer. + + :param embedding_dim: + :param scale: + :param vocab_size: + :param padding_idx: + :param freeze: freeze the embeddings during training + """ + # pylint: disable=unused-argument + super().__init__() + + self.embedding_dim = embedding_dim + self.scale = scale + self.vocab_size = vocab_size + self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx) + + if freeze: + freeze_params(self) + + def forward(self, x: Tensor) -> Tensor: + """ + Perform lookup for input `x` in the embedding table. + + :param x: index in the vocabulary + :return: embedded representation for `x` + """ + if self.scale: + return self.lut(x) * math.sqrt(self.embedding_dim) + return self.lut(x) + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(" + f"embedding_dim={self.embedding_dim}, " + f"vocab_size={self.vocab_size})") + + # from fairseq + def load_from_file(self, embed_path: Path, vocab: Vocabulary) -> None: + """Load pretrained embedding weights from text file. + + - First line is expected to contain vocabulary size and dimension. + The dimension has to match the model's specified embedding size, + the vocabulary size is used in logging only. + - Each line should contain word and embedding weights + separated by spaces. + - The pretrained vocabulary items that are not part of the + joeynmt's vocabulary will be ignored (not loaded from the file). + - The initialization (specified in config["model"]["embed_initializer"]) + of joeynmt's vocabulary items that are not part of the + pretrained vocabulary will be kept (not overwritten in this func). + - This function should be called after initialization! + + Example: + 2 5 + the -0.0230 -0.0264 0.0287 0.0171 0.1403 + at -0.0395 -0.1286 0.0275 0.0254 -0.0932 + + :param embed_path: embedding weights text file + :param vocab: Vocabulary object + """ + # pylint: disable=logging-too-many-args + + embed_dict: Dict[int, Tensor] = {} + # parse file + with embed_path.open("r", encoding="utf-8", errors="ignore") as f_embed: + vocab_size, d = map(int, f_embed.readline().split()) + assert self.embedding_dim == d, "Embedding dimension doesn't match." + for line in f_embed.readlines(): + tokens = line.rstrip().split(" ") + if tokens[0] in vocab.specials or not vocab.is_unk(tokens[0]): + embed_dict[vocab.lookup(tokens[0])] = torch.FloatTensor( + [float(t) for t in tokens[1:]]) + + logger.warning( + "Loaded %d of %d (%%) tokens in the pre-trained WE.", + len(embed_dict), + vocab_size, + len(embed_dict) / vocab_size, + ) + + # assign + for idx, weights in embed_dict.items(): + if idx < self.vocab_size: + assert self.embedding_dim == len(weights) + self.lut.weight.data[idx] = weights + + logger.warning( + "Loaded %d of %d (%%) tokens of the JoeyNMT's vocabulary.", + len(embed_dict), + len(vocab), + len(embed_dict) / len(vocab), + ) diff --git a/joeynmt/encoders.py b/joeynmt/encoders.py new file mode 100644 index 0000000..43ed978 --- /dev/null +++ b/joeynmt/encoders.py @@ -0,0 +1,429 @@ +# coding: utf-8 +""" +Various encoders +""" +import logging +from typing import List, Tuple + +import torch +from torch import Tensor, nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from joeynmt.constants import PAD_ID +from joeynmt.helpers import freeze_params, lengths_to_padding_mask, pad +from joeynmt.transformer_layers import ( + ConformerEncoderLayer, + PositionalEncoding, + TransformerEncoderLayer, +) + +logger = logging.getLogger(__name__) + + +class Encoder(nn.Module): + """ + Base encoder class + """ + + # pylint: disable=abstract-method + @property + def output_size(self): + """ + Return the output size + + :return: + """ + return self._output_size + + +class RecurrentEncoder(Encoder): + """Encodes a sequence of word embeddings""" + + # pylint: disable=unused-argument + def __init__( + self, + rnn_type: str = "gru", + hidden_size: int = 1, + emb_size: int = 1, + num_layers: int = 1, + dropout: float = 0.0, + emb_dropout: float = 0.0, + bidirectional: bool = True, + freeze: bool = False, + **kwargs, + ) -> None: + """ + Create a new recurrent encoder. + + :param rnn_type: RNN type: `gru` or `lstm`. + :param hidden_size: Size of each RNN. + :param emb_size: Size of the word embeddings. + :param num_layers: Number of encoder RNN layers. + :param dropout: Is applied between RNN layers. + :param emb_dropout: Is applied to the RNN input (word embeddings). + :param bidirectional: Use a bi-directional RNN. + :param freeze: freeze the parameters of the encoder during training + :param kwargs: + """ + super().__init__() + + self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False) + self.type = rnn_type + self.emb_size = emb_size + + rnn = nn.GRU if rnn_type == "gru" else nn.LSTM + + self.rnn = rnn( + emb_size, + hidden_size, + num_layers, + batch_first=True, + bidirectional=bidirectional, + dropout=dropout if num_layers > 1 else 0.0, + ) + + self._output_size = 2 * hidden_size if bidirectional else hidden_size + + if freeze: + freeze_params(self) + + def _check_shapes_input_forward(self, embed_src: Tensor, src_length: Tensor, + mask: Tensor) -> None: + """ + Make sure the shape of the inputs to `self.forward` are correct. + Same input semantics as `self.forward`. + + :param embed_src: embedded source tokens + :param src_length: source length + :param mask: source mask + """ + # pylint: disable=unused-argument + assert embed_src.shape[0] == src_length.shape[0] + assert embed_src.shape[2] == self.emb_size + # assert mask.shape == embed_src.shape + assert len(src_length.shape) == 1 + + def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, + **kwargs) -> Tuple[Tensor, Tensor, Tensor]: + """ + Applies a bidirectional RNN to sequence of embeddings x. + The input mini-batch x needs to be sorted by src length. + x and mask should have the same dimensions [batch, time, dim]. + + :param embed_src: embedded src inputs, + shape (batch_size, src_len, embed_size) + :param src_length: length of src inputs + (counting tokens before padding), shape (batch_size) + :param mask: indicates padding areas (zeros where padding), shape + (batch_size, src_len, embed_size) + :param kwargs: + :return: + - output: hidden states with + shape (batch_size, max_length, directions*hidden), + - hidden_concat: last hidden state with + shape (batch_size, directions*hidden) + """ + self._check_shapes_input_forward(embed_src=embed_src, + src_length=src_length, + mask=mask) + total_length = embed_src.size(1) + + # apply dropout to the rnn input + embed_src = self.emb_dropout(embed_src) + + packed = pack_padded_sequence(embed_src, src_length.cpu(), batch_first=True) + output, hidden = self.rnn(packed) + + if isinstance(hidden, tuple): + hidden, memory_cell = hidden # pylint: disable=unused-variable + + output, _ = pad_packed_sequence(output, + batch_first=True, + total_length=total_length) + # hidden: dir*layers x batch x hidden + # output: batch x max_length x directions*hidden + batch_size = hidden.size()[1] + # separate final hidden states by layer and direction + hidden_layerwise = hidden.view( + self.rnn.num_layers, + 2 if self.rnn.bidirectional else 1, + batch_size, + self.rnn.hidden_size, + ) + # final_layers: layers x directions x batch x hidden + + # concatenate the final states of the last layer for each directions + # thanks to pack_padded_sequence final states don't include padding + fwd_hidden_last = hidden_layerwise[-1:, 0] + bwd_hidden_last = hidden_layerwise[-1:, 1] + + # only feed the final state of the top-most layer to the decoder + # pylint: disable=no-member + hidden_concat = torch.cat([fwd_hidden_last, bwd_hidden_last], dim=2).squeeze(0) + # final: batch x directions*hidden + + assert hidden_concat.size(0) == output.size(0), ( + hidden_concat.size(), + output.size(), + ) + return output, hidden_concat, None + + def __repr__(self): + return f"{self.__class__.__name__}(rnn={self.rnn})" + + +class TransformerEncoder(Encoder): + """ + Transformer Encoder + """ + + def __init__( + self, + hidden_size: int = 512, + ff_size: int = 2048, + num_layers: int = 8, + num_heads: int = 4, + dropout: float = 0.1, + emb_dropout: float = 0.1, + freeze: bool = False, + **kwargs, + ): + """ + Initializes the Transformer. + :param hidden_size: hidden size and size of embeddings + :param ff_size: position-wise feed-forward layer size. + (Typically this is 2*hidden_size.) + :param num_layers: number of layers + :param num_heads: number of heads for multi-headed attention + :param dropout: dropout probability for Transformer layers + :param emb_dropout: Is applied to the input (word embeddings). + :param freeze: freeze the parameters of the encoder during training + :param kwargs: + """ + super().__init__() + + self._output_size = hidden_size + + # build all (num_layers) layers + self.layers = nn.ModuleList([ + TransformerEncoderLayer( + size=hidden_size, + ff_size=ff_size, + num_heads=num_heads, + dropout=dropout, + alpha=kwargs.get("alpha", 1.0), + layer_norm=kwargs.get("layer_norm", "pre"), + ) for _ in range(num_layers) + ]) + + self.pe = PositionalEncoding(hidden_size) + self.emb_dropout = nn.Dropout(p=emb_dropout) + + self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) \ + if kwargs.get("layer_norm", "pre") == "pre" else None + + if freeze: + freeze_params(self) + + # conv1d subsampling for audio inputs + self.subsample = kwargs.get("subsample", False) + if self.subsample: + self.subsampler = Conv1dSubsampler(kwargs["in_channels"], + kwargs["conv_channels"], hidden_size, + kwargs.get("conv_kernel_sizes", [3, 3])) + self.pad_index = kwargs.get("pad_index", PAD_ID) + assert self.pad_index is not None + + def forward( + self, + embed_src: Tensor, + src_length: Tensor, + mask: Tensor = None, + **kwargs, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Pass the input (and mask) through each layer in turn. + Applies a Transformer encoder to sequence of embeddings x. + The input mini-batch x needs to be sorted by src length. + x and mask should have the same dimensions [batch, time, dim]. + + :param embed_src: embedded src inputs, + shape (batch_size, src_len, embed_size) + :param src_length: length of src inputs + (counting tokens before padding), shape (batch_size) + :param mask: indicates padding areas (zeros where padding), shape + (batch_size, 1, src_len) + :return: + - output: hidden states with shape (batch_size, max_length, hidden) + - None + - mask + """ + if self.subsample: + embed_src, src_length = self.subsampler(embed_src, src_length) + + if mask is None: + mask = lengths_to_padding_mask(src_length).unsqueeze(1) + + x = self.pe(embed_src) # add position encoding to word embeddings + x = self.emb_dropout(x) + + for layer in self.layers: + x = layer(x, mask) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + if kwargs.get('repad', False) and "src_max_len" in kwargs and self.subsample: + x, mask = self._repad(x, mask, kwargs["src_max_len"]) + assert src_length.size() == (x.size(0), ), (src_length.size(), x.size()) + assert mask.size() == (x.size(0), 1, x.size(1)), (mask.size(), x.size()) + return x, None, mask + + def _repad(self, x, mask, src_max_len): + # re-pad `x` and `mask` so that all seqs in parallel gpus have the same len! + src_max_len = int( + self.subsampler.get_out_seq_lens_tensor( + torch.tensor(src_max_len).float()).item()) + x = pad(x, src_max_len, pad_index=self.pad_index, dim=1) + mask = pad(mask, src_max_len, pad_index=self.pad_index, dim=-1) + return x, mask + + def __repr__(self): + return (f"{self.__class__.__name__}(num_layers={len(self.layers)}, " + f"num_heads={self.layers[0].src_src_att.num_heads}, " + f"alpha={self.layers[0].alpha}, " + f'layer_norm="{self.layers[0]._layer_norm_position}", ' + f'subsample={self.subsample})') + + +class Conv1dSubsampler(nn.Module): + """ + Convolutional subsampler: a stack of 1D convolution (along temporal dimension) + followed by non-linear activation via gated linear units + (https://arxiv.org/abs/1911.08460) + cf.) https://github.com/pytorch/fairseq/blob/main/fairseq/models/speech_to_text/s2t_transformer.py + + :param in_channels: the number of input channels (embed_size = num_freq) + :param mid_channels: the number of intermediate channels + :param out_channels: the number of output channels (hidden_size) + :param kernel_sizes: the kernel size for each convolutional layer + :return: + - output tensor + - sequence length after subsampling + """ # noqa: E501 + + def __init__(self, + in_channels: int, + mid_channels: int, + out_channels: int = None, + kernel_sizes: List[int] = (3, 3)): + super().__init__() + + self.kernel_sizes = kernel_sizes + self.n_layers = len(kernel_sizes) + self.conv_layers = nn.ModuleList( + nn.Conv1d( + in_channels if i == 0 else mid_channels // 2, + mid_channels if i < self.n_layers - 1 else out_channels * 2, + k, + stride=2, + padding=k // 2, + ) for i, k in enumerate(kernel_sizes)) + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + for k in self.kernel_sizes: + out = ((out.float() + 2 * (k // 2) - (k - 1) - 1) / 2 + 1).floor().long() + return out + + def forward(self, src_tokens, src_lengths): + # reshape after DataParallel batch split + max_len = torch.max(src_lengths).item() + assert max_len > 0, "empty batch!" + if src_tokens.size(1) != max_len: + src_tokens = src_tokens[:, :max_len, :] + assert src_tokens.size(1) == max_len, (src_tokens.size(), max_len, src_lengths) + + _, in_seq_len, _ = src_tokens.size() # -> B x T x (C x D) + x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T + for conv in self.conv_layers: + x = conv(x) + x = nn.functional.glu(x, dim=1) + _, _, out_seq_len = x.size() + x = x.transpose(1, 2).contiguous() # -> B x T x (C x D) + out_seq_lens = self.get_out_seq_lens_tensor(src_lengths) + + assert x.size(1) == torch.max(out_seq_lens).item(), \ + (x.size(), in_seq_len, out_seq_len, out_seq_lens) + return x, out_seq_lens + + +class ConformerEncoder(TransformerEncoder): + """ + Conformer Encoder + """ + + def __init__( + self, + hidden_size: int = 512, + ff_size: int = 2048, + num_layers: int = 8, + num_heads: int = 4, + dropout: float = 0.1, + emb_dropout: float = 0.1, + freeze: bool = False, + **kwargs, + ): + super().__init__() + + self._output_size = hidden_size + + # build all (num_layers) layers + self.layers = nn.ModuleList([ + ConformerEncoderLayer(size=hidden_size, + ff_size=ff_size, + num_heads=num_heads, + dropout=dropout, + alpha=kwargs.get("alpha", 1.0), + layer_norm=kwargs.get("layer_norm", "pre"), + depthwise_conv_kernel_size=kwargs.get( + "depthwise_conv_kernel_size", 31)) + for _ in range(num_layers) + ]) + + self.pe = PositionalEncoding(hidden_size) + self.emb_dropout = nn.Dropout(p=emb_dropout) + self.linear = nn.Linear(hidden_size, hidden_size) + + if freeze: + freeze_params(self) + + # conv1d subsampling for audio inputs + self.subsampler = Conv1dSubsampler(kwargs["in_channels"], + kwargs["conv_channels"], hidden_size, + kwargs.get("conv_kernel_sizes", [3, 3])) + self.pad_index = kwargs.get("pad_index", PAD_ID) + assert self.pad_index is not None + + def forward( + self, + embed_src: Tensor, + src_length: Tensor, + mask: Tensor = None, + **kwargs, + ) -> Tuple[Tensor, Tensor, Tensor]: + x, src_length = self.subsampler(embed_src, src_length) # always subsample + mask = lengths_to_padding_mask(src_length).unsqueeze(1) # recompute src mask + + x = self.pe(x) # add position encoding to spectrogram features + x = self.linear(x) + x = self.emb_dropout(x) + + for layer in self.layers: + x = layer(x, mask) # T x B x C + + if kwargs.get('repad', False) and "src_max_len" in kwargs: + x, mask = self._repad(x, mask, kwargs["src_max_len"]) + assert src_length.size() == (x.size(0), ), (src_length.size(), x.size()) + assert mask.size() == (x.size(0), 1, x.size(1)), (mask.size(), x.size()) + return x, None, mask diff --git a/joeynmt/helpers.py b/joeynmt/helpers.py new file mode 100644 index 0000000..a86e392 --- /dev/null +++ b/joeynmt/helpers.py @@ -0,0 +1,730 @@ +# coding: utf-8 +""" +Collection of helper functions +""" +from __future__ import annotations + +import copy +import functools +import logging +import operator +import random +import re +import shutil +import sys +import unicodedata +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import packaging +import pkg_resources +import torch +import yaml +from torch import Tensor, nn +from torch.multiprocessing import cpu_count +from torch.nn.functional import pad as _pad +from torch.utils.tensorboard import SummaryWriter + +from joeynmt.constants import PAD_ID +from joeynmt.plotting import plot_heatmap + +np.set_printoptions(linewidth=sys.maxsize) # format for printing numpy array + + +class ConfigurationError(Exception): + """Custom exception for misspecifications of configuration""" + + +def make_model_dir(model_dir: Path, overwrite: bool = False) -> Path: + """ + Create a new directory for the model. + + :param model_dir: path to model directory + :param overwrite: whether to overwrite an existing directory + :return: path to model directory + """ + model_dir = model_dir.absolute() + if model_dir.is_dir(): + if not overwrite: + raise FileExistsError(f"Model directory {model_dir} exists " + f"and overwriting is disabled.") + # delete previous directory to start with empty dir again + shutil.rmtree(model_dir) + model_dir.mkdir(parents=True) # create model_dir recursively + return model_dir + + +def make_logger(log_dir: Path = None, mode: str = "train") -> str: + """ + Create a logger for logging the training/testing process. + + :param log_dir: path to file where log is stored as well + :param mode: log file name. 'train', 'test' or 'translate' + :return: joeynmt version number + """ + logger = logging.getLogger("") # root logger + version = pkg_resources.require("joeynmt")[0].version + + # add handlers only once. + if len(logger.handlers) == 0: + logger.setLevel(level=logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(name)s - %(message)s") + + if log_dir is not None: + if log_dir.is_dir(): + log_file = log_dir / f"{mode}.log" + + fh = logging.FileHandler(log_file.as_posix(), encoding="utf-8") + fh.setLevel(level=logging.DEBUG) + logger.addHandler(fh) + fh.setFormatter(formatter) + + sh = logging.StreamHandler() + sh.setLevel(logging.INFO) + sh.setFormatter(formatter) + + logger.addHandler(sh) + logger.info("Hello! This is Joey-NMT (version %s).", version) + + return version + + +def check_version(pkg_version: str, cfg_version: str) -> None: + """ + Check joeynmt version + + :param pkg_version: package version number + :param cfg_version: version number specified in config + """ + joeynmt_version = packaging.version.parse(pkg_version) + config_version = packaging.version.parse(cfg_version) + # check if the major version number matches + # pylint: disable=use-maxsplit-arg + assert joeynmt_version.major == config_version.major, ( + f"You are using JoeyNMT version {str(joeynmt_version)}, " + f'but {str(config_version)} is expected in the given config.') + + +def log_cfg(cfg: Dict, prefix: str = "cfg") -> None: + """ + Write configuration to log. + + :param cfg: configuration to log + :param prefix: prefix for logging + """ + logger = logging.getLogger(__name__) + for k, v in cfg.items(): + if isinstance(v, dict): + p = ".".join([prefix, k]) + log_cfg(v, prefix=p) + else: + p = ".".join([prefix, k]) + logger.info("%34s : %s", p, v) + + +def clones(module: nn.Module, n: int) -> nn.ModuleList: + """ + Produce N identical layers. Transformer helper function. + + :param module: the module to clone + :param n: clone this many times + :return: cloned modules + """ + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +def subsequent_mask(size: int) -> Tensor: + """ + Mask out subsequent positions (to prevent attending to future positions) + Transformer helper function. + + :param size: size of mask (2nd and 3rd dim) + :return: Tensor with 0s and 1s of shape (1, size, size) + """ + ones = torch.ones(size, size, dtype=torch.bool) + return torch.tril(ones, out=ones).unsqueeze(0) + + +def set_seed(seed: int) -> None: + """ + Set the random seed for modules torch, numpy and random. + + :param seed: random seed + """ + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.backends.cudnn.deterministic = True + torch.cuda.manual_seed_all(seed) + + +def load_config(path: Union[Path, str] = "configs/default.yaml") -> Dict: + """ + Loads and parses a YAML configuration file. + + :param path: path to YAML configuration file + :return: configuration dictionary + """ + if isinstance(path, str): + path = Path(path) + with path.open("r", encoding="utf-8") as ymlfile: + cfg = yaml.safe_load(ymlfile) + return cfg + + +def write_list_to_file(output_path: Path, array: List[Any]) -> None: + """ + Write list of str to file in `output_path`. + + :param output_path: output file path + :param array: list of strings + """ + with output_path.open("w", encoding="utf-8") as opened_file: + for entry in array: + opened_file.write(f"{entry}\n") + + +def read_list_from_file(input_path: Path) -> List[str]: + """ + Read list of str from file in `input_path`. + + :param input_path: input file path + :return: list of strings + """ + if input_path is None: + return [] + return [ + line.rstrip("\n") + for line in input_path.read_text(encoding="utf-8").splitlines() + ] + + +def parse_train_args(cfg: Dict, mode: str = "training") -> Tuple: + """Parse and validate train args specified in config file""" + logger = logging.getLogger(__name__) + + model_dir: Path = Path(cfg["model_dir"]) + assert model_dir.is_dir(), f"{model_dir} not found." + + use_cuda: bool = cfg["use_cuda"] and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + n_gpu: int = torch.cuda.device_count() if use_cuda else 0 + num_workers: int = cfg.get("num_workers", 0) + if num_workers > 0: + num_workers = min(cpu_count(), num_workers) + + # normalization + normalization: str = cfg.get("normalization", "batch") + if normalization not in ["batch", "tokens", "none"]: + raise ConfigurationError( + "Invalid `normalization` option. Valid options: {`batch`, `token`, `none`}." + ) + + # model initialization + def _load_path(path): + load_path = cfg.get(path, None) + if load_path is not None: + load_path = Path(load_path) + assert load_path.is_file(), load_path + return load_path + + load_model: Optional[Path] = _load_path("load_model") + + # fp16 + fp16: bool = cfg.get("fp16", False) + + if mode == "prediction": + return model_dir, load_model, device, n_gpu, num_workers, normalization, fp16 + + # layer initialization + load_encoder: Optional[Path] = _load_path("load_encoder") + load_decoder: Optional[Path] = _load_path("load_decoder") + + # objective + loss_type: str = cfg.get("loss", "crossentropy") + if loss_type not in ["crossentropy", "crossentropy-ctc"]: + raise ConfigurationError( + "Invalid `loss` type. Valid option: {`crossentropy`, `crossentropy-ctc`}.") + ctc_weight: float = cfg.get("ctc_weight", 0.0) + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + label_smoothing: float = cfg.get("label_smoothing", 0.0) + + # minimum learning rate for early stopping + learning_rate_min: float = cfg.get("learning_rate_min", 1.0e-8) + + # save/delete checkpoints + keep_best_ckpts: int = int(cfg.get("keep_best_ckpts", 5)) + _keep_last_ckpts: Optional[int] = cfg.get("keep_last_ckpts", None) + if _keep_last_ckpts is not None: # backward compatibility + keep_best_ckpts = _keep_last_ckpts + logger.warning("`keep_last_ckpts` option is outdated. " + "Please use `keep_best_ckpts`, instead.") + + # logging, validation + logging_freq: int = cfg.get("logging_freq", 100) + validation_freq: int = cfg.get("validation_freq", 1000) + log_valid_sents: List[int] = cfg.get("print_valid_sents", [0, 1, 2]) + + # early stopping + early_stopping_metric: str = cfg.get("early_stopping_metric", "ppl").lower() + if early_stopping_metric not in ["acc", "loss", "ppl", "bleu", "chrf", "wer"]: + raise ConfigurationError( + "Invalid setting for `early_stopping_metric`. " + "Valid options: {`acc`, `loss`, `ppl`, `bleu`, `chrf`, `wer`}.") + + # data & batch handling + seed: int = cfg.get("random_seed", 42) + shuffle: bool = cfg.get("shuffle", True) + epochs: int = cfg["epochs"] + max_updates: float = cfg.get("updates", np.inf) + batch_size: int = cfg["batch_size"] + batch_type: str = cfg.get("batch_type", "sentence") + if batch_type not in ["sentence", "token"]: + raise ConfigurationError( + "Invalid `batch_type` option. Valid options: {`sentence`, `token`}.") + batch_multiplier: int = cfg.get("batch_multiplier", 1) + + # resume training process + reset_best_ckpt = cfg.get("reset_best_ckpt", False) + reset_scheduler = cfg.get("reset_scheduler", False) + reset_optimizer = cfg.get("reset_optimizer", False) + reset_iter_state = cfg.get("reset_iter_state", False) + + return ( + model_dir, + load_model, + load_encoder, + load_decoder, + loss_type, + ctc_weight, + label_smoothing, + normalization, + learning_rate_min, + keep_best_ckpts, + logging_freq, + validation_freq, + log_valid_sents, + early_stopping_metric, + seed, + shuffle, + epochs, + max_updates, + batch_size, + batch_type, + batch_multiplier, + device, + n_gpu, + num_workers, + fp16, + reset_best_ckpt, + reset_scheduler, + reset_optimizer, + reset_iter_state, + ) + + +def parse_test_args(cfg: Dict) -> Tuple: + """Parse test args""" + logger = logging.getLogger(__name__) + + # batch options + batch_size: int = cfg.get("batch_size", 64) + batch_type: str = cfg.get("batch_type", "sentences") + if batch_type not in ["sentence", "token"]: + raise ConfigurationError( + "Invalid `batch_type` option. Valid options: {`sentence`, `token`}.") + if batch_size > 1000 and batch_type == "sentence": + logger.warning( + "WARNING: Are you sure you meant to work on huge batches like this? " + "`batch_size` is > 1000 for sentence-batching. Consider decreasing it " + "or switching to `batch_type: 'token'`.") + + # limit on generation length + max_output_length: int = cfg.get("max_output_length", -1) + min_output_length: int = cfg.get("min_output_length", 1) + + # eval metrics + if "eval_metrics" in cfg: + eval_metrics = [s.strip().lower() for s in cfg["eval_metrics"]] + elif "eval_metric" in cfg: + eval_metrics = [cfg["eval_metric"].strip().lower()] + logger.warning( + "`eval_metric` option is obsolete. Please use `eval_metrics`, instead.") + else: + eval_metrics = [] + for eval_metric in eval_metrics: + if eval_metric not in [ + "bleu", "chrf", "token_accuracy", "sequence_accuracy", "wer" + ]: + raise ConfigurationError( + "Invalid setting for `eval_metrics`. Valid options: {'bleu', 'chrf', " + "'token_accuracy', 'sequence_accuracy', 'wer'}.") + + # sacrebleu cfg + sacrebleu_cfg: Dict = cfg.get("sacrebleu_cfg", {}) + if "sacrebleu" in cfg: + sacrebleu_cfg: Dict = cfg["sacrebleu"] + logger.warning( + "`sacrebleu` option is obsolete. Please use `sacrebleu_cfg`, instead.") + + # beam search options + n_best: int = cfg.get("n_best", 1) + beam_size: int = cfg.get("beam_size", 1) + beam_alpha: float = cfg.get("beam_alpha", -1) + if "alpha" in cfg: + beam_alpha = cfg["alpha"] + logger.warning("`alpha` option is obsolete. Please use `beam_alpha`, instead.") + assert beam_size > 0, "Beam size must be > 0." + assert n_best > 0, "N-best size must be > 0." + assert n_best <= beam_size, "`n_best` must be smaller than or equal to `beam_size`." + + # control options + return_attention: bool = cfg.get("return_attention", False) + return_prob: str = cfg.get("return_prob", "none") + if return_prob not in ["hyp", "ref", "none"]: + raise ConfigurationError( + "Invalid `return_prob` option. Valid options: {`hyp`, `ref`, `none`}.") + generate_unk: bool = cfg.get("generate_unk", True) + repetition_penalty: float = cfg.get("repetition_penalty", -1) + if 0 < repetition_penalty < 1: + raise ConfigurationError( + "Repetition penalty must be > 1. (-1 indicates no repetition penalty.)") + no_repeat_ngram_size: int = cfg.get("no_repeat_ngram_size", -1) + + return ( + batch_size, + batch_type, + max_output_length, + min_output_length, + eval_metrics, + sacrebleu_cfg, + beam_size, + beam_alpha, + n_best, + return_attention, + return_prob, + generate_unk, + repetition_penalty, + no_repeat_ngram_size, + ) + + +def store_attention_plots( + attentions: np.ndarray, + targets: List[List[str]], + sources: List[List[str]], + output_prefix: str, + indices: List[int], + tb_writer: Optional[SummaryWriter] = None, + steps: int = 0, +) -> None: + """ + Saves attention plots. + + :param attentions: attention scores + :param targets: list of tokenized targets + :param sources: list of tokenized sources + :param output_prefix: prefix for attention plots + :param indices: indices selected for plotting + :param tb_writer: Tensorboard summary writer (optional) + :param steps: current training steps, needed for tb_writer + :param dpi: resolution for images + """ + for i in indices: + if i >= len(sources): + continue + plot_file = f"{output_prefix}.{i}.png" + src = sources[i] + trg = targets[i] + attention_scores = attentions[i].T + try: + fig = plot_heatmap( + scores=attention_scores, + column_labels=trg, + row_labels=src, + output_path=plot_file, + dpi=100, + ) + if tb_writer is not None: + # lower resolution for tensorboard + fig = plot_heatmap( + scores=attention_scores, + column_labels=trg, + row_labels=src, + output_path=None, + dpi=50, + ) + tb_writer.add_figure(f"attention/{i}.", fig, global_step=steps) + except Exception: # pylint: disable=broad-except + print(f"Couldn't plot example {i}: " + f"src len {len(src)}, trg len {len(trg)}, " + f"attention scores shape {attention_scores.shape}") + continue + + +def get_latest_checkpoint(ckpt_dir: Path) -> Optional[Path]: + """ + Returns the latest checkpoint (by creation time, not the steps number!) + from the given directory. + If there is no checkpoint in this directory, returns None + + :param ckpt_dir: + :return: latest checkpoint file + """ + list_of_files = ckpt_dir.glob("*.ckpt") + latest_checkpoint = None + if list_of_files: + latest_checkpoint = max(list_of_files, key=lambda f: f.stat().st_ctime) + + # check existence + if latest_checkpoint is None: + raise FileNotFoundError(f"No checkpoint found in directory {ckpt_dir}.") + return latest_checkpoint + + +def load_checkpoint(path: Path, device: torch.device) -> Dict: + """ + Load model from saved checkpoint. + + :param path: path to checkpoint + :param device: cuda device name or cpu + :return: checkpoint (dict) + """ + logger = logging.getLogger(__name__) + assert path.is_file(), f"Checkpoint {path} not found." + checkpoint = torch.load(path, map_location=device) + logger.info("Load model from %s.", path.resolve()) + return checkpoint + + +def resolve_ckpt_path(load_model: Path, model_dir: Path) -> Path: + """ + Resolve checkpoint path + + :param load_model: config entry (cfg['training']['load_model']) or CLI arg (--ckpt) + :param model_dir: Path(cfg['training']['model_dir']) + :return: resolved checkpoint path + """ + if load_model is None: + if (model_dir / "best.ckpt").is_file(): + load_model = model_dir / "best.ckpt" + else: + load_model = get_latest_checkpoint(model_dir) + assert load_model.is_file(), load_model + return load_model + + +def tile(x: Tensor, count: int, dim=0) -> Tensor: + """ + Tiles x on dimension dim count times. From OpenNMT. Used for beam search. + + :param x: tensor to tile + :param count: number of tiles + :param dim: dimension along which the tensor is tiled + :return: tiled tensor + """ + if isinstance(x, tuple): + h, c = x + return tile(h, count, dim=dim), tile(c, count, dim=dim) + + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + # yapf: disable + x = (x.view(batch, -1) + .transpose(0, 1) + .repeat(count, 1) + .transpose(0, 1) + .contiguous() + .view(*out_size)) + if dim != 0: + x = x.permute(perm).contiguous() + return x + + +def freeze_params(module: nn.Module) -> None: + """ + Freeze the parameters of this module, + i.e. do not update them during training + + :param module: freeze parameters of this module + """ + for _, p in module.named_parameters(): + p.requires_grad = False + + +def delete_ckpt(to_delete: Path) -> None: + """ + Delete checkpoint + + :param to_delete: checkpoint file to be deleted + """ + logger = logging.getLogger(__name__) + try: + logger.info("delete %s", to_delete.as_posix()) + to_delete.unlink() + + except FileNotFoundError as e: + logger.warning( + "Wanted to delete old checkpoint %s but " + "file does not exist. (%s)", + to_delete, + e, + ) + + +def symlink_update(target: Path, link_name: Path) -> Optional[Path]: + """ + This function finds the file that the symlink currently points to, sets it + to the new target, and returns the previous target if it exists. + + :param target: A path to a file that we want the symlink to point to. + no parent dir, filename only, i.e. "10000.ckpt" + :param link_name: This is the name of the symlink that we want to update. + link name with parent dir, i.e. "models/my_model/best.ckpt" + + :return: + - current_last: This is the previous target of the symlink, before it is + updated in this function. If the symlink did not exist before or did + not have a target, None is returned instead. + """ + if link_name.is_symlink(): + current_last = link_name.resolve() + link_name.unlink() + link_name.symlink_to(target) + return current_last + link_name.symlink_to(target) + return None + + +def flatten(array: List[List[Any]]) -> List[Any]: + """ + Flatten a nested 2D list. faster even with a very long array than + [item for subarray in array for item in subarray] or newarray.extend(). + + :param array: a nested list + :return: flattened list + """ + return functools.reduce(operator.iconcat, array, []) + + +def expand_reverse_index(reverse_index: List[int], n_best: int = 1) -> List[int]: + """ + Expand resort_reverse_index for n_best prediction + + ex. 1) reverse_index = [1, 0, 2] and n_best = 2, then this will return + [2, 3, 0, 1, 4, 5]. + + ex. 2) reverse_index = [1, 0, 2] and n_best = 3, then this will return + [3, 4, 5, 0, 1, 2, 6, 7, 8] + + :param reverse_index: reverse_index returned from batch.sort_by_src_length() + :param n_best: + :return: expanded sort_reverse_index + """ + if n_best == 1: + return reverse_index + + resort_reverse_index = [] + for ix in reverse_index: + for n in range(0, n_best): + resort_reverse_index.append(ix * n_best + n) + assert len(resort_reverse_index) == len(reverse_index) * n_best + return resort_reverse_index + + +def remove_extra_spaces(s: str) -> str: + """ + Remove extra spaces + - used in pre_process() / post_process() in tokenizer.py + + :param s: input string + :return: string w/o extra white spaces + """ + s = re.sub("\u200b", "", s) + s = re.sub("[  ]+", " ", s) + + s = s.replace(" ?", "?") + s = s.replace(" !", "!") + s = s.replace(" ,", ",") + s = s.replace(" .", ".") + s = s.replace(" :", ":") + return s.strip() + + +def unicode_normalize(s: str) -> str: + """ + apply unicodedata NFKC normalization + - used in pre_process() in tokenizer.py + + cf.) https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_ter.py + + :param s: input string + :return: normalized string + """ # noqa: E501 + s = unicodedata.normalize("NFKC", s) + s = s.replace("’", "'") + s = s.replace("“", '"') + s = s.replace("”", '"') + return s + + +def remove_punctuation(s: str, space: chr): + """ + Remove punctuation based on Unicode category. + Taken from https://github.com/pytorch/fairseq/blob/main/fairseq/scoring/tokenizer.py + cf.) https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_ter.py + + :param s: input string + :param space: charactor for white space (delimiter special char) + :return: string without punctuation + """ # noqa: E501 + return space.join(t for t in s.split(space) + if not all(unicodedata.category(c)[0] == "P" for c in t)) + + +def lengths_to_padding_mask(lengths: Tensor) -> Tensor: + """ + get padding mask according to the given lengths + + :param lengths: length list in shape (batch_size, 1) + :return: mask + """ + bsz, max_lengths = lengths.size(0), torch.max(lengths).item() + mask = torch.arange(max_lengths).to(lengths.device).view(1, max_lengths) + mask = mask.expand(bsz, -1) >= lengths.view(bsz, 1).expand(-1, max_lengths) + return ~mask + + +def pad(x: Tensor, max_len: int, pad_index: int = PAD_ID, dim: int = 1) -> Tensor: + """ + pad tensor + + :param x: tensor in shape (batch_size, seq_len, *) + :param max_len: max_length + :param pad_index: index of the pad token + :param dim: dimension to pad + :return: padded tensor + """ + if pad_index is None: + pad_index = PAD_ID + + if dim == 1: + _, seq_len, _ = x.size() + offset = max_len - seq_len + new_x = _pad(x, (0, 0, 0, offset, 0, 0), "constant", pad_index) \ + if x.size(dim) < max_len else x + elif dim == -1: + _, _, seq_len = x.size() + offset = max_len - seq_len + new_x = _pad(x, (0, offset), "constant", pad_index) \ + if x.size(dim) < max_len else x + assert new_x.size(dim) == max_len, (x.size(), offset, new_x.size(), max_len) + return new_x diff --git a/joeynmt/helpers_for_pose.py b/joeynmt/helpers_for_pose.py new file mode 100644 index 0000000..be3f7ba --- /dev/null +++ b/joeynmt/helpers_for_pose.py @@ -0,0 +1,165 @@ +# coding: utf-8 +""" +Collection of helper functions for audio processing +""" +import io +import logging +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torchaudio +import torchaudio.compliance.kaldi as ta_kaldi +import torchaudio.sox_effects as ta_sox + +from joeynmt.constants import PAD_ID + +logger = logging.getLogger(__name__) + + +# from fairseq +def _convert_to_mono(waveform: torch.FloatTensor, sample_rate: int) \ + -> torch.FloatTensor: + if waveform.shape[0] > 1: + effects = [["channels", "1"]] + return ta_sox.apply_effects_tensor(waveform, sample_rate, effects)[0] + return waveform + + +# from fairseq +def _get_torchaudio_fbank(waveform: torch.FloatTensor, + sample_rate: int, + n_bins: int = 80) -> np.ndarray: + """Get mel-filter bank features via TorchAudio.""" + features = ta_kaldi.fbank(waveform, + num_mel_bins=n_bins, + sample_frequency=sample_rate) + return features.numpy() + + +# from fairseq +def extract_fbank_features(waveform: torch.FloatTensor, + sample_rate: int, + output_path: Optional[Path] = None, + n_mel_bins: int = 80, + overwrite: bool = False) -> Optional[np.ndarray]: + # pylint: disable=inconsistent-return-statements + + if output_path is not None and output_path.is_file() and not overwrite: + return np.load(output_path.as_posix()) + + _waveform = _convert_to_mono(waveform, sample_rate) + _waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers + + try: + features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins) + except Exception as e: + raise ValueError(f"torchaudio faild to extract mel filterbank features " + f"at: {output_path.stem}. {e}") from e + + if output_path is not None: + np.save(output_path.as_posix(), features) + assert output_path.is_file(), output_path + + return features + + +# from fairseq +def _is_npy_data(data: bytes) -> bool: + return data[0] == 147 and data[1] == 78 + + +# from fairseq +def _get_features_from_zip(path, byte_offset, byte_size): + with path.open("rb") as f: + f.seek(byte_offset) + data = f.read(byte_size) + byte_features = io.BytesIO(data) + if len(data) > 1 and _is_npy_data(data): + features = np.load(byte_features) + else: + raise ValueError(f'Unknown file format for ' + f'"{path}" [{byte_offset}:{byte_size}]') + return features + + +# from fairseq +def get_n_frames(wave_length: int, sample_rate: int): + duration_ms = int(wave_length / sample_rate * 1000) + n_frames = int(1 + (duration_ms - 25) / 10) + return n_frames + + +# from fairseq +def get_features(root_path: Path, fbank_path: str) -> np.ndarray: + """Get speech features from ZIP file + accessed via byte offset and length + + :return: (np.ndarray) speech features in shape of (num_frames, num_freq) + """ + _path, *extra = fbank_path.split(":") + _path = root_path / _path + if not _path.is_file(): + raise FileNotFoundError(f"File not found: {_path}") + + if len(extra) == 0: + if _path.suffix == ".npy": + features = np.load(_path.as_posix()) + elif _path.suffix in [".mp3", ".wav"]: + waveform, sample_rate = torchaudio.load(_path.as_posix()) + features = extract_fbank_features(waveform, sample_rate) + else: + raise ValueError(f"Invalid file type: {_path}") + elif len(extra) == 2: + assert _path.suffix == ".zip" + extra = [int(i) for i in extra] + features = _get_features_from_zip(_path, extra[0], extra[1]) + else: + raise ValueError(f"Invalid path: {root_path / fbank_path}") + + assert len(features.shape) == 2, "spectrogram must be a 2-D array." + return features + + +def pad_features( + feat_list: List[np.ndarray], + embed_size: int = 80, + pad_index: int = PAD_ID, +) -> Tuple[np.ndarray, List[int]]: + """ + Pad continuous feature representation in batch. + called in batch construction (not in data loading) + + :param feat_list: list of features + :param embed_size: (int) number of frequencies + :param pad_index: pad index + :returns: + - features np.ndarray, (batch_size, src_len, embed_size) + - lengths List[int], (batch_size) + """ + max_len = max([int(f.shape[0]) for f in feat_list]) + batch_size = len(feat_list) + + # encoder input has shape of (batch_size, src_len, embed_size) + # (see encoder.forward()) + features = np.zeros((batch_size, max_len, embed_size), dtype=np.float32) + features.fill(float(pad_index)) + lengths = [] + + for i, f in enumerate(feat_list): + length = min(int(f.shape[0]), max_len) + assert length > 0, "empty feature!" + features[i, :length, :] = f[:length, :] + lengths.append(length) + + m = max(lengths) + if m < features.shape[1]: + features = features[:, :m, :] + + # validation + assert max(lengths) == features.shape[1] + assert embed_size == features.shape[2] + assert sum(lengths) > 0 + + return features, lengths diff --git a/joeynmt/hub_interface.py b/joeynmt/hub_interface.py new file mode 100644 index 0000000..7f941ae --- /dev/null +++ b/joeynmt/hub_interface.py @@ -0,0 +1,301 @@ +# coding: utf-8 +""" +Torch Hub Interface +""" +import logging +from functools import partial +from pathlib import Path +from typing import Dict, List, NamedTuple, Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +from torch import nn + +from joeynmt.constants import EOS_TOKEN +from joeynmt.datasets import BaseDataset, build_dataset +from joeynmt.helpers import ( + load_checkpoint, + load_config, + parse_train_args, + resolve_ckpt_path, +) +from joeynmt.helpers_for_pose import pad_features +from joeynmt.model import Model, build_model +from joeynmt.plotting import plot_heatmap +from joeynmt.prediction import predict +from joeynmt.tokenizers import build_tokenizer +from joeynmt.vocabulary import build_vocab + +logger = logging.getLogger(__name__) + +PredictionOutput = NamedTuple( + "PredictionOutput", + [ + ("translation", List[str]), + ("tokens", Optional[List[List[str]]]), + ("token_probs", Optional[List[List[float]]]), + ("sequence_probs", Optional[List[float]]), + ("attention_probs", Optional[List[List[float]]]), + ], +) + + +def _check_file_path(path: Union[str, Path], model_dir: Path) -> Path: + """Check torch hub cache path""" + if path is None: + return None + p = Path(path) if isinstance(path, str) else path + if not p.is_file(): + p = model_dir / p.name + assert p.is_file(), p + return p + + +def _from_pretrained( + model_name_or_path: Union[str, Path], + ckpt_file: Union[str, Path] = None, + cfg_file: Union[str, Path] = "config.yaml", + **kwargs, +): + """Prepare model and data placeholder""" + # model dir + model_dir = Path(model_name_or_path) if isinstance(model_name_or_path, + str) else model_name_or_path + assert model_dir.is_dir(), model_dir + + # cfg file + cfg_file = _check_file_path(cfg_file, model_dir) + assert cfg_file.is_file(), cfg_file + cfg = load_config(cfg_file) + cfg.update(kwargs) + + task = cfg["data"].get("task", "MT") + assert task in ["MT", "S2T"], "`task` must be either `MT` or `S2T`." + + # rewrite paths in cfg + for side in ["src", "trg"]: + if task == "S2T" and side == "src": + assert cfg["data"]["dataset_type"] == "speech" + assert cfg["data"][side]["tokenizer_type"] == "speech" + else: + data_side = cfg["data"][side] + data_side["voc_file"] = _check_file_path(data_side["voc_file"], + model_dir).as_posix() + if "tokenizer_cfg" in data_side: + for tok_model in ["codes", "model_file"]: + if tok_model in data_side["tokenizer_cfg"]: + data_side["tokenizer_cfg"][tok_model] = _check_file_path( + data_side["tokenizer_cfg"][tok_model], + model_dir).as_posix() + + if "load_model" in cfg["training"]: + cfg["training"]["load_model"] = _check_file_path(cfg["training"]["load_model"], + model_dir).as_posix() + if not Path(cfg["training"]["model_dir"]).is_dir(): + cfg["training"]["model_dir"] = model_dir.as_posix() + + # parse and validate cfg + (_, load_model_path, device, n_gpu, num_workers, normalization, + fp16) = parse_train_args(cfg["training"], mode="prediction") + + # read vocabs + src_vocab, trg_vocab = build_vocab(cfg["data"], model_dir=model_dir) + + # build model + model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) + + # load model state from disk + logger.info("Preparing a joeynmt model...") + ckpt_file = _check_file_path(ckpt_file, model_dir) + load_model_path = load_model_path if ckpt_file is None else ckpt_file + ckpt = resolve_ckpt_path(load_model_path, model_dir) + model_checkpoint = load_checkpoint(ckpt, device=device) + model.load_state_dict(model_checkpoint["model_state"]) + + # create stream dataset + src_lang = cfg["data"]["src"]["lang"] + trg_lang = cfg["data"]["trg"]["lang"] + tokenizer = build_tokenizer(cfg["data"]) + if task == "MT": + sequence_encoder = { + src_lang: partial(src_vocab.sentences_to_ids, bos=False, eos=True), + trg_lang: partial(trg_vocab.sentences_to_ids, bos=True, eos=True), + } + elif task == "S2T": + sequence_encoder = { + "src": partial(pad_features, embed_size=tokenizer["src"].num_freq), + "trg": partial(trg_vocab.sentences_to_ids, bos=True, eos=True), + } + test_data = build_dataset( + dataset_type="stream" if task == "MT" else "speech_stream", + path=None, + src_lang=src_lang if task == "MT" else "src", + trg_lang=trg_lang if task == "MT" else "trg", + split="test", + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + task=task, + ) + + config = { + "device": device, + "n_gpu": n_gpu, + "fp16": fp16, + "cfg": cfg, + "num_workers": num_workers, + "normalization": normalization, + } + return config, test_data, model + + +class TranslatorHubInterface(nn.Module): + """ + PyTorch Hub interface for generating sequences from a pre-trained + encoder-decoder model. + """ + + def __init__(self, config: Dict, dataset: BaseDataset, model: Model): + super().__init__() + self.cfg = config["cfg"] + self.device = config["device"] + self.n_gpu = config["n_gpu"] + self.fp16 = config["fp16"] + self.num_workers = config["num_workers"] + self.normalization = config["normalization"] + self.dataset = dataset + self.model = model + if self.device.type == "cuda": + self.model.to(self.device) + self.model.eval() + + def score(self, + src: List[str], + trg: Optional[List[str]] = None, + **kwargs) -> List[PredictionOutput]: + assert isinstance(src, list), "Please provide a list of sentences!" + assert len( + src + ) <= 64, "For big dataset, please use `test` function instead of `translate`!" + kwargs["return_prob"] = "hyp" if trg is None else "ref" + kwargs["return_attention"] = True + + if trg is not None and self.model.loss_function is None: + self.model.loss_function = ( # need to instantiate loss func + self.cfg["training"].get("loss", "crossentropy"), + self.cfg["training"].get("label_smoothing", 0.1), + ) + + scores, translations, tokens, probs, attention_probs, test_cfg = self._generate( + src, trg, **kwargs) + + beam_size = test_cfg.get("beam_size", 1) + n_best = test_cfg.get("n_best", 1) + + out = [] + for i in range(len(src)): + offset = i * n_best + out.append( + PredictionOutput( + translation=trg[i] if trg else translations[offset:offset + n_best], + tokens=tokens[offset:offset + n_best], + token_probs=[p.tolist() + for p in probs[offset:offset + + n_best]] if beam_size == 1 else None, + sequence_probs=[p[0] + for p in probs[offset:offset + + n_best]] if beam_size > 1 else None, + attention_probs=attention_probs[offset:offset + n_best] + if attention_probs else None, + )) + if trg: + out, scores # pylint:disable=pointless-statement + return out + + def generate(self, src: List[str], **kwargs) -> List[str]: + assert isinstance(src, list), "Please provide a list of sentences!" + assert len( + src + ) <= 64, "for big dataset, please use `test` function instead of `translate`!" + kwargs["return_prob"] = kwargs.get("return_prob", "none") + + scores, translations, tokens, probs, _, _ = self._generate(src, **kwargs) + + if kwargs["return_prob"] != "none": + return scores, translations, tokens, probs + return translations + + def _generate(self, + src: List[str], + trg: Optional[List[str]] = None, + **kwargs) -> List[str]: + + # overwrite config + test_cfg = self.cfg['testing'].copy() + test_cfg.update(kwargs) + + assert self.dataset.__class__.__name__ in [ + "StreamDataset", "SpeechStreamDataset" + ], self.dataset + test_cfg["batch_type"] = "sentence" + test_cfg["batch_size"] = len(src) + + self.dataset.reset_cache() # reset cache + if trg is not None: + assert len(src) == len(trg), "src and trg must have the same length!" + self.dataset.has_trg = True + test_cfg["n_best"] = 1 + test_cfg["beam_size"] = 1 + test_cfg["return_prob"] = "ref" + for src_sent, trg_sent in zip(src, trg): + self.dataset.set_item(src_sent, trg_sent) + else: + self.dataset.has_trg = False + for sentence in src: + self.dataset.set_item(sentence) + + assert len(self.dataset) > 0 + + scores, _, translations, tokens, probs, attention_probs = predict( + model=self.model, + data=self.dataset, + compute_loss=(trg is not None), + device=self.device, + n_gpu=self.n_gpu, + normalization=self.normalization, + num_workers=self.num_workers, + cfg=test_cfg, + fp16=self.fp16, + ) + if translations: + assert len(src) * test_cfg.get("n_best", 1) == len(translations) + self.dataset.reset_cache() # reset cache + + return scores, translations, tokens, probs, attention_probs, test_cfg + + def plot_attention(self, src: str, trg: str, attention_scores: np.ndarray) -> None: + # preprocess and tokenize sentences + self.dataset.reset_cache() # reset cache + self.dataset.has_trg = True + self.dataset.set_item(src, trg) + src_tokens = self.dataset.get_item(idx=0, + lang=self.dataset.src_lang, + is_train=False) + trg_tokens = self.dataset.get_item(idx=0, + lang=self.dataset.trg_lang, + is_train=False) + self.dataset.reset_cache() # reset cache + + # plot attention scores + fig = plot_heatmap( + scores=np.array(attention_scores).T, + column_labels=trg_tokens + [EOS_TOKEN], + row_labels=src_tokens + [EOS_TOKEN], + output_path=None, + dpi=50, + ) + + dummy = plt.figure() + new_manager = dummy.canvas.manager + new_manager.canvas.figure = fig + fig.set_canvas(new_manager.canvas) + fig.show() diff --git a/joeynmt/initialization.py b/joeynmt/initialization.py new file mode 100644 index 0000000..ff46d6a --- /dev/null +++ b/joeynmt/initialization.py @@ -0,0 +1,228 @@ +# coding: utf-8 +""" +Implements custom initialization +""" +import logging +import math +from typing import Dict + +import torch +from torch import Tensor, nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from joeynmt.embeddings import Embeddings +from joeynmt.helpers import ConfigurationError + +logger = logging.getLogger(__name__) + + +def orthogonal_rnn_init_(cell: nn.RNNBase, gain: float = 1.0) -> None: + """ + Orthogonal initialization of recurrent weights + RNN parameters contain 3 or 4 matrices in one parameter, so we slice it. + """ + with torch.no_grad(): + for _, hh, _, _ in cell.all_weights: + for i in range(0, hh.size(0), cell.hidden_size): + nn.init.orthogonal_(hh.data[i:i + cell.hidden_size], gain=gain) + + +def lstm_forget_gate_init_(cell: nn.RNNBase, value: float = 1.0) -> None: + """ + Initialize LSTM forget gates with `value`. + + :param cell: LSTM cell + :param value: initial value, default: 1 + """ + with torch.no_grad(): + for _, _, ih_b, hh_b in cell.all_weights: + length = len(ih_b) + ih_b.data[length // 4:length // 2].fill_(value) + hh_b.data[length // 4:length // 2].fill_(value) + + +def xavier_uniform_n_(w: Tensor, gain: float = 1.0, n: int = 4) -> None: + """ + Xavier initializer for parameters that combine multiple matrices in one + parameter for efficiency. This is e.g. used for GRU and LSTM parameters, + where e.g. all gates are computed at the same time by 1 big matrix. + + :param w: parameter + :param gain: default 1 + :param n: default 4 + """ + with torch.no_grad(): + fan_in, fan_out = _calculate_fan_in_and_fan_out(w) + assert fan_out % n == 0, "fan_out should be divisible by n" + fan_out //= n + std = gain * math.sqrt(2.0 / (fan_in + fan_out)) + a = math.sqrt(3.0) * std + nn.init.uniform_(w, -a, a) + + +def compute_alpha_beta(num_enc_layers: int, num_dec_layers: int) -> Dict[str, Dict]: + """ + DeepNet: compute alpha/beta value suggested in https://arxiv.org/abs/2203.00555 + """ + return { + "alpha": { + "encoder": 0.81 * (num_enc_layers**4 * num_dec_layers)**(1 / 16), + "decoder": (3 * num_dec_layers)**(1 / 4), + }, + "beta": { + "encoder": 0.87 * (num_enc_layers**4 * num_dec_layers)**(-1 / 16), + "decoder": (12 * num_dec_layers)**(-1 / 4), + }, + } + + +def initialize_model(model: nn.Module, cfg: dict, src_padding_idx: int, + trg_padding_idx: int) -> None: + """ + This initializes a model based on the provided config. + + All initializer configuration is part of the `model` section of the configuration + file. For an example, see e.g. `https://github.com/joeynmt/joeynmt/blob/main/ + configs/iwslt14_ende_spm.yaml`. + + The main initializer is set using the `initializer` key. Possible values are + `xavier_uniform`, `uniform`, `normal` or `zeros`. (`xavier_uniform` is the default). + + When an initializer is set to `uniform`, then `init_weight` sets the range for + the values (-init_weight, init_weight). + + When an initializer is set to `normal`, then `init_weight` sets the standard + deviation for the weights (with mean 0). + + The word embedding initializer is set using `embed_initializer` and takes the same + values. The default is `normal` with `embed_init_weight = 0.01`. + + Biases are initialized separately using `bias_initializer`. The default is `zeros`, + but you can use the same initializers as the main initializer. + + Set `init_rnn_orthogonal` to True if you want RNN orthogonal initialization (for + recurrent matrices). Default is False. + + `lstm_forget_gate` controls how the LSTM forget gate is initialized. Default is `1`. + + :param model: model to initialize + :param cfg: the model configuration + :param src_padding_idx: index of source padding token + :param trg_padding_idx: index of target padding token + """ + # pylint: disable=too-many-branches,too-many-statements + # defaults: xavier gain 1.0, embeddings: normal 0.01, biases: zeros, no orthogonal + gain = float(cfg.get("init_gain", 1.0)) # for xavier + init = cfg.get("initializer", "xavier_uniform") + if init == "xavier": + init = "xavier_uniform" + logger.warning( + "`xavier` option is obsolete. Please use `xavier_uniform`, instead.") + init_weight = float(cfg.get("init_weight", 0.01)) + + embed_init = cfg.get("embed_initializer", "xavier_uniform") + if embed_init == "xavier": + embed_init = "xavier_uniform" + logger.warning( + "`xavier` option is obsolete. Please use `xavier_uniform`, instead.") + embed_init_weight = float(cfg.get("embed_init_weight", 0.01)) + embed_gain = float(cfg.get("embed_init_gain", 1.0)) # for xavier + + bias_init = cfg.get("bias_initializer", "zeros") + bias_init_weight = float(cfg.get("bias_init_weight", 0.01)) + + deepnet = {} + if (init == "xavier_normal" + and cfg["encoder"]["type"] == cfg["decoder"]["type"] == "transformer"): + # apply `alpha`: weight factor for residual connection + deepnet["xavier_normal"] = compute_alpha_beta(cfg["encoder"]["num_layers"], + cfg["decoder"]["num_layers"]) + + for layer in model.encoder.layers: + layer.alpha = deepnet["xavier_normal"]["alpha"]["encoder"] + layer.feed_forward.alpha = deepnet["xavier_normal"]["alpha"]["encoder"] + for layer in model.decoder.layers: + layer.alpha = deepnet["xavier_normal"]["alpha"]["decoder"] + layer.feed_forward.alpha = deepnet["xavier_normal"]["alpha"]["decoder"] + + def _parse_init(s: str, scale: float, _gain: float): + # pylint: disable=no-else-return,unnecessary-lambda + scale = float(scale) + assert scale > 0.0, "incorrect init_weight" + if s.lower() == "xavier_uniform": + return lambda p: nn.init.xavier_uniform_(p, gain=_gain) + elif s.lower() == "xavier_normal": + return lambda p: nn.init.xavier_normal_(p, gain=_gain) + elif s.lower() == "uniform": + return lambda p: nn.init.uniform_(p, a=-scale, b=scale) + elif s.lower() == "normal": + return lambda p: nn.init.normal_(p, mean=0.0, std=scale) + elif s.lower() == "zeros": + return lambda p: nn.init.zeros_(p) + else: + raise ConfigurationError("Unknown initializer.") + + init_fn_ = _parse_init(init, init_weight, gain) + embed_init_fn_ = _parse_init(embed_init, embed_init_weight, embed_gain) + bias_init_fn_ = _parse_init(bias_init, bias_init_weight, gain) + + with torch.no_grad(): + for name, p in model.named_parameters(): + + if "embed" in name: + embed_init_fn_(p) + + elif "bias" in name: + bias_init_fn_(p) + + elif len(p.size()) > 1: + + # RNNs combine multiple matrices is one, which messes up + # xavier initialization + if init == "xavier_uniform" and "rnn" in name: + n = 1 + if "encoder" in name: + n = 4 if isinstance(model.encoder.rnn, nn.LSTM) else 3 + elif "decoder" in name: + n = 4 if isinstance(model.decoder.rnn, nn.LSTM) else 3 + xavier_uniform_n_(p.data, gain=gain, n=n) + + elif init == "xavier_normal" and init in deepnet: + # use beta value suggested in https://arxiv.org/abs/2203.00555 + beta = 1.0 + if ("pwff_layer" in name or "v_layer" in name + or "output_layer" in name): + if "encoder" in name: + beta = deepnet[init]["beta"]["encoder"] + elif "decoder" in name: + beta = deepnet[init]["beta"]["decoder"] + nn.init.xavier_normal_(p, gain=beta) + + else: + init_fn_(p) + + # zero out paddings + if isinstance(model.src_embed, Embeddings): + model.src_embed.lut.weight.data[src_padding_idx].zero_() + model.trg_embed.lut.weight.data[trg_padding_idx].zero_() + + orthogonal = cfg.get("init_rnn_orthogonal", False) + lstm_forget_gate = cfg.get("lstm_forget_gate", 1.0) + + # encoder rnn orthogonal initialization & LSTM forget gate + if hasattr(model.encoder, "rnn"): + + if orthogonal: + orthogonal_rnn_init_(model.encoder.rnn) + + if isinstance(model.encoder.rnn, nn.LSTM): + lstm_forget_gate_init_(model.encoder.rnn, lstm_forget_gate) + + # decoder rnn orthogonal initialization & LSTM forget gate + if hasattr(model.decoder, "rnn"): + + if orthogonal: + orthogonal_rnn_init_(model.decoder.rnn) + + if isinstance(model.decoder.rnn, nn.LSTM): + lstm_forget_gate_init_(model.decoder.rnn, lstm_forget_gate) diff --git a/joeynmt/loss.py b/joeynmt/loss.py new file mode 100644 index 0000000..69d7202 --- /dev/null +++ b/joeynmt/loss.py @@ -0,0 +1,166 @@ +# coding: utf-8 +""" +Loss functions +""" +import logging +from typing import Tuple + +import torch +from torch import Tensor, nn +from torch.autograd import Variable +from torch.nn.modules.loss import _Loss + +logger = logging.getLogger(__name__) + + +class XentLoss(nn.Module): + """ + Cross-Entropy Loss with optional label smoothing + """ + + def __init__(self, pad_index: int, smoothing: float = 0.0): + super().__init__() + self.smoothing = smoothing + self.pad_index = pad_index + self.criterion: _Loss # (type annotation) + if self.smoothing <= 0.0: + # standard xent loss + self.criterion = nn.NLLLoss(ignore_index=self.pad_index, reduction="sum") + else: + # custom label-smoothed loss, computed with KL divergence loss + self.criterion = nn.KLDivLoss(reduction="sum") + + self.require_ctc_layer = False + + def _smooth_targets(self, targets: Tensor, vocab_size: int) -> Variable: + """ + Smooth target distribution. All non-reference words get uniform + probability mass according to "smoothing". + + :param targets: target indices, batch*seq_len + :param vocab_size: size of the output vocabulary + :return: smoothed target distributions, batch*seq_len x vocab_size + """ + # batch*seq_len x vocab_size + smooth_dist = targets.new_zeros((targets.size(0), vocab_size)).float() + # fill distribution uniformly with smoothing + smooth_dist.fill_(self.smoothing / (vocab_size - 2)) + # assign true label the probability of 1-smoothing ("confidence") + smooth_dist.scatter_(1, targets.unsqueeze(1).data, 1.0 - self.smoothing) + # give padding probability of 0 everywhere + smooth_dist[:, self.pad_index] = 0 + # masking out padding area (sum of probabilities for padding area = 0) + padding_positions = torch.nonzero(targets.data == self.pad_index, + as_tuple=False) + if len(padding_positions) > 0: + smooth_dist.index_fill_(0, padding_positions.squeeze(), 0.0) + return Variable(smooth_dist, requires_grad=False) + + def _reshape(self, log_probs: Tensor, targets: Tensor) -> Tensor: + vocab_size = log_probs.size(-1) + + # reshape log_probs to (batch*seq_len x vocab_size) + log_probs_flat = log_probs.contiguous().view(-1, vocab_size) + + if self.smoothing > 0: + targets_flat = self._smooth_targets(targets=targets.contiguous().view(-1), + vocab_size=vocab_size) + # targets: distributions with batch*seq_len x vocab_size + assert log_probs_flat.size() == targets_flat.size(), ( + log_probs.size(), + targets_flat.size(), + ) + else: + # targets: indices with batch*seq_len + targets_flat = targets.contiguous().view(-1) + assert log_probs_flat.size(0) == targets_flat.size(0), ( + log_probs.size(0), + targets_flat.size(0), + ) + + return log_probs_flat, targets_flat + + def forward(self, log_probs: Tensor, **kwargs) -> Tuple[Tensor]: + """ + Compute the cross-entropy between logits and targets. + + If label smoothing is used, target distributions are not one-hot, but + "1-smoothing" for the correct target token and the rest of the + probability mass is uniformly spread across the other tokens. + + :param log_probs: log probabilities as predicted by model + :return: logits + """ + assert "trg" in kwargs + log_probs, targets = self._reshape(log_probs, kwargs["trg"]) + + # compute loss + logits = self.criterion(log_probs, targets) + return (logits, ) + + def __repr__(self): + return (f"{self.__class__.__name__}(criterion={self.criterion}, " + f"smoothing={self.smoothing})") + + +class XentCTCLoss(XentLoss): + """ + Cross-Entropy + CTC loss with optional label smoothing + """ + + def __init__(self, + pad_index: int, + bos_index: int, + smoothing: float = 0.0, + zero_infinity: bool = True, + ctc_weight: float = 0.3): + super().__init__(pad_index=pad_index, smoothing=smoothing) + + self.require_ctc_layer = True + self.bos_index = bos_index + self.ctc_weight = ctc_weight + self.ctc = nn.CTCLoss(blank=bos_index, + reduction='sum', + zero_infinity=zero_infinity) + + def forward(self, log_probs, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compute the cross-entropy loss and ctc loss + + :param log_probs: log probabilities as predicted by model + shape (batch_size, seq_length, vocab_size) + :return: + - total loss + - xent loss + - ctc loss + """ + assert "trg" in kwargs + assert "trg_length" in kwargs + assert "src_mask" in kwargs + assert "ctc_log_probs" in kwargs + + # reshape tensors for cross_entropy + log_probs_flat, targets_flat = self._reshape(log_probs, kwargs["trg"]) + + # cross_entropy loss + xent_loss = self.criterion(log_probs_flat, targets_flat) + + # ctc_loss + # reshape ctc_log_probs to (seq_length, batch_size, vocab_size) + ctc_loss = self.ctc( + kwargs["ctc_log_probs"].transpose(0, 1).contiguous(), + targets=kwargs["trg"], # (seq_length, batch_size) + input_lengths=kwargs["src_mask"].squeeze(1).sum(dim=1), + target_lengths=kwargs["trg_length"]) + + # interpolation + total_loss = (1.0 - self.ctc_weight) * xent_loss + self.ctc_weight * ctc_loss + + assert not total_loss.isnan(), "loss has to be non-NaN value." + assert total_loss.item() >= 0.0, "loss has to be non-negative." + return total_loss, xent_loss, ctc_loss + + def __repr__(self): + return (f"{self.__class__.__name__}(" + f"criterion={self.criterion}, smoothing={self.smoothing}, " + f"ctc={self.ctc}, ctc_weight={self.ctc_weight})") diff --git a/joeynmt/metrics.py b/joeynmt/metrics.py new file mode 100644 index 0000000..0d860d4 --- /dev/null +++ b/joeynmt/metrics.py @@ -0,0 +1,125 @@ +# coding: utf-8 +""" +Evaluation metrics +""" +import logging +from inspect import getfullargspec +from typing import List + +import editdistance +from sacrebleu.metrics import BLEU, CHRF + +logger = logging.getLogger(__name__) + + +def chrf(hypotheses: List[str], references: List[str], **sacrebleu_cfg) -> float: + """ + Character F-score from sacrebleu + cf. https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py + + :param hypotheses: list of hypotheses (strings) + :param references: list of references (strings) + :return: character f-score (0 <= chf <= 1) + see Breaking Change in sacrebleu v2.0 + """ + kwargs = {} + if sacrebleu_cfg: + valid_keys = getfullargspec(CHRF).args + for k, v in sacrebleu_cfg.items(): + if k in valid_keys: + kwargs[k] = v + + metric = CHRF(**kwargs) + score = metric.corpus_score(hypotheses=hypotheses, references=[references]).score + + # log sacrebleu signature + logger.info(metric.get_signature()) + return score / 100 + + +def bleu(hypotheses: List[str], references: List[str], **sacrebleu_cfg) -> float: + """ + Raw corpus BLEU from sacrebleu (without tokenization) + cf. https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/bleu.py + + :param hypotheses: list of hypotheses (strings) + :param references: list of references (strings) + :return: bleu score + """ + kwargs = {} + if sacrebleu_cfg: + valid_keys = getfullargspec(BLEU).args + for k, v in sacrebleu_cfg.items(): + if k in valid_keys: + kwargs[k] = v + + metric = BLEU(**kwargs) + score = metric.corpus_score(hypotheses=hypotheses, references=[references]).score + + # log sacrebleu signature + logger.info(metric.get_signature()) + return score + + +def token_accuracy(hypotheses: List[List[str]], references: List[List[str]]) -> float: + """ + Compute the accuracy of hypothesis tokens: correct tokens / all tokens + Tokens are correct if they appear in the same position in the reference. + We lookup the references before one-hot-encoding, that is, UNK generation in + hypotheses is always evaluated as incorrect. + + :param hypotheses: list of tokenized hypotheses (List[List[str]]) + :param references: list of tokenized references (List[List[str]]) + :return: token accuracy (float) + """ + correct_tokens = 0 + all_tokens = 0 + assert len(hypotheses) == len(references) + for hyp, ref in zip(hypotheses, references): + all_tokens += len(hyp) + for h_i, r_i in zip(hyp, ref): + # min(len(h), len(r)) tokens considered + if h_i == r_i: + correct_tokens += 1 + return (correct_tokens / all_tokens) * 100 if all_tokens > 0 else 0.0 + + +def sequence_accuracy(hypotheses: List[str], references: List[str]) -> float: + """ + Compute the accuracy of hypothesis tokens: correct tokens / all tokens + Tokens are correct if they appear in the same position in the reference. + We lookup the references before one-hot-encoding, that is, hypotheses with UNK + are always evaluated as incorrect. + + :param hypotheses: list of hypotheses (strings) + :param references: list of references (strings) + :return: + """ + assert len(hypotheses) == len(references) + correct_sequences = sum( + [1 for (hyp, ref) in zip(hypotheses, references) if hyp == ref]) + return (correct_sequences / len(hypotheses)) * 100 if hypotheses else 0.0 + + +def wer(hypotheses, references, tokenizer): + """ + Compute word error rate in corpus-level + + :param hypotheses: list of hypotheses (strings) + :param references: list of references (strings) + :param tokenizer: tokenize function (callable) + :return: normalized word error rate + """ + numerator = 0.0 + denominator = 0.0 + # sentence-level wer + # for hyp, ref in zip(hypotheses, references): + # wer = editdistance.eval(tokenizer(hyp), + # tokenizer(ref)) / len(tokenizer(ref)) + # numerator += max(wer, 1.0) # can be `wer > 1` if `len(hyp) > len(ref)` + # denominator += 1.0 + # corpus-level wer + for hyp, ref in zip(hypotheses, references): + numerator += editdistance.eval(tokenizer(hyp), tokenizer(ref)) + denominator += len(tokenizer(ref)) + return (numerator / denominator) * 100 if denominator else 0.0 diff --git a/joeynmt/model.py b/joeynmt/model.py new file mode 100644 index 0000000..1fac86d --- /dev/null +++ b/joeynmt/model.py @@ -0,0 +1,432 @@ +# coding: utf-8 +""" +Module to represents whole models +""" +import logging +from pathlib import Path +from typing import Callable, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from joeynmt.decoders import Decoder, RecurrentDecoder, TransformerDecoder +from joeynmt.embeddings import Embeddings +from joeynmt.encoders import ( + ConformerEncoder, + Encoder, + RecurrentEncoder, + TransformerEncoder, +) +from joeynmt.helpers import ConfigurationError +from joeynmt.initialization import initialize_model +from joeynmt.loss import XentCTCLoss, XentLoss +from joeynmt.vocabulary import Vocabulary + +logger = logging.getLogger(__name__) + + +class Model(nn.Module): + """ + Base Model class + """ + + # pylint: disable=too-many-instance-attributes + + def __init__( + self, + encoder: Encoder, + decoder: Decoder, + src_embed: Embeddings, + trg_embed: Embeddings, + src_vocab: Vocabulary, + trg_vocab: Vocabulary, + task: str = "MT", + ) -> None: + """ + Create a new encoder-decoder model + + :param encoder: encoder + :param decoder: decoder + :param src_embed: source embedding + :param trg_embed: target embedding + :param src_vocab: source vocabulary + :param trg_vocab: target vocabulary + """ + super().__init__() + + self.src_embed = src_embed # nn.Identity() if task == "S2T" + self.trg_embed = trg_embed + self.encoder = encoder + self.decoder = decoder + self.src_vocab = src_vocab + self.trg_vocab = trg_vocab + self.pad_index = self.trg_vocab.pad_index + self.bos_index = self.trg_vocab.bos_index + self.eos_index = self.trg_vocab.eos_index + self.unk_index = self.trg_vocab.unk_index + self._loss_function: Callable = None # set by the TrainManager + self.task = task + + if self.task == "S2T": + assert isinstance(self.encoder, TransformerEncoder) + assert isinstance(self.decoder, TransformerDecoder) + + @property + def loss_function(self): + return self._loss_function + + @loss_function.setter + def loss_function(self, cfg: Tuple): + loss_type, label_smoothing, ctc_weight = cfg + if loss_type == "crossentropy-ctc": + loss_function = XentCTCLoss( + pad_index=self.pad_index, + bos_index=self.bos_index, # bos -> blank + smoothing=label_smoothing, + ctc_weight=ctc_weight) + elif loss_type == "crossentropy": + loss_function = XentLoss(pad_index=self.pad_index, + smoothing=label_smoothing) + self.decoder.ctc_output_layer = None + self._loss_function = loss_function + + def forward(self, + return_type: str = None, + **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Interface for multi-gpu + + For DataParallel, We need to encapsulate all model call: `model.encode()`, + `model.decode()`, and `model.encode_decode()` by `model.__call__()`. + `model.__call__()` triggers model.forward() together with pre hooks and post + hooks, which takes care of multi-gpu distribution. + + :param return_type: one of {"loss", "encode", "decode"} + """ + if return_type is None: + raise ValueError("Please specify return_type: {`loss`, `loss_probs`, " + "`encode`, `decode`, `decode_ctc`}.") + + if return_type.startswith("loss"): + assert self.loss_function is not None + assert "trg" in kwargs and "trg_mask" in kwargs # need trg to compute loss + return_tuple = [None, None, None, None] + + out, ctc_out, src_mask = self._encode_decode(**kwargs) + + # compute log probs + log_probs = F.log_softmax(out, dim=-1) + + # compute batch loss + if self.loss_function.require_ctc_layer and isinstance(ctc_out, Tensor): + kwargs["src_mask"] = src_mask # pass through subsampled mask + kwargs["ctc_log_probs"] = F.log_softmax(ctc_out, dim=-1) + # pylint: disable=not-callable + batch_loss = self.loss_function(log_probs, **kwargs) + assert isinstance(batch_loss, tuple) and 1 <= len(batch_loss) <= 3 + + # return batch loss + # = sum over all elements in batch that are not pad + for i, loss in enumerate(list(batch_loss)): + return_tuple[i] = loss + + # count correct tokens before decoding (for accuracy) + trg_mask = kwargs["trg_mask"].squeeze(1) + assert kwargs["trg"].size() == trg_mask.size() + n_correct = torch.sum( + log_probs.argmax(-1).masked_select(trg_mask).eq( + kwargs["trg"].masked_select(trg_mask))) + return_tuple[-1] = n_correct + + if return_type == "loss_probs": + return_tuple[1] = log_probs + return_tuple[2] = kwargs.get("ctc_log_probs", None) + + elif return_type == "encode": + encoder_output, encoder_hidden, src_mask = self._encode(**kwargs) + + # return encoder outputs + return_tuple = (encoder_output, encoder_hidden, src_mask, None) + + elif return_type == "decode": + outputs, hidden, att_probs, att_vectors, _ = self._decode(**kwargs) + + # return decoder outputs + return_tuple = (outputs, hidden, att_probs, att_vectors) + + elif return_type == "decode_ctc": + outputs, hidden, att_probs, _, ctc_out = self._decode(**kwargs) + + # return decoder outputs with ctc + return_tuple = (outputs, hidden, att_probs, ctc_out) + + return tuple(return_tuple) + + def _encode_decode( + self, + src: Tensor, + trg_input: Tensor, + src_mask: Tensor, + src_length: Tensor, + trg_mask: Tensor = None, + **kwargs, + ) -> Tensor: + """ + First encodes the source sentence. + Then produces the target one word at a time. + + :param src: source input + :param trg_input: target input + :param src_mask: source mask + :param src_length: length of source inputs + :param trg_mask: target mask + :return: + - decoder output + - ctc output + - src mask + """ + encoder_output, encoder_hidden, src_mask = self._encode(src=src, + src_length=src_length, + src_mask=src_mask, + **kwargs) + + unroll_steps = trg_input.size(1) + + decoder_output, _, _, _, ctc_output = self._decode( + encoder_output=encoder_output, + encoder_hidden=encoder_hidden, + src_mask=src_mask, + trg_input=trg_input, + unroll_steps=unroll_steps, + trg_mask=trg_mask, + **kwargs) + + return decoder_output, ctc_output, src_mask + + def _encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor, + **_kwargs) -> Tuple[Tensor, Tensor, Tensor]: + """ + Encodes the source sentence. + Note: this is called after DataParallel + + :param src: spectrogram if task == "S2T" else one-hot encoded sequence + :param src_length: + :param src_mask: None if task == "S2T" else bool tensor in shape + (batch_size, 1, seq_len) + :return: + - encoder_outputs + - hidden_concat + - src_mask + """ + assert _kwargs["task"] == self.task, (_kwargs["task"], self.task) # batch type + return self.encoder(self.src_embed(src), src_length, src_mask, **_kwargs) + + def _decode( + self, + encoder_output: Tensor, + encoder_hidden: Tensor, + src_mask: Tensor, + trg_input: Tensor, + unroll_steps: int, + decoder_hidden: Tensor = None, + att_vector: Tensor = None, + trg_mask: Tensor = None, + **_kwargs, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Decode, given an encoded source sentence. + + :param encoder_output: encoder states for attention computation + :param encoder_hidden: last encoder state for decoder initialization + :param src_mask: source mask, 1 at valid tokens + :param trg_input: target inputs + :param unroll_steps: number of steps to unroll the decoder for + :param decoder_hidden: decoder hidden state (optional) + :param att_vector: previous attention vector (optional) + :param trg_mask: mask for target steps + :return: decoder outputs + - decoder_output + - decoder_hidden + - att_prob + - att_vector + - ctc_output + """ + return self.decoder( + trg_embed=self.trg_embed(trg_input), + encoder_output=encoder_output, + encoder_hidden=encoder_hidden, + src_mask=src_mask, + unroll_steps=unroll_steps, + hidden=decoder_hidden, + prev_att_vector=att_vector, + trg_mask=trg_mask, + **_kwargs, + ) + + def __repr__(self) -> str: + """ + String representation: a description of encoder, decoder and embeddings + + :return: string representation + """ + return (f"{self.__class__.__name__}(task={self.task},\n" + f"\tencoder={self.encoder},\n" + f"\tdecoder={self.decoder},\n" + f"\tsrc_embed={self.src_embed},\n" + f"\ttrg_embed={self.trg_embed},\n" + f"\tloss_function={self.loss_function})") + + def log_parameters_list(self) -> None: + """ + Write all model parameters (name, shape) to the log. + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + n_params = sum([np.prod(p.size()) for p in model_parameters]) + logger.info("Total params: %d", n_params) + trainable_params = [n for (n, p) in self.named_parameters() if p.requires_grad] + logger.debug("Trainable parameters: %s", sorted(trainable_params)) + assert trainable_params + + +class _DataParallel(nn.DataParallel): + """DataParallel wrapper to pass through the model attributes""" + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + +def build_model(cfg: dict = None, + src_vocab: Vocabulary = None, + trg_vocab: Vocabulary = None) -> Model: + """ + Build and initialize the model according to the configuration. + + :param cfg: dictionary configuration containing model specifications + :param src_vocab: source vocabulary + :param trg_vocab: target vocabulary + :return: built and initialized model + """ + # pylint: disable=too-many-branches + logger.info("Building an encoder-decoder model...") + enc_cfg = cfg["encoder"] + dec_cfg = cfg["decoder"] + + task = "MT" if src_vocab is not None else "S2T" + + trg_pad_index = trg_vocab.pad_index + src_pad_index = src_vocab.pad_index if task == "MT" else trg_pad_index + + if task == "MT": + src_embed = Embeddings(**enc_cfg["embeddings"], + vocab_size=len(src_vocab), + padding_idx=src_pad_index) + else: + src_embed = nn.Identity() + + # this ties source and target embeddings for softmax layer tying, see further below + if cfg.get("tied_embeddings", False): + if task == "MT" and src_vocab == trg_vocab: + trg_embed = src_embed # share embeddings for src and trg + else: + raise ConfigurationError( + "Embedding cannot be tied since vocabularies differ.") + else: + trg_embed = Embeddings(**dec_cfg["embeddings"], + vocab_size=len(trg_vocab), + padding_idx=trg_pad_index) + + # build encoder + enc_dropout = enc_cfg.get("dropout", 0.0) + enc_emb_dropout = enc_cfg["embeddings"].get("dropout", enc_dropout) + enc_type = enc_cfg.get("type", "transformer") + if enc_type not in ["recurrent", "transformer", "conformer"]: + raise ConfigurationError("Invalid encoder type. Valid options: " + "{`recurrent`, `transformer`, `conformer`}.") + if enc_type == "transformer": + if task == "MT": + assert enc_cfg["embeddings"]["embedding_dim"] == enc_cfg["hidden_size"], ( + "for transformer, emb_size must be the same as hidden_size.") + emb_size = src_embed.embedding_dim + elif task == "S2T": + emb_size = enc_cfg["embeddings"]["embedding_dim"] + # TODO: check if emb_size == num_freq + encoder = TransformerEncoder(**enc_cfg, + emb_size=emb_size, + emb_dropout=enc_emb_dropout, + pad_index=src_pad_index) + elif enc_type == "conformer": + assert task == "S2T", "Conformer model can be used only for S2T task." + emb_size = enc_cfg["embeddings"]["embedding_dim"] + # TODO: check if emb_size == num_freq + encoder = ConformerEncoder(**enc_cfg, + emb_size=emb_size, + emb_dropout=enc_emb_dropout, + pad_index=src_pad_index) + else: + assert task == "MT", "RNN model not supported for s2t task. use transformer." + encoder = RecurrentEncoder(**enc_cfg, + vemb_size=src_embed.embedding_dim, + emb_dropout=enc_emb_dropout) + + # build decoder + dec_dropout = dec_cfg.get("dropout", 0.0) + dec_emb_dropout = dec_cfg["embeddings"].get("dropout", dec_dropout) + dec_type = dec_cfg.get("type", "transformer") + if dec_type not in ["recurrent", "transformer"]: + raise ConfigurationError( + "Invalid decoder type. Valid options: {`transformer`, `recurrent`}.") + if dec_type == "transformer": + if task == "S2T": + # pylint: disable=protected-access + dec_cfg["encoder_output_size_for_ctc"] = encoder._output_size + decoder = TransformerDecoder(**dec_cfg, + encoder=encoder, + vocab_size=len(trg_vocab), + emb_size=trg_embed.embedding_dim, + emb_dropout=dec_emb_dropout) + else: + decoder = RecurrentDecoder(**dec_cfg, + encoder=encoder, + vocab_size=len(trg_vocab), + emb_size=trg_embed.embedding_dim, + emb_dropout=dec_emb_dropout) + + model = Model( + encoder=encoder, + decoder=decoder, + src_embed=src_embed, + trg_embed=trg_embed, + src_vocab=src_vocab, + trg_vocab=trg_vocab, + task=task, + ) + + # tie softmax layer with trg embeddings + if cfg.get("tied_softmax", False): + if trg_embed.lut.weight.shape == model.decoder.output_layer.weight.shape: + # (also) share trg embeddings and softmax layer: + model.decoder.output_layer.weight = trg_embed.lut.weight + else: + raise ConfigurationError( + "For tied_softmax, the decoder embedding_dim and decoder hidden_size " + "must be the same. The decoder must be a Transformer.") + + # custom initialization of model parameters + initialize_model(model, cfg, src_pad_index, trg_pad_index) + + # initialize embeddings from file + enc_embed_path = enc_cfg["embeddings"].get("load_pretrained", None) + dec_embed_path = dec_cfg["embeddings"].get("load_pretrained", None) + if enc_embed_path and task == "MT": + logger.info("Loading pretrained src embeddings...") + model.src_embed.load_from_file(Path(enc_embed_path), src_vocab) + if dec_embed_path and not cfg.get("tied_embeddings", False): + logger.info("Loading pretrained trg embeddings...") + model.trg_embed.load_from_file(Path(dec_embed_path), trg_vocab) + + logger.info("Enc-dec model built.") + return model diff --git a/joeynmt/plotting.py b/joeynmt/plotting.py new file mode 100644 index 0000000..5d074ea --- /dev/null +++ b/joeynmt/plotting.py @@ -0,0 +1,86 @@ +# coding: utf-8 +""" +Plot attentions +""" +from typing import List, Optional + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import rcParams +from matplotlib.backends.backend_pdf import PdfPages +from matplotlib.figure import Figure + +matplotlib.use("Agg") +# matplotlib.font_manager.fontManager.addfont("ipaexg.ttf") + + +def plot_heatmap( + scores: np.ndarray, + column_labels: List[str], + row_labels: List[str], + output_path: Optional[str] = None, + dpi: int = 300, +) -> Figure: + """ + Plotting function that can be used to visualize (self-)attention. + Plots are saved if `output_path` is specified, in format that this file + ends with ('pdf' or 'png'). + + :param scores: attention scores + :param column_labels: labels for columns (e.g. target tokens) + :param row_labels: labels for rows (e.g. source tokens) + :param output_path: path to save to + :param dpi: set resolution for matplotlib + :return: pyplot figure + """ + + if output_path is not None: + assert output_path.endswith(".png") or output_path.endswith(".pdf"), \ + "output path must have .png or .pdf extension" + + x_sent_len = len(column_labels) + y_sent_len = len(row_labels) + scores = scores[:y_sent_len, :x_sent_len] + # check that cut off part didn't have any attention + assert np.sum(scores[y_sent_len:, :x_sent_len]) == 0 + + # automatic label size + labelsize = 25 * (10 / max(x_sent_len, y_sent_len)) + + # font config + rcParams["xtick.labelsize"] = labelsize + rcParams["ytick.labelsize"] = labelsize + # rcParams['font.family'] = "IPAexGothic" # support CJK + + fig, ax = plt.subplots(figsize=(10, 10), dpi=dpi) + plt.imshow( + scores, + cmap="viridis", + aspect="equal", + origin="upper", + vmin=0.0, + vmax=1.0, + ) + ax.xaxis.tick_top() + ax.set_xticks(np.arange(scores.shape[1]) + 0, minor=False) + ax.set_yticks(np.arange(scores.shape[0]) + 0, minor=False) + + ax.set_xticklabels(column_labels, minor=False, rotation="vertical") + ax.set_yticklabels(row_labels, minor=False) + + plt.tight_layout() + + if output_path is not None: + if output_path.endswith(".pdf"): + pp = PdfPages(output_path) + pp.savefig(fig) + pp.close() + else: + if not output_path.endswith(".png"): + output_path += ".png" + plt.savefig(output_path) + + plt.close() + + return fig diff --git a/joeynmt/prediction.py b/joeynmt/prediction.py new file mode 100644 index 0000000..c252792 --- /dev/null +++ b/joeynmt/prediction.py @@ -0,0 +1,634 @@ +# coding: utf-8 +""" +This modules holds methods for generating predictions from a model. +""" +import logging +import math +import sys +import time +from functools import partial +from itertools import zip_longest +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset +from tqdm import tqdm + +from joeynmt.data import load_data +from joeynmt.datasets import build_dataset +from joeynmt.helpers import ( + check_version, + expand_reverse_index, + load_checkpoint, + load_config, + make_logger, + parse_test_args, + parse_train_args, + resolve_ckpt_path, + set_seed, + store_attention_plots, + write_list_to_file, +) +from joeynmt.helpers_for_pose import pad_features +from joeynmt.metrics import bleu, chrf, sequence_accuracy, token_accuracy, wer +from joeynmt.model import Model, _DataParallel, build_model +from joeynmt.search import search +from joeynmt.tokenizers import EvaluationTokenizer, build_tokenizer +from joeynmt.vocabulary import build_vocab + +logger = logging.getLogger(__name__) + + +def predict( + model: Model, + data: Dataset, + device: torch.device, + n_gpu: int, + compute_loss: bool = False, + normalization: str = "batch", + num_workers: int = 0, + cfg: Dict = None, + fp16: bool = False, +) -> Tuple[Dict[str, float], List[str], List[str], List[List[str]], List[np.ndarray], + List[np.ndarray]]: + """ + Generates translations for the given data. + If `compute_loss` is True and references are given, also computes the loss. + + :param model: model module + :param data: dataset for validation + :param device: torch device + :param n_gpu: number of GPUs + :param compute_loss: whether to computes a scalar loss for given inputs and targets + :param normalization: one of {`batch`, `tokens`, `none`} + :param num_workers: number of workers for `collate_fn()` in data iterator + :param cfg: `testing` section in yaml config file + :param fp16: whether to use fp16 + :return: + - valid_scores: (dict) current validation scores, + - valid_ref: (list) validation references, + - valid_hyp: (list) validation hypotheses, + - decoded_valid: (list) token-level validation hypotheses (before post-process), + - valid_sequence_scores: (list) log probabilities for validation hypotheses + - valid_attention_scores: (list) attention scores for validation hypotheses + """ + # pylint: disable=too-many-branches,too-many-statements + # parse test cfg + ( + eval_batch_size, + eval_batch_type, + max_output_length, + min_output_length, + eval_metrics, + sacrebleu_cfg, + beam_size, + beam_alpha, + n_best, + return_attention, + return_prob, + generate_unk, + repetition_penalty, + no_repeat_ngram_size, + ) = parse_test_args(cfg) + + if return_prob == "ref": # no decoding needed + decoding_description = "" + else: + decoding_description = ( # write the decoding strategy in the log + " (Greedy decoding with " if beam_size < 2 else + f" (Beam search with beam_size={beam_size}, beam_alpha={beam_alpha}, " + f"n_best={n_best}, ") + decoding_description += ( + f"min_output_length={min_output_length}, " + f"max_output_length={max_output_length}, " + f"return_prob='{return_prob}', generate_unk={generate_unk}, " + f"repetition_penalty={repetition_penalty}, " + f"no_repeat_ngram_size={no_repeat_ngram_size})") + logger.info("Predicting %d example(s)...%s", len(data), decoding_description) + + assert eval_batch_size >= n_gpu, "`batch_size` must be bigger than `n_gpu`." + # **CAUTION:** a batch will be expanded to batch.nseqs * beam_size, and it might + # cause an out-of-memory error. + # if batch_size > beam_size: + # batch_size //= beam_size + + valid_iter = data.make_iter( + batch_size=eval_batch_size, + batch_type=eval_batch_type, + shuffle=False, + num_workers=num_workers, + pad_index=model.pad_index, + device=device, + ) + + # disable dropout + model.eval() + + # place holders for scores + valid_scores = {"loss": float("nan"), "acc": float("nan"), "ppl": float("nan")} + all_outputs = [] + valid_attention_scores = [] + valid_sequence_scores = [] + total_loss = 0 + total_nseqs = 0 + total_ntokens = 0 + total_n_correct = 0 + output, ref_scores, hyp_scores, attention_scores = None, None, None, None + disable_tqdm = data.__class__.__name__ == "StreamDataset" + + gen_start_time = time.time() + with tqdm(total=len(data), disable=disable_tqdm, desc="Predicting...") as pbar: + for batch in valid_iter: + total_nseqs += batch.nseqs # number of sentences in the current batch + + # sort batch now by src length and keep track of order + reverse_index = batch.sort_by_src_length() + sort_reverse_index = expand_reverse_index(reverse_index, n_best) + batch_size = len(sort_reverse_index) # = batch.nseqs * n_best + + # run as during training to get validation loss (e.g. xent) + if compute_loss and batch.has_trg: + assert model.loss_function is not None + + # don't track gradients during validation + with torch.no_grad(): + batch_loss, log_probs, attn, n_correct = model( + return_type="loss", + return_attention=return_attention, + **vars(batch)) + # sum over multiple gpus + batch_loss = batch.normalize(batch_loss, "sum", n_gpu=n_gpu) + n_correct = batch.normalize(n_correct, "sum", n_gpu=n_gpu) + if return_prob == "ref": + ref_scores = batch.score(log_probs) + attention_scores = attn.detach().cpu().numpy() + output = batch.trg + + total_loss += batch_loss.item() # cast Tensor to float + total_n_correct += n_correct.item() # cast Tensor to int + total_ntokens += batch.ntokens + + # if return_prob == "ref", then no search needed. + # (just look up the prob of the ground truth.) + if return_prob != "ref": + # run search as during inference to produce translations + output, hyp_scores, attention_scores = search( + model=model, + batch=batch, + beam_size=beam_size, + beam_alpha=beam_alpha, + max_output_length=max_output_length, + n_best=n_best, + return_attention=return_attention, + return_prob=return_prob, + generate_unk=generate_unk, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + fp16=fp16, + ) + + # sort outputs back to original order + all_outputs.extend(output[sort_reverse_index]) # either hyp or ref + valid_attention_scores.extend(attention_scores[sort_reverse_index] + if attention_scores is not None else []) + valid_sequence_scores.extend( + ref_scores[sort_reverse_index] \ + if ref_scores is not None and ref_scores.shape[0] == batch_size + else hyp_scores[sort_reverse_index] \ + if hyp_scores is not None and hyp_scores.shape[0] == batch_size + else []) + + pbar.update(batch.nseqs) + + gen_duration = time.time() - gen_start_time + + assert len(valid_iter.dataset) == total_nseqs == len(data), \ + (len(valid_iter.dataset), total_nseqs, len(data)) + assert len(all_outputs) == len(data) * n_best, (len(all_outputs), len(data), n_best) + + if compute_loss: + if normalization == "batch": + normalizer = total_nseqs + elif normalization == "tokens": + normalizer = total_ntokens + elif normalization == "none": + normalizer = 1 + + # avoid zero division + assert normalizer > 0 + assert total_ntokens > 0 + + # normalized loss + valid_scores["loss"] = total_loss / normalizer + # accuracy before decoding + valid_scores["acc"] = total_n_correct / total_ntokens + # exponent of token-level negative log likelihood + valid_scores["ppl"] = math.exp(total_loss / total_ntokens) + + # decode ids back to str symbols (cut-off AFTER eos; eos itself is included.) + decoded_valid, valid_sequence_scores = model.trg_vocab.arrays_to_sentences( + arrays=all_outputs, score_arrays=valid_sequence_scores, cut_at_eos=True) + # TODO: `valid_sequence_scores` should have the same seq length as `decoded_valid` + # -> needed to be cut-off at eos synchronously + + if return_prob == "ref": # no evaluation needed + logger.info( + "Evaluation result (scoring) %s, duration: %.4f[sec]", + ", ".join([ + f"{eval_metric}: {valid_scores[eval_metric]:6.2f}" + for eval_metric in ["loss", "ppl", "acc"] + ]), + gen_duration, + ) + return ( + valid_scores, + None, # valid_ref + None, # valid_hyp + decoded_valid, + valid_sequence_scores, + valid_attention_scores, + ) + + # retrieve detokenized hypotheses and references + valid_hyp = [ + data.tokenizer[data.trg_lang].post_process(s, generate_unk=generate_unk) + for s in decoded_valid + ] + # references are not length-filtered, not duplicated for n_best > 1 + valid_ref = [data.tokenizer[data.trg_lang].post_process(s) for s in data.trg] + + # if references are given, evaluate 1best generation against them + if data.has_trg: + valid_hyp_1best = (valid_hyp if n_best == 1 else + [valid_hyp[i] for i in range(0, len(valid_hyp), n_best)]) + assert len(valid_hyp_1best) == len(valid_ref), (valid_hyp_1best, valid_ref) + + eval_start_time = time.time() + + # evaluate with metrics on full dataset + for eval_metric in eval_metrics: + if eval_metric == "bleu": + valid_scores[eval_metric] = bleu(valid_hyp_1best, valid_ref, + **sacrebleu_cfg) # detokenized ref + elif eval_metric == "chrf": + valid_scores[eval_metric] = chrf(valid_hyp_1best, valid_ref, + **sacrebleu_cfg) # detokenized ref + elif eval_metric == "token_accuracy": + decoded_valid_1best = (decoded_valid if n_best == 1 else [ + decoded_valid[i] for i in range(0, len(decoded_valid), n_best) + ]) + valid_scores[eval_metric] = token_accuracy( + decoded_valid_1best, + data.get_list(lang=data.trg_lang, tokenized=True), # tokenized ref + ) + elif eval_metric == "sequence_accuracy": + valid_scores[eval_metric] = sequence_accuracy( + valid_hyp_1best, valid_ref) + elif eval_metric == "wer": + if "eval" not in data.tokenizer: # better to handle this in data.py? + data.tokenizer["eval"] = EvaluationTokenizer( + lowercase=sacrebleu_cfg.get("lowercase", False), + tokenize=sacrebleu_cfg.get("tokenize", "13a"), + no_punc=sacrebleu_cfg.get("no_punc", False)) # WER w/o punc + valid_scores[eval_metric] = wer(valid_hyp_1best, valid_ref, + data.tokenizer["eval"]) + + eval_duration = time.time() - eval_start_time + score_str = ", ".join([ + f"{eval_metric}: {valid_scores[eval_metric]:6.2f}" + for eval_metric in eval_metrics + ["loss", "ppl", "acc"] + if not math.isnan(valid_scores[eval_metric]) + ]) + logger.info( + "Evaluation result (%s) %s, generation: %.4f[sec], evaluation: %.4f[sec]", + "beam search" if beam_size > 1 else "greedy", + score_str, + gen_duration, + eval_duration, + ) + else: + logger.info("Generation took %.4f[sec]. (No references given)", gen_duration) + + return ( + valid_scores, + valid_ref, + valid_hyp, + decoded_valid, + valid_sequence_scores, + valid_attention_scores, + ) + + +def test( + cfg_file, + ckpt: str, + output_path: str = None, + datasets: dict = None, + save_attention: bool = False, + save_scores: bool = False, +) -> None: + """ + Main test function. Handles loading a model from checkpoint, generating + translations, storing them, and plotting attention. + + :param cfg_file: path to configuration file + :param ckpt: path to checkpoint to load + :param output_path: path to output + :param datasets: datasets to predict + :param save_attention: whether to save attention visualizations + :param save_scores: whether to save scores + """ + # pylint: disable=too-many-branches + cfg = load_config(Path(cfg_file)) + # parse train cfg + ( + model_dir, + load_model, + device, + n_gpu, + num_workers, + normalization, + fp16, + ) = parse_train_args(cfg["training"], mode="prediction") + + if len(logger.handlers) == 0: + pkg_version = make_logger(model_dir, mode="test") # version string returned + if "joeynmt_version" in cfg: + check_version(pkg_version, cfg["joeynmt_version"]) + + # load the data + if datasets is None: + src_vocab, trg_vocab, _, dev_data, test_data = load_data( + data_cfg=cfg["data"], datasets=["dev", "test"]) + data_to_predict = {"dev": dev_data, "test": test_data} + else: # avoid to load data again + data_to_predict = {"dev": datasets["dev"], "test": datasets["test"]} + src_vocab = datasets["src_vocab"] + trg_vocab = datasets["trg_vocab"] + + # build model and load parameters into it + model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) + + # check options + if save_attention: + if cfg["model"]["decoder"]["type"] == "transformer": + assert cfg["testing"].get("beam_size", 1) == 1, ( + "Attention plots can be saved with greedy decoding only. Please set " + "`beam_size: 1` in the config.") + cfg["testing"]["return_attention"] = True + return_prob = cfg["testing"].get("return_prob", "none") + if save_scores: + assert output_path, "Please specify --output_path for saving scores." + if return_prob == "none": + logger.warning("Please specify prob type: {`ref` or `hyp`} in the config. " + "Scores will not be saved.") + save_scores = False + elif return_prob == "ref": + assert cfg["testing"].get("beam_size", 1) == 1, ( + "Scores of given references can be computed with greedy decoding only." + "Please set `beam_size: 1` in the config.") + model.loss_function = ( # need to instantiate loss func to compute scores + cfg["training"].get("loss", "crossentropy"), + cfg["training"].get("label_smoothing", 0.1), + cfg["training"].get("ctc_weight", 0.3), + ) + + # when checkpoint is not specified, take latest (best) from model dir + load_model = load_model if ckpt is None else Path(ckpt) + ckpt = resolve_ckpt_path(load_model, model_dir) + + # load model checkpoint + model_checkpoint = load_checkpoint(ckpt, device=device) + + # restore model and optimizer parameters + model.load_state_dict(model_checkpoint["model_state"]) + if device.type == "cuda": + model.to(device) + + # multi-gpu eval + if n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): + model = _DataParallel(model) + logger.info(model) + + # set the random seed + set_seed(seed=cfg["training"].get("random_seed", 42)) + + for data_set_name, data_set in data_to_predict.items(): + if data_set is not None: + data_set.reset_random_subset() # no subsampling in evaluation + + logger.info( + "%s on %s set...", + "Scoring" if return_prob == "ref" else "Decoding", + data_set_name, + ) + _, _, hypotheses, hypotheses_raw, seq_scores, att_scores, = predict( + model=model, + data=data_set, + compute_loss=return_prob == "ref", + device=device, + n_gpu=n_gpu, + num_workers=num_workers, + normalization=normalization, + cfg=cfg["testing"], + fp16=fp16, + ) + + if save_attention: + if att_scores: + attention_file_name = f"{data_set_name}.{ckpt.stem}.att" + attention_file_path = (model_dir / attention_file_name).as_posix() + logger.info("Saving attention plots. This might take a while..") + store_attention_plots( + attentions=att_scores, + targets=hypotheses_raw, + sources=data_set.get_list(lang=data_set.src_lang, + tokenized=True), + indices=range(len(hypotheses)), + output_prefix=attention_file_path, + ) + logger.info("Attention plots saved to: %s", attention_file_path) + else: + logger.warning( + "Attention scores could not be saved. Note that attention " + "scores are not available when using beam search. " + "Set beam_size to 1 for greedy decoding.") + + if output_path is not None: + if save_scores and seq_scores is not None: + # save scores + output_path_scores = Path(f"{output_path}.{data_set_name}.scores") + write_list_to_file(output_path_scores, seq_scores) + # save tokens + output_path_tokens = Path(f"{output_path}.{data_set_name}.tokens") + write_list_to_file(output_path_tokens, hypotheses_raw) + logger.info( + "Scores and corresponding tokens saved to: %s.{scores|tokens}", + f"{output_path}.{data_set_name}", + ) + if hypotheses is not None: + # save translations + output_path_set = Path(f"{output_path}.{data_set_name}") + write_list_to_file(output_path_set, hypotheses) + logger.info("Translations saved to: %s.", output_path_set) + + +def translate( + cfg_file: str, + ckpt: str = None, + output_path: str = None, +) -> None: + """ + Interactive translation function. + Loads model from checkpoint and translates either the stdin input or asks for + input to translate interactively. Translations and scores are printed to stdout. + Note: The input sentences don't have to be pre-tokenized. + + :param cfg_file: path to configuration file + :param ckpt: path to checkpoint to load + :param output_path: path to output file + """ + + # pylint: disable=too-many-branches + def _translate_data(test_data, cfg): + """Translates given dataset, using parameters from outer scope.""" + _, _, hypotheses, trg_tokens, trg_scores, _ = predict( + model=model, + data=test_data, + compute_loss=False, + device=device, + n_gpu=n_gpu, + normalization="none", + num_workers=num_workers, + cfg=cfg, + fp16=fp16, + ) + return hypotheses, trg_tokens, trg_scores + + cfg = load_config(Path(cfg_file)) + # parse and validate cfg + model_dir, load_model, device, n_gpu, num_workers, _, fp16 = parse_train_args( + cfg["training"], mode="prediction") + test_cfg = cfg["testing"] + src_cfg = cfg["data"]["src"] + trg_cfg = cfg["data"]["trg"] + task = cfg["data"].get("task", "MT").upper() + + pkg_version = make_logger(model_dir, mode="translate") # version string returned + if "joeynmt_version" in cfg: + check_version(pkg_version, cfg["joeynmt_version"]) + + # when checkpoint is not specified, take latest (best) from model dir + load_model = load_model if ckpt is None else Path(ckpt) + ckpt = resolve_ckpt_path(load_model, model_dir) + + # read vocabs + src_vocab, trg_vocab = build_vocab(cfg["data"], model_dir=model_dir) + + # build model and load parameters into it + model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) + + # load model state from disk + model_checkpoint = load_checkpoint(ckpt, device=device) + model.load_state_dict(model_checkpoint["model_state"]) + + if device.type == "cuda": + model.to(device) + + tokenizer = build_tokenizer(cfg["data"]) + if task == "MT": + sequence_encoder = { + src_cfg["lang"]: partial(src_vocab.sentences_to_ids, bos=False, eos=True), + trg_cfg["lang"]: None, + } + elif task == "S2T": + sequence_encoder = { + "src": partial(pad_features, embed_size=tokenizer["src"].num_freq), + "trg": None, + } + test_data = build_dataset( + dataset_type="stream" if task == "MT" else "speech_stream", + path=None, + src_lang=src_cfg["lang"] if task == "MT" else "src", + trg_lang=trg_cfg["lang"] if task == "MT" else "trg", + split="test", + tokenizer=tokenizer, + sequence_encoder=sequence_encoder, + task=task, + ) + + # set the random seed + set_seed(seed=cfg["training"].get("random_seed", 42)) + + n_best = test_cfg.get("n_best", 1) + beam_size = test_cfg.get("beam_size", 1) + return_prob = test_cfg.get("return_prob", "none") + if not sys.stdin.isatty(): # pylint: disable=too-many-nested-blocks + # input stream given + for i, line in enumerate(sys.stdin.readlines()): + if not line.strip(): + # skip empty lines and print warning + logger.warning("The sentence in line %d is empty. Skip to load.", i) + continue + test_data.set_item(line.rstrip()) + all_hypotheses, tokens, scores = _translate_data(test_data, test_cfg) + assert len(all_hypotheses) == len(test_data) * n_best + + if output_path is not None: + # write to outputfile if given + out_file = Path(output_path).expanduser() + + if n_best > 1: + for n in range(n_best): + write_list_to_file( + out_file.parent / f"{out_file.stem}-{n}.{out_file.suffix}", + [ + all_hypotheses[i] + for i in range(n, len(all_hypotheses), n_best) + ], + ) + else: + write_list_to_file(out_file, all_hypotheses) + + logger.info("Translations saved to: %s.", out_file) + + else: + # print to stdout + for hyp in all_hypotheses: + print(hyp) + + else: + # enter interactive mode + test_cfg["batch_size"] = 1 # CAUTION: this will raise an error if n_gpus > 1 + test_cfg["batch_type"] = "sentence" + np.set_printoptions(linewidth=sys.maxsize) # for printing scores in stdout + while True: + try: + src_input = input("\nPlease enter a source sentence:\n") + if not src_input.strip(): + break + + # every line has to be made into dataset + test_data.set_item(src_input.rstrip()) + hypotheses, tokens, scores = _translate_data(test_data, test_cfg) + + print("JoeyNMT:") + for i, (hyp, token, + score) in enumerate(zip_longest(hypotheses, tokens, scores)): + assert hyp is not None, (i, hyp, token, score) + print(f"#{i + 1}: {hyp}") + if return_prob in ["hyp"]: + if beam_size > 1: # beam search: sequence-level scores + print(f"\ttokens: {token}\n\tsequence score: {score[0]}") + else: # greedy: token-level scores + assert len(token) == len(score), (token, score) + print(f"\ttokens: {token}\n\tscores: {score}") + + # reset cache + test_data.reset_cache() + + except (KeyboardInterrupt, EOFError): + print("\nBye.") + break diff --git a/joeynmt/search.py b/joeynmt/search.py new file mode 100644 index 0000000..534d4d8 --- /dev/null +++ b/joeynmt/search.py @@ -0,0 +1,840 @@ +# coding: utf-8 +""" +Search module +""" +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + +from joeynmt.batch import Batch +from joeynmt.decoders import RecurrentDecoder, TransformerDecoder +from joeynmt.helpers import tile +from joeynmt.model import Model + +__all__ = ["greedy", "beam_search", "search"] + + +def greedy( + src_mask: Tensor, + max_output_length: int, + model: Model, + encoder_output: Tensor, + encoder_hidden: Tensor, + **kwargs, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Greedy decoding. Select the token word highest probability at each time step. + This function is a wrapper that calls recurrent_greedy for recurrent decoders and + transformer_greedy for transformer decoders. + + :param src_mask: mask for source inputs, 0 for positions after + :param max_output_length: maximum length for the hypotheses + :param model: model to use for greedy decoding + :param encoder_output: encoder hidden states for attention + :param encoder_hidden: encoder last state for decoder initialization + :return: + - stacked_output: output hypotheses (2d array of indices), + - stacked_scores: scores (2d array of token-wise log probabilities), + - stacked_attention_scores: attention scores (3d array) + """ + # pylint: disable=no-else-return + if isinstance(model.decoder, TransformerDecoder): + return transformer_greedy( + src_mask, + max_output_length, + model, + encoder_output, + encoder_hidden, + **kwargs, + ) + elif isinstance(model.decoder, RecurrentDecoder): + return recurrent_greedy(src_mask, max_output_length, model, encoder_output, + encoder_hidden, **kwargs) + else: + raise NotImplementedError( + f"model.decoder({model.decoder.__class__.__name__}) not supported.") + + +def recurrent_greedy( + src_mask: Tensor, + max_output_length: int, + model: Model, + encoder_output: Tensor, + encoder_hidden: Tensor, + **kwargs, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Greedy decoding: in each step, choose the word that gets highest score. + Version for recurrent decoder. + + :param src_mask: mask for source inputs, 0 for positions after + :param max_output_length: maximum length for the hypotheses + :param model: model to use for greedy decoding + :param encoder_output: encoder hidden states for attention + :param encoder_hidden: encoder last state for decoder initialization + :return: + - stacked_output: output hypotheses (2d array of indices), + - stacked_scores: scores (2d array of token-wise log probabilities), + - stacked_attention_scores: attention scores (3d array) + """ + bos_index = model.bos_index + eos_index = model.eos_index + unk_index = model.unk_index + batch_size = src_mask.size(0) + min_output_length: int = kwargs.get("min_output_length", 1) + generate_unk: bool = kwargs.get("generate_unk", True) # whether to generate UNK + return_prob: bool = kwargs.get("return_prob", "none") == "hyp" + prev_y = src_mask.new_full((batch_size, 1), fill_value=bos_index, dtype=torch.long) + + output = [] + scores = [] + attention_scores = [] + hidden = None + prev_att_vector = None + finished = src_mask.new_zeros((batch_size, 1)).byte() + device = encoder_output.device + fp16 = kwargs.get("fp16", False) + + for step in range(max_output_length): + # decode one single step + with torch.autocast(device_type=device.type, enabled=fp16): + with torch.no_grad(): + out, hidden, att_probs, prev_att_vector = model( + return_type="decode", + trg_input=prev_y, + encoder_output=encoder_output, + encoder_hidden=encoder_hidden, + src_mask=src_mask, + unroll_steps=1, + decoder_hidden=hidden, + att_vector=prev_att_vector, + ) + # out: batch x time=1 x vocab (logits) + + if return_prob: + out = F.log_softmax(out, dim=-1) + + if not generate_unk: + out[:, :, unk_index] = float("-inf") + + # don't generate EOS until we reached min_output_length + if step < min_output_length: + out[:, :, eos_index] = float("-inf") + + # greedy decoding: choose arg max over vocabulary in each step + prob, next_word = torch.max(out, dim=-1) # batch x time=1 + output.append(next_word.squeeze(1).detach().cpu()) + if return_prob: + scores.append(prob.squeeze(1).detach().cpu()) + prev_y = next_word + attention_scores.append(att_probs.squeeze(1).detach().cpu()) + # shape: (batch_size, max_src_length) + + # check if previous symbol was + is_eos = torch.eq(next_word, eos_index) + finished += is_eos + # stop predicting if reached for all elements in batch + if (finished >= 1).sum() == batch_size: + break + + stacked_output = torch.stack(output, dim=1).long() # batch, time + stacked_scores = torch.stack(scores, dim=1).float() if return_prob else None + stacked_attention_scores = torch.stack(attention_scores, dim=1).float() + return stacked_output, stacked_scores, stacked_attention_scores + + +def transformer_greedy( + src_mask: Tensor, + max_output_length: int, + model: Model, + encoder_output: Tensor, + encoder_hidden: Tensor, + **kwargs, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Special greedy function for transformer, since it works differently. + The transformer remembers all previous states and attends to them. + + :param src_mask: mask for source inputs, 0 for positions after + :param max_output_length: maximum length for the hypotheses + :param model: model to use for greedy decoding + :param encoder_output: encoder hidden states for attention + :param encoder_hidden: encoder final state (unused in Transformer) + :return: + - stacked_output: output hypotheses (2d array of indices), + - stacked_scores: scores (2d array of token-wise log probabilities), + - stacked_attention_scores: attention scores (3d array) + """ + # pylint: disable=unused-argument + bos_index = model.bos_index + eos_index = model.eos_index + unk_index = model.unk_index + pad_index = model.pad_index + batch_size, _, src_len = src_mask.size() + device = encoder_output.device + fp16: bool = kwargs.get("fp16", False) + + # options to control generation + generate_unk: bool = kwargs.get("generate_unk", True) # whether to generate UNK + return_attention: bool = kwargs.get("return_attention", False) + return_prob: bool = kwargs.get("return_prob", "none") == "hyp" + min_output_length: int = kwargs.get("min_output_length", 1) + repetition_penalty: float = kwargs.get("repetition_penalty", -1) + no_repeat_ngram_size: int = kwargs.get("no_repeat_ngram_size", -1) + encoder_input: Tensor = kwargs.get("encoder_input", None) # for repetition blocker + compute_softmax: bool = (return_prob or repetition_penalty > 0 + or no_repeat_ngram_size > 0 or encoder_input is not None) + + # start with BOS-symbol for each sentence in the batch + ys = encoder_output.new_full((batch_size, 1), bos_index, dtype=torch.long) + + # placeholder for scores + yv = ys.new_zeros((batch_size, 1), dtype=torch.float) if return_prob else None + + # placeholder for attentions + yt = ys.new_zeros((batch_size, 1, src_len), dtype=torch.float) \ + if return_attention else None + + # a subsequent mask is intersected with this in decoder forward pass + trg_mask = src_mask.new_ones([1, 1, 1]) + if isinstance(model, torch.nn.DataParallel): + trg_mask = torch.stack([src_mask.new_ones([1, 1]) for _ in model.device_ids]) + + finished = src_mask.new_zeros(batch_size).byte() + + for step in range(max_output_length): + with torch.autocast(device_type=device.type, enabled=fp16): + with torch.no_grad(): + out, _, att, _ = model( + return_type="decode", + trg_input=ys, # model.trg_embed(ys) # embed the previous tokens + encoder_output=encoder_output, + encoder_hidden=None, + src_mask=src_mask, + unroll_steps=None, + decoder_hidden=None, + trg_mask=trg_mask, + return_attention=return_attention, + ) + + out = out[:, -1] # logits + if not generate_unk: + out[:, unk_index] = float("-inf") + + # don't generate EOS until we reached min_output_length + if step < min_output_length: + out[:, eos_index] = float("-inf") + + if compute_softmax: + out = F.log_softmax(out, dim=-1) + + # ngram blocker + if no_repeat_ngram_size > 1: + out = block_repeat_ngrams( + ys, + out, + no_repeat_ngram_size, + step, + src_tokens=encoder_input, + exclude_tokens=[bos_index, eos_index, unk_index, pad_index], + ) + + # repetition_penalty + if repetition_penalty > 1.0: + out = penalize_repetition( + ys, + out, + repetition_penalty, + exclude_tokens=[bos_index, eos_index, unk_index, pad_index], + ) + if encoder_input is not None: + out = penalize_repetition( + encoder_input, + out, + repetition_penalty, + exclude_tokens=[bos_index, eos_index, unk_index, pad_index], + ) + + # take the most likely token + prob, next_word = torch.max(out, dim=1) + next_word = next_word.data + ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) + if return_prob: + prob = prob.data + yv = torch.cat([yv, prob.unsqueeze(-1)], dim=1) + if return_attention: + assert att is not None + att = att.data[:, -1, :].unsqueeze(1) # take last trg token only + yt = torch.cat([yt, att], dim=1) # (batch_size, trg_len, src_len) + + # check if previous symbol was + is_eos = torch.eq(next_word, eos_index) + finished += is_eos + # stop predicting if reached for all elements in batch + if (finished >= 1).sum() == batch_size: + break + + # remove BOS-symbol + + output = ys[:, 1:].detach().cpu().long() + scores = yv[:, 1:].detach().cpu().float() if return_prob else None + attention = yt[:, 1:, :].detach().cpu().float() if return_attention else None + assert output.shape[0] == batch_size, (output.shape, batch_size) + return output, scores, attention + + +def beam_search( + model: Model, + beam_size: int, + encoder_output: Tensor, + encoder_hidden: Tensor, + src_mask: Tensor, + max_output_length: int, + alpha: float, + n_best: int = 1, + **kwargs, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Beam search with size k. In each decoding step, find the k most likely partial + hypotheses. Inspired by OpenNMT-py, adapted for Transformer. + + :param model: + :param beam_size: size of the beam + :param encoder_output: + :param encoder_hidden: + :param src_mask: + :param max_output_length: + :param alpha: `alpha` factor for length penalty + :param n_best: return this many hypotheses, <= beam (currently only 1) + :return: + - stacked_output: output hypotheses (2d array of indices), + - stacked_scores: scores (2d array of sequence-wise log probabilities), + - stacked_attention_scores: attention scores (3d array) + """ + # pylint: disable=too-many-statements,too-many-branches + assert beam_size > 0, "Beam size must be >0." + assert n_best <= beam_size, f"Can only return {beam_size} best hypotheses." + + # Take the best 2 x {beam_size} predictions so as to avoid duplicates in generation. + # yet, only return {n_best} hypotheses. + # beam_size = 2 * beam_size + + # init + bos_index = model.bos_index + eos_index = model.eos_index + pad_index = model.pad_index + unk_index = model.unk_index + batch_size = src_mask.size(0) + + generate_unk: bool = kwargs.get("generate_unk", True) # whether to generate UNK + return_prob: bool = kwargs.get("return_prob", "none") == "hyp" + min_output_length: int = kwargs.get("min_output_length", 1) + repetition_penalty: float = kwargs.get("repetition_penalty", -1) + no_repeat_ngram_size: int = kwargs.get("no_repeat_ngram_size", -1) + encoder_input: Tensor = kwargs.get("encoder_input", None) # for repetition blocker + + trg_vocab_size = model.decoder.output_size + device = encoder_output.device + fp16: bool = kwargs.get("fp16", False) + is_transformer = isinstance(model.decoder, TransformerDecoder) + + att_vectors = None # for RNN only, not used for Transformer + hidden = None # for RNN only, not used for Transformer + trg_mask = None # for Transformer only, not used for RNN + + # Recurrent models only: initialize RNN hidden state + if not is_transformer: + # pylint: disable=protected-access + # tile encoder states and decoder initial states beam_size times + # `hidden` shape: (layers, batch_size * beam_size, dec_hidden_size) + hidden = model.decoder._init_hidden(encoder_hidden) + hidden = tile(hidden, beam_size, dim=1) + # DataParallel splits batch along the 0th dim. + # Place back the batch_size to the 1st dim here. + # `hidden` shape: (batch_size * beam_size, layers, dec_hidden_size) + if isinstance(hidden, tuple): + h, c = hidden + hidden = (h.permute(1, 0, 2), c.permute(1, 0, 2)) + else: + hidden = hidden.permute(1, 0, 2) + + # `encoder_output` shape: (batch_size * beam_size, src_len, enc_hidden_size) + encoder_output = tile(encoder_output.contiguous(), beam_size, dim=0) + # `src_mask` shape: (batch_size * beam_size, 1, src_len) + src_mask = tile(src_mask, beam_size, dim=0) + # `encoder_input` shape: (batch_size * beam_size, src_len) + if encoder_input is not None: # used in src-side repetition blocker + encoder_input = tile(encoder_input.contiguous(), beam_size, + dim=0).view(batch_size * beam_size, -1) + assert encoder_input.size(0) == batch_size * beam_size, ( + encoder_input.size(0), + batch_size * beam_size, + ) + + # Transformer only: create target mask + if is_transformer: + trg_mask = src_mask.new_ones([1, 1, 1]) + if isinstance(model, torch.nn.DataParallel): + trg_mask = torch.stack( + [src_mask.new_ones([1, 1]) for _ in model.device_ids]) + + # numbering elements in the batch + batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) + + # numbering elements in the extended batch, i.e. k copies of each batch element + beam_offset = torch.arange(0, + batch_size * beam_size, + step=beam_size, + dtype=torch.long, + device=device) + + # keeps track of the top beam size hypotheses to expand for each element in the + # batch to be further decoded (that are still "alive") + # `alive_seq` shape: (batch_size * beam_size, hyp_len) ... now hyp_len = 1 + alive_seq = torch.full((batch_size * beam_size, 1), + bos_index, + dtype=torch.long, + device=device) + + # Give full probability (=zero in log space) to the first beam on the first step. + # `topk_log_probs` shape: (batch_size, beam_size) + topk_log_probs = torch.zeros(batch_size, beam_size, device=device) + topk_log_probs[:, 1:] = float("-inf") + + # Structure that holds finished hypotheses. + hypotheses = [[] for _ in range(batch_size)] + + results = { + "predictions": [[] for _ in range(batch_size)], + "scores": [[] for _ in range(batch_size)], + } + + # indicator if the generation is finished + # `is_finished` shape: (batch_size, beam_size) + is_finished = torch.full((batch_size, beam_size), + False, + dtype=torch.bool, + device=device) + + for step in range(max_output_length): + if is_transformer: + # For Transformer, we feed the complete predicted sentence so far. + decoder_input = alive_seq + + # decode one single step + with torch.autocast(device_type=device.type, enabled=fp16): + with torch.no_grad(): + logits, _, _, _ = model( # logits before final softmax + return_type="decode", + encoder_output=encoder_output, + encoder_hidden=None, # only for initializing decoder_hidden + src_mask=src_mask, + trg_input=decoder_input, # trg_embed = embed(decoder_input) + decoder_hidden=None, # don't need to keep it for transformer + att_vector=None, # don't need to keep it for transformer + unroll_steps=1, + trg_mask=trg_mask, # subsequent mask for Transformer only + ) + + # For the Transformer we made predictions for all time steps up to this + # point, so we only want to know about the last time step. + logits = logits[:, -1] + hidden = None + else: + # For Recurrent models, only feed the previous trg word prediction + decoder_input = alive_seq[:, -1].view(-1, 1) # only the last word + + with torch.autocast(device_type=device.type, enabled=fp16): + with torch.no_grad(): + # pylint: disable=unused-variable + logits, hidden, att_scores, att_vectors = model( + return_type="decode", + encoder_output=encoder_output, + encoder_hidden=None, # only for initializing decoder_hidden + src_mask=src_mask, + trg_input=decoder_input, # trg_embed = embed(decoder_input) + decoder_hidden=hidden, + att_vector=att_vectors, + unroll_steps=1, + trg_mask=None, # subsequent mask for Transformer only + ) + + # compute log probability distribution over trg vocab + # `log_probs` shape: (remaining_batch_size * beam_size, trg_vocab) + log_probs = F.log_softmax(logits, dim=-1).squeeze(1) + if not generate_unk: + log_probs[:, unk_index] = float("-inf") + + # don't generate EOS until we reached min_output_length + if step < min_output_length: + log_probs[:, eos_index] = float("-inf") + + # block repetitions + if no_repeat_ngram_size > 0: + log_probs = block_repeat_ngrams( + alive_seq, + log_probs, + no_repeat_ngram_size, + step, + src_tokens=encoder_input, + exclude_tokens=[bos_index, eos_index, unk_index, pad_index], + ) + + if repetition_penalty > 1.0: + log_probs = penalize_repetition( + alive_seq, + log_probs, + repetition_penalty, + exclude_tokens=[bos_index, eos_index, unk_index, pad_index], + ) + if encoder_input is not None: # src + log_probs = penalize_repetition( + encoder_input, + log_probs, + repetition_penalty, + exclude_tokens=[bos_index, eos_index, unk_index, pad_index], + ) + + # multiply probs by the beam probability (=add logprobs) + # `log_probs` shape: (remaining_batch_size * beam_size, trg_vocab) + log_probs += topk_log_probs.view(-1).unsqueeze(1) # add column-wise + curr_scores = log_probs.clone() + + # compute length penalty + if alpha > 0: + length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha + curr_scores /= length_penalty + + # flatten log_probs into a list of possibilities + # `curr_scores` shape: (remaining_batch_size, beam_size * trg_vocab_size) + curr_scores = curr_scores.reshape(-1, beam_size * trg_vocab_size) + + # pick currently best top k hypotheses (flattened order) + # `topk_scores` and `topk_ids` shape: (remaining_batch_size, beam_size) + topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) + + if alpha > 0: + # recover original log probs + topk_log_probs = topk_scores * length_penalty + else: + topk_log_probs = topk_scores.clone() + + # reconstruct beam origin and true word ids from flattened order + topk_beam_index = topk_ids.div(trg_vocab_size, rounding_mode="floor") + topk_ids = topk_ids.fmod(trg_vocab_size) # resolve true word ids + + # map topk_beam_index to batch_index in the flat representation + # `batch_index` shape: (remaining_batch_size, beam_size) + batch_index = topk_beam_index + beam_offset[:topk_ids.size(0)].unsqueeze(1) + select_indices = batch_index.view(-1) + + # append latest prediction + # `alive_seq` shape: (remaining_batch_size * beam_size, hyp_len) + alive_seq = torch.cat( + [alive_seq.index_select(0, select_indices), + topk_ids.view(-1, 1)], + -1, + ) + # `is_finished` shape: (remaining_batch_size, beam_size) + is_finished = topk_ids.eq(eos_index) | is_finished | topk_scores.eq(-np.inf) + if step + 1 == max_output_length: + is_finished.fill_(True) + + # end condition is whether all beam candidates in each example are finished + end_condition = is_finished.all(-1) # shape: (remaining_batch_size,) + + # save finished hypotheses + if is_finished.any(): + # `predictions` shape: (remaining_batch_size, beam_size, hyp_len) + predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) + + for i in range(is_finished.size(0)): # loop over remaining examples + b = batch_offset[i].item() # index of that example in the batch + if end_condition[i]: + is_finished[i].fill_(True) + # indices of finished beam candidates for this example (1d tensor) + # i.e. finished_hyp = [0, 1] means 0th and 1st candidates reached eos + finished_hyp = is_finished[i].nonzero(as_tuple=False).view(-1) + for j in finished_hyp: # loop over finished beam candidates + n_eos = (predictions[i, j, 1:] == eos_index).count_nonzero().item() + if n_eos > 1: # pylint: disable=no-else-continue + # If the prediction has more than one EOS, it means that the + # prediction should have already been added to the hypotheses, + # so we don't add them again. + continue + elif (n_eos == 0 and step + 1 == max_output_length) or ( + n_eos == 1 and predictions[i, j, -1] == eos_index): + # If the prediction has no EOS, it means we reached max length. + # If the prediction has exactly one EOS, it should be the last + # token of the sequence. Then we add it to the hypotheses. + # We exclude the candidate which has one EOS but some other + # token was appended after EOS. + hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:])) + + # if all nbest candidates of the i-th example reached the end, save them + if end_condition[i]: + best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) + for n, (score, pred) in enumerate(best_hyp): + if n >= n_best: + break + if len(pred) < max_output_length: + assert ( + pred[-1] == eos_index + ), f"adding a candidate which doesn't end with eos: {pred}" + results["scores"][b].append(score) + results["predictions"][b].append(pred) + + # batch indices of the examples which contain unfinished candidates + unfinished = end_condition.eq(False).nonzero(as_tuple=False).view(-1) + # if all examples are translated, no need to go further + if len(unfinished) == 0: + break + # remove finished examples for the next step + # shape: (remaining_batch_size, beam_size) + batch_index = batch_index.index_select(0, unfinished) + topk_log_probs = topk_log_probs.index_select(0, unfinished) + is_finished = is_finished.index_select(0, unfinished) + batch_offset = batch_offset.index_select(0, unfinished) + + # **CAUTION:** `alive_seq` still can contain finished beam candidates + # because we only remove finished examples. For instance, beam_size = 3, + # 2 sents remain in batch and all 3 candidates of the 1st sent is finished. + # end_condition = [True, False] + # unfinished = [1] # 2nd sent (idx 1) remaining in batch is unfinished + # Say, the first and the second beam candidate of the second example are + # finished but the third candidate of the second example is still alive. + # Then we include all three candidates of the second example in `alive_seq`, + # even though the 1st and 2nd candidates of the second example are finished. + # alive_seq = [ + # [5, 8, 7, 3, 1], # eos_index = 3; already finished in prev step + # [4, 9, 6, 5, 3], # eos_index = 3; finished in this step + # [4, 9, 5, 6, 7], # not finished yet + # ] + # Yet, we won't add already finished candidates to the `hypotheses` list, + # but only the candidates that finished in the very current time step. + # TODO: release the space of finished ones, explore more unfinished ones. + # `alive_seq` shape: (remaining_batch_size * beam_size, hyp_len) + alive_seq = predictions.index_select(0, unfinished).view( + -1, alive_seq.size(-1)) + if encoder_input is not None: + src_len = encoder_input.size(-1) + encoder_input = encoder_input.view(-1, beam_size, src_len) \ + .index_select(0, unfinished).view(-1, src_len) + assert encoder_input.size(0) == alive_seq.size(0) + + # reorder indices, outputs and masks + select_indices = batch_index.view(-1) + encoder_output = encoder_output.index_select(0, select_indices) + src_mask = src_mask.index_select(0, select_indices) + + if hidden is not None and not is_transformer: + if isinstance(hidden, tuple): + # for LSTMs, states are tuples of tensors + h, c = hidden + h = h.index_select(0, select_indices) + c = c.index_select(0, select_indices) + hidden = (h, c) + else: + # for GRUs, states are single tensors + hidden = hidden.index_select(0, select_indices) + + if att_vectors is not None: + att_vectors = att_vectors.index_select(0, select_indices) + + # if num_predictions < n_best, fill the results list up with UNK. + for b in range(batch_size): + num_predictions = len(results["predictions"][b]) + num_scores = len(results["scores"][b]) + assert num_predictions == num_scores + for _ in range(n_best - num_predictions): + results["predictions"][b].append(torch.tensor([unk_index]).long()) + results["scores"][b].append(torch.tensor([-1]).float()) + assert len(results["predictions"][b]) == n_best + assert len(results["scores"][b]) == n_best + + def pad_and_stack_hyps(hyps: List[np.ndarray]): + max_len = max([hyp.shape[0] for hyp in hyps]) + filled = torch.ones((len(hyps), max_len), dtype=torch.int64) * pad_index + for j, h in enumerate(hyps): + for k, i in enumerate(h): + filled[j, k] = i + return filled + + # from results to stacked outputs + # `final_outputs`: shape (batch_size * n_best, hyp_len) + predictions_list = [u.cpu().float() for r in results["predictions"] for u in r] + final_outputs = pad_and_stack_hyps(predictions_list) + + # sequence-wise log probabilities (summed up over the sequence) + # `scores`: shape (batch_size * n_best, 1) + scores = torch.tensor([[u.item()] for r in results["scores"] for u in r]) \ + if return_prob else None + + assert final_outputs.shape[0] == batch_size * n_best + return final_outputs, scores, None + + +def search( + model: Model, + batch: Batch, + max_output_length: int, + beam_size: int, + beam_alpha: float, + n_best: int = 1, + **kwargs, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get outputs and attentions scores for a given batch. + + :param model: Model class + :param batch: batch to generate hypotheses for + :param max_output_length: maximum length of hypotheses + :param beam_size: size of the beam for beam search, if 0 use greedy + :param beam_alpha: alpha value for beam search + :param n_best: candidates to return + :returns: + - stacked_output: hypotheses for batch, + - stacked_scores: log probabilities for batch, + - stacked_attention_scores: attention scores for batch + """ + device = batch.src.device + fp16: bool = kwargs.get("fp16", False) + with torch.autocast(device_type=device.type, enabled=fp16): + with torch.no_grad(): + encoder_output, encoder_hidden, src_mask, _ = model(return_type="encode", + **vars(batch)) + src_mask = src_mask if batch.src_mask is None else batch.src_mask + assert src_mask is not None + + # if maximum output length is not globally specified, adapt to src len + if max_output_length < 0: + max_output_length = int(max(batch.src_length.cpu().numpy()) * 1.5) + + # block src-side repetition (to avoid untranslated copy in trg) + if (kwargs.get("no_repeat_ngram_size", -1) > 1 + or kwargs.get("repetition_penalty", -1) > 1): + kwargs["encoder_input"] = batch.src + + # decoding + if beam_size < 2: # greedy + stacked_output, stacked_scores, stacked_attention_scores = greedy( + src_mask=src_mask, + max_output_length=max_output_length, + model=model, + encoder_output=encoder_output, + encoder_hidden=encoder_hidden, + **kwargs, + ) + + else: # beam search + stacked_output, stacked_scores, stacked_attention_scores = beam_search( + model=model, + beam_size=beam_size, + encoder_output=encoder_output, + encoder_hidden=encoder_hidden, + src_mask=src_mask, + max_output_length=max_output_length, + alpha=beam_alpha, + n_best=n_best, + **kwargs, + ) + + # cast to numpy nd.array + def _to_numpy(t: Tensor): + if torch.is_tensor(t): + return t.detach().cpu().numpy() + return t # if t is None + + return ( + _to_numpy(stacked_output), + _to_numpy(stacked_scores), + _to_numpy(stacked_attention_scores), + ) + + +def block_repeat_ngrams(tokens: Tensor, scores: Tensor, no_repeat_ngram_size: int, + step: int, **kwargs) -> Tensor: + """ + For each hypothesis, check a list of previous ngrams and set associated log probs + to -inf. Taken from fairseq's NGramRepeatBlock. + + :param tokens: target tokens generated so far + :param scores: log probabilities of the next token to generate in this time step + :param no_repeat_ngram_size: ngram size to prohibit + :param step: generation step (= length of hypotheses so far) + """ + hyp_size = tokens.size(0) + banned_batch_tokens = [set([]) for _ in range(hyp_size)] + + trg_tokens = tokens.cpu().tolist() + check_end_pos = step + 2 - no_repeat_ngram_size + offset = no_repeat_ngram_size - 1 + + src_tokens = kwargs.get("src_tokens", None) + if src_tokens is not None: + src_length = src_tokens.size(-1) + assert src_tokens.size(0) == hyp_size, (src_tokens.size(), hyp_size) + src_tokens = src_tokens.cpu().tolist() + exclude_tokens = kwargs.get("exclude_tokens", []) + + # get repeated ngrams + for hyp_idx in range(hyp_size): + if len(trg_tokens[hyp_idx]) > no_repeat_ngram_size: + # (n-1) token prefix at the time step + # 0 1 2 3 4 <- step + # if tokens[hyp_idx] = [2, 5, 5, 6, 5] at step 4 with ngram_size = 3, + # ^ ^ ^ + # then ngram_to_check = [6, 5], and set the token in the next position to + # -inf, if there are ngrams starts with [6, 5]. + ngram_to_check = trg_tokens[hyp_idx][-offset:] + + for i in range(1, check_end_pos): # ignore BOS + if ngram_to_check == trg_tokens[hyp_idx][i:i + offset]: + banned_batch_tokens[hyp_idx].add(trg_tokens[hyp_idx][i + offset]) + + # src_tokens + if src_tokens is not None: + check_end_pos_src = src_length + 1 - no_repeat_ngram_size + for i in range(check_end_pos_src): # no BOS in src + if ngram_to_check == src_tokens[hyp_idx][i:i + offset]: + banned_batch_tokens[hyp_idx].add(src_tokens[hyp_idx][i + + offset]) + + # set the score of the banned tokens to -inf + for i, banned_tokens in enumerate(banned_batch_tokens): + banned_tokens = set(banned_tokens) - set(exclude_tokens) + scores[i, list(banned_tokens)] = float("-inf") + return scores + + +def penalize_repetition(tokens: Tensor, + scores: Tensor, + penalty: float, + exclude_tokens: List[int] = None) -> Tensor: + """ + Reduce probability of the given tokens. + Taken from Huggingface's RepetitionPenaltyLogitsProcessor. + + :param tokens: token ids to penalize + :param scores: log probabilities of the next token to generate + :param penalty: penalty value, bigger value implies less probability + :param exclude_tokens: list of token ids to exclude from penalizing + """ + scores_before = scores if exclude_tokens else None + score = torch.gather(scores, 1, tokens) + + # if score < 0 then repetition penalty has to be multiplied + # to reduce the previous token probability + score = torch.where(score < 0, score * penalty, score / penalty) + + scores.scatter_(1, tokens, score) + + # exclude special tokens + if exclude_tokens: + for token in exclude_tokens: + # pylint: disable=unsubscriptable-object + scores[:, token] = scores_before[:, token] + return scores diff --git a/joeynmt/tokenizers.py b/joeynmt/tokenizers.py new file mode 100644 index 0000000..0685c9d --- /dev/null +++ b/joeynmt/tokenizers.py @@ -0,0 +1,561 @@ +# coding: utf-8 +""" +Tokenizer module +""" +import logging +import shutil +from pathlib import Path +from typing import Callable, Dict, List, Union + +import numpy as np +import sentencepiece as sp +from sacrebleu.metrics.bleu import _get_tokenizer +from subword_nmt import apply_bpe + +from joeynmt.constants import BOS_TOKEN, EOS_TOKEN, PAD_TOKEN, UNK_TOKEN +from joeynmt.data_augmentation import CMVN, SpecAugment +from joeynmt.helpers import ( + ConfigurationError, + remove_extra_spaces, + remove_punctuation, + unicode_normalize, +) +from joeynmt.helpers_for_pose import get_features + +logger = logging.getLogger(__name__) + + +class BasicTokenizer: + SPACE = chr(32) # ' ': half-width white space (ascii) + SPACE_ESCAPE = chr(9601) # '▁': sentencepiece default + SPECIALS = [BOS_TOKEN, EOS_TOKEN, PAD_TOKEN, UNK_TOKEN] + + def __init__( + self, + level: str = "word", + lowercase: bool = False, + normalize: bool = False, + max_length: int = -1, + min_length: int = -1, + **kwargs, + ): + # pylint: disable=unused-argument + self.level = level + self.lowercase = lowercase + self.normalize = normalize + + # filter by length + self.max_length = max_length + self.min_length = min_length + + # pretokenizer + self.pretokenizer = kwargs.get("pretokenizer", "none").lower() + assert self.pretokenizer in ["none", "moses"], \ + "Currently, we support moses tokenizer only." + # sacremoses + if self.pretokenizer == "moses": + try: + from sacremoses import ( # pylint: disable=import-outside-toplevel + MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer, + ) + + # sacremoses package has to be installed. + # https://github.com/alvations/sacremoses + except ImportError as e: + logger.error(e) + raise ImportError from e + + self.lang = kwargs.get("lang", "en") + self.moses_tokenizer = MosesTokenizer(lang=self.lang) + self.moses_detokenizer = MosesDetokenizer(lang=self.lang) + if self.normalize: + self.moses_normalizer = MosesPunctNormalizer() + + def pre_process(self, raw_input: str) -> str: + """ + Pre-process text + - ex.) Lowercase, Normalize, Remove emojis, + Pre-tokenize(add extra white space before punc) etc. + - applied for all inputs both in training and inference. + """ + assert isinstance(raw_input, str) and raw_input.strip() != "", \ + "The input sentence is empty! Please make sure " \ + "that you are feeding a valid input." + + if self.normalize: + raw_input = remove_extra_spaces(unicode_normalize(raw_input)) + + if self.pretokenizer == "moses": + if self.normalize: + raw_input = self.moses_normalizer.normalize(raw_input) + raw_input = self.moses_tokenizer.tokenize(raw_input, return_str=True) + + if self.lowercase: + raw_input = raw_input.lower() + + # ensure the string is not empty. + assert raw_input is not None and len(raw_input) > 0, raw_input + return raw_input + + def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: + """Tokenize single sentence""" + if self.level == "word": + sequence = raw_input.split(self.SPACE) + elif self.level == "char": + sequence = list(raw_input.replace(self.SPACE, self.SPACE_ESCAPE)) + + if is_train and self._filter_by_length(len(sequence)): + return None + return sequence + + def _filter_by_length(self, length: int) -> bool: + """ + Check if the given seq length is out of the valid range. + + :param length: (int) number of tokens + :return: True if the length is invalid(= to be filtered out), False if valid. + """ + return length > self.max_length > 0 or self.min_length > length > 0 + + def _remove_special(self, sequence: List[str], generate_unk: bool = False): + specials = self.SPECIALS[:-1] if generate_unk else self.SPECIALS + return [token for token in sequence if token not in specials] + + def post_process(self, + sequence: Union[List[str], str], + generate_unk: bool = True) -> str: + """Detokenize""" + if isinstance(sequence, list): + sequence = self._remove_special(sequence, generate_unk=generate_unk) + if self.level == "word": + if self.pretokenizer == "moses": + sequence = self.moses_detokenizer.detokenize(sequence) + else: + sequence = self.SPACE.join(sequence) + elif self.level == "char": + sequence = "".join(sequence).replace(self.SPACE_ESCAPE, self.SPACE) + + # Remove extra spaces + if self.normalize: + sequence = remove_extra_spaces(sequence) + + # ensure the string is not empty. + assert sequence is not None and len(sequence) > 0, sequence + return sequence + + def set_vocab(self, itos: List[str]) -> None: + """ + Set vocab + :param itos: (list) indices-to-symbols mapping + """ + pass # pylint: disable=unnecessary-pass + + def __repr__(self): + return (f"{self.__class__.__name__}(level={self.level}, " + f"lowercase={self.lowercase}, normalize={self.normalize}, " + f"filter_by_length=({self.min_length}, {self.max_length}), " + f"pretokenizer={self.pretokenizer})") + + +class SentencePieceTokenizer(BasicTokenizer): + + def __init__( + self, + level: str = "bpe", + lowercase: bool = False, + normalize: bool = False, + max_length: int = -1, + min_length: int = -1, + **kwargs, + ): + super().__init__(level, lowercase, normalize, max_length, min_length, **kwargs) + assert self.level == "bpe" + + self.model_file: Path = Path(kwargs["model_file"]) + assert self.model_file.is_file(), f"model file {self.model_file} not found." + + self.spm = sp.SentencePieceProcessor() + self.spm.load(kwargs["model_file"]) + + self.nbest_size: int = kwargs.get("nbest_size", 5) + self.alpha: float = kwargs.get("alpha", 0.0) + + def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: + """Tokenize""" + if is_train and self.alpha > 0: + tokenized = self.spm.sample_encode_as_pieces( + raw_input, + nbest_size=self.nbest_size, + alpha=self.alpha, + ) + else: + tokenized = self.spm.encode(raw_input, out_type=str) + + if is_train and self._filter_by_length(len(tokenized)): + return None + return tokenized + + def post_process(self, + sequence: Union[List[str], str], + generate_unk: bool = True) -> str: + """Detokenize""" + if isinstance(sequence, list): + sequence = self._remove_special(sequence, generate_unk=generate_unk) + + # Decode back to str + sequence = self.spm.decode(sequence) + sequence = sequence.replace(self.SPACE_ESCAPE, self.SPACE).strip() + + # Apply moses detokenizer + if self.pretokenizer == "moses": + sequence = self.moses_detokenizer.detokenize(sequence.split()) + + # Remove extra spaces + if self.normalize: + sequence = remove_extra_spaces(sequence) + + # ensure the string is not empty. + assert sequence is not None and len(sequence) > 0, sequence + return sequence + + def set_vocab(self, itos: List[str]) -> None: + """Set vocab""" + self.spm.SetVocabulary(itos) + + def copy_cfg_file(self, model_dir: Path) -> None: + """Copy confg file to model_dir""" + if (model_dir / self.model_file.name).is_file(): + logger.warning( + "%s already exists. Stop copying.", + (model_dir / self.model_file.name).as_posix(), + ) + shutil.copy2(self.model_file, (model_dir / self.model_file.name).as_posix()) + + def __repr__(self): + return (f"{self.__class__.__name__}(level={self.level}, " + f"lowercase={self.lowercase}, normalize={self.normalize}, " + f"filter_by_length=({self.min_length}, {self.max_length}), " + f"pretokenizer={self.pretokenizer}, " + f"tokenizer={self.spm.__class__.__name__}, " + f"nbest_size={self.nbest_size}, alpha={self.alpha})") + + +class SubwordNMTTokenizer(BasicTokenizer): + + def __init__( + self, + level: str = "bpe", + lowercase: bool = False, + normalize: bool = False, + max_length: int = -1, + min_length: int = -1, + **kwargs, + ): + super().__init__(level, lowercase, normalize, max_length, min_length, **kwargs) + assert self.level == "bpe" + + self.codes: Path = Path(kwargs["codes"]) + assert self.codes.is_file(), f"codes file {self.codes} not found." + + self.separator: str = kwargs.get("separator", "@@") + bpe_parser = apply_bpe.create_parser() + bpe_args = bpe_parser.parse_args( + ["--codes", kwargs["codes"], "--separator", self.separator]) + self.bpe = apply_bpe.BPE( + bpe_args.codes, + bpe_args.merges, + bpe_args.separator, + None, + bpe_args.glossaries, + ) + self.dropout: float = kwargs.get("dropout", 0.0) + + def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: + """Tokenize""" + dropout = self.dropout if is_train else 0.0 + tokenized = self.bpe.process_line(raw_input, dropout).strip().split() + if is_train and self._filter_by_length(len(tokenized)): + return None + return tokenized + + def post_process(self, + sequence: Union[List[str], str], + generate_unk: bool = True) -> str: + """Detokenize""" + if isinstance(sequence, list): + sequence = self._remove_special(sequence, generate_unk=generate_unk) + + # Remove separators, join with spaces + sequence = self.SPACE.join(sequence).replace(self.separator + self.SPACE, + "") + # Remove final merge marker. + if sequence.endswith(self.separator): + sequence = sequence[:-len(self.separator)] + + # Moses detokenizer + if self.pretokenizer == "moses": + sequence = self.moses_detokenizer.detokenize(sequence.split()) + + # Remove extra spaces + if self.normalize: + sequence = remove_extra_spaces(sequence) + + # ensure the string is not empty. + assert sequence is not None and len(sequence) > 0, sequence + return sequence + + def set_vocab(self, itos: List[str]) -> None: + """Set vocab""" + vocab = set(itos) - set(self.SPECIALS) + self.bpe.vocab = vocab + + def copy_cfg_file(self, model_dir: Path) -> None: + """Copy confg file to model_dir""" + shutil.copy2(self.codes, (model_dir / self.codes.name).as_posix()) + + def __repr__(self): + return (f"{self.__class__.__name__}(level={self.level}, " + f"lowercase={self.lowercase}, normalize={self.normalize}, " + f"filter_by_length=({self.min_length}, {self.max_length}), " + f"pretokenizer={self.pretokenizer}, " + f"tokenizer={self.bpe.__class__.__name__}, " + f"separator={self.separator}, dropout={self.dropout})") + + +class FastBPETokenizer(SubwordNMTTokenizer): + + def __init__( + self, + level: str = "bpe", + lowercase: bool = False, + normalize: bool = False, + max_length: int = -1, + min_length: int = -1, + **kwargs, + ): + try: + import fastBPE # pylint: disable=import-outside-toplevel + except ImportError as e: + logger.error(e) + raise ImportError from e + super(SubwordNMTTokenizer, self).__init__(level, lowercase, normalize, + max_length, min_length, **kwargs) + assert self.level == "bpe" + + # set codes file path + self.codes: Path = Path(kwargs["codes"]) + assert self.codes.is_file(), f"codes file {self.codes} not found." + + # instantiate fastBPE object + self.bpe = fastBPE.fastBPE(self.codes.as_posix()) + self.separator = "@@" + self.dropout = 0.0 + + def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: + # fastBPE.apply() + tokenized = self.bpe.apply([raw_input]) + tokenized = tokenized[0].strip().split() + + # check if the input sequence length stays within the valid length range + if is_train and self._filter_by_length(len(tokenized)): + return None + return tokenized + + def set_vocab(self, itos: List[str]) -> None: + pass + + +class SpeechProcessor: + """SpeechProcessor""" + + def __init__( + self, + level: str = "frame", + num_freq: int = 80, + normalize: bool = False, + max_length: int = -1, + min_length: int = -1, + **kwargs, + ): + self.level = level + self.num_freq = num_freq + self.normalize = normalize + + # filter by length + self.max_length = max_length + self.min_length = min_length + + self.specaugment: Callable = SpecAugment(**kwargs["specaugment"]) \ + if "specaugment" in kwargs else None + self.cmvn: Callable = CMVN(**kwargs["cmvn"]) if "cmvn" in kwargs else None + self.root_path = "" # assigned later by dataset.__init__() + + def __call__(self, line: str, is_train: bool = False) -> np.ndarray: + """ + get features + + :param line: path to audio file or pre-extracted features + :param is_train: + + :return: spectrogram in shape (num_frames, num_freq) + """ + # lookup + item = get_features(self.root_path, line) # shape = (num_frames, num_freq) + + num_frames, num_freq = item.shape + assert num_freq == self.num_freq + + if self._filter_too_short_item(num_frames): + # A too short sequence cannot be convolved! + # -> filter out anyway even in test-dev set. + return None + if self._filter_too_long_item(num_frames): + # Don't use too long sequence in training. + if is_train: # pylint: disable=no-else-return + return None + else: # in test, truncate the sequence + item = item[:self.max_length, :] + num_frames = item.shape[0] + assert num_frames <= self.max_length + + # cmvn / specaugment + # pylint: disable=not-callable + if self.cmvn and self.cmvn.before: + item = self.cmvn(item) + if is_train and self.specaugment: + item = self.specaugment(item) + if self.cmvn and not self.cmvn.before: + item = self.cmvn(item) + return item + + def _filter_too_short_item(self, length: int) -> bool: + return self.min_length > length > 0 + + def _filter_too_long_item(self, length: int) -> bool: + return length > self.max_length > 0 + + def __repr__(self): + return (f"{self.__class__.__name__}(" + f"level={self.level}, normalize={self.normalize}, " + f"filter_by_length=({self.min_length}, {self.max_length}), " + f"cmvn={self.cmvn}, specaugment={self.specaugment})") + + +class EvaluationTokenizer(BasicTokenizer): + """A generic evaluation-time tokenizer, which leverages built-in tokenizers in + sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides + lowercasing, punctuation removal and character tokenization, which are applied + after sacreBLEU tokenization. + + :param level: (str) tokenization level. {"word", "bpe", "char"} + :param lowercase: (bool) lowercase the text. + :param tokenize: (str) the type of sacreBLEU tokenizer to apply. + """ + ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"] + + def __init__(self, lowercase: bool = False, tokenize: str = "13a", **kwargs): + super().__init__(level="word", + lowercase=lowercase, + normalize=False, + max_length=-1, + min_length=-1) + + assert tokenize in self.ALL_TOKENIZER_TYPES, f"`{tokenize}` not supported." + self.tokenizer = _get_tokenizer(tokenize)() + self.no_punc = kwargs.get("no_punc", False) + + def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: + tokenized = self.tokenizer(raw_input) + + if self.lowercase: + tokenized = tokenized.lower() + + # Remove punctuation (apply this after tokenization!) + if self.no_punc: + tokenized = remove_punctuation(tokenized, space=self.SPACE) + return tokenized.split() + + def __repr__(self): + return (f"{self.__class__.__name__}(level={self.level}, " + f"lowercase={self.lowercase}, " + f"tokenizer={self.tokenizer}, " + f"no_punc={self.no_punc})") + + +def _build_tokenizer(cfg: Dict) -> BasicTokenizer: + """Builds tokenizer.""" + tokenizer = None + tokenizer_cfg = cfg.get("tokenizer_cfg", {}) + + # assign lang for moses tokenizer + if tokenizer_cfg.get("pretokenizer", "none") == "moses": + tokenizer_cfg["lang"] = cfg["lang"] + + if cfg["level"] in ["word", "char"]: + tokenizer = BasicTokenizer( + level=cfg["level"], + lowercase=cfg.get("lowercase", False), + normalize=cfg.get("normalize", False), + max_length=cfg.get("max_length", -1), + min_length=cfg.get("min_length", -1), + **tokenizer_cfg, + ) + elif cfg["level"] == "bpe": + tokenizer_type = cfg.get("tokenizer_type", cfg.get("bpe_type", "sentencepiece")) + if tokenizer_type == "sentencepiece": + assert "model_file" in tokenizer_cfg + tokenizer = SentencePieceTokenizer( + level=cfg["level"], + lowercase=cfg.get("lowercase", False), + normalize=cfg.get("normalize", False), + max_length=cfg.get("max_length", -1), + min_length=cfg.get("min_length", -1), + **tokenizer_cfg, + ) + elif tokenizer_type == "subword-nmt": + assert "codes" in tokenizer_cfg + tokenizer = SubwordNMTTokenizer( + level=cfg["level"], + lowercase=cfg.get("lowercase", False), + normalize=cfg.get("normalize", False), + max_length=cfg.get("max_length", -1), + min_length=cfg.get("min_length", -1), + **tokenizer_cfg, + ) + elif tokenizer_type == "fastbpe": + assert "codes" in tokenizer_cfg + tokenizer = FastBPETokenizer( + level=cfg["level"], + lowercase=cfg.get("lowercase", False), + normalize=cfg.get("normalize", False), + max_length=cfg.get("max_length", -1), + min_length=cfg.get("min_length", -1), + **tokenizer_cfg, + ) + else: + raise ConfigurationError(f"{tokenizer_type}: Unknown tokenizer type.") + elif cfg["level"] == "frame": + tokenizer = SpeechProcessor( + level=cfg["level"], + num_freq=cfg["num_freq"], + normalize=cfg.get("normalize", False), + max_length=cfg.get("max_length", -1), + min_length=cfg.get("min_length", -1), + **tokenizer_cfg, + ) + else: + raise ConfigurationError(f"{cfg['level']}: Unknown tokenization level.") + return tokenizer + + +def build_tokenizer(data_cfg: Dict) -> Dict[str, BasicTokenizer]: + task = data_cfg.get("task", "MT").upper() + src_lang = data_cfg["src"]["lang"] if task == "MT" else "src" + trg_lang = data_cfg["trg"]["lang"] if task == "MT" else "trg" + tokenizer = { + src_lang: _build_tokenizer(data_cfg["src"]), + trg_lang: _build_tokenizer(data_cfg["trg"]), + } + logger.info("%s Tokenizer: %s", src_lang, tokenizer[src_lang]) + logger.info("%s Tokenizer: %s", trg_lang, tokenizer[trg_lang]) + return tokenizer diff --git a/joeynmt/training.py b/joeynmt/training.py new file mode 100644 index 0000000..93f955b --- /dev/null +++ b/joeynmt/training.py @@ -0,0 +1,894 @@ +# coding: utf-8 +""" +Training module +""" +import argparse +import heapq +import logging +import math +import shutil +import time +from collections import OrderedDict +from pathlib import Path +from typing import List, Tuple + +import torch +from torch.utils.data import Dataset +from torch.utils.tensorboard import SummaryWriter + +from joeynmt.batch import Batch +from joeynmt.builders import build_gradient_clipper, build_optimizer, build_scheduler +from joeynmt.data import load_data +from joeynmt.helpers import ( + check_version, + delete_ckpt, + load_checkpoint, + load_config, + log_cfg, + make_logger, + make_model_dir, + parse_train_args, + set_seed, + store_attention_plots, + symlink_update, + write_list_to_file, +) +from joeynmt.model import Model, _DataParallel, build_model +from joeynmt.prediction import predict, test + +logger = logging.getLogger(__name__) + + +class TrainManager: + """ + Manages training loop, validations, learning rate scheduling + and early stopping. + """ + + # pylint: disable=too-many-instance-attributes + + def __init__(self, model: Model, cfg: dict) -> None: + """ + Creates a new TrainManager for a model, specified as in configuration. + Note: no need to pass batch_class here. see make_data_iter() + + :param model: torch module defining the model + :param cfg: dictionary containing the training configurations + """ + self.task = cfg["data"]["task"].upper() + + ( # pylint: disable=unbalanced-tuple-unpacking + model_dir, + load_model, + load_encoder, + load_decoder, + loss_type, + ctc_weight, + label_smoothing, + normalization, + learning_rate_min, + keep_best_ckpts, + logging_freq, + validation_freq, + log_valid_sents, + early_stopping_metric, + seed, + shuffle, + epochs, + max_updates, + batch_size, + batch_type, + batch_multiplier, + device, + n_gpu, + num_workers, + fp16, + reset_best_ckpt, + reset_scheduler, + reset_optimizer, + reset_iter_state, + ) = parse_train_args(cfg["training"]) + + # logging and storing + self.model_dir = model_dir + self.tb_writer = SummaryWriter(log_dir=(model_dir / "tensorboard").as_posix()) + self.logging_freq = logging_freq + self.validation_freq = validation_freq + self.log_valid_sents = log_valid_sents + + # model + self.model = model + self.model.log_parameters_list() + self.model.loss_function = (loss_type, label_smoothing, ctc_weight) + logger.info(self.model) + + # CPU / GPU + self.device = device + self.n_gpu = n_gpu + self.num_workers = num_workers + if self.device.type == "cuda": + self.model.to(self.device) + + # optimization + self.clip_grad_fun = build_gradient_clipper(config=cfg["training"]) + self.optimizer = build_optimizer(config=cfg["training"], + parameters=self.model.parameters()) + + # fp16 + self.fp16: bool = fp16 # True or False for scaler + self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) + if self.fp16: + self.dtype = torch.float16 if self.device.type == "cuda" else torch.bfloat16 + else: + self.dtype = torch.get_default_dtype() + + # save/delete checkpoints + self.num_ckpts: int = keep_best_ckpts + self.ckpt_queue: List[Tuple[float, Path]] = [] # heap queue + + # early_stopping + self.early_stopping_metric = early_stopping_metric + # early_stopping_metric decides on how to find the early stopping point: ckpts + # are written when there's a new high/low score for this metric. If we schedule + # after loss/ppl, we want to minimize the score, else we want to maximize it. + if self.early_stopping_metric in ["ppl", "loss", "wer"]: # lower is better + self.minimize_metric = True + elif self.early_stopping_metric in ["acc", "bleu", "chrf"]: # higher is better + self.minimize_metric = False + + # learning rate scheduling + self.scheduler, self.scheduler_step_at = build_scheduler( + config=cfg["training"], + scheduler_mode="min" if self.minimize_metric else "max", + optimizer=self.optimizer, + hidden_size=cfg["model"]["encoder"]["hidden_size"], + ) + + # data & batch handling + self.seed = seed + self.shuffle = shuffle + self.epochs = epochs + self.max_updates = max_updates + self.max_updates = max_updates + self.batch_size = batch_size + self.batch_type = batch_type + self.learning_rate_min = learning_rate_min + self.batch_multiplier = batch_multiplier + self.normalization = normalization + + # Placeholder so that we can use the train_iter in other functions. + self.train_iter, self.train_iter_state = None, None + + # initialize training statistics + self.stats = self.TrainStatistics( + steps=0, + is_min_lr=False, + is_max_update=False, + total_tokens=0, + best_ckpt_iter=0, + best_ckpt_score=float("inf") if self.minimize_metric else float("-inf"), + minimize_metric=self.minimize_metric, + total_correct=0, + ) + + # load model parameters + if load_model is not None: + self.init_from_checkpoint( + load_model, + reset_best_ckpt=reset_best_ckpt, + reset_scheduler=reset_scheduler, + reset_optimizer=reset_optimizer, + reset_iter_state=reset_iter_state, + ) + for layer_name, load_path in [ + ("encoder", load_encoder), + ("decoder", load_decoder), + ]: + if load_path is not None: + self.init_layers(path=load_path, layer=layer_name) + + # gpu training + if self.n_gpu > 1: + self.model = _DataParallel(self.model) + + # config for generation + self.valid_cfg = cfg["testing"].copy() + self.valid_cfg["beam_size"] = 1 # greedy decoding during train loop + # in greedy decoding, we use the same batch_size as the one in training + self.valid_cfg["batch_size"] = self.batch_size + self.valid_cfg["batch_type"] = self.batch_type + # no further exploration during training + self.valid_cfg["n_best"] = 1 + # self.valid_cfg["return_attention"] = False # don't override this param + self.valid_cfg["return_prob"] = "none" + self.valid_cfg["generate_unk"] = True + self.valid_cfg["repetition_penalty"] = -1 # turn off + self.valid_cfg["no_repeat_ngram_size"] = -1 # turn off + + def _save_checkpoint(self, new_best: bool, score: float) -> None: + """ + Save the model's current parameters and the training state to a checkpoint. + + The training state contains the total number of training steps, the total number + of training tokens, the best checkpoint score and iteration so far, and + optimizer and scheduler states. + + :param new_best: This boolean signals which symlink we will use for the new + checkpoint. If it is true, we update best.ckpt. + :param score: Validation score which is used as key of heap queue. if score is + float('nan'), the queue won't be updated. + """ + model_path = Path(self.model_dir) / f"{self.stats.steps}.ckpt" + model_state_dict = (self.model.module.state_dict() if isinstance( + self.model, torch.nn.DataParallel) else self.model.state_dict()) + train_iter_state = self.train_iter.batch_sampler.sampler.generator.get_state() \ + if hasattr(self.train_iter.batch_sampler.sampler, 'generator') else None + # yapf: disable + state = { + "steps": self.stats.steps, + "total_tokens": self.stats.total_tokens, + "best_ckpt_score": self.stats.best_ckpt_score, + "best_ckpt_iteration": self.stats.best_ckpt_iter, + "model_state": model_state_dict, + "optimizer_state": self.optimizer.state_dict(), + "scaler_state": self.scaler.state_dict(), + "scheduler_state": (self.scheduler.state_dict() + if self.scheduler is not None else None), + "train_iter_state": train_iter_state, + "total_correct": self.stats.total_correct, + } + torch.save(state, model_path.as_posix()) + + # update symlink + symlink_target = Path(f"{self.stats.steps}.ckpt") + # last symlink + last_path = Path(self.model_dir) / "latest.ckpt" + prev_path = symlink_update(symlink_target, last_path) # update always + # best symlink + best_path = Path(self.model_dir) / "best.ckpt" + if new_best: + prev_path = symlink_update(symlink_target, best_path) + assert best_path.resolve().stem == str(self.stats.best_ckpt_iter) + + # push to and pop from the heap queue + to_delete = None + if not math.isnan(score) and self.num_ckpts > 0: + if len(self.ckpt_queue) < self.num_ckpts: # no pop, push only + heapq.heappush(self.ckpt_queue, (score, model_path)) + else: # push + pop the worst one in the queue + if self.minimize_metric: + # pylint: disable=protected-access + heapq._heapify_max(self.ckpt_queue) + to_delete = heapq._heappop_max(self.ckpt_queue) + heapq.heappush(self.ckpt_queue, (score, model_path)) + # pylint: enable=protected-access + else: + to_delete = heapq.heappushpop(self.ckpt_queue, (score, model_path)) + + if to_delete is not None: + assert to_delete[1] != model_path # don't delete the last ckpt + if to_delete[1].stem != best_path.resolve().stem: + delete_ckpt(to_delete[1]) # don't delete the best ckpt + + assert len(self.ckpt_queue) <= self.num_ckpts + + # remove old symlink target if not in queue after push/pop + if prev_path is not None and prev_path.stem not in [ + c[1].stem for c in self.ckpt_queue + ]: + delete_ckpt(prev_path) + + def init_from_checkpoint( + self, + path: Path, + reset_best_ckpt: bool = False, + reset_scheduler: bool = False, + reset_optimizer: bool = False, + reset_iter_state: bool = False, + ) -> None: + """ + Initialize the trainer from a given checkpoint file. + + This checkpoint file contains not only model parameters, but also + scheduler and optimizer states, see `self._save_checkpoint`. + + :param path: path to checkpoint + :param reset_best_ckpt: reset tracking of the best checkpoint, + use for domain adaptation with a new dev + set or when using a new metric for fine-tuning. + :param reset_scheduler: reset the learning rate scheduler, and do not + use the one stored in the checkpoint. + :param reset_optimizer: reset the optimizer, and do not use the one + stored in the checkpoint. + :param reset_iter_state: reset the sampler's internal state and do not + use the one stored in the checkpoint. + """ + logger.info("Loading model from %s", path) + model_checkpoint = load_checkpoint(path=path, device=self.device) + + # restore model and optimizer parameters + self.model.load_state_dict(model_checkpoint["model_state"]) + + if not reset_optimizer: + self.optimizer.load_state_dict(model_checkpoint["optimizer_state"]) + if "scaler_state" in model_checkpoint: + self.scaler.load_state_dict(model_checkpoint["scaler_state"]) + else: + logger.info("Reset optimizer.") + + if not reset_scheduler: + if (model_checkpoint["scheduler_state"] is not None + and self.scheduler is not None): + self.scheduler.load_state_dict(model_checkpoint["scheduler_state"]) + else: + logger.info("Reset scheduler.") + + if not reset_best_ckpt: + self.stats.best_ckpt_score = model_checkpoint["best_ckpt_score"] + self.stats.best_ckpt_iter = model_checkpoint["best_ckpt_iteration"] + else: + logger.info("Reset tracking of the best checkpoint.") + + if not reset_iter_state: + # restore counters + assert "train_iter_state" in model_checkpoint + self.stats.steps = model_checkpoint["steps"] + self.stats.total_tokens = model_checkpoint["total_tokens"] + self.stats.total_correct = model_checkpoint.get("total_correct", 0) + self.train_iter_state = model_checkpoint["train_iter_state"] + else: + # reset counters if explicitly 'train_iter_state: True' in config + logger.info("Reset data iterator (random seed: {%d}).", self.seed) + + # move to gpu + if self.device.type == "cuda": + self.model.to(self.device) + + def init_layers(self, path: Path, layer: str) -> None: + """ + Initialize encoder decoder layers from a given checkpoint file. + + :param path: path to checkpoint + :param layer: layer name; 'encoder' or 'decoder' expected + """ + assert path is not None + layer_state_dict = OrderedDict() + logger.info("Loading %s laysers from %s", layer, path) + ckpt = load_checkpoint(path=path, device=self.device) + for k, v in ckpt["model_state"].items(): + if k.startswith(layer): + layer_state_dict[k] = v + self.model.load_state_dict(layer_state_dict, strict=False) + + def train_and_validate(self, train_data: Dataset, valid_data: Dataset) -> None: + """ + Train the model and validate it from time to time on the validation set. + + :param train_data: training data + :param valid_data: validation data + """ + # pylint: disable=too-many-branches,too-many-statements + self.train_iter = train_data.make_iter( + batch_size=self.batch_size, + batch_type=self.batch_type, + seed=self.seed, + shuffle=self.shuffle, + num_workers=self.num_workers, + device=self.device, + pad_index=self.model.pad_index, + ) + + if self.train_iter_state is not None: + self.train_iter.batch_sampler.sampler.generator.set_state( + self.train_iter_state.cpu()) + + ################################################################# + # simplify accumulation logic: + ################################################################# + # for epoch in range(epochs): + # self.model.zero_grad() + # epoch_loss = 0.0 + # batch_loss = 0.0 + # for i, batch in enumerate(self.train_iter): + # + # # gradient accumulation: + # # loss.backward() inside _train_step() + # batch_loss += self._train_step(inputs) + # + # if (i + 1) % self.batch_multiplier == 0: + # self.optimizer.step() # update! + # self.model.zero_grad() # reset gradients + # self.steps += 1 # increment counter + # + # epoch_loss += batch_loss # accumulate batch loss + # batch_loss = 0 # reset batch loss + # + # # leftovers are just ignored. + ################################################################# + + logger.info( + "Train stats:\n" + "\tdevice: %s\n" + "\tn_gpu: %d\n" + "\t16-bits training: %r\n" + "\tgradient accumulation: %d\n" + "\tbatch size per device: %d\n" + "\teffective batch size (w. parallel & accumulation): %d", + self.device.type, # next(self.model.parameters()).device + self.n_gpu, + self.fp16, + self.batch_multiplier, + self.batch_size // self.n_gpu if self.n_gpu > 1 else self.batch_size, + self.batch_size * self.batch_multiplier, + ) + + # detect NaN gradients + torch.autograd.set_detect_anomaly(True) + + try: + for epoch_no in range(self.epochs): + logger.info("EPOCH %d", epoch_no + 1) + + if self.scheduler_step_at == "epoch": + self.scheduler.step(epoch=epoch_no) + + self.model.train() + + # Reset statistics for each epoch. + start = time.time() + total_valid_duration = 0 + start_tokens = self.stats.total_tokens + start_correct = self.stats.total_correct + self.model.zero_grad(set_to_none=True) + epoch_loss = 0 + total_batch_loss = 0 + total_nll_loss = 0 + total_ctc_loss = 0 + total_nseqs = 0 + total_ntokens = 0 + + # subsample train data each epoch + if train_data.random_subset > 0: + try: + train_data.reset_random_subset() + train_data.sample_random_subset(seed=epoch_no) + logger.info( + "Sample random subset from dev set: n=%d, seed=%d", + len(train_data), + epoch_no, + ) + except AssertionError as e: + logger.warning(e) + + batch: Batch # yield a joeynmt Batch object + for i, batch in enumerate(self.train_iter): + # sort batch now by src length and keep track of order + batch.sort_by_src_length() + + # get batch loss + + (norm_batch_loss, norm_nll_loss, norm_ctc_loss, + n_correct) = self._train_step(batch) + total_batch_loss += norm_batch_loss + total_nll_loss += norm_nll_loss if norm_nll_loss is not None else 0 + total_ctc_loss += norm_ctc_loss if norm_ctc_loss is not None else 0 + + # increment seq/token counter + total_nseqs += batch.nseqs + total_ntokens += batch.ntokens + self.stats.total_tokens += batch.ntokens + self.stats.total_correct += n_correct + + # update! + if (i + 1) % self.batch_multiplier == 0: + # clip gradients (in-place) + if self.clip_grad_fun is not None: + self.clip_grad_fun(parameters=self.model.parameters()) + + # make gradient step + self.scaler.step(self.optimizer) + self.scaler.update() + + # decay lr + if self.scheduler_step_at == "step": + self.scheduler.step(self.stats.steps) + + # reset gradients + self.model.zero_grad(set_to_none=True) + + # increment step counter + self.stats.steps += 1 + if self.stats.steps >= self.max_updates: + self.stats.is_max_update = True + + # log learning progress + if self.stats.steps % self.logging_freq == 0: + elapsed = time.time() - start - total_valid_duration + elapsed_tok = self.stats.total_tokens - start_tokens + elapsed_correct = self.stats.total_correct - start_correct + self.tb_writer.add_scalar("train/batch_loss", + total_batch_loss, + self.stats.steps) + if total_nll_loss != 0: + self.tb_writer.add_scalar("train/batch_nll_loss", + total_nll_loss, + self.stats.steps) + if total_ctc_loss != 0: + self.tb_writer.add_scalar("train/batch_ctc_loss", + total_ctc_loss, + self.stats.steps) + self.tb_writer.add_scalar("train/batch_acc", + elapsed_correct / elapsed_tok, + self.stats.steps) + logger.info( + "Epoch %3d, " + "Step: %8d, " + "Batch Loss: %12.6f, " + "Batch Acc: %.6f, " + "Tokens per Sec: %8.0f, " + "Lr: %.6f", + epoch_no + 1, + self.stats.steps, + total_batch_loss, + elapsed_correct / elapsed_tok, + elapsed_tok / elapsed, + self.optimizer.param_groups[0]["lr"], + ) + start = time.time() + total_valid_duration = 0 + start_tokens = self.stats.total_tokens + start_correct = self.stats.total_correct + + # update epoch_loss + epoch_loss += total_batch_loss # accumulate loss + total_batch_loss = 0 # reset batch loss + total_nll_loss = 0 + total_ctc_loss = 0 + + # validate on the entire dev set + if self.stats.steps % self.validation_freq == 0: + valid_duration = self._validate(valid_data) + total_valid_duration += valid_duration + + # check current_lr + current_lr = self.optimizer.param_groups[0]["lr"] + if current_lr < self.learning_rate_min: + self.stats.is_min_lr = True + + self.tb_writer.add_scalar("train/learning_rate", current_lr, + self.stats.steps) + + if self.stats.is_min_lr or self.stats.is_max_update: + break + + if self.stats.is_min_lr or self.stats.is_max_update: + log_str = (f"minimum lr {self.learning_rate_min}" + if self.stats.is_min_lr else + f"maximum num. of updates {self.max_updates}") + logger.info("Training ended since %s was reached.", log_str) + break + + logger.info( + "Epoch %3d: total training loss %.2f, num seqs %d, num tokens %d", + epoch_no + 1, epoch_loss, total_nseqs, total_ntokens) + else: + logger.info("Training ended after %3d epochs.", epoch_no + 1) + logger.info( + "Best validation result (greedy) at step %8d: %6.2f %s.", + self.stats.best_ckpt_iter, + self.stats.best_ckpt_score, + self.early_stopping_metric, + ) + except KeyboardInterrupt: + self._save_checkpoint(False, float("nan")) + + self.tb_writer.close() # close Tensorboard writer + + def _train_step(self, batch: Batch) -> Tuple[float, float, float, float]: + """ + Train the model on one batch: Compute the loss. + + :param batch: training batch + :return: + - losses for batch (sum) + - nll loss + - ctc loss + - number of correct tokens for batch (sum) + """ + # reactivate training + self.model.train() + + with torch.autocast(device_type=self.device.type, + dtype=self.dtype, + enabled=self.fp16): + # get loss (run as during training with teacher forcing) + batch_loss, nll_loss, ctc_loss, correct_tokens = self.model( + return_type="loss", **vars(batch)) + + # normalize batch loss + norm_batch_loss = batch.normalize(batch_loss, self.normalization, self.n_gpu, + self.batch_multiplier) + + norm_nll_loss = batch.normalize( + nll_loss, self.normalization, self.n_gpu, + self.batch_multiplier) if nll_loss is not None else None + + norm_ctc_loss = batch.normalize( + ctc_loss, self.normalization, self.n_gpu, + self.batch_multiplier) if ctc_loss is not None else None + + # sum over multiple gpus + sum_correct_tokens = batch.normalize(correct_tokens, "sum", self.n_gpu) + + # accumulate gradients + self.scaler.scale(norm_batch_loss).backward() + + return (norm_batch_loss.item(), + norm_nll_loss.item() if norm_nll_loss is not None else None, + norm_ctc_loss.item() if norm_ctc_loss is not None else None, + sum_correct_tokens.item()) + + def _validate(self, valid_data: Dataset): + if valid_data.random_subset > 0: # subsample validation set each valid step + try: + valid_data.reset_random_subset() + valid_data.sample_random_subset(seed=self.stats.steps) + logger.info( + "Sample random subset from dev set: n=%d, seed=%d", + len(valid_data), + self.stats.steps, + ) + except AssertionError as e: + logger.warning(e) + + valid_start_time = time.time() + ( + valid_scores, + valid_references, + valid_hypotheses, + valid_hypotheses_raw, + valid_sequence_scores, # pylint: disable=unused-variable + valid_attention_scores, + ) = predict( + model=self.model, + data=valid_data, + compute_loss=True, + device=self.device, + n_gpu=self.n_gpu, + normalization=self.normalization, + cfg=self.valid_cfg, + fp16=self.fp16, + ) + valid_duration = time.time() - valid_start_time + + # for eval_metric in ['loss', 'ppl', 'acc'] + self.eval_metrics: + for eval_metric, score in valid_scores.items(): + if not math.isnan(score): + self.tb_writer.add_scalar(f"valid/{eval_metric}", score, + self.stats.steps) + + ckpt_score = valid_scores[self.early_stopping_metric] + + if self.scheduler_step_at == "validation": + self.scheduler.step(metrics=ckpt_score) + + # update new best + new_best = self.stats.is_best(ckpt_score) + if new_best: + self.stats.best_ckpt_score = ckpt_score + self.stats.best_ckpt_iter = self.stats.steps + logger.info( + "Hooray! New best validation result [%s]!", + self.early_stopping_metric, + ) + + # save checkpoints + is_better = (self.stats.is_better(ckpt_score, self.ckpt_queue) + if len(self.ckpt_queue) > 0 else True) + if self.num_ckpts < 0 or is_better: + self._save_checkpoint(new_best, ckpt_score) + + # append to validation report + self._add_report(valid_scores=valid_scores, new_best=new_best) + + self._log_examples( + references=valid_references, + hypotheses=valid_hypotheses, + hypotheses_raw=valid_hypotheses_raw, + data=valid_data, + ) + + # store validation set outputs + write_list_to_file(self.model_dir / f"{self.stats.steps}.hyps", + valid_hypotheses) + + # store attention plots for selected valid sentences + if valid_attention_scores: + store_attention_plots( + attentions=valid_attention_scores, + targets=valid_hypotheses_raw, + sources=valid_data.get_list(lang=valid_data.src_lang, tokenized=True), + indices=self.log_valid_sents, + output_prefix=(self.model_dir / f"att.{self.stats.steps}").as_posix(), + tb_writer=self.tb_writer, + steps=self.stats.steps, + ) + + return valid_duration + + def _add_report(self, valid_scores: dict, new_best: bool = False) -> None: + """ + Append a one-line report to validation logging file. + + :param valid_scores: validation evaluation score [eval_metric] + :param new_best: whether this is a new best model + """ + current_lr = self.optimizer.param_groups[0]["lr"] + + valid_file = self.model_dir / "validations.txt" + with valid_file.open("a", encoding="utf-8") as opened_file: + score_str = "\t".join([f"Steps: {self.stats.steps}"] + [ + f"{eval_metric}: {score:.5f}" + for eval_metric, score in valid_scores.items() if not math.isnan(score) + ] + [f"LR: {current_lr:.8f}", "*" if new_best else ""]) + opened_file.write(f"{score_str}\n") + + def _log_examples( + self, + hypotheses: List[str], + references: List[str], + hypotheses_raw: List[List[str]], + data: Dataset, + ) -> None: + """ + Log the first `self.log_valid_sents` sentences from given examples. + + :param hypotheses: decoded hypotheses (list of strings) + :param references: decoded references (list of strings) + :param hypotheses_raw: raw hypotheses (list of list of tokens) + :param data: Dataset + """ + for p in self.log_valid_sents: + if p >= len(hypotheses): + continue + logger.info("Example #%d", p) + + # tokenized text + if self.task == "MT": + tokenized_src = data.get_item(idx=p, lang=data.src_lang) + tokenized_trg = data.get_item(idx=p, lang=data.trg_lang) + + if self.task == "MT": + logger.debug("\tTokenized source: %s", tokenized_src) + logger.debug("\tTokenized reference: %s", tokenized_trg) + logger.debug("\tTokenized hypothesis: %s", hypotheses_raw[p]) + + # detokenized text + detokenized_src = data.tokenizer[data.src_lang].post_process(data.src[p]) \ + if self.task == "MT" else data.src[p] + logger.info("\tSource: %s", detokenized_src) + logger.info("\tReference: %s", references[p]) + logger.info("\tHypothesis: %s", hypotheses[p]) + + class TrainStatistics: + + def __init__( + self, + steps: int = 0, + is_min_lr: bool = False, + is_max_update: bool = False, + total_tokens: int = 0, + best_ckpt_iter: int = 0, + best_ckpt_score: float = float("inf"), + minimize_metric: bool = True, + total_correct: int = 0, + ) -> None: + self.steps = steps # global update step counter + self.is_min_lr = is_min_lr # stop by reaching learning rate minimum + self.is_max_update = is_max_update # stop by reaching max num of updates + self.total_tokens = total_tokens # number of total tokens seen so far + self.best_ckpt_iter = best_ckpt_iter # store iteration point of best ckpt + self.best_ckpt_score = best_ckpt_score # initial values for best scores + self.minimize_metric = minimize_metric # minimize or maximize score + self.total_correct = total_correct # number of correct tokens seen so far + + def is_best(self, score): + if self.minimize_metric: + is_best = score < self.best_ckpt_score + else: + is_best = score > self.best_ckpt_score + return is_best + + def is_better(self, score: float, heap_queue: list): + assert len(heap_queue) > 0 + if self.minimize_metric: + is_better = score < heapq.nlargest(1, heap_queue)[0][0] + else: + is_better = score > heapq.nsmallest(1, heap_queue)[0][0] + return is_better + + +def train(cfg_file: str, skip_test: bool = False) -> None: + """ + Main training function. After training, also test on test data if given. + + :param cfg_file: path to configuration yaml file + :param skip_test: whether a test should be run or not after training + """ + # read config file + cfg = load_config(Path(cfg_file)) + + # make logger + model_dir = make_model_dir( + Path(cfg["training"]["model_dir"]), + overwrite=cfg["training"].get("overwrite", False), + ) + pkg_version = make_logger(model_dir, mode="train") + # TODO: save version number in model checkpoints + if "joeynmt_version" in cfg: + check_version(pkg_version, cfg["joeynmt_version"]) + + # write all entries of config to the log + log_cfg(cfg) + + # store copy of original training config in model dir + shutil.copy2(cfg_file, (model_dir / "config.yaml").as_posix()) + + # set the random seed + set_seed(seed=cfg["training"].get("random_seed", 42)) + + # load the data + src_vocab, trg_vocab, train_data, dev_data, test_data = load_data( + data_cfg=cfg["data"]) + + # store the vocabs and tokenizers + if src_vocab is not None: + src_vocab.to_file(model_dir / "src_vocab.txt") + if hasattr(train_data.tokenizer[train_data.src_lang], "copy_cfg_file"): + train_data.tokenizer[train_data.src_lang].copy_cfg_file(model_dir) + trg_vocab.to_file(model_dir / "trg_vocab.txt") + if hasattr(train_data.tokenizer[train_data.trg_lang], "copy_cfg_file"): + train_data.tokenizer[train_data.trg_lang].copy_cfg_file(model_dir) + + # build an encoder-decoder model + model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab) + + # for training management, e.g. early stopping and model selection + trainer = TrainManager(model=model, cfg=cfg) + + # train the model + trainer.train_and_validate(train_data=train_data, valid_data=dev_data) + + if not skip_test: + # predict with the best model on validation and test + # (if test data is available) + + ckpt = model_dir / f"{trainer.stats.best_ckpt_iter}.ckpt" + output_path = model_dir / f"{trainer.stats.best_ckpt_iter:08d}.hyps" + + datasets_to_test = { + "dev": dev_data, + "test": test_data, + "src_vocab": src_vocab, + "trg_vocab": trg_vocab, + } + test( + cfg_file, + ckpt=ckpt.as_posix(), + output_path=output_path.as_posix(), + datasets=datasets_to_test, + ) + else: + logger.info("Skipping test after training.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Joey-NMT") + parser.add_argument( + "config", + default="configs/default.yaml", + type=str, + help="Training configuration file (yaml).", + ) + args = parser.parse_args() + train(cfg_file=args.config) diff --git a/joeynmt/transformer_layers.py b/joeynmt/transformer_layers.py new file mode 100644 index 0000000..959ef06 --- /dev/null +++ b/joeynmt/transformer_layers.py @@ -0,0 +1,544 @@ +# -*- coding: utf-8 -*- +""" +Transformer layers +""" +import math +from typing import Optional + +import torch +from torch import Tensor, nn + + +class MultiHeadedAttention(nn.Module): + """ + Multi-Head Attention module from "Attention is All You Need" + + Implementation modified from OpenNMT-py. + https://github.com/OpenNMT/OpenNMT-py + """ + + def __init__(self, num_heads: int, size: int, dropout: float = 0.1) -> None: + """ + Create a multi-headed attention layer. + + :param num_heads: the number of heads + :param size: hidden size (must be divisible by num_heads) + :param dropout: probability of dropping a unit + """ + super().__init__() + + assert size % num_heads == 0 + + self.head_size = head_size = size // num_heads + self.model_size = size + self.num_heads = num_heads + + self.k_layer = nn.Linear(size, num_heads * head_size) + self.v_layer = nn.Linear(size, num_heads * head_size) + self.q_layer = nn.Linear(size, num_heads * head_size) + + self.output_layer = nn.Linear(size, size) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + k: Tensor, + v: Tensor, + q: Tensor, + mask: Optional[Tensor] = None, + return_weights: Optional[bool] = None, + ): + """ + Computes multi-headed attention. + + :param k: keys [batch_size, seq_len, hidden_size] + :param v: values [batch_size, seq_len, hidden_size] + :param q: query [batch_size, seq_len, hidden_size] + :param mask: optional mask [batch_size, 1, seq_len] + :param return_weights: whether to return the attention weights, + averaged over heads. + :return: + - output [batch_size, query_len, hidden_size] + - attention_weights [batch_size, query_len, key_len] + """ + batch_size = k.size(0) + key_len = k.size(1) + query_len = q.size(1) + + # project the queries (q), keys (k), and values (v) + k = self.k_layer(k) + v = self.v_layer(v) + q = self.q_layer(q) + + # reshape q, k, v for our computation to [batch_size, num_heads, ..] + k = k.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2) + q = q.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2) + + # compute scores + q = q / math.sqrt(self.head_size) + + # [batch_size, num_heads, query_len, key_len] + scores = torch.matmul(q, k.transpose(2, 3)) + + # apply the mask (if we have one) + # we add a dimension for the heads to it below: [batch_size, 1, 1, key_len] + if mask is not None: + scores = scores.masked_fill(~mask.unsqueeze(1), float("-inf")) + + # apply attention dropout and compute context vectors. + attention_weights = self.softmax(scores) + attention_probs = self.dropout(attention_weights) + + # get context vector (select values with attention) and reshape + # back to [batch_size, query_len, hidden_size] + context = torch.matmul(attention_probs, v) + context = context.transpose(1, 2).contiguous().view( + batch_size, -1, self.num_heads * self.head_size) + + output = self.output_layer(context) + + if return_weights: + # average attention weights over heads: [batch_size, query_len, key_len] + attention_output_weights = attention_weights.view( + batch_size, self.num_heads, query_len, key_len) + return output, attention_output_weights.sum(dim=1) / self.num_heads + return output, None + + +class PositionwiseFeedForward(nn.Module): + """ + Position-wise Feed-forward layer + Projects to ff_size and then back down to input_size. + """ + + def __init__( + self, + input_size: int, + ff_size: int, + dropout: float = 0.1, + alpha: float = 1.0, + layer_norm: str = "pre", + ) -> None: + """ + Initializes position-wise feed-forward layer. + :param input_size: dimensionality of the input. + :param ff_size: dimensionality of intermediate representation + :param dropout: dropout probability + :param alpha: weight factor for residual connection + :param layer_norm: either "pre" or "post" + """ + super().__init__() + + self.layer_norm = nn.LayerNorm(input_size, eps=1e-6) + self.pwff_layer = nn.Sequential( + nn.Linear(input_size, ff_size), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(ff_size, input_size), + nn.Dropout(dropout), + ) + + self.alpha = alpha + self._layer_norm_position = layer_norm + assert self._layer_norm_position in {"pre", "post"} + + def forward(self, x: Tensor) -> Tensor: + residual = x + if self._layer_norm_position == "pre": + x = self.layer_norm(x) + + x = self.pwff_layer(x) + self.alpha * residual + + if self._layer_norm_position == "post": + x = self.layer_norm(x) + return x + + +class PositionalEncoding(nn.Module): + """ + Pre-compute position encodings (PE). + In forward pass, this adds the position-encodings to the input for as many time + steps as necessary. + + Implementation based on OpenNMT-py. + https://github.com/OpenNMT/OpenNMT-py + """ + + def __init__(self, size: int = 0, max_len: int = 5000) -> None: + """ + Positional Encoding with maximum length + + :param size: embeddings dimension size + :param max_len: maximum sequence length + """ + if size % 2 != 0: + raise ValueError( + f"Cannot use sin/cos positional encoding with odd dim (got dim={size})") + pe = torch.zeros(max_len, size) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + (torch.arange(0, size, 2, dtype=torch.float) * -(math.log(10000.0) / size))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) # shape: (1, max_len, size) + super().__init__() + self.register_buffer("pe", pe) + self.dim = size + + def forward(self, emb: Tensor) -> Tensor: + """ + Embed inputs. + + :param emb: (Tensor) Sequence of word embeddings vectors + shape (seq_len, batch_size, dim) + :return: positionally encoded word embeddings + """ + # Add position encodings + return emb + self.pe[:, :emb.size(1)] + + +class TransformerEncoderLayer(nn.Module): + """ + One Transformer encoder layer has a Multi-head attention layer plus a position-wise + feed-forward layer. + """ + + def __init__( + self, + size: int = 0, + ff_size: int = 0, + num_heads: int = 0, + dropout: float = 0.1, + alpha: float = 1.0, + layer_norm: str = "pre", + ) -> None: + """ + A single Transformer encoder layer. + + Note: don't change the name or the order of members! + otherwise pretrained models cannot be loaded correctly. + + :param size: model dimensionality + :param ff_size: size of the feed-forward intermediate layer + :param num_heads: number of heads + :param dropout: dropout to apply to input + :param alpha: weight factor for residual connection + :param layer_norm: either "pre" or "post" + """ + super().__init__() + + self.layer_norm = nn.LayerNorm(size, eps=1e-6) + self.src_src_att = MultiHeadedAttention(num_heads, size, dropout=dropout) + + self.feed_forward = PositionwiseFeedForward( + size, + ff_size=ff_size, + dropout=dropout, + alpha=alpha, + layer_norm=layer_norm, + ) + + self.dropout = nn.Dropout(dropout) + self.size = size + + self.alpha = alpha + self._layer_norm_position = layer_norm + assert self._layer_norm_position in {"pre", "post"} + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """ + Forward pass for a single transformer encoder layer. + First applies self attention, then dropout with residual connection (adding + the input to the result), then layer norm, and then a position-wise + feed-forward layer. + + :param x: layer input + :param mask: input mask + :return: output tensor + """ + residual = x + if self._layer_norm_position == "pre": + x = self.layer_norm(x) + + x, _ = self.src_src_att(x, x, x, mask) + x = self.dropout(x) + self.alpha * residual + + if self._layer_norm_position == "post": + x = self.layer_norm(x) + + out = self.feed_forward(x) + return out + + +class TransformerDecoderLayer(nn.Module): + """ + Transformer decoder layer. + + Consists of self-attention, source-attention, and feed-forward. + """ + + def __init__( + self, + size: int = 0, + ff_size: int = 0, + num_heads: int = 0, + dropout: float = 0.1, + alpha: float = 1.0, + layer_norm: str = "pre", + ) -> None: + """ + Represents a single Transformer decoder layer. + It attends to the source representation and the previous decoder states. + + Note: don't change the name or the order of members! + otherwise pretrained models cannot be loaded correctly. + + :param size: model dimensionality + :param ff_size: size of the feed-forward intermediate layer + :param num_heads: number of heads + :param dropout: dropout to apply to input + :param alpha: weight factor for residual connection + :param layer_norm: either "pre" or "post" + """ + super().__init__() + self.size = size + + self.trg_trg_att = MultiHeadedAttention(num_heads, size, dropout=dropout) + self.src_trg_att = MultiHeadedAttention(num_heads, size, dropout=dropout) + + self.feed_forward = PositionwiseFeedForward( + size, + ff_size=ff_size, + dropout=dropout, + alpha=alpha, + layer_norm=layer_norm, + ) + + self.x_layer_norm = nn.LayerNorm(size, eps=1e-6) + self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6) + + self.dropout = nn.Dropout(dropout) + self.alpha = alpha + + self._layer_norm_position = layer_norm + assert self._layer_norm_position in {"pre", "post"} + + def forward( + self, + x: Tensor, + memory: Tensor, + src_mask: Tensor, + trg_mask: Tensor, + return_attention: bool = False, + ) -> Tensor: + """ + Forward pass of a single Transformer decoder layer. + + First applies target-target self-attention, dropout with residual connection + (adding the input to the result), and layer norm. + + Second computes source-target cross-attention, dropout with residual connection + (adding the self-attention to the result), and layer norm. + + Finally goes through a position-wise feed-forward layer. + + :param x: inputs + :param memory: source representations + :param src_mask: source mask + :param trg_mask: target mask (so as to not condition on future steps) + :param return_attention: whether to return the attention weights + :return: + - output tensor + - attention weights + """ + # 1. target-target self-attention + residual = x + if self._layer_norm_position == "pre": + x = self.x_layer_norm(x) + + h1, _ = self.trg_trg_att(x, x, x, mask=trg_mask) + h1 = self.dropout(h1) + self.alpha * residual + + if self._layer_norm_position == "post": + h1 = self.x_layer_norm(h1) + + # 2. source-target cross-attention + h1_residual = h1 + if self._layer_norm_position == "pre": + h1 = self.dec_layer_norm(h1) + + h2, att = self.src_trg_att(memory, + memory, + h1, + mask=src_mask, + return_weights=return_attention) + h2 = self.dropout(h2) + self.alpha * h1_residual + + if self._layer_norm_position == "post": + h2 = self.dec_layer_norm(h2) + + # 3. final position-wise feed-forward layer + out = self.feed_forward(h2) + + if return_attention: + return out, att + return out, None + + +class ConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__( + self, + hidden_size: int, + channels: int, + depthwise_kernel_size: int, + dropout: float, + ): + """ + Args: + hidden_size: hidden dimension + channels: Number of channels in depthwise conv layers + depthwise_kernel_size: Depthwise conv layer kernel size + dropout: dropout value + """ + super().__init__() + assert (depthwise_kernel_size - + 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" + self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) + self.pointwise_conv1 = nn.Conv1d( + hidden_size, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + depthwise_kernel_size, + stride=1, + padding=(depthwise_kernel_size - 1) // 2, + groups=channels, + ) + self.batch_norm = nn.BatchNorm1d(channels) + self.swish = nn.Hardswish() + self.pointwise_conv2 = nn.Conv1d( + channels, + hidden_size, + kernel_size=1, + stride=1, + padding=0, + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x) -> torch.Tensor: + x = self.layer_norm(x) + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = self.glu(x) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.batch_norm(x) + x = self.swish(x) + + x = self.pointwise_conv2(x) + x = self.dropout(x) + return x.transpose(1, 2) + + +class ConformerEncoderLayer(torch.nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + def __init__( + self, + size: int = 512, + ff_size: int = 2048, + num_heads: int = 4, + dropout: float = 0.1, + depthwise_conv_kernel_size: int = 31, + alpha: float = 1.0, + layer_norm: str = "pre", + ): + super().__init__() + + self.initial_feed_forward = PositionwiseFeedForward( + size, + ff_size=ff_size, + dropout=dropout, + alpha=alpha, + layer_norm=layer_norm, + ) + + self.src_att_layer_norm = nn.LayerNorm(size, eps=1e-6) + self.src_att_dropout = nn.Dropout(dropout) + self.src_src_att = MultiHeadedAttention(num_heads, size, dropout=dropout) + + self.conv_module = ConvolutionModule( + hidden_size=size, + channels=size, + depthwise_kernel_size=depthwise_conv_kernel_size, + dropout=dropout, + ) + + self.final_feed_forward = PositionwiseFeedForward( + size, + ff_size=ff_size, + dropout=dropout, + alpha=alpha, + layer_norm=layer_norm, + ) + self.final_layer_norm = nn.LayerNorm(size, eps=1e-6) + + self.alpha = alpha + self.size = size + + self._layer_norm_position = layer_norm + assert self._layer_norm_position in {"pre", "post"} + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """ + Forward pass for a single conformer encoder layer. + + :param x: layer input + :param mask: input mask + :return: output tensor + """ + residual = x + x = self.initial_feed_forward(x) + x = 0.5 * x + residual + + residual = x + if self._layer_norm_position == "pre": + x = self.src_att_layer_norm(x) + + x, _ = self.src_src_att(x, x, x, mask) + x = self.src_att_dropout(x) + self.alpha * residual + + if self._layer_norm_position == "post": + x = self.src_att_layer_norm(x) + + residual = x + x = x.transpose(0, 1) # [T, B, C] to [B, T, C] + x = self.conv_module(x) + x = x.transpose(0, 1) # [B, T, C] to [T, B, C] + x = x + self.alpha * residual + + # feed forward layer + residual = x + if self._layer_norm_position == "pre": + x = self.final_layer_norm(x) + + x = self.final_feed_forward(x) + x = 0.5 * x + residual + + if self._layer_norm_position == "post": + x = self.final_layer_norm(x) + return x diff --git a/joeynmt/validation.py b/joeynmt/validation.py new file mode 100644 index 0000000..e69de29 diff --git a/joeynmt/vocabulary.py b/joeynmt/vocabulary.py new file mode 100644 index 0000000..83f68b8 --- /dev/null +++ b/joeynmt/vocabulary.py @@ -0,0 +1,298 @@ +# coding: utf-8 +""" +Vocabulary module +""" +import logging +import sys +from collections import Counter +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np + +from joeynmt.constants import ( + BOS_ID, + BOS_TOKEN, + EOS_ID, + EOS_TOKEN, + PAD_ID, + PAD_TOKEN, + UNK_ID, + UNK_TOKEN, +) +from joeynmt.datasets import BaseDataset +from joeynmt.helpers import flatten, read_list_from_file, write_list_to_file + +logger = logging.getLogger(__name__) + + +class Vocabulary: + """Vocabulary represents mapping between tokens and indices.""" + + def __init__(self, tokens: List[str]) -> None: + """ + Create vocabulary from list of tokens. + Special tokens are added if not already in list. + + :param tokens: list of tokens + """ + # warning: stoi grows with unknown tokens, don't use for saving or size + + # special symbols + self.specials = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN] + + # don't allow to access _stoi and _itos outside of this class + self._stoi: Dict[str, int] = {} # string to index + self._itos: List[str] = [] # index to string + + # construct + self.add_tokens(tokens=self.specials + tokens) + assert len(self._stoi) == len(self._itos) + + # assign after stoi is built + self.pad_index = self.lookup(PAD_TOKEN) + self.bos_index = self.lookup(BOS_TOKEN) + self.eos_index = self.lookup(EOS_TOKEN) + self.unk_index = self.lookup(UNK_TOKEN) + assert self.pad_index == PAD_ID + assert self.bos_index == BOS_ID + assert self.eos_index == EOS_ID + assert self.unk_index == UNK_ID + assert self._itos[UNK_ID] == UNK_TOKEN + + def add_tokens(self, tokens: List[str]) -> None: + """ + Add list of tokens to vocabulary + + :param tokens: list of tokens to add to the vocabulary + """ + for t in tokens: + new_index = len(self._itos) + # add to vocab if not already there + if t not in self._itos: + self._itos.append(t) + self._stoi[t] = new_index + + def to_file(self, file: Path) -> None: + """ + Save the vocabulary to a file, by writing token with index i in line i. + + :param file: path to file where the vocabulary is written + """ + write_list_to_file(file, self._itos) + + def is_unk(self, token: str) -> bool: + """ + Check whether a token is covered by the vocabulary + + :param token: + :return: True if covered, False otherwise + """ + return self.lookup(token) == UNK_ID + + def lookup(self, token: str) -> int: + """ + look up the encoding dictionary. (needed for multiprocessing) + + :param token: surface str + :return: token id + """ + return self._stoi.get(token, UNK_ID) + + def __len__(self) -> int: + return len(self._itos) + + def __eq__(self, other) -> bool: + if isinstance(other, Vocabulary): + return self._itos == other._itos + return False + + def array_to_sentence(self, + array: np.ndarray, + score_array: np.ndarray = None, + cut_at_eos: bool = True, + skip_pad: bool = True) -> Tuple[List[str], List[float]]: + """ + Converts an array of IDs to a sentence, optionally cutting the result off at the + end-of-sequence token. + + :param array: 1D array containing indices + :param score_array: 1D array containing float scores + :param cut_at_eos: cut the decoded sentences at the first + :param skip_pad: skip generated tokens + :return: + - list of strings (tokens) + - list of floats (scores) + """ + if score_array is None: + score_array = [None for _ in range(len(array))] + assert len(array) == len(score_array) + + sentence, scores = [], [] + for i, s in zip(array, score_array): + t = self._itos[i] + s = float('NaN') if s is None else s + if skip_pad and t == PAD_TOKEN: + continue + sentence.append(t) + scores.append(s) + # break at the position AFTER eos + if cut_at_eos and t == EOS_TOKEN: + break + assert len(sentence) == len(scores) + return sentence, scores + + def arrays_to_sentences( + self, + arrays: np.ndarray, + score_arrays: np.ndarray = None, + cut_at_eos: bool = True, + skip_pad: bool = True) -> Tuple[List[List[str]], List[List[float]]]: + """ + Convert multiple arrays containing sequences of token IDs to their sentences, + optionally cutting them off at the end-of-sequence token. + + :param arrays: 2D array containing indices + :param score_arrays: 2D array containing float scores + :param cut_at_eos: cut the decoded sentences at the first + :param skip_pad: skip generated tokens + :return: + - list of list of strings (tokens) + - list of list of floats (scores) + """ + if score_arrays is None or len(score_arrays) == 0: + score_arrays = [None for _ in range(len(arrays))] + assert len(arrays) == len(score_arrays) + + sentences, scores = [], [] + for array, score_array in zip(arrays, score_arrays): + sent, score = self.array_to_sentence(array=array, + score_array=score_array, + cut_at_eos=cut_at_eos, + skip_pad=skip_pad) + sentences.append(sent) + scores.append(score) + assert len(sentences) == len(scores) + return sentences, scores + + def sentences_to_ids(self, + sentences: List[List[str]], + bos: bool = True, + eos: bool = True) -> Tuple[List[List[int]], List[int]]: + """ + Encode sentences to indices and pad sequences to the maximum length of the + sentences given + + :param sentences: list of tokenized sentences + :return: + - padded ids + - original lengths before padding + """ + max_len = max([len(sent) for sent in sentences]) + if bos: + max_len += 1 + if eos: + max_len += 1 + padded, lengths = [], [] + for sent in sentences: + encoded = [self.lookup(s) for s in sent] + if bos: + encoded = [self.bos_index] + encoded + if eos: + encoded = encoded + [self.eos_index] + offset = max(0, max_len - len(encoded)) + padded.append(encoded + [self.pad_index] * offset) + lengths.append(len(encoded)) + return padded, lengths + + def log_vocab(self, k: int) -> str: + """first k vocab entities""" + return " ".join(f"({i}) {t}" for i, t in enumerate(self._itos[:k])) + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(len={self.__len__()}, " + f"specials={self.specials})") + + +def sort_and_cut(counter: Counter, + max_size: int = sys.maxsize, + min_freq: int = -1) -> List[str]: + """ + Cut counter to most frequent, sorted numerically and alphabetically + :param counter: flattened token list in Counter object + :param max_size: maximum size of vocabulary + :param min_freq: minimum frequency for an item to be included + :return: list of valid tokens + """ + # filter counter by min frequency + if min_freq > -1: + counter = Counter({t: c for t, c in counter.items() if c >= min_freq}) + + # sort by frequency, then alphabetically + tokens_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) + tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) + + # cut off + vocab_tokens = [i[0] for i in tokens_and_frequencies[:max_size]] + assert len(vocab_tokens) <= max_size, (len(vocab_tokens), max_size) + return vocab_tokens + + +def _build_vocab(cfg: Dict, dataset: BaseDataset = None) -> Vocabulary: + """ + Builds vocabulary either from file or sentences. + + :param cfg: data cfg + :param dataset: dataset object which contains preprocessed sentences + :return: Vocabulary created from either `tokens` or `vocab_file` + """ + vocab_file = cfg.get("voc_file", None) + min_freq = cfg.get("voc_min_freq", 1) # min freq for an item to be included + max_size = int(cfg.get("voc_limit", sys.maxsize)) # max size of vocabulary + assert max_size > 0 + + if vocab_file is not None: + # load it from file (not to apply `sort_and_cut()`) + unique_tokens = read_list_from_file(Path(vocab_file)) + + elif dataset is not None: + # tokenize sentences + sents = dataset.get_list(lang=cfg["lang"], tokenized=True) + + # newly create unique token list (language-wise) + counter = Counter(flatten(sents)) + unique_tokens = sort_and_cut(counter, max_size, min_freq) + else: + raise Exception("Please provide a vocab file path or dataset.") + + vocab = Vocabulary(unique_tokens) + assert len(vocab) <= max_size + len(vocab.specials), (len(vocab), max_size) + + # check for all except for UNK token whether they are OOVs + for s in vocab.specials[1:]: + assert not vocab.is_unk(s) + + return vocab + + +def build_vocab(cfg: Dict, + dataset: BaseDataset = None, + model_dir: Path = None) -> Tuple[Vocabulary, Vocabulary]: + task = cfg["task"].upper() + # use the vocab file saved in model_dir + if task == "MT": + if model_dir is not None and cfg["src"].get("voc_file", None) is None: + assert (model_dir / "src_vocab.txt").is_file() + cfg["src"]["voc_file"] = (model_dir / "src_vocab.txt").as_posix() + if model_dir is not None and cfg["trg"].get("voc_file", None) is None: + assert (model_dir / "trg_vocab.txt").is_file() + cfg["trg"]["voc_file"] = (model_dir / "trg_vocab.txt").as_posix() + + src_vocab = _build_vocab(cfg["src"], dataset) if task == "MT" else None + trg_vocab = _build_vocab(cfg["trg"], dataset) + + if task == "MT": + assert src_vocab.pad_index == trg_vocab.pad_index + assert src_vocab.bos_index == trg_vocab.bos_index + assert src_vocab.eos_index == trg_vocab.eos_index + return src_vocab, trg_vocab diff --git a/poseData_utils.py b/poseData_utils.py new file mode 100644 index 0000000..8c6baf8 --- /dev/null +++ b/poseData_utils.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# coding: utf-8 + +# Adapted from +# https://github.com/pytorch/fairseq/blob/master/examples/speech_to_text/data_utils.py + +import csv +import io +import zipfile +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd +import sentencepiece as sp +from tqdm import tqdm + +from joeynmt.constants import ( + BOS_ID, + BOS_TOKEN, + EOS_ID, + EOS_TOKEN, + PAD_ID, + PAD_TOKEN, + UNK_ID, + UNK_TOKEN, +) +from joeynmt.helpers_for_pose import _is_npy_data + + +def get_zip_manifest(zip_path: Path, npy_root: Optional[Path] = None): + manifest = {} + with zipfile.ZipFile(zip_path, mode="r") as f: + info = f.infolist() + # retrieve offsets + for i in tqdm(info): + utt_id = Path(i.filename).stem + offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size + with zip_path.open("rb") as f: + f.seek(offset) + data = f.read(file_size) + assert len(data) > 1 and _is_npy_data(data), (utt_id, len(data)) + manifest[utt_id] = f"{zip_path.name}:{offset}:{file_size}" + # sanity check + if npy_root is not None: + byte_data = np.load(io.BytesIO(data)) + npy_data = np.load((npy_root / f"{utt_id}.npy").as_posix()) + assert np.allclose(byte_data, npy_data) + return manifest + + +def create_zip(data_root: Path, zip_path: Path): + paths = list(data_root.glob("*.npy")) + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f: + for path in tqdm(paths): + try: + f.write(path, arcname=path.name) + except Exception as e: # pylint: disable=broad-except + raise Exception(f"{path}") from e + + +def save_tsv(df: pd.DataFrame, path: Path, header: bool = True) -> None: + df.to_csv(path.as_posix(), + sep="\t", + header=header, + index=False, + encoding="utf-8", + escapechar="\\", + quoting=csv.QUOTE_NONE) + + +def load_tsv(path: Path): + return pd.read_csv(path.as_posix(), + sep="\t", + header=0, + encoding="utf-8", + escapechar="\\", + quoting=csv.QUOTE_NONE, + na_filter=False) + + +def build_sp_model(input_path: Path, model_path_prefix: Path, **kwargs): + """ + Build sentencepiece model + """ + # Train SentencePiece Model + arguments = [ + f"--input={input_path.as_posix()}", + f"--model_prefix={model_path_prefix.as_posix()}", + f"--model_type={kwargs.get('model_type', 'unigram')}", + f"--vocab_size={kwargs.get('vocab_size', 5000)}", + f"--character_coverage={kwargs.get('character_coverage', 1.0)}", + f"--num_threads={kwargs.get('num_workers', 1)}", + f"--unk_piece={UNK_TOKEN}", + f"--bos_piece={BOS_TOKEN}", + f"--eos_piece={EOS_TOKEN}", + f"--pad_piece={PAD_TOKEN}", + f"--unk_id={UNK_ID}", + f"--bos_id={BOS_ID}", + f"--eos_id={EOS_ID}", + f"--pad_id={PAD_ID}", + "--vocabulary_output_piece_score=false", + ] + if 'user_defined_symbols' in kwargs: + arguments.append(f"--user_defined_symbols={kwargs['user_defined_symbols']}") + sp.SentencePieceTrainer.Train(" ".join(arguments)) \ No newline at end of file diff --git a/prepare_poses.py b/prepare_poses.py new file mode 100644 index 0000000..22f73ea --- /dev/null +++ b/prepare_poses.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python +# coding: utf-8 +""" +Prepare popses + +expected dir structure: + Output/ # <- point here in --data_root in argument + └── Poses/ + ├── fbank534/ + │ ├── test1.npy + │ ├── test2.npy + │ ├── test3.npy + ├── fbank534.zip + ├── joey_train_asr.tsv + ├── joey_dev_asr.tsv + └── joey_test_asr.tsv +""" + +import argparse +from pathlib import Path + +import numpy as np +import pandas as pd +from poseData_utils import build_sp_model, create_zip, get_zip_manifest, save_tsv +from datasets_pose import load_dataset, extract_to_fbank + +from joeynmt.helpers import write_list_to_file + +COLUMNS = ["id", "src", "n_frames", "trg"] + +SEED = 123 +N_MEL_FILTERS = 534 +N_WORKERS = 4 # cpu_count() +SP_MODEL_TYPE = "bpe" # one of ["bpe", "unigram", "char"] +VOCAB_SIZE = 40 # joint vocab +EXPENDED_DATASET = 1000 # the minimum number of samples in the dataset + + +def process(data_root, name, pumping: bool = False): + root = Path(data_root).absolute() + cur_root = root / name + + # dir for filterbank (shared across splits) + feature_root = cur_root / f"fbank{N_MEL_FILTERS}" + feature_root.mkdir(parents=True, exist_ok=True) + + # Extract features + print(f"Create OpenSLR {name} dataset.") + + print("Fetching train split ...") + dataset = load_dataset("DataSet") + + print("Extracting log mel filter bank features ...") + for instance in dataset: + utt_id = instance[0] + extract_to_fbank(instance[1], feature_root / f'{utt_id}.npy', overwrite=False) + + # Pack features into ZIP + print("ZIPing features...") + create_zip(feature_root, feature_root.with_suffix(".zip")) + + print("Fetching ZIP manifest...") + zip_manifest = get_zip_manifest(feature_root.with_suffix(".zip")) + + # Generate TSV manifest + print("Generating manifest...") + all_data = [] + + for instance in dataset: + utt_id = instance[0] + n_frames = np.load(feature_root / f'{utt_id}.npy').shape[0] + all_data.append({ + "id": utt_id, + "src": zip_manifest[str(utt_id)], + "n_frames": n_frames, + "trg": instance[2] + }) + + if EXPENDED_DATASET > len(all_data) and pumping: + print("Pumping dataset...") + for i in range(EXPENDED_DATASET - len(all_data)): + utt_id = all_data[i % len(all_data)]["id"] + n_frames = all_data[i % len(all_data)]["n_frames"] + trg = all_data[i % len(all_data)]["trg"] + src = all_data[i % len(all_data)]["src"] + all_data.append({ + "id": f'{utt_id}({i})', # unique id + "src": src, + "n_frames": n_frames, + "trg": trg + }) + + all_df = pd.DataFrame.from_records(all_data) + save_tsv(all_df, cur_root / "poses_all_data.tsv") + + # Split the data into train and test set and save the splits in tsv + np.random.seed(SEED) + probs = np.random.rand(len(all_df)) + mask = {} + mask['train'] = probs < 0.995 + mask['dev'] = (probs >= 0.995) & (probs < 0.998) + mask['test'] = probs >= 0.998 + + for split in ['train', 'dev', 'test']: + split_df = all_df[mask[split]] + # save tsv + save_tsv(split_df, cur_root / f"{split}.tsv") + # save plain txt + write_list_to_file(cur_root / f"{split}.txt", split_df['trg'].to_list()) + print(split, len(split_df)) + + # Generate joint vocab + print("Building joint vocab...") + kwargs = { + 'model_type': SP_MODEL_TYPE, + 'vocab_size': VOCAB_SIZE, + 'character_coverage': 1.0, + 'num_workers': N_WORKERS + } + build_sp_model(cur_root / "train.txt", cur_root / f"spm_bpe{VOCAB_SIZE}", **kwargs) + print("Done!") + exit(0) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data_root", "-d", required=True, type=str) + parser.add_argument("--dataset_name", required=True, type=str) + args = parser.parse_args() + + process(args.data_root, args.dataset_name, pumping=True) + + +if __name__ == "__main__": + main() + exit(0) diff --git a/swu_representation.py b/swu_representation.py new file mode 100644 index 0000000..732333f --- /dev/null +++ b/swu_representation.py @@ -0,0 +1,90 @@ +import re +from typing import List + +re_swu = { + 'symbol': r'[\U00040001-\U0004FFFF]', + 'coord': r'[\U0001D80C-\U0001DFFF]{2}', + 'sort': r'\U0001D800', + 'box': r'\U0001D801-\U0001D804' +} +re_swu['prefix'] = rf"(?:{re_swu['sort']}(?:{re_swu['symbol']})+)" +re_swu['spatial'] = rf"{re_swu['symbol']}{re_swu['coord']}" +re_swu['signbox'] = rf"{re_swu['box']}{re_swu['coord']}(?:{re_swu['spatial']})*" +re_swu['sign'] = rf"{re_swu['prefix']}?{re_swu['signbox']}" +re_swu['sortable'] = rf"{re_swu['prefix']}{re_swu['signbox']}" + + +def fsw2data(fswText: str) -> str: + match = fswText[re.search('[MBLR]', fswText).start():] # remove signs prefix + match = match[re.search('[S]', match).start():] # remove Box positional prefix + match = match.split('S')[1:] # split the string into a list of sign + match = "x".join(match).split('x') # split the string into a list of sign + # if the len maych[i] bigget then 5 split it to 2 part without for loop + data = [] + for i in range(len(match)): + if len(match[i]) > 5: + data.append(match[i][:5]) + data.append(match[i][5:]) + else: + data.append(match[i]) + return " ".join(data) + + +def data2fsw(dataText: str) -> str: + signs = ['A'] + dataText = dataText.split(' ') + for i in dataText: + if len(i) >= 5: + signs.append(i) + dataText.remove(i) + fsw = "S".join(signs) + "M514x517" + for index, val in enumerate(signs[1:]): + fsw += "S" + val + dataText[2 * index] + "x" + dataText[2 * index + 1] + return fsw + +# from signbank-plus +def swu2key(swuSym: str) -> str: + symcode = ord(swuSym) - 0x40001 + base = symcode // 96 + fill = (symcode - (base * 96)) // 16 + rotation = symcode - (base * 96) - (fill * 16) + return f'S{hex(base + 0x100)[2:]}{hex(fill)[2:]}{hex(rotation)[2:]}' + +# from signbank-plus +def swu2coord(swuCoord: str) -> List[int]: + return [swu2num(swuCoord[0]), swu2num(swuCoord[1])] + +# from signbank-plus +def swu2num(swuNum: str) -> int: + return ord(swuNum) - 0x1D80C + 250 + +# from signbank-plus +def swu2fsw(swuText: str) -> str: + if not swuText: + return '' + + # Initial replacements + fsw = swuText.replace("𝠀", "A").replace("𝠁", "B").replace("𝠂", "L").replace("𝠃", "M").replace("𝠄", "R") + + # SWU symbols to FSW keys + syms = re.findall(re_swu['symbol'], fsw) + if syms: + for sym in syms: + fsw = fsw.replace(sym, swu2key(sym)) + + # SWU coordinates to FSW coordinates + coords = re.findall(re_swu['coord'], fsw) + if coords: + for coord in coords: + fsw = fsw.replace(coord, 'x'.join(map(str, swu2coord(coord)))) + + return fsw + +# from signbank-plus +def swu2data(swuText: str) -> str: + return fsw2data(swu2fsw(swuText)) + + +if __name__ == "__main__": + print(data2fsw("18711 490 483 20500 486 506")) + print(fsw2data("AS18711S20500M514x517S18711490x483S20500486x506").split(' '))