diff --git a/signwriting_transcription/pose_to_signwriting/bin.py b/signwriting_transcription/pose_to_signwriting/bin.py index 96d7e6c..59b179c 100644 --- a/signwriting_transcription/pose_to_signwriting/bin.py +++ b/signwriting_transcription/pose_to_signwriting/bin.py @@ -7,22 +7,26 @@ import numpy as np import pympi +from pose_format import Pose from tqdm import tqdm from signwriting_transcription.pose_to_signwriting.data.config import create_test_config -from signwriting_transcription.pose_to_signwriting.data.datasets_pose import pose_to_matrix +from signwriting_transcription.pose_to_signwriting.data.datasets_pose import pose_to_matrix, frame2ms from signwriting_transcription.pose_to_signwriting.data.pose_data_utils import build_pose_vocab from signwriting_transcription.pose_to_signwriting.data.preprocessing import preprocess_single_file from signwriting_transcription.pose_to_signwriting.joeynmt_pose.prediction import translate HUGGINGFACE_REPO_ID = "ohadlanger/signwriting_transcription" +PADDING_PACTOR = 0.25 # padding factor for tight strategy, 25% padding from both sides of the segment def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--pose', required=True, type=str, help='path to input pose file') parser.add_argument('--elan', required=True, type=str, help='path to elan file') - parser.add_argument('--model', type=str, default='359a544.ckpt', help='model to use') + parser.add_argument('--model', type=str, default='bc2de71.ckpt', help='model to use') + parser.add_argument('--strategy', type=str, default='tight', + choices=['tight', 'wide'], help='segmentation strategy to use') return parser.parse_args() @@ -49,14 +53,38 @@ def download_model(experiment_dir: Path, model_name: str): create_test_config(str(experiment_dir), str(experiment_dir)) +def preprocessing_signs(preprocessed_pose: Pose, sign_annotations: list, strategy: str, temp_dir: str): + temp_files = [] # list of temporary files + start_point = 0 + temp_path = Path(temp_dir) + # get pose length in ms + pose_length = frame2ms(len(preprocessed_pose.body.data), preprocessed_pose.body.fps) + for index, (sign_start, sign_end, _) in tqdm(enumerate(sign_annotations)): + if index + 1 < len(sign_annotations): + end_point = sign_annotations[index + 1][0] + else: + end_point = pose_length + if strategy == 'wide': # wide strategy - split the all pose between the segments + end_point = (end_point + sign_start) // 2 + np_pose = pose_to_matrix(preprocessed_pose, start_point, end_point).filled(fill_value=0) + start_point = end_point + else: # tight strategy - add padding(PADDING_PACTOR) to the tight segment + # add padding to the segment by the distance between the segments + np_pose = pose_to_matrix(preprocessed_pose, sign_start - (sign_start - start_point) * PADDING_PACTOR, + sign_end + (end_point - sign_end) * PADDING_PACTOR).filled(fill_value=0) + start_point = sign_end + pose_path = temp_path / f'{index}.npy' + np.save(pose_path, np_pose) + temp_files.append(pose_path) + return temp_files + + def main(): args = get_args() experiment_dir = Path('experiment') experiment_dir.mkdir(exist_ok=True) - temp_dir = Path(tempfile.TemporaryDirectory().name) - print('Downloading model...') download_model(experiment_dir, args.model) @@ -68,25 +96,15 @@ def main(): preprocessed_pose = preprocess_single_file(args.pose, normalization=False) print('Predicting signs...') - temp_files = [] - for index, segment in tqdm(enumerate(sign_annotations)): - np_pose = pose_to_matrix(preprocessed_pose, segment[0], segment[1]).filled(fill_value=0) - pose_path = temp_dir / f'{index}.npy' - np.save(pose_path, np_pose) - temp_files.append(pose_path) + with tempfile.TemporaryDirectory() as temp_dir: + temp_files = preprocessing_signs(preprocessed_pose, sign_annotations, args.strategy, temp_dir) + hyp_list = translate('experiment/config.yaml', temp_files) - hyp_list = translate('experiment/config.yaml', temp_files) - - for index, segment in enumerate(sign_annotations): - eaf.remove_annotation('SIGN', segment[0]) - eaf.add_annotation('SIGN', segment[0], segment[1], hyp_list[index]) + for index, (start, end, _) in enumerate(sign_annotations): + eaf.remove_annotation('SIGN', start) + eaf.add_annotation('SIGN', start, end, hyp_list[index]) eaf.to_file(args.elan) - print('Cleaning up...') - for temp_file in temp_files: - temp_file.unlink() - temp_dir.rmdir() - if __name__ == '__main__': main() diff --git a/signwriting_transcription/pose_to_signwriting/data/datasets_pose.py b/signwriting_transcription/pose_to_signwriting/data/datasets_pose.py index 506ab35..bc2b301 100644 --- a/signwriting_transcription/pose_to_signwriting/data/datasets_pose.py +++ b/signwriting_transcription/pose_to_signwriting/data/datasets_pose.py @@ -19,7 +19,11 @@ def ms2frame(ms, frame_rate) -> int: return int(ms / 1000 * frame_rate) -def pose_to_matrix(file_path_or_pose: Union[str, Pose], start_ms, end_ms): +def frame2ms(frame, frame_rate) -> int: + return int(frame * 1000 / frame_rate) + + +def pose_to_matrix(file_path_or_pose: Union[str, Pose], start_ms, end_ms=None): if isinstance(file_path_or_pose, str): with open(file_path_or_pose, "rb") as file: pose = Pose.read(file.read()) @@ -28,17 +32,19 @@ def pose_to_matrix(file_path_or_pose: Union[str, Pose], start_ms, end_ms): frame_rate = 29.97003 if file_path_or_pose == '19097be0e2094c4aa6b2fdc208c8231e.pose' else pose.body.fps pose = pose.body.data pose = pose.reshape(len(pose), -1) - pose = pose[ms2frame(start_ms, frame_rate):ms2frame(end_ms, frame_rate)] + start_frame = ms2frame(start_ms, frame_rate) + end_frame = ms2frame(end_ms, frame_rate) if end_ms is not None else None + pose = pose[start_frame:end_frame] return pose -def load_dataset(folder_name): - with open(f'{folder_name}/target.csv', 'r', encoding='utf-8') as csvfile: +def load_dataset(target_folder, data_folder): + with open(f'{target_folder}/target.csv', 'r', encoding='utf-8') as csvfile: reader = csv.DictReader(csvfile) dataset = [] for line in reader: try: - pose = pose_to_matrix(f"{folder_name}/{line['pose']}", line['start'], line['end']) + pose = pose_to_matrix(f"{data_folder}/{line['pose']}", line['start'], line['end']) except FileNotFoundError: continue pose = pose.filled(fill_value=0) diff --git a/signwriting_transcription/pose_to_signwriting/data/pose_data_utils.py b/signwriting_transcription/pose_to_signwriting/data/pose_data_utils.py index 224149e..dad83f1 100644 --- a/signwriting_transcription/pose_to_signwriting/data/pose_data_utils.py +++ b/signwriting_transcription/pose_to_signwriting/data/pose_data_utils.py @@ -35,9 +35,9 @@ def get_zip_manifest(zip_path: Path, npy_root: Optional[Path] = None): with zipfile.ZipFile(zip_path, mode="r") as file: info = file.infolist() # retrieve offsets - for i in tqdm(info): - utt_id = Path(i.filename).stem - offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size + for index in tqdm(info): + utt_id = Path(index.filename).stem + offset, file_size = index.header_offset + 30 + len(index.filename), index.file_size with zip_path.open("rb") as file: file.seek(offset) data = file.read(file_size) diff --git a/signwriting_transcription/pose_to_signwriting/data/prepare_poses.py b/signwriting_transcription/pose_to_signwriting/data/prepare_poses.py index 8e0d344..773ffdc 100644 --- a/signwriting_transcription/pose_to_signwriting/data/prepare_poses.py +++ b/signwriting_transcription/pose_to_signwriting/data/prepare_poses.py @@ -65,14 +65,14 @@ def get_split_data(dataset, feature_root, pumping): if EXPANDED_DATASET > len(all_data) and pumping: print("Pumping dataset...") backup = all_data.copy() - for i in range(EXPANDED_DATASET - len(backup)): - utt_id = backup[i % len(backup)]["id"] - n_frames = backup[i % len(backup)]["n_frames"] - trg = backup[i % len(backup)]["trg"] - src = backup[i % len(backup)]["src"] - split = backup[i % len(backup)]["split"] + for index in range(EXPANDED_DATASET - len(backup)): + utt_id = backup[index % len(backup)]["id"] + n_frames = backup[index % len(backup)]["n_frames"] + trg = backup[index % len(backup)]["trg"] + src = backup[index % len(backup)]["src"] + split = backup[index % len(backup)]["split"] all_data.append({ - "id": f'{utt_id}({i})', # unique id + "id": f'{utt_id}({index})', # unique id "src": src, "n_frames": n_frames, "trg": trg, @@ -83,8 +83,8 @@ def get_split_data(dataset, feature_root, pumping): def process(args): # pylint: disable=too-many-locals - dataset_root, data_root, name, tokenizer_type, pumping = ( - args.dataset_root, args.data_root, args.dataset_name, args.tokenizer_type, args.pumping) + dataset_root, data_root, name, tokenizer_type, data_segment, pumping = ( + args.dataset_root, args.data_root, args.dataset_name, args.tokenizer_type, args.data_segment, args.pumping) cur_root = Path(data_root).absolute() cur_root = cur_root / name @@ -96,7 +96,17 @@ def process(args): print(f"Create pose {name} dataset.") print("Fetching train split ...") - dataset = load_dataset(dataset_root) + dataset = load_dataset(dataset_root, dataset_root) + if data_segment: + segment_dataset = load_dataset(data_segment, dataset_root) + modified_segment_dataset = [] + for instance in segment_dataset: + instance = list(instance) + if instance[3] != 'test': + instance[0] = f"seg_{instance[0]}" + dataset.extend(segment_dataset) + modified_segment_dataset.append(tuple(instance)) + dataset.extend(modified_segment_dataset) print("Extracting pose features ...") for instance in dataset: @@ -141,6 +151,7 @@ def main(): parser.add_argument("--dataset-root", required=True, type=str) parser.add_argument("--dataset-name", required=True, type=str) parser.add_argument("--tokenizer-type", required=True, type=str) + parser.add_argument("--data-segment", required=False, type=str, default=None) parser.add_argument("--pumping", required=False, type=str, default=True) args = parser.parse_args() process(args)