Skip to content

Commit

Permalink
fix(pose_to_signwriting): make bin able to work with arbitrary file p…
Browse files Browse the repository at this point in the history
…aths
  • Loading branch information
AmitMY committed Feb 19, 2024
1 parent 4db04d3 commit ffa1247
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 33 deletions.
59 changes: 42 additions & 17 deletions signwriting_transcription/pose_to_signwriting/bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from signwriting_transcription.pose_to_signwriting.data.config import create_test_config
from signwriting_transcription.pose_to_signwriting.data.datasets_pose import pose_to_matrix
from signwriting_transcription.pose_to_signwriting.data.pose_data_utils import build_pose_vocab
from signwriting_transcription.pose_to_signwriting.data.preprocessing import preprocess
from signwriting_transcription.pose_to_signwriting.data.preprocessing import preprocess_single_file
from signwriting_transcription.pose_to_signwriting.joeynmt_pose.prediction import translate

HUGGINGFACE_REPO_ID = "ohadlanger/signwriting_transcription"
Expand All @@ -25,45 +25,70 @@ def get_args():
return parser.parse_args()


def main():
args = get_args()
print('Downloading model...')
os.makedirs("experiment", exist_ok=True)
if not os.path.exists(f'experiment/{args.model}'):
def download_model(experiment_dir: Path, model_name: str):
model_path = experiment_dir / model_name
if not model_path.exists():
# pylint: disable=import-outside-toplevel
from huggingface_hub import hf_hub_download

hf_hub_download(repo_id=HUGGINGFACE_REPO_ID, filename=args.model, repo_type='space', local_dir='experiment')
hf_hub_download(repo_id=HUGGINGFACE_REPO_ID, filename=model_name, repo_type='space', local_dir='experiment')
full_path = str(Path('experiment').absolute())
best_ckpt_path = f'{full_path}/best.ckpt'
# remove symlink if exists
if os.path.exists(best_ckpt_path):
os.remove(best_ckpt_path)
os.symlink(f'{full_path}/{args.model}', best_ckpt_path)
build_pose_vocab(Path('experiment/spm_bpe1182.vocab').absolute())
create_test_config('experiment', 'experiment')
os.symlink(f'{full_path}/{model_name}', best_ckpt_path)

vocab_path = experiment_dir / 'spm_bpe1182.vocab'
if not vocab_path.exists():
build_pose_vocab(vocab_path.absolute())

config_path = experiment_dir / 'config.yaml'
if not config_path.exists():
create_test_config(str(experiment_dir), str(experiment_dir))


def main():
args = get_args()

experiment_dir = Path('experiment')
experiment_dir.mkdir(exist_ok=True)

temp_dir = experiment_dir / 'temp'
temp_dir.mkdir(exist_ok=True)

print('Downloading model...')
download_model(experiment_dir, args.model)

print('Loading ELAN file...')
eaf = pympi.Elan.Eaf(file_path=args.elan, author="sign-language-processing/signwriting-transcription")
sign_annotations = eaf.get_annotation_data_for_tier('SIGN')

print('loading sign.....')
preprocess('.', '.', False)
print('Preprocessing pose.....')
temp_pose_path = temp_dir / 'pose.pose'
preprocess_single_file(args.pose, temp_pose_path, normalization=False)

print('Predicting signs...')
temp_files = []
for index, segment in tqdm(enumerate(sign_annotations)):
np_pose = pose_to_matrix(args.pose, segment[0], segment[1]).filled(fill_value=0)
np.save(f'experiment/temp{index}.npy', np_pose)
temp_files.append(f'experiment/temp{index}.npy')
np_pose = pose_to_matrix(temp_pose_path, segment[0], segment[1]).filled(fill_value=0)
pose_path = temp_dir / f'{index}.npy'
np.save(pose_path, np_pose)
temp_files.append(pose_path)

hyp_list = translate('experiment/config.yaml', temp_files)
for rm_file in temp_files:
os.remove(rm_file)

for index, segment in enumerate(sign_annotations):
eaf.remove_annotation('SIGN', segment[0])
eaf.add_annotation('SIGN', segment[0], segment[1], hyp_list[index])
eaf.to_file(args.elan)

print('Cleaning up...')
temp_pose_path.unlink()
for temp_file in temp_files:
temp_file.unlink()
temp_dir.rmdir()


if __name__ == '__main__':
main()
31 changes: 18 additions & 13 deletions signwriting_transcription/pose_to_signwriting/data/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
import argparse
from pathlib import Path

from pose_format.utils.generic import reduce_holistic
from pose_format import Pose
from tqdm import tqdm
from pose_format.utils.generic import reduce_holistic
from sign_vq.data.normalize import pre_process_mediapipe, normalize_mean_std
from tqdm import tqdm


def preprocess_single_file(src_file: Path, trg_file: Path, normalization=True):
with open(src_file, 'rb') as pose_file:
pose = Pose.read(pose_file.read())
if normalization:
pose = pre_process_mediapipe(pose)
pose = normalize_mean_std(pose)
else:
pose = reduce_holistic(pose)
with open(trg_file, 'wb') as pose_file:
pose.write(pose_file)


def preprocess(src_dir, trg_dir, normalization=True):
def preprocess(src_dir: Path, trg_dir: Path, normalization=True):
src_dir = Path(src_dir)
trg_dir = Path(trg_dir)
trg_dir.mkdir(parents=True, exist_ok=True)
for path in tqdm(src_dir.glob("*.pose")):
with open(src_dir / path.name, 'rb') as pose_file:
pose = Pose.read(pose_file.read())
if normalization:
pose = pre_process_mediapipe(pose)
pose = normalize_mean_std(pose)
else:
pose = reduce_holistic(pose)
with open(trg_dir / path.name, 'wb') as pose_file:
pose.write(pose_file)
for src_file in tqdm(src_dir.glob("*.pose")):
trg_file = trg_dir / src_file.name
preprocess_single_file(src_file, trg_file, normalization)


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from functools import partial
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -497,7 +497,7 @@ def test(

def translate(
cfg_file: str,
pose_files:list[str],
pose_files:list[Union[str, Path]],
ckpt: str = None,
) -> List[str]:
"""
Expand Down Expand Up @@ -585,7 +585,7 @@ def _translate_data(test_data, cfg):

n_best = test_cfg.get("n_best", 1)
for pose_file in pose_files:
test_data.set_item(pose_file)
test_data.set_item(str(pose_file))
all_hypotheses, _, _ = _translate_data(test_data, test_cfg)
assert len(all_hypotheses) == len(test_data) * n_best
for hey in all_hypotheses:
Expand Down

0 comments on commit ffa1247

Please sign in to comment.