diff --git a/README.md b/README.md
index 713bfd5..fd3699b 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,9 @@
This project trains and serves models to translate SignWriting into spoken language text and vice versa.
+Ideally, we would like to use [Bergamot](https://github.com/mozilla/firefox-translations-training),
+but we could not get it to work.
+
## Usage
```bash
@@ -15,16 +18,83 @@ signwriting_to_text --spoken-language="en" --signed-language="ase" --input="M525
text_to_signwriting --spoken-language="en" --signed-language="ase" --input="Sign Language"
```
-### Examples
+## Data
-(These examples are taken from the DSGS Vokabeltrainer)
+We use the SignBank+ Dataset from [signbank-plus](https://github.com/sign-language-processing/signbank-plus).
-| | 00004 | 00007 | 00015 |
-|:-----------:|:--------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------:|
-| SignWriting | | | |
-| Video | | | |
+After finding a large mismatch in distribution between the `validation` and `test` sets,
+we decided to use the `test` set as the `validation` set, without training multiple models.
+This is a *bad* decision if we were to train and compare multiple models, or want to improve the model in the future.
+To change this, change the `sockeye.train` command in the `train_sockeye_model.sh` script.
-## Data
+## Steps
-We use the SignBank+ Dataset from [signbank-plus](https://github.com/sign-language-processing/signbank-plus).
+```bash
+# 0. Setup the environment.
+conda create --name sockeye python=3.11 -y
+conda activate sockeye
+
+cd signwriting_translation
+
+MODEL_DIR=/shares/volk.cl.uzh/amoryo/checkpoints/signwriting-translation
+DATA_DIR=/home/amoryo/sign-language/signwriting-translation/parallel
+DIRECTION="spoken-to-signed"
+
+# 1. Download and tokenize the signbank-plus dataset
+sbatch prepare_data.sh
+
+# 2. Train a spoken-to-signed translation model
+# (Train without factors)
+sbatch train_sockeye_model.sh \
+ --data_dir="$DATA_DIR/$DIRECTION" \
+ --model_dir="$MODEL_DIR/$DIRECTION/no-factors" \
+ --optimized_metric="signwriting-similarity" \
+ --partition lowprio
+# (Train with factors)
+sbatch train_sockeye_model.sh \
+ --data_dir="$DATA_DIR/$DIRECTION" \
+ --model_dir="$MODEL_DIR/$DIRECTION/target-factors-v4" \
+ --optimized_metric="signwriting-similarity" \
+ --use_target_factors=true \
+ --partition lowprio
+# (Fine tune model on cleaned data)
+sbatch train_sockeye_model.sh \
+ --data_dir="$DATA_DIR-clean/$DIRECTION" \
+ --model_dir="$MODEL_DIR/$DIRECTION/target-factors-tuned" \
+ --base_model_dir="$MODEL_DIR/$DIRECTION/target-factors-v4" \
+ --optimized_metric="signwriting-similarity" \
+ --use_target_factors=true \
+ --partition lowprio
+
+# 2.1 (Optional) See the validation metrics
+cat "$MODEL_DIR/$DIRECTION/no-factors/model/metrics" | grep "signwriting-similarity"
+cat "$MODEL_DIR/$DIRECTION/target-factors/model/metrics" | grep "signwriting-similarity"
+
+# 3. Test it yourself
+python -m signwriting_translation.bin \
+ --model="$MODEL_DIR/$DIRECTION/target-factors-tuned/model" \
+ --spoken-language="en" \
+ --signed-language="ase" \
+ --input="My name is John."
+```
+
+## Upload to HuggingFace
+
+```bash
+# Copy the model files to a new directory
+SE_MODEL_PATH="$MODEL_DIR/$DIRECTION/target-factors-tuned"
+HF_MODEL_PATH="$MODEL_DIR/$DIRECTION/huggingface/target-factors-tuned"
+
+mkdir -p "$HF_MODEL_PATH"
+cp tokenizer.json "$HF_MODEL_PATH/tokenizer.json"
+cp "$SE_MODEL_PATH/model/params.best" "$HF_MODEL_PATH/params.best"
+cp "$SE_MODEL_PATH/model/version" "$HF_MODEL_PATH/version"
+cp "$SE_MODEL_PATH/model/metrics" "$HF_MODEL_PATH/metrics"
+cp "$SE_MODEL_PATH/model/config" "$HF_MODEL_PATH/config"
+cp "$SE_MODEL_PATH/model/args.yaml" "$HF_MODEL_PATH/args.yaml"
+cp "$SE_MODEL_PATH/model/vocab."* "$HF_MODEL_PATH"
+# Upload to HuggingFace
+huggingface-cli login
+huggingface-cli upload sign/sockeye-text-to-factored-signwriting "$HF_MODEL_PATH" .
+```
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index d4d2a7b..75bba90 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,8 +8,12 @@ authors = [
readme = "README.md"
dependencies = [
"signwriting @ git+https://github.com/sign-language-processing/signwriting",
+ "sockeye @ git+https://github.com/sign-language-processing/sockeye",
+ "tokenizers",
+ "huggingface-hub"
]
+[project.optional-dependencies]
dev = [
"pytest",
"pylint"
@@ -26,6 +30,7 @@ disable = [
"C0115", # Missing class docstring
"C0116", # Missing function or method docstring
]
+good-names = ["i", "f", "x", "y"]
[tool.setuptools]
packages = [
diff --git a/signwriting_translation/bin.py b/signwriting_translation/bin.py
index 7ca3e5e..2b2f914 100644
--- a/signwriting_translation/bin.py
+++ b/signwriting_translation/bin.py
@@ -1,10 +1,65 @@
#!/usr/bin/env python
import argparse
+import time
+from functools import lru_cache
+from pathlib import Path
+from typing import List
+
+from signwriting.tokenizer import SignWritingTokenizer
+from sockeye.inference import TranslatorOutput
+from tokenizers import Tokenizer
+
+sw_tokenizer = SignWritingTokenizer()
+
+
+def process_translation_output(output: TranslatorOutput):
+ all_factors = [output.tokens] + output.factor_tokens
+ symbols = [" ".join(f).replace("M c0 r0", "M") for f in list(zip(*all_factors))]
+ return sw_tokenizer.tokens_to_text((" ".join(symbols)).split(" "))
+
+
+@lru_cache(maxsize=None)
+def load_sockeye_translator(model_path: str, log_timing: bool = False):
+ if not Path(model_path).is_dir():
+ from huggingface_hub import snapshot_download
+ model_path = snapshot_download(repo_id=model_path)
+
+ from sockeye.translate import parse_translation_arguments, load_translator_from_args
+
+ now = time.time()
+ args = parse_translation_arguments([
+ "-m", model_path,
+ "--beam-size", "5",
+ ])
+ translator = load_translator_from_args(args, True)
+ if log_timing:
+ print("Loaded sockeye translator in", time.time() - now, "seconds")
+
+ tokenizer = Tokenizer.from_file(str(Path(model_path) / 'tokenizer.json'))
+
+ return translator, tokenizer
+
+
+def translate(translator, texts: List[str], log_timing: bool = False):
+ from sockeye.inference import make_input_from_plain_string
+
+ inputs = [make_input_from_plain_string(sentence_id=i, string=s)
+ for i, s in enumerate(texts)]
+
+ now = time.time()
+ outputs = translator.translate(inputs)
+ translation_time = time.time() - now
+ avg_time = translation_time / len(texts)
+ if log_timing:
+ print("Translated", len(texts), "texts in", translation_time, "seconds", f"({avg_time:.2f} seconds per text)")
+ return [process_translation_output(output) for output in outputs]
def get_args():
parser = argparse.ArgumentParser()
+ parser.add_argument('--model', type=str, help='Path to trained model',
+ default="sign/sockeye-text-to-factored-signwriting")
parser.add_argument('--spoken-language', required=True, type=str, help='spoken language code')
parser.add_argument('--signed-language', required=True, type=str, help='signed language code')
parser.add_argument('--input', required=True, type=str, help='input text or signwriting sequence')
@@ -15,11 +70,19 @@ def signwriting_to_text():
# pylint: disable=unused-variable
args = get_args()
+ translator, tokenizer = load_sockeye_translator(args.model)
+ tokenized_text = " ".join(tokenizer.encode(args.input).tokens)
+ model_input = f"${args.spoken_language} ${args.signed_language} {tokenized_text}"
+ outputs = translate(translator, [model_input])
+ print(outputs[0])
+
def text_to_signwriting():
# pylint: disable=unused-variable
args = get_args()
+ return translate(args.model, [args.input])
+
if __name__ == '__main__':
signwriting_to_text()
diff --git a/signwriting_translation/create_parallel_data.py b/signwriting_translation/create_parallel_data.py
new file mode 100644
index 0000000..c9da9ac
--- /dev/null
+++ b/signwriting_translation/create_parallel_data.py
@@ -0,0 +1,161 @@
+import argparse
+import csv
+from itertools import chain
+from pathlib import Path
+
+from signwriting.formats.fsw_to_sign import fsw_to_sign
+from signwriting.tokenizer import SignWritingTokenizer, normalize_signwriting
+from tqdm import tqdm
+
+from spoken_language_tokenizer import tokenize_spoken_text
+
+csv.field_size_limit(int(1e6))
+
+
+def load_csv(data_path: Path):
+ with open(data_path, 'r', encoding="utf-8") as f:
+ reader = csv.DictReader(f)
+ return list(reader)
+
+
+DIRECTIONS = {
+ "spoken-to-signed": {
+ "expanded": 1,
+ "more": 2,
+ "cleaned": 2,
+ },
+ "signed-to-spoken": {
+ "expanded": 1,
+ "more": 3,
+ "cleaned": 4,
+ }
+}
+
+CLEAN_DIRECTIONS = {
+ "spoken-to-signed": {
+ "more": 1,
+ "cleaned": 1,
+ },
+ "signed-to-spoken": {
+ "more": 1,
+ "cleaned": 1,
+ }
+}
+
+sw_tokenizer = SignWritingTokenizer()
+
+
+def process_row(row, files, spoken_direction, repeats=1):
+ lang_token_1, lang_token_2, *signs = row["source"].split(" ")
+ # if not (lang_token_1 == "" and lang_token_2 == ""):
+ # return
+
+ signs = normalize_signwriting(" ".join(signs)).split(" ")
+
+ tokenized_signs = [sw_tokenizer.text_to_tokens(sign) for sign in signs]
+ signed_tokens = chain.from_iterable(tokenized_signs)
+ signed = " ".join(signed_tokens)
+
+ # sign language factors
+ signs = [fsw_to_sign(sign) for sign in signs]
+ for sign in signs: # override box position same as the tokenizer does
+ sign["box"]["position"] = (500, 500)
+ units = list(chain.from_iterable([[sign["box"]] + sign["symbols"] for sign in signs]))
+
+ spoken = tokenize_spoken_text(row["target"])
+
+ if spoken_direction == "source":
+ spoken = f"{lang_token_1} {lang_token_2} {spoken}"
+ else:
+ signed = f"{lang_token_1} {lang_token_2} {signed}"
+ units.insert(0, {"symbol": lang_token_1, "position": [0, 0]})
+ units.insert(0, {"symbol": lang_token_2, "position": [0, 0]})
+
+ factors = [
+ [s["symbol"][:4] for s in units],
+ ["c" + (s["symbol"][4] if len(s["symbol"]) > 4 else '0') for s in units],
+ ["r" + (s["symbol"][5] if len(s["symbol"]) > 5 else '0') for s in units],
+ ["p" + str(s["position"][0]) for s in units],
+ ["p" + str(s["position"][1]) for s in units],
+ ]
+
+ for _ in range(repeats):
+ files["spoken"].write(spoken + "\n")
+ files["signed"].write(signed + "\n")
+
+ for i, factor_file in files["signed_factors"].items():
+ factor_file.write(" ".join(factors[i]) + "\n")
+
+
+def create_files(split_dir, spoken_d, signed_d):
+ return {
+ "spoken": open(split_dir / f"{spoken_d}.txt", "w", encoding="utf-8"),
+ "signed": open(split_dir / f"{signed_d}.txt", "w", encoding="utf-8"),
+ "signed_factors": {
+ i: open(split_dir / f"{signed_d}_{i}.txt", "w", encoding="utf-8")
+ for i in range(5)
+ }
+ }
+
+
+# pylint: disable=too-many-locals
+def create_parallel_data(data_dir: Path, output_dir: Path, clean_only=False):
+ directions_obj = CLEAN_DIRECTIONS if clean_only else DIRECTIONS
+
+ for direction, partitions in directions_obj.items():
+ direction_output_dir = output_dir / direction
+ direction_output_dir.mkdir(parents=True, exist_ok=True)
+
+ train_dir = direction_output_dir / "train"
+ train_dir.mkdir(parents=True, exist_ok=True)
+ dev_dir = direction_output_dir / "dev"
+ dev_dir.mkdir(parents=True, exist_ok=True)
+ test_dir = direction_output_dir / "test"
+ test_dir.mkdir(parents=True, exist_ok=True)
+
+ spoken_d = "source" if direction == "spoken-to-signed" else "target"
+ signed_d = "target" if direction == "spoken-to-signed" else "source"
+
+ split_files = {
+ "train": create_files(train_dir, spoken_d, signed_d),
+ "dev": create_files(dev_dir, spoken_d, signed_d),
+ }
+ for partition, repeats in partitions.items():
+ for split, files in split_files.items():
+ split_path = data_dir / partition / f"{split}.csv"
+ if not split_path.exists():
+ if split == "train":
+ raise FileNotFoundError(f"File {split_path} does not exist")
+ continue
+ with open(split_path, 'r', encoding="utf-8") as f:
+ reader = csv.DictReader(f)
+ for row in tqdm(reader):
+ process_row(row, files, spoken_d, repeats)
+
+ test_files = create_files(test_dir, spoken_d, signed_d)
+ test_path = data_dir / "test" / "all.csv"
+ with open(test_path, 'r', encoding="utf-8") as f:
+ reader = csv.DictReader(f)
+ for row in tqdm(reader):
+ process_row(row, test_files, spoken_d, repeats=1)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data-dir', type=str, help='Path to data directory')
+ parser.add_argument('--output-dir', type=str, help='Path to output directory',
+ default="parallel")
+ parser.add_argument('--clean-only', action='store_true', help='Use only cleaned data')
+ args = parser.parse_args()
+
+ data_dir = Path(args.data_dir)
+
+ # create output directory
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ create_parallel_data(data_dir, output_dir, clean_only=args.clean_only)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/signwriting_translation/deteokenize_signwriting.py b/signwriting_translation/deteokenize_signwriting.py
new file mode 100644
index 0000000..f059f7b
--- /dev/null
+++ b/signwriting_translation/deteokenize_signwriting.py
@@ -0,0 +1,32 @@
+from signwriting.tokenizer import SignWritingTokenizer
+
+tokenizer = SignWritingTokenizer()
+
+
+def print_prediction(file_path: str):
+ with open(file_path, 'r', encoding="utf-8") as f:
+ lines = f.readlines()
+ for line in lines:
+ print(tokenizer.tokens_to_text(line.strip().split(" ")))
+
+
+def print_factored_prediction(factors_file_template: str):
+ files_rows = []
+ for i in range(5):
+ factors_file_path = factors_file_template.format(i)
+ with open(factors_file_path, 'r', encoding="utf-8") as f:
+ lines = f.readlines()
+ files_rows.append([line.strip().split(" ") for line in lines])
+
+ rows_factors = list(zip(*files_rows))
+ for row in rows_factors:
+ unfactored = list(zip(*row))
+ symbols = [" ".join(f).replace("M c0 r0", "M") for f in unfactored]
+
+ print(tokenizer.tokens_to_text((" ".join(symbols)).split(" ")))
+
+
+print_prediction("/home/amoryo/sign-language/signwriting-translation/parallel/spoken-to-signed/test/target.txt")
+print_factored_prediction(
+ # pylint: disable=line-too-long
+ "/shares/volk.cl.uzh/amoryo/checkpoints/signwriting-translation/spoken-to-signed/target-factors/model/decode.output.{}.00332")
diff --git a/signwriting_translation/graph_metrics.py b/signwriting_translation/graph_metrics.py
new file mode 100644
index 0000000..cd5440a
--- /dev/null
+++ b/signwriting_translation/graph_metrics.py
@@ -0,0 +1,40 @@
+from collections import defaultdict
+
+import matplotlib.pyplot as plt
+
+MODELS_DIR = "/shares/volk.cl.uzh/amoryo/checkpoints/signwriting-translation/"
+DIRECTION = "spoken-to-signed"
+MODELS = [
+ "no-factors",
+ "target-factors",
+ "target-factors-v2",
+ "target-factors-v4",
+ "target-factors-tuned"
+]
+
+if __name__ == "__main__":
+ models_metrics = defaultdict(lambda: defaultdict(list))
+
+ for model_name in MODELS:
+ metrics_file = f"{MODELS_DIR}{DIRECTION}/{model_name}/model/metrics"
+ with open(metrics_file, 'r', encoding="utf-8") as f:
+ lines = f.readlines()
+ for line in lines:
+ for metric in line.strip().split("\t")[1:]:
+ name, value = metric.split("=")
+ try:
+ models_metrics[model_name][name].append(float(value))
+ except ValueError:
+ pass
+
+ for metric in ['chrf', 'signwriting-similarity']:
+ plt.figure(figsize=(10, 5))
+
+ plt.grid(axis='y', linestyle='--', linewidth=0.5)
+ for model_name, metrics in models_metrics.items():
+ plt.plot(metrics[f"{metric}-val"], label=model_name)
+ if metric == 'signwriting-similarity':
+ plt.ylim(0.35, None)
+ plt.legend(loc='lower right')
+ plt.savefig(f"{metric}.png")
+ plt.close()
diff --git a/signwriting_translation/prepare_data.sh b/signwriting_translation/prepare_data.sh
new file mode 100644
index 0000000..c2b562a
--- /dev/null
+++ b/signwriting_translation/prepare_data.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+#SBATCH --job-name=prepare-data
+#SBATCH --time=24:00:00
+#SBATCH --mem=32G
+#SBATCH --output=prepare_data.out
+
+#SBATCH --ntasks=1
+
+set -e # exit on error
+set -x # echo commands
+
+module load anaconda3
+source activate sockeye
+
+
+# Download the SignBank repository if not exists
+SIGNBANK_DIR="/home/amoryo/sign-language/signbank-annotation/signbank-plus"
+[ ! -d "$SIGNBANK_DIR" ] && \
+git clone https://github.com/sign-language-processing/signbank-plus.git "$SIGNBANK_DIR"
+
+# Process data for machine translation if not exists
+[ ! -d "$SIGNBANK_DIR/data/parallel/cleaned" ] && \
+python "$SIGNBANK_DIR/signbank_plus/prep_nmt.py"
+
+# Train a tokenizer
+python spoken_language_tokenizer.py \
+ --files $SIGNBANK_DIR/data/parallel/cleaned/train.target $SIGNBANK_DIR/data/parallel/more/train.target \
+ --output="tokenizer.json"
+
+# Prepare the parallel corpus (with source/target-factors)
+python create_parallel_data.py \
+ --data-dir="$SIGNBANK_DIR/data/parallel/" \
+ --output-dir="../parallel"
+
+# Prepare the clean parallel corpus (with source/target-factors)
+python create_parallel_data.py \
+ --data-dir="$SIGNBANK_DIR/data/parallel/" \
+ --output-dir="../parallel-clean" \
+ --clean-only
+
diff --git a/signwriting_translation/spoken_language_tokenizer.py b/signwriting_translation/spoken_language_tokenizer.py
new file mode 100644
index 0000000..4c903a3
--- /dev/null
+++ b/signwriting_translation/spoken_language_tokenizer.py
@@ -0,0 +1,46 @@
+import argparse
+from functools import lru_cache
+from typing import List
+
+import tokenizers
+from tokenizers import Tokenizer, pre_tokenizers, normalizers, decoders, trainers
+from tokenizers.models import BPE
+
+
+@lru_cache(maxsize=None)
+def load_tokenizer(tokenizer_file: str = 'tokenizer.json'):
+ return Tokenizer.from_file(tokenizer_file)
+
+
+@lru_cache(maxsize=int(1e7))
+def tokenize_spoken_text(text: str, tokenizer_file: str = 'tokenizer.json'):
+ tokenizer = load_tokenizer(tokenizer_file)
+ encoding = tokenizer.encode(text)
+ return " ".join(encoding.tokens)
+
+
+def train(files: List[str], target_file: str):
+ tokenizer = Tokenizer(BPE())
+ tokenizer.normalizer = normalizers.NFKD()
+ # Take the pre tokenizer setting from GPT-4, https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
+ pre_tokenizers.Split(pattern=tokenizers.Regex(
+ "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"),
+ behavior="removed", invert=True),
+ # For non ascii characters, it gets completely unreadable, but it works nonetheless!
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
+ ])
+ tokenizer.decoder = decoders.ByteLevel()
+ trainer = trainers.BpeTrainer(vocab_size=8000)
+ tokenizer.train(files=files, trainer=trainer)
+
+ tokenizer.save(target_file)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--files", nargs='+', type=str, help="Files to train tokenizer on")
+ parser.add_argument("--output", type=str, help="Output file for tokenizer model")
+ args = parser.parse_args()
+
+ train(args.files, args.output)
diff --git a/signwriting_translation/train_sockeye_model.sh b/signwriting_translation/train_sockeye_model.sh
new file mode 100644
index 0000000..528b5e8
--- /dev/null
+++ b/signwriting_translation/train_sockeye_model.sh
@@ -0,0 +1,164 @@
+#!/bin/bash
+
+#SBATCH --job-name=train-sockeye
+#SBATCH --time=48:00:00
+#SBATCH --mem=16G
+#SBATCH --output=train-%j.out
+
+#SBATCH --ntasks=1
+#SBATCH --gres gpu:1
+#SBATCH --constraint=GPUMEM80GB
+
+set -e # exit on error
+set -x # echo commands
+
+module load anaconda3
+source activate sockeye
+
+
+# Parse command line arguments
+for arg in "$@"; do
+ # if stripped arg is empty or (space, new line), skip
+ [ -z "${arg// }" ] && continue
+
+ # Split the argument into name and value parts
+ key="${arg%%=*}" # Extract everything before '='
+ value="${arg#*=}" # Extract everything after '='
+
+ # Remove leading '--' from the key name
+ key="${key##--}"
+
+ # Declare variable dynamically and assign value
+ declare "$key"="$value"
+done
+
+
+mkdir -p $model_dir
+
+# e.g., "signwriting-similarity", "chrf" (default)
+optimized_metric=${optimized_metric:-"chrf"}
+
+# Flags for source and target factors
+use_source_factors=${use_source_factors:-"false"}
+use_target_factors=${use_target_factors:-"false"}
+
+# Clone sockeye if doesn't exist
+#[ ! -d sockeye ] && git clone https://github.com/sign-language-processing/sockeye.git
+#pip install ./sockeye
+#
+## Install SignWriting evaluation package for optimized metric
+#pip install git+https://github.com/sign-language-processing/signwriting
+#pip install git+https://github.com/sign-language-processing/signwriting-evaluation
+#pip install tensorboard
+
+function find_source_files() {
+ local directory=$1
+ find "$directory" -type f -name 'source_[1-9]*.txt' -printf "$directory/%f\n" | sort | tr '\n' ' '
+}
+
+function find_target_files() {
+ local directory=$1
+ find "$directory" -type f -name 'target_[1-9]*.txt' -printf "$directory/%f\n" | sort | tr '\n' ' '
+}
+
+
+function translation_files() {
+ local name=$1
+ local type=$2 # e.g., "source" or "target"
+ local split=$3 # e.g., "train", "dev", or "test"
+ local use_factors=$4 # Pass 'true' or 'false' to use factors
+
+ # Determine the file finder function based on the type
+ local find_function="find_${type}_files"
+
+ if [[ "$use_factors" == "true" ]]; then
+ echo "--${name} ${split}/${type}_0.txt --${name}-factors $($find_function "$split")"
+ else
+ echo "--${name} ${split}/${type}.txt"
+ fi
+}
+
+function find_vocabulary_factors_files() {
+ local directory=$1
+ local type_short=$2
+ find "$directory" -type f -name "vocab.${type_short}.[1-9]*.json" -printf "$directory/%f\n" | sort | tr '\n' ' '
+}
+
+function vocabulary_files() {
+ local base_model_dir=$1
+ local type=$2 # e.g., "src" or "trg"
+ local type_short=$3 # e.g., "src" or "trg"
+ local use_factors=$4 # Pass 'true' or 'false' to use factors
+
+ if [ -z "$base_model_dir" ]; then
+ return
+ fi
+
+ echo "--${type}-vocab $base_model_dir/model/vocab.${type_short}.0.json "
+
+ if [[ "$use_factors" == "true" ]]; then
+ echo "--${type}-factor-vocabs $(find_vocabulary_factors_files $base_model_dir/model $type_short)"
+ fi
+}
+
+# max seq len based on factor usage
+max_seq_len=2048
+[ "$use_source_factors" == "true" ] && max_seq_len=512
+[ "$use_target_factors" == "true" ] && max_seq_len=512
+
+# Prepare data
+TRAIN_DATA_DIR="$model_dir/train_data"
+[ ! -f "$TRAIN_DATA_DIR/data.version" ] && \
+python -m sockeye.prepare_data \
+ --max-seq-len $max_seq_len:$max_seq_len \
+ $(vocabulary_files "$base_model_dir" "source" "src" $use_source_factors) \
+ $(translation_files "source" "source" "$data_dir/train" $use_source_factors) \
+ $(vocabulary_files "$base_model_dir" "target" "trg" $use_target_factors) \
+ $(translation_files "target" "target" "$data_dir/train" $use_target_factors) \
+ --output $TRAIN_DATA_DIR \
+
+cp tokenizer.json $model_dir/tokenizer.json
+
+MODEL_DIR="$model_dir/model"
+rm -rf $MODEL_DIR
+
+# batch size refers to number of target tokens, has to be larger than max tokens set in prepare_data
+batch_size=$((max_seq_len * 2 + 1))
+extra_arguments=""
+# params is set --params $base_model_dir/model/params.best if $base_model_dir is set
+if [ -n "$base_model_dir" ]; then
+ extra_arguments="${extra_arguments} --params $base_model_dir/model/params.best"
+fi
+
+# From https://aclanthology.org/2023.findings-eacl.127.pdf
+# target-factors-weight 0.2
+# weight-tying-type "trg_softmax"
+# learning-rate-reduce-factor 0.7
+# label-smoothing 0.2
+# embed-dropout 0.5
+# transformer-dropout is double than the default, but less than 0.5 from the paper
+python -m sockeye.train \
+ -d $TRAIN_DATA_DIR \
+ --weight-tying-type "trg_softmax" \
+ --max-seq-len $max_seq_len:$max_seq_len \
+ --batch-size $batch_size \
+ --source-factors-combine sum \
+ --target-factors-combine sum \
+ --target-factors-weight 0.2 \
+ $(translation_files "validation-source" "source" "$data_dir/test" $use_source_factors) \
+ $(translation_files "validation-target" "target" "$data_dir/test" $use_target_factors) \
+ --optimized-metric "$optimized_metric" \
+ --learning-rate-warmup 1000 \
+ --learning-rate-reduce-factor 0.7 \
+ --decode-and-evaluate 1000 \
+ --checkpoint-interval 1000 \
+ --max-num-checkpoint-not-improved 50 \
+ --embed-dropout 0.5 \
+ --transformer-dropout-prepost 0.2 \
+ --transformer-dropout-act 0.2 \
+ --transformer-dropout-attention 0.2 \
+ --label-smoothing 0.2 \
+ --label-smoothing-impl torch \
+ --no-bucketing \
+ $extra_arguments \
+ --output $MODEL_DIR