diff --git a/README.md b/README.md index 713bfd5..fd3699b 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,9 @@ This project trains and serves models to translate SignWriting into spoken language text and vice versa. +Ideally, we would like to use [Bergamot](https://github.com/mozilla/firefox-translations-training), +but we could not get it to work. + ## Usage ```bash @@ -15,16 +18,83 @@ signwriting_to_text --spoken-language="en" --signed-language="ase" --input="M525 text_to_signwriting --spoken-language="en" --signed-language="ase" --input="Sign Language" ``` -### Examples +## Data -(These examples are taken from the DSGS Vokabeltrainer) +We use the SignBank+ Dataset from [signbank-plus](https://github.com/sign-language-processing/signbank-plus). -| | 00004 | 00007 | 00015 | -|:-----------:|:--------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------:| -| SignWriting | | | | -| Video | | | | +After finding a large mismatch in distribution between the `validation` and `test` sets, +we decided to use the `test` set as the `validation` set, without training multiple models. +This is a *bad* decision if we were to train and compare multiple models, or want to improve the model in the future. +To change this, change the `sockeye.train` command in the `train_sockeye_model.sh` script. -## Data +## Steps -We use the SignBank+ Dataset from [signbank-plus](https://github.com/sign-language-processing/signbank-plus). +```bash +# 0. Setup the environment. +conda create --name sockeye python=3.11 -y +conda activate sockeye + +cd signwriting_translation + +MODEL_DIR=/shares/volk.cl.uzh/amoryo/checkpoints/signwriting-translation +DATA_DIR=/home/amoryo/sign-language/signwriting-translation/parallel +DIRECTION="spoken-to-signed" + +# 1. Download and tokenize the signbank-plus dataset +sbatch prepare_data.sh + +# 2. Train a spoken-to-signed translation model +# (Train without factors) +sbatch train_sockeye_model.sh \ + --data_dir="$DATA_DIR/$DIRECTION" \ + --model_dir="$MODEL_DIR/$DIRECTION/no-factors" \ + --optimized_metric="signwriting-similarity" \ + --partition lowprio +# (Train with factors) +sbatch train_sockeye_model.sh \ + --data_dir="$DATA_DIR/$DIRECTION" \ + --model_dir="$MODEL_DIR/$DIRECTION/target-factors-v4" \ + --optimized_metric="signwriting-similarity" \ + --use_target_factors=true \ + --partition lowprio +# (Fine tune model on cleaned data) +sbatch train_sockeye_model.sh \ + --data_dir="$DATA_DIR-clean/$DIRECTION" \ + --model_dir="$MODEL_DIR/$DIRECTION/target-factors-tuned" \ + --base_model_dir="$MODEL_DIR/$DIRECTION/target-factors-v4" \ + --optimized_metric="signwriting-similarity" \ + --use_target_factors=true \ + --partition lowprio + +# 2.1 (Optional) See the validation metrics +cat "$MODEL_DIR/$DIRECTION/no-factors/model/metrics" | grep "signwriting-similarity" +cat "$MODEL_DIR/$DIRECTION/target-factors/model/metrics" | grep "signwriting-similarity" + +# 3. Test it yourself +python -m signwriting_translation.bin \ + --model="$MODEL_DIR/$DIRECTION/target-factors-tuned/model" \ + --spoken-language="en" \ + --signed-language="ase" \ + --input="My name is John." +``` + +## Upload to HuggingFace + +```bash +# Copy the model files to a new directory +SE_MODEL_PATH="$MODEL_DIR/$DIRECTION/target-factors-tuned" +HF_MODEL_PATH="$MODEL_DIR/$DIRECTION/huggingface/target-factors-tuned" + +mkdir -p "$HF_MODEL_PATH" +cp tokenizer.json "$HF_MODEL_PATH/tokenizer.json" +cp "$SE_MODEL_PATH/model/params.best" "$HF_MODEL_PATH/params.best" +cp "$SE_MODEL_PATH/model/version" "$HF_MODEL_PATH/version" +cp "$SE_MODEL_PATH/model/metrics" "$HF_MODEL_PATH/metrics" +cp "$SE_MODEL_PATH/model/config" "$HF_MODEL_PATH/config" +cp "$SE_MODEL_PATH/model/args.yaml" "$HF_MODEL_PATH/args.yaml" +cp "$SE_MODEL_PATH/model/vocab."* "$HF_MODEL_PATH" +# Upload to HuggingFace +huggingface-cli login +huggingface-cli upload sign/sockeye-text-to-factored-signwriting "$HF_MODEL_PATH" . +``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d4d2a7b..75bba90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,12 @@ authors = [ readme = "README.md" dependencies = [ "signwriting @ git+https://github.com/sign-language-processing/signwriting", + "sockeye @ git+https://github.com/sign-language-processing/sockeye", + "tokenizers", + "huggingface-hub" ] +[project.optional-dependencies] dev = [ "pytest", "pylint" @@ -26,6 +30,7 @@ disable = [ "C0115", # Missing class docstring "C0116", # Missing function or method docstring ] +good-names = ["i", "f", "x", "y"] [tool.setuptools] packages = [ diff --git a/signwriting_translation/bin.py b/signwriting_translation/bin.py index 7ca3e5e..2b2f914 100644 --- a/signwriting_translation/bin.py +++ b/signwriting_translation/bin.py @@ -1,10 +1,65 @@ #!/usr/bin/env python import argparse +import time +from functools import lru_cache +from pathlib import Path +from typing import List + +from signwriting.tokenizer import SignWritingTokenizer +from sockeye.inference import TranslatorOutput +from tokenizers import Tokenizer + +sw_tokenizer = SignWritingTokenizer() + + +def process_translation_output(output: TranslatorOutput): + all_factors = [output.tokens] + output.factor_tokens + symbols = [" ".join(f).replace("M c0 r0", "M") for f in list(zip(*all_factors))] + return sw_tokenizer.tokens_to_text((" ".join(symbols)).split(" ")) + + +@lru_cache(maxsize=None) +def load_sockeye_translator(model_path: str, log_timing: bool = False): + if not Path(model_path).is_dir(): + from huggingface_hub import snapshot_download + model_path = snapshot_download(repo_id=model_path) + + from sockeye.translate import parse_translation_arguments, load_translator_from_args + + now = time.time() + args = parse_translation_arguments([ + "-m", model_path, + "--beam-size", "5", + ]) + translator = load_translator_from_args(args, True) + if log_timing: + print("Loaded sockeye translator in", time.time() - now, "seconds") + + tokenizer = Tokenizer.from_file(str(Path(model_path) / 'tokenizer.json')) + + return translator, tokenizer + + +def translate(translator, texts: List[str], log_timing: bool = False): + from sockeye.inference import make_input_from_plain_string + + inputs = [make_input_from_plain_string(sentence_id=i, string=s) + for i, s in enumerate(texts)] + + now = time.time() + outputs = translator.translate(inputs) + translation_time = time.time() - now + avg_time = translation_time / len(texts) + if log_timing: + print("Translated", len(texts), "texts in", translation_time, "seconds", f"({avg_time:.2f} seconds per text)") + return [process_translation_output(output) for output in outputs] def get_args(): parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, help='Path to trained model', + default="sign/sockeye-text-to-factored-signwriting") parser.add_argument('--spoken-language', required=True, type=str, help='spoken language code') parser.add_argument('--signed-language', required=True, type=str, help='signed language code') parser.add_argument('--input', required=True, type=str, help='input text or signwriting sequence') @@ -15,11 +70,19 @@ def signwriting_to_text(): # pylint: disable=unused-variable args = get_args() + translator, tokenizer = load_sockeye_translator(args.model) + tokenized_text = " ".join(tokenizer.encode(args.input).tokens) + model_input = f"${args.spoken_language} ${args.signed_language} {tokenized_text}" + outputs = translate(translator, [model_input]) + print(outputs[0]) + def text_to_signwriting(): # pylint: disable=unused-variable args = get_args() + return translate(args.model, [args.input]) + if __name__ == '__main__': signwriting_to_text() diff --git a/signwriting_translation/create_parallel_data.py b/signwriting_translation/create_parallel_data.py new file mode 100644 index 0000000..c9da9ac --- /dev/null +++ b/signwriting_translation/create_parallel_data.py @@ -0,0 +1,161 @@ +import argparse +import csv +from itertools import chain +from pathlib import Path + +from signwriting.formats.fsw_to_sign import fsw_to_sign +from signwriting.tokenizer import SignWritingTokenizer, normalize_signwriting +from tqdm import tqdm + +from spoken_language_tokenizer import tokenize_spoken_text + +csv.field_size_limit(int(1e6)) + + +def load_csv(data_path: Path): + with open(data_path, 'r', encoding="utf-8") as f: + reader = csv.DictReader(f) + return list(reader) + + +DIRECTIONS = { + "spoken-to-signed": { + "expanded": 1, + "more": 2, + "cleaned": 2, + }, + "signed-to-spoken": { + "expanded": 1, + "more": 3, + "cleaned": 4, + } +} + +CLEAN_DIRECTIONS = { + "spoken-to-signed": { + "more": 1, + "cleaned": 1, + }, + "signed-to-spoken": { + "more": 1, + "cleaned": 1, + } +} + +sw_tokenizer = SignWritingTokenizer() + + +def process_row(row, files, spoken_direction, repeats=1): + lang_token_1, lang_token_2, *signs = row["source"].split(" ") + # if not (lang_token_1 == "" and lang_token_2 == ""): + # return + + signs = normalize_signwriting(" ".join(signs)).split(" ") + + tokenized_signs = [sw_tokenizer.text_to_tokens(sign) for sign in signs] + signed_tokens = chain.from_iterable(tokenized_signs) + signed = " ".join(signed_tokens) + + # sign language factors + signs = [fsw_to_sign(sign) for sign in signs] + for sign in signs: # override box position same as the tokenizer does + sign["box"]["position"] = (500, 500) + units = list(chain.from_iterable([[sign["box"]] + sign["symbols"] for sign in signs])) + + spoken = tokenize_spoken_text(row["target"]) + + if spoken_direction == "source": + spoken = f"{lang_token_1} {lang_token_2} {spoken}" + else: + signed = f"{lang_token_1} {lang_token_2} {signed}" + units.insert(0, {"symbol": lang_token_1, "position": [0, 0]}) + units.insert(0, {"symbol": lang_token_2, "position": [0, 0]}) + + factors = [ + [s["symbol"][:4] for s in units], + ["c" + (s["symbol"][4] if len(s["symbol"]) > 4 else '0') for s in units], + ["r" + (s["symbol"][5] if len(s["symbol"]) > 5 else '0') for s in units], + ["p" + str(s["position"][0]) for s in units], + ["p" + str(s["position"][1]) for s in units], + ] + + for _ in range(repeats): + files["spoken"].write(spoken + "\n") + files["signed"].write(signed + "\n") + + for i, factor_file in files["signed_factors"].items(): + factor_file.write(" ".join(factors[i]) + "\n") + + +def create_files(split_dir, spoken_d, signed_d): + return { + "spoken": open(split_dir / f"{spoken_d}.txt", "w", encoding="utf-8"), + "signed": open(split_dir / f"{signed_d}.txt", "w", encoding="utf-8"), + "signed_factors": { + i: open(split_dir / f"{signed_d}_{i}.txt", "w", encoding="utf-8") + for i in range(5) + } + } + + +# pylint: disable=too-many-locals +def create_parallel_data(data_dir: Path, output_dir: Path, clean_only=False): + directions_obj = CLEAN_DIRECTIONS if clean_only else DIRECTIONS + + for direction, partitions in directions_obj.items(): + direction_output_dir = output_dir / direction + direction_output_dir.mkdir(parents=True, exist_ok=True) + + train_dir = direction_output_dir / "train" + train_dir.mkdir(parents=True, exist_ok=True) + dev_dir = direction_output_dir / "dev" + dev_dir.mkdir(parents=True, exist_ok=True) + test_dir = direction_output_dir / "test" + test_dir.mkdir(parents=True, exist_ok=True) + + spoken_d = "source" if direction == "spoken-to-signed" else "target" + signed_d = "target" if direction == "spoken-to-signed" else "source" + + split_files = { + "train": create_files(train_dir, spoken_d, signed_d), + "dev": create_files(dev_dir, spoken_d, signed_d), + } + for partition, repeats in partitions.items(): + for split, files in split_files.items(): + split_path = data_dir / partition / f"{split}.csv" + if not split_path.exists(): + if split == "train": + raise FileNotFoundError(f"File {split_path} does not exist") + continue + with open(split_path, 'r', encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in tqdm(reader): + process_row(row, files, spoken_d, repeats) + + test_files = create_files(test_dir, spoken_d, signed_d) + test_path = data_dir / "test" / "all.csv" + with open(test_path, 'r', encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in tqdm(reader): + process_row(row, test_files, spoken_d, repeats=1) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data-dir', type=str, help='Path to data directory') + parser.add_argument('--output-dir', type=str, help='Path to output directory', + default="parallel") + parser.add_argument('--clean-only', action='store_true', help='Use only cleaned data') + args = parser.parse_args() + + data_dir = Path(args.data_dir) + + # create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + create_parallel_data(data_dir, output_dir, clean_only=args.clean_only) + + +if __name__ == "__main__": + main() diff --git a/signwriting_translation/deteokenize_signwriting.py b/signwriting_translation/deteokenize_signwriting.py new file mode 100644 index 0000000..f059f7b --- /dev/null +++ b/signwriting_translation/deteokenize_signwriting.py @@ -0,0 +1,32 @@ +from signwriting.tokenizer import SignWritingTokenizer + +tokenizer = SignWritingTokenizer() + + +def print_prediction(file_path: str): + with open(file_path, 'r', encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + print(tokenizer.tokens_to_text(line.strip().split(" "))) + + +def print_factored_prediction(factors_file_template: str): + files_rows = [] + for i in range(5): + factors_file_path = factors_file_template.format(i) + with open(factors_file_path, 'r', encoding="utf-8") as f: + lines = f.readlines() + files_rows.append([line.strip().split(" ") for line in lines]) + + rows_factors = list(zip(*files_rows)) + for row in rows_factors: + unfactored = list(zip(*row)) + symbols = [" ".join(f).replace("M c0 r0", "M") for f in unfactored] + + print(tokenizer.tokens_to_text((" ".join(symbols)).split(" "))) + + +print_prediction("/home/amoryo/sign-language/signwriting-translation/parallel/spoken-to-signed/test/target.txt") +print_factored_prediction( + # pylint: disable=line-too-long + "/shares/volk.cl.uzh/amoryo/checkpoints/signwriting-translation/spoken-to-signed/target-factors/model/decode.output.{}.00332") diff --git a/signwriting_translation/graph_metrics.py b/signwriting_translation/graph_metrics.py new file mode 100644 index 0000000..cd5440a --- /dev/null +++ b/signwriting_translation/graph_metrics.py @@ -0,0 +1,40 @@ +from collections import defaultdict + +import matplotlib.pyplot as plt + +MODELS_DIR = "/shares/volk.cl.uzh/amoryo/checkpoints/signwriting-translation/" +DIRECTION = "spoken-to-signed" +MODELS = [ + "no-factors", + "target-factors", + "target-factors-v2", + "target-factors-v4", + "target-factors-tuned" +] + +if __name__ == "__main__": + models_metrics = defaultdict(lambda: defaultdict(list)) + + for model_name in MODELS: + metrics_file = f"{MODELS_DIR}{DIRECTION}/{model_name}/model/metrics" + with open(metrics_file, 'r', encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + for metric in line.strip().split("\t")[1:]: + name, value = metric.split("=") + try: + models_metrics[model_name][name].append(float(value)) + except ValueError: + pass + + for metric in ['chrf', 'signwriting-similarity']: + plt.figure(figsize=(10, 5)) + + plt.grid(axis='y', linestyle='--', linewidth=0.5) + for model_name, metrics in models_metrics.items(): + plt.plot(metrics[f"{metric}-val"], label=model_name) + if metric == 'signwriting-similarity': + plt.ylim(0.35, None) + plt.legend(loc='lower right') + plt.savefig(f"{metric}.png") + plt.close() diff --git a/signwriting_translation/prepare_data.sh b/signwriting_translation/prepare_data.sh new file mode 100644 index 0000000..c2b562a --- /dev/null +++ b/signwriting_translation/prepare_data.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +#SBATCH --job-name=prepare-data +#SBATCH --time=24:00:00 +#SBATCH --mem=32G +#SBATCH --output=prepare_data.out + +#SBATCH --ntasks=1 + +set -e # exit on error +set -x # echo commands + +module load anaconda3 +source activate sockeye + + +# Download the SignBank repository if not exists +SIGNBANK_DIR="/home/amoryo/sign-language/signbank-annotation/signbank-plus" +[ ! -d "$SIGNBANK_DIR" ] && \ +git clone https://github.com/sign-language-processing/signbank-plus.git "$SIGNBANK_DIR" + +# Process data for machine translation if not exists +[ ! -d "$SIGNBANK_DIR/data/parallel/cleaned" ] && \ +python "$SIGNBANK_DIR/signbank_plus/prep_nmt.py" + +# Train a tokenizer +python spoken_language_tokenizer.py \ + --files $SIGNBANK_DIR/data/parallel/cleaned/train.target $SIGNBANK_DIR/data/parallel/more/train.target \ + --output="tokenizer.json" + +# Prepare the parallel corpus (with source/target-factors) +python create_parallel_data.py \ + --data-dir="$SIGNBANK_DIR/data/parallel/" \ + --output-dir="../parallel" + +# Prepare the clean parallel corpus (with source/target-factors) +python create_parallel_data.py \ + --data-dir="$SIGNBANK_DIR/data/parallel/" \ + --output-dir="../parallel-clean" \ + --clean-only + diff --git a/signwriting_translation/spoken_language_tokenizer.py b/signwriting_translation/spoken_language_tokenizer.py new file mode 100644 index 0000000..4c903a3 --- /dev/null +++ b/signwriting_translation/spoken_language_tokenizer.py @@ -0,0 +1,46 @@ +import argparse +from functools import lru_cache +from typing import List + +import tokenizers +from tokenizers import Tokenizer, pre_tokenizers, normalizers, decoders, trainers +from tokenizers.models import BPE + + +@lru_cache(maxsize=None) +def load_tokenizer(tokenizer_file: str = 'tokenizer.json'): + return Tokenizer.from_file(tokenizer_file) + + +@lru_cache(maxsize=int(1e7)) +def tokenize_spoken_text(text: str, tokenizer_file: str = 'tokenizer.json'): + tokenizer = load_tokenizer(tokenizer_file) + encoding = tokenizer.encode(text) + return " ".join(encoding.tokens) + + +def train(files: List[str], target_file: str): + tokenizer = Tokenizer(BPE()) + tokenizer.normalizer = normalizers.NFKD() + # Take the pre tokenizer setting from GPT-4, https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee + tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ + pre_tokenizers.Split(pattern=tokenizers.Regex( + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"), + behavior="removed", invert=True), + # For non ascii characters, it gets completely unreadable, but it works nonetheless! + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) + ]) + tokenizer.decoder = decoders.ByteLevel() + trainer = trainers.BpeTrainer(vocab_size=8000) + tokenizer.train(files=files, trainer=trainer) + + tokenizer.save(target_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--files", nargs='+', type=str, help="Files to train tokenizer on") + parser.add_argument("--output", type=str, help="Output file for tokenizer model") + args = parser.parse_args() + + train(args.files, args.output) diff --git a/signwriting_translation/train_sockeye_model.sh b/signwriting_translation/train_sockeye_model.sh new file mode 100644 index 0000000..528b5e8 --- /dev/null +++ b/signwriting_translation/train_sockeye_model.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +#SBATCH --job-name=train-sockeye +#SBATCH --time=48:00:00 +#SBATCH --mem=16G +#SBATCH --output=train-%j.out + +#SBATCH --ntasks=1 +#SBATCH --gres gpu:1 +#SBATCH --constraint=GPUMEM80GB + +set -e # exit on error +set -x # echo commands + +module load anaconda3 +source activate sockeye + + +# Parse command line arguments +for arg in "$@"; do + # if stripped arg is empty or (space, new line), skip + [ -z "${arg// }" ] && continue + + # Split the argument into name and value parts + key="${arg%%=*}" # Extract everything before '=' + value="${arg#*=}" # Extract everything after '=' + + # Remove leading '--' from the key name + key="${key##--}" + + # Declare variable dynamically and assign value + declare "$key"="$value" +done + + +mkdir -p $model_dir + +# e.g., "signwriting-similarity", "chrf" (default) +optimized_metric=${optimized_metric:-"chrf"} + +# Flags for source and target factors +use_source_factors=${use_source_factors:-"false"} +use_target_factors=${use_target_factors:-"false"} + +# Clone sockeye if doesn't exist +#[ ! -d sockeye ] && git clone https://github.com/sign-language-processing/sockeye.git +#pip install ./sockeye +# +## Install SignWriting evaluation package for optimized metric +#pip install git+https://github.com/sign-language-processing/signwriting +#pip install git+https://github.com/sign-language-processing/signwriting-evaluation +#pip install tensorboard + +function find_source_files() { + local directory=$1 + find "$directory" -type f -name 'source_[1-9]*.txt' -printf "$directory/%f\n" | sort | tr '\n' ' ' +} + +function find_target_files() { + local directory=$1 + find "$directory" -type f -name 'target_[1-9]*.txt' -printf "$directory/%f\n" | sort | tr '\n' ' ' +} + + +function translation_files() { + local name=$1 + local type=$2 # e.g., "source" or "target" + local split=$3 # e.g., "train", "dev", or "test" + local use_factors=$4 # Pass 'true' or 'false' to use factors + + # Determine the file finder function based on the type + local find_function="find_${type}_files" + + if [[ "$use_factors" == "true" ]]; then + echo "--${name} ${split}/${type}_0.txt --${name}-factors $($find_function "$split")" + else + echo "--${name} ${split}/${type}.txt" + fi +} + +function find_vocabulary_factors_files() { + local directory=$1 + local type_short=$2 + find "$directory" -type f -name "vocab.${type_short}.[1-9]*.json" -printf "$directory/%f\n" | sort | tr '\n' ' ' +} + +function vocabulary_files() { + local base_model_dir=$1 + local type=$2 # e.g., "src" or "trg" + local type_short=$3 # e.g., "src" or "trg" + local use_factors=$4 # Pass 'true' or 'false' to use factors + + if [ -z "$base_model_dir" ]; then + return + fi + + echo "--${type}-vocab $base_model_dir/model/vocab.${type_short}.0.json " + + if [[ "$use_factors" == "true" ]]; then + echo "--${type}-factor-vocabs $(find_vocabulary_factors_files $base_model_dir/model $type_short)" + fi +} + +# max seq len based on factor usage +max_seq_len=2048 +[ "$use_source_factors" == "true" ] && max_seq_len=512 +[ "$use_target_factors" == "true" ] && max_seq_len=512 + +# Prepare data +TRAIN_DATA_DIR="$model_dir/train_data" +[ ! -f "$TRAIN_DATA_DIR/data.version" ] && \ +python -m sockeye.prepare_data \ + --max-seq-len $max_seq_len:$max_seq_len \ + $(vocabulary_files "$base_model_dir" "source" "src" $use_source_factors) \ + $(translation_files "source" "source" "$data_dir/train" $use_source_factors) \ + $(vocabulary_files "$base_model_dir" "target" "trg" $use_target_factors) \ + $(translation_files "target" "target" "$data_dir/train" $use_target_factors) \ + --output $TRAIN_DATA_DIR \ + +cp tokenizer.json $model_dir/tokenizer.json + +MODEL_DIR="$model_dir/model" +rm -rf $MODEL_DIR + +# batch size refers to number of target tokens, has to be larger than max tokens set in prepare_data +batch_size=$((max_seq_len * 2 + 1)) +extra_arguments="" +# params is set --params $base_model_dir/model/params.best if $base_model_dir is set +if [ -n "$base_model_dir" ]; then + extra_arguments="${extra_arguments} --params $base_model_dir/model/params.best" +fi + +# From https://aclanthology.org/2023.findings-eacl.127.pdf +# target-factors-weight 0.2 +# weight-tying-type "trg_softmax" +# learning-rate-reduce-factor 0.7 +# label-smoothing 0.2 +# embed-dropout 0.5 +# transformer-dropout is double than the default, but less than 0.5 from the paper +python -m sockeye.train \ + -d $TRAIN_DATA_DIR \ + --weight-tying-type "trg_softmax" \ + --max-seq-len $max_seq_len:$max_seq_len \ + --batch-size $batch_size \ + --source-factors-combine sum \ + --target-factors-combine sum \ + --target-factors-weight 0.2 \ + $(translation_files "validation-source" "source" "$data_dir/test" $use_source_factors) \ + $(translation_files "validation-target" "target" "$data_dir/test" $use_target_factors) \ + --optimized-metric "$optimized_metric" \ + --learning-rate-warmup 1000 \ + --learning-rate-reduce-factor 0.7 \ + --decode-and-evaluate 1000 \ + --checkpoint-interval 1000 \ + --max-num-checkpoint-not-improved 50 \ + --embed-dropout 0.5 \ + --transformer-dropout-prepost 0.2 \ + --transformer-dropout-act 0.2 \ + --transformer-dropout-attention 0.2 \ + --label-smoothing 0.2 \ + --label-smoothing-impl torch \ + --no-bucketing \ + $extra_arguments \ + --output $MODEL_DIR