Skip to content

Commit

Permalink
feat(): add training and inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitMY committed May 11, 2024
1 parent f2d7997 commit 17243d7
Show file tree
Hide file tree
Showing 9 changed files with 630 additions and 8 deletions.
86 changes: 78 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

This project trains and serves models to translate SignWriting into spoken language text and vice versa.

Ideally, we would like to use [Bergamot](https://github.com/mozilla/firefox-translations-training),
but we could not get it to work.

## Usage

```bash
Expand All @@ -15,16 +18,83 @@ signwriting_to_text --spoken-language="en" --signed-language="ase" --input="M525
text_to_signwriting --spoken-language="en" --signed-language="ase" --input="Sign Language"
```

### Examples
## Data

(These examples are taken from the DSGS Vokabeltrainer)
We use the SignBank+ Dataset from [signbank-plus](https://github.com/sign-language-processing/signbank-plus).

| | 00004 | 00007 | 00015 |
|:-----------:|:--------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------:|
| SignWriting | <img src="https://github.com/sign-language-processing/signwriting-transcription/blob/main/assets/examples/00004.png?raw=true" width="50px"> | <img src="https://github.com/sign-language-processing/signwriting-transcription/blob/main/assets/examples/00007.png?raw=true" width="50px"> | <img src="https://github.com/sign-language-processing/signwriting-transcription/blob/main/assets/examples/00015.png?raw=true" width="50px"> |
| Video | <img src="https://github.com/sign-language-processing/signwriting-transcription/blob/main/assets/examples/00004.gif?raw=true" width="150px"> | <img src="https://github.com/sign-language-processing/signwriting-transcription/blob/main/assets/examples/00007.gif?raw=true" width="150px"> | <img src="https://github.com/sign-language-processing/signwriting-transcription/blob/main/assets/examples/00015.gif?raw=true" width="150px"> |
After finding a large mismatch in distribution between the `validation` and `test` sets,
we decided to use the `test` set as the `validation` set, without training multiple models.
This is a *bad* decision if we were to train and compare multiple models, or want to improve the model in the future.
To change this, change the `sockeye.train` command in the `train_sockeye_model.sh` script.

## Data
## Steps

We use the SignBank+ Dataset from [signbank-plus](https://github.com/sign-language-processing/signbank-plus).
```bash
# 0. Setup the environment.
conda create --name sockeye python=3.11 -y
conda activate sockeye

cd signwriting_translation

MODEL_DIR=/shares/volk.cl.uzh/amoryo/checkpoints/signwriting-translation
DATA_DIR=/home/amoryo/sign-language/signwriting-translation/parallel
DIRECTION="spoken-to-signed"

# 1. Download and tokenize the signbank-plus dataset
sbatch prepare_data.sh

# 2. Train a spoken-to-signed translation model
# (Train without factors)
sbatch train_sockeye_model.sh \
--data_dir="$DATA_DIR/$DIRECTION" \
--model_dir="$MODEL_DIR/$DIRECTION/no-factors" \
--optimized_metric="signwriting-similarity" \
--partition lowprio
# (Train with factors)
sbatch train_sockeye_model.sh \
--data_dir="$DATA_DIR/$DIRECTION" \
--model_dir="$MODEL_DIR/$DIRECTION/target-factors-v4" \
--optimized_metric="signwriting-similarity" \
--use_target_factors=true \
--partition lowprio
# (Fine tune model on cleaned data)
sbatch train_sockeye_model.sh \
--data_dir="$DATA_DIR-clean/$DIRECTION" \
--model_dir="$MODEL_DIR/$DIRECTION/target-factors-tuned" \
--base_model_dir="$MODEL_DIR/$DIRECTION/target-factors-v4" \
--optimized_metric="signwriting-similarity" \
--use_target_factors=true \
--partition lowprio

# 2.1 (Optional) See the validation metrics
cat "$MODEL_DIR/$DIRECTION/no-factors/model/metrics" | grep "signwriting-similarity"
cat "$MODEL_DIR/$DIRECTION/target-factors/model/metrics" | grep "signwriting-similarity"

# 3. Test it yourself
python -m signwriting_translation.bin \
--model="$MODEL_DIR/$DIRECTION/target-factors-tuned/model" \
--spoken-language="en" \
--signed-language="ase" \
--input="My name is John."
```

## Upload to HuggingFace

```bash
# Copy the model files to a new directory
SE_MODEL_PATH="$MODEL_DIR/$DIRECTION/target-factors-tuned"
HF_MODEL_PATH="$MODEL_DIR/$DIRECTION/huggingface/target-factors-tuned"

mkdir -p "$HF_MODEL_PATH"
cp tokenizer.json "$HF_MODEL_PATH/tokenizer.json"
cp "$SE_MODEL_PATH/model/params.best" "$HF_MODEL_PATH/params.best"
cp "$SE_MODEL_PATH/model/version" "$HF_MODEL_PATH/version"
cp "$SE_MODEL_PATH/model/metrics" "$HF_MODEL_PATH/metrics"
cp "$SE_MODEL_PATH/model/config" "$HF_MODEL_PATH/config"
cp "$SE_MODEL_PATH/model/args.yaml" "$HF_MODEL_PATH/args.yaml"
cp "$SE_MODEL_PATH/model/vocab."* "$HF_MODEL_PATH"

# Upload to HuggingFace
huggingface-cli login
huggingface-cli upload sign/sockeye-text-to-factored-signwriting "$HF_MODEL_PATH" .
```
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ authors = [
readme = "README.md"
dependencies = [
"signwriting @ git+https://github.com/sign-language-processing/signwriting",
"sockeye @ git+https://github.com/sign-language-processing/sockeye",
"tokenizers",
"huggingface-hub"
]

[project.optional-dependencies]
dev = [
"pytest",
"pylint"
Expand All @@ -26,6 +30,7 @@ disable = [
"C0115", # Missing class docstring
"C0116", # Missing function or method docstring
]
good-names = ["i", "f", "x", "y"]

[tool.setuptools]
packages = [
Expand Down
63 changes: 63 additions & 0 deletions signwriting_translation/bin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,65 @@
#!/usr/bin/env python

import argparse
import time
from functools import lru_cache
from pathlib import Path
from typing import List

from signwriting.tokenizer import SignWritingTokenizer
from sockeye.inference import TranslatorOutput
from tokenizers import Tokenizer

sw_tokenizer = SignWritingTokenizer()


def process_translation_output(output: TranslatorOutput):
all_factors = [output.tokens] + output.factor_tokens
symbols = [" ".join(f).replace("M c0 r0", "M") for f in list(zip(*all_factors))]
return sw_tokenizer.tokens_to_text((" ".join(symbols)).split(" "))


@lru_cache(maxsize=None)
def load_sockeye_translator(model_path: str, log_timing: bool = False):
if not Path(model_path).is_dir():
from huggingface_hub import snapshot_download
model_path = snapshot_download(repo_id=model_path)

from sockeye.translate import parse_translation_arguments, load_translator_from_args

now = time.time()
args = parse_translation_arguments([
"-m", model_path,
"--beam-size", "5",
])
translator = load_translator_from_args(args, True)
if log_timing:
print("Loaded sockeye translator in", time.time() - now, "seconds")

tokenizer = Tokenizer.from_file(str(Path(model_path) / 'tokenizer.json'))

return translator, tokenizer


def translate(translator, texts: List[str], log_timing: bool = False):
from sockeye.inference import make_input_from_plain_string

inputs = [make_input_from_plain_string(sentence_id=i, string=s)
for i, s in enumerate(texts)]

now = time.time()
outputs = translator.translate(inputs)
translation_time = time.time() - now
avg_time = translation_time / len(texts)
if log_timing:
print("Translated", len(texts), "texts in", translation_time, "seconds", f"({avg_time:.2f} seconds per text)")
return [process_translation_output(output) for output in outputs]


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='Path to trained model',
default="sign/sockeye-text-to-factored-signwriting")
parser.add_argument('--spoken-language', required=True, type=str, help='spoken language code')
parser.add_argument('--signed-language', required=True, type=str, help='signed language code')
parser.add_argument('--input', required=True, type=str, help='input text or signwriting sequence')
Expand All @@ -15,11 +70,19 @@ def signwriting_to_text():
# pylint: disable=unused-variable
args = get_args()

translator, tokenizer = load_sockeye_translator(args.model)
tokenized_text = " ".join(tokenizer.encode(args.input).tokens)
model_input = f"${args.spoken_language} ${args.signed_language} {tokenized_text}"
outputs = translate(translator, [model_input])
print(outputs[0])


def text_to_signwriting():
# pylint: disable=unused-variable
args = get_args()

return translate(args.model, [args.input])


if __name__ == '__main__':
signwriting_to_text()
161 changes: 161 additions & 0 deletions signwriting_translation/create_parallel_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import argparse
import csv
from itertools import chain
from pathlib import Path

from signwriting.formats.fsw_to_sign import fsw_to_sign
from signwriting.tokenizer import SignWritingTokenizer, normalize_signwriting
from tqdm import tqdm

from spoken_language_tokenizer import tokenize_spoken_text

csv.field_size_limit(int(1e6))


def load_csv(data_path: Path):
with open(data_path, 'r', encoding="utf-8") as f:
reader = csv.DictReader(f)
return list(reader)


DIRECTIONS = {
"spoken-to-signed": {
"expanded": 1,
"more": 2,
"cleaned": 2,
},
"signed-to-spoken": {
"expanded": 1,
"more": 3,
"cleaned": 4,
}
}

CLEAN_DIRECTIONS = {
"spoken-to-signed": {
"more": 1,
"cleaned": 1,
},
"signed-to-spoken": {
"more": 1,
"cleaned": 1,
}
}

sw_tokenizer = SignWritingTokenizer()


def process_row(row, files, spoken_direction, repeats=1):
lang_token_1, lang_token_2, *signs = row["source"].split(" ")
# if not (lang_token_1 == "<en>" and lang_token_2 == "<ase>"):
# return

signs = normalize_signwriting(" ".join(signs)).split(" ")

tokenized_signs = [sw_tokenizer.text_to_tokens(sign) for sign in signs]
signed_tokens = chain.from_iterable(tokenized_signs)
signed = " ".join(signed_tokens)

# sign language factors
signs = [fsw_to_sign(sign) for sign in signs]
for sign in signs: # override box position same as the tokenizer does
sign["box"]["position"] = (500, 500)
units = list(chain.from_iterable([[sign["box"]] + sign["symbols"] for sign in signs]))

spoken = tokenize_spoken_text(row["target"])

if spoken_direction == "source":
spoken = f"{lang_token_1} {lang_token_2} {spoken}"
else:
signed = f"{lang_token_1} {lang_token_2} {signed}"
units.insert(0, {"symbol": lang_token_1, "position": [0, 0]})
units.insert(0, {"symbol": lang_token_2, "position": [0, 0]})

factors = [
[s["symbol"][:4] for s in units],
["c" + (s["symbol"][4] if len(s["symbol"]) > 4 else '0') for s in units],
["r" + (s["symbol"][5] if len(s["symbol"]) > 5 else '0') for s in units],
["p" + str(s["position"][0]) for s in units],
["p" + str(s["position"][1]) for s in units],
]

for _ in range(repeats):
files["spoken"].write(spoken + "\n")
files["signed"].write(signed + "\n")

for i, factor_file in files["signed_factors"].items():
factor_file.write(" ".join(factors[i]) + "\n")


def create_files(split_dir, spoken_d, signed_d):
return {
"spoken": open(split_dir / f"{spoken_d}.txt", "w", encoding="utf-8"),
"signed": open(split_dir / f"{signed_d}.txt", "w", encoding="utf-8"),
"signed_factors": {
i: open(split_dir / f"{signed_d}_{i}.txt", "w", encoding="utf-8")
for i in range(5)
}
}


# pylint: disable=too-many-locals
def create_parallel_data(data_dir: Path, output_dir: Path, clean_only=False):
directions_obj = CLEAN_DIRECTIONS if clean_only else DIRECTIONS

for direction, partitions in directions_obj.items():
direction_output_dir = output_dir / direction
direction_output_dir.mkdir(parents=True, exist_ok=True)

train_dir = direction_output_dir / "train"
train_dir.mkdir(parents=True, exist_ok=True)
dev_dir = direction_output_dir / "dev"
dev_dir.mkdir(parents=True, exist_ok=True)
test_dir = direction_output_dir / "test"
test_dir.mkdir(parents=True, exist_ok=True)

spoken_d = "source" if direction == "spoken-to-signed" else "target"
signed_d = "target" if direction == "spoken-to-signed" else "source"

split_files = {
"train": create_files(train_dir, spoken_d, signed_d),
"dev": create_files(dev_dir, spoken_d, signed_d),
}
for partition, repeats in partitions.items():
for split, files in split_files.items():
split_path = data_dir / partition / f"{split}.csv"
if not split_path.exists():
if split == "train":
raise FileNotFoundError(f"File {split_path} does not exist")
continue
with open(split_path, 'r', encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in tqdm(reader):
process_row(row, files, spoken_d, repeats)

test_files = create_files(test_dir, spoken_d, signed_d)
test_path = data_dir / "test" / "all.csv"
with open(test_path, 'r', encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in tqdm(reader):
process_row(row, test_files, spoken_d, repeats=1)


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str, help='Path to data directory')
parser.add_argument('--output-dir', type=str, help='Path to output directory',
default="parallel")
parser.add_argument('--clean-only', action='store_true', help='Use only cleaned data')
args = parser.parse_args()

data_dir = Path(args.data_dir)

# create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

create_parallel_data(data_dir, output_dir, clean_only=args.clean_only)


if __name__ == "__main__":
main()
Loading

0 comments on commit 17243d7

Please sign in to comment.