Skip to content

Commit

Permalink
update the code to work with segmentation + adding strategies to bin …
Browse files Browse the repository at this point in the history
…file
  • Loading branch information
RotemZilberman committed Feb 21, 2024
1 parent cbbf536 commit 32fa45e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
14 changes: 11 additions & 3 deletions signwriting_transcription/pose_to_signwriting/bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pose', required=True, type=str, help='path to input pose file')
parser.add_argument('--elan', required=True, type=str, help='path to elan file')
parser.add_argument('--model', type=str, default='359a544.ckpt', help='model to use')
parser.add_argument('--model', type=str, default='bc2de71.ckpt', help='model to use')
parser.add_argument('--strategies', required=False, type=str, default='tight',
options=['tight', 'wide'], help='strategy to use')
return parser.parse_args()


Expand Down Expand Up @@ -66,12 +68,18 @@ def main():

print('Preprocessing pose.....')
temp_pose_path = temp_dir / 'pose.pose'
preprocess_single_file(args.pose, temp_pose_path, normalization=False)
preprocess_single_file(args.pose, temp_pose_path, normalization=True)

print('Predicting signs...')
temp_files = []
start = 0
for index, segment in tqdm(enumerate(sign_annotations)):
np_pose = pose_to_matrix(temp_pose_path, segment[0], segment[1]).filled(fill_value=0)
if args.strategies == 'wide':
end = (sign_annotations[index + 1][0] + segment[1]) // 2 if index + 1 < len(sign_annotations) else None
np_pose = pose_to_matrix(temp_pose_path, start, end).filled(fill_value=0)
start = end
else:
np_pose = pose_to_matrix(temp_pose_path, segment[0], segment[1]).filled(fill_value=0)
pose_path = temp_dir / f'{index}.npy'
np.save(pose_path, np_pose)
temp_files.append(pose_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@ def ms2frame(ms, frame_rate) -> int:
return int(ms / 1000 * frame_rate)


def pose_to_matrix(file_path, start_ms, end_ms):
def pose_to_matrix(file_path, start_ms=0, end_ms=None):
with open(file_path, "rb") as file:
pose = Pose.read(file.read())
frame_rate = 29.97003 if file_path == '19097be0e2094c4aa6b2fdc208c8231e.pose' else pose.body.fps
pose = pose.body.data
pose = pose.reshape(len(pose), -1)
pose = pose[ms2frame(start_ms, frame_rate):ms2frame(end_ms, frame_rate)]
pose = pose[ms2frame(start_ms, frame_rate):ms2frame(end_ms, frame_rate)] if (
end_ms is not None) else pose[ms2frame(start_ms, frame_rate):]
return pose


def load_dataset(folder_name):
with open(f'{folder_name}/target.csv', 'r', encoding='utf-8') as csvfile:
def load_dataset(target_folder, data_folder):
with open(f'{target_folder}/target.csv', 'r', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)
dataset = []
for line in reader:
try:
pose = pose_to_matrix(f"{folder_name}/{line['pose']}", line['start'], line['end'])
pose = pose_to_matrix(f"{data_folder}/{line['pose']}", line['start'], line['end'])
except FileNotFoundError:
continue
pose = pose.filled(fill_value=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def get_split_data(dataset, feature_root, pumping):

def process(args):
# pylint: disable=too-many-locals
dataset_root, data_root, name, tokenizer_type, pumping = (
args.dataset_root, args.data_root, args.dataset_name, args.tokenizer_type, args.pumping)
dataset_root, data_root, name, tokenizer_type, data_segment, pumping = (
args.dataset_root, args.data_root, args.dataset_name, args.tokenizer_type, args.data_segment, args.pumping)
cur_root = Path(data_root).absolute()
cur_root = cur_root / name

Expand All @@ -96,7 +96,17 @@ def process(args):
print(f"Create pose {name} dataset.")

print("Fetching train split ...")
dataset = load_dataset(dataset_root)
dataset = load_dataset(dataset_root, dataset_root)
if data_segment:
segment_dataset = load_dataset(data_segment, dataset_root)
modified_segment_dataset = []
for instance in segment_dataset:
instance = list(instance)
instance[0] = f"seg_{instance[0]}"
instance[3] = 'train' if instance[3] == 'test' else instance[3]
dataset.extend(segment_dataset)
modified_segment_dataset.append(tuple(instance))
dataset.extend(modified_segment_dataset)

print("Extracting pose features ...")
for instance in dataset:
Expand Down Expand Up @@ -141,6 +151,7 @@ def main():
parser.add_argument("--dataset-root", required=True, type=str)
parser.add_argument("--dataset-name", required=True, type=str)
parser.add_argument("--tokenizer-type", required=True, type=str)
parser.add_argument("--data-segment", required=False, type=str, default=None)
parser.add_argument("--pumping", required=False, type=str, default=True)
args = parser.parse_args()
process(args)
Expand Down

0 comments on commit 32fa45e

Please sign in to comment.