Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pose to signwriting segmentation #7

Merged
merged 10 commits into from
Feb 26, 2024
18 changes: 15 additions & 3 deletions signwriting_transcription/pose_to_signwriting/bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pose', required=True, type=str, help='path to input pose file')
parser.add_argument('--elan', required=True, type=str, help='path to elan file')
parser.add_argument('--model', type=str, default='359a544.ckpt', help='model to use')
parser.add_argument('--model', type=str, default='bc2de71.ckpt', help='model to use')
parser.add_argument('--strategies', required=False, type=str, default='tight',
RotemZilberman marked this conversation as resolved.
Show resolved Hide resolved
choices=['tight', 'wide'], help='strategy to use')
return parser.parse_args()


Expand Down Expand Up @@ -66,12 +68,22 @@ def main():

print('Preprocessing pose.....')
temp_pose_path = temp_dir / 'pose.pose'
preprocess_single_file(args.pose, temp_pose_path, normalization=False)
preprocess_single_file(args.pose, temp_pose_path, normalization=True)

print('Predicting signs...')
temp_files = []
start = 0
for index, segment in tqdm(enumerate(sign_annotations)):
np_pose = pose_to_matrix(temp_pose_path, segment[0], segment[1]).filled(fill_value=0)
end = sign_annotations[index + 1][0] if index + 1 < len(sign_annotations) else None
RotemZilberman marked this conversation as resolved.
Show resolved Hide resolved
if args.strategies == 'wide':
end = (end + segment[1]) // 2 if index + 1 < len(sign_annotations) else None
np_pose = pose_to_matrix(temp_pose_path, start, end).filled(fill_value=0)
start = end
else:
end = end if end is not None else segment[1] + 300
np_pose = pose_to_matrix(temp_pose_path, segment[0] - (segment[0] - start) * 0.25
AmitMY marked this conversation as resolved.
Show resolved Hide resolved
, segment[1] + (end - segment[1]) * 0.25).filled(fill_value=0)
start = segment[1]
pose_path = temp_dir / f'{index}.npy'
np.save(pose_path, np_pose)
temp_files.append(pose_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@ def ms2frame(ms, frame_rate) -> int:
return int(ms / 1000 * frame_rate)


def pose_to_matrix(file_path, start_ms, end_ms):
def pose_to_matrix(file_path, start_ms=0, end_ms=None):
with open(file_path, "rb") as file:
pose = Pose.read(file.read())
frame_rate = 29.97003 if file_path == '19097be0e2094c4aa6b2fdc208c8231e.pose' else pose.body.fps
pose = pose.body.data
pose = pose.reshape(len(pose), -1)
pose = pose[ms2frame(start_ms, frame_rate):ms2frame(end_ms, frame_rate)]
pose = pose[ms2frame(start_ms, frame_rate):ms2frame(end_ms, frame_rate)] if (
AmitMY marked this conversation as resolved.
Show resolved Hide resolved
end_ms is not None) else pose[ms2frame(start_ms, frame_rate):]
return pose


def load_dataset(folder_name):
with open(f'{folder_name}/target.csv', 'r', encoding='utf-8') as csvfile:
def load_dataset(target_folder, data_folder):
with open(f'{target_folder}/target.csv', 'r', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)
dataset = []
for line in reader:
try:
pose = pose_to_matrix(f"{folder_name}/{line['pose']}", line['start'], line['end'])
pose = pose_to_matrix(f"{data_folder}/{line['pose']}", line['start'], line['end'])
except FileNotFoundError:
continue
pose = pose.filled(fill_value=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def get_split_data(dataset, feature_root, pumping):

def process(args):
# pylint: disable=too-many-locals
dataset_root, data_root, name, tokenizer_type, pumping = (
args.dataset_root, args.data_root, args.dataset_name, args.tokenizer_type, args.pumping)
dataset_root, data_root, name, tokenizer_type, data_segment, pumping = (
args.dataset_root, args.data_root, args.dataset_name, args.tokenizer_type, args.data_segment, args.pumping)
cur_root = Path(data_root).absolute()
cur_root = cur_root / name

Expand All @@ -96,7 +96,17 @@ def process(args):
print(f"Create pose {name} dataset.")

print("Fetching train split ...")
dataset = load_dataset(dataset_root)
dataset = load_dataset(dataset_root, dataset_root)
if data_segment:
segment_dataset = load_dataset(data_segment, dataset_root)
modified_segment_dataset = []
for instance in segment_dataset:
instance = list(instance)
instance[0] = f"seg_{instance[0]}"
instance[3] = 'train' if instance[3] == 'test' else instance[3]
RotemZilberman marked this conversation as resolved.
Show resolved Hide resolved
dataset.extend(segment_dataset)
modified_segment_dataset.append(tuple(instance))
dataset.extend(modified_segment_dataset)

print("Extracting pose features ...")
for instance in dataset:
Expand Down Expand Up @@ -141,6 +151,7 @@ def main():
parser.add_argument("--dataset-root", required=True, type=str)
parser.add_argument("--dataset-name", required=True, type=str)
parser.add_argument("--tokenizer-type", required=True, type=str)
parser.add_argument("--data-segment", required=False, type=str, default=None)
parser.add_argument("--pumping", required=False, type=str, default=True)
args = parser.parse_args()
process(args)
Expand Down
Loading