Skip to content

Commit

Permalink
fix(): include missing utilities to get the training to run
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitMY committed Jan 6, 2025
1 parent 1891628 commit e3a020b
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 4 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.idea/
.idea/
.env
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
readme = "README.md"
dependencies = [
"pose-format>=0.3.2",
"numpy",
"numpy<2.0.0",
"pympi-ling", # Working with ELAN files in CLI
"torch",
]
Expand All @@ -19,7 +19,8 @@ dev = [
"pytest",
"pylint",
"pytorch-lightning",
"sign-language-datasets",
"mediapipe",
"sign_language_datasets @ git+https://github.com/sign-language-processing/datasets",
"wandb",
"matplotlib",
"scikit-learn",
Expand Down
2 changes: 1 addition & 1 deletion sign_language_segmentation/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import numpy.ma as ma
import torch
from _shared.tfds_dataset import ProcessedPoseDatum, get_tfds_dataset
from .tfds_dataset import ProcessedPoseDatum, get_tfds_dataset
from pose_format import Pose
from pose_format.numpy.representation.distance import DistanceRepresentation
from pose_format.utils.generic import normalize_hands_3d
Expand Down
160 changes: 160 additions & 0 deletions sign_language_segmentation/src/pose_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import Tuple

import numpy as np
from numpy import ma
from pose_format import Pose
from pose_format.numpy import NumPyPoseBody
from pose_format.pose_header import PoseHeader, PoseHeaderDimensions
from pose_format.utils.normalization_3d import PoseNormalizer
from pose_format.utils.openpose import OpenPose_Components


def pose_hide_legs(pose: Pose):
if pose.header.components[0].name == "POSE_LANDMARKS":
point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX"]
# pylint: disable=protected-access
points = [
pose.header._get_point_index("POSE_LANDMARKS", side + "_" + n)
for n in point_names
for side in ["LEFT", "RIGHT"]
]
pose.body.data[:, :, points, :] = 0
pose.body.confidence[:, :, points] = 0
elif pose.header.components[0].name == "pose_keypoints_2d":
point_names = ["Hip", "Knee", "Ankle", "BigToe", "SmallToe", "Heel"]
# pylint: disable=protected-access
points = [
pose.header._get_point_index("pose_keypoints_2d", side + n)
for n in point_names
for side in ["L", "R"]
]
pose.body.data[:, :, points, :] = 0
pose.body.confidence[:, :, points] = 0
else:
raise ValueError("Unknown pose header schema for hiding legs")


def pose_shoulders(pose_header: PoseHeader):
if pose_header.components[0].name == "POSE_LANDMARKS":
return ("POSE_LANDMARKS", "RIGHT_SHOULDER"), ("POSE_LANDMARKS", "LEFT_SHOULDER")

if pose_header.components[0].name == "BODY_135":
return ("BODY_135", "RShoulder"), ("BODY_135", "LShoulder")

if pose_header.components[0].name == "pose_keypoints_2d":
return ("pose_keypoints_2d", "RShoulder"), ("pose_keypoints_2d", "LShoulder")

raise ValueError("Unknown pose header schema for normalization")


def hands_indexes(pose_header: PoseHeader):
if pose_header.components[0].name == "POSE_LANDMARKS":
return [pose_header._get_point_index("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
pose_header._get_point_index("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP")]

if pose_header.components[0].name == "pose_keypoints_2d":
return [pose_header._get_point_index("hand_left_keypoints_2d", "M_CMC"),
pose_header._get_point_index("hand_right_keypoints_2d", "M_CMC")]


def pose_normalization_info(pose_header: PoseHeader):
(c1, p1), (c2, p2) = pose_shoulders(pose_header)
return pose_header.normalization_info(p1=(c1, p1), p2=(c2, p2))


def hands_components(pose_header: PoseHeader):
if pose_header.components[0].name in ["POSE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]:
return ("LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"), \
("WRIST", "PINKY_MCP", "INDEX_FINGER_MCP"), \
("WRIST", "MIDDLE_FINGER_MCP")

if pose_header.components[0].name in ["pose_keypoints_2d", "hand_left_keypoints_2d", "hand_right_keypoints_2d"]:
return ("hand_left_keypoints_2d", "hand_right_keypoints_2d"), \
("BASE", "P_CMC", "I_CMC"), \
("BASE", "M_CMC")

raise ValueError("Unknown pose header")


def normalize_component_3d(pose, component_name: str, plane: Tuple[str, str, str], line: Tuple[str, str]):
hand_pose = pose.get_components([component_name])
plane = hand_pose.header.normalization_info(p1=(component_name, plane[0]),
p2=(component_name, plane[1]),
p3=(component_name, plane[2]))
line = hand_pose.header.normalization_info(p1=(component_name, line[0]),
p2=(component_name, line[1]))
normalizer = PoseNormalizer(plane=plane, line=line)
normalized_hand = normalizer(hand_pose.body.data)

# Add normalized hand to pose
pose.body.data = ma.concatenate([pose.body.data, normalized_hand], axis=2).astype(np.float32)
pose.body.confidence = np.concatenate([pose.body.confidence, hand_pose.body.confidence], axis=2)


def normalize_hands_3d(pose: Pose, left_hand=True, right_hand=True):
(left_hand_component, right_hand_component), plane, line = hands_components(pose.header)
if left_hand:
normalize_component_3d(pose, left_hand_component, plane, line)
if right_hand:
normalize_component_3d(pose, right_hand_component, plane, line)


def fake_pose(num_frames: int, fps=25, dims=2, components=OpenPose_Components):
dimensions = PoseHeaderDimensions(width=1, height=1, depth=1)
header = PoseHeader(version=0.1, dimensions=dimensions, components=components)

total_points = header.total_points()
data = np.random.randn(num_frames, 1, total_points, dims)
confidence = np.random.randn(num_frames, 1, total_points)
masked_data = ma.masked_array(data)

body = NumPyPoseBody(fps=int(fps), data=masked_data, confidence=confidence)

return Pose(header, body)


def correct_wrist(pose: Pose, hand: str) -> Pose:
wrist_index = pose.header._get_point_index(f'{hand}_HAND_LANDMARKS', 'WRIST')
wrist = pose.body.data[:, :, wrist_index]
wrist_conf = pose.body.confidence[:, :, wrist_index]

body_wrist_index = pose.header._get_point_index('POSE_LANDMARKS', f'{hand}_WRIST')
body_wrist = pose.body.data[:, :, body_wrist_index]
body_wrist_conf = pose.body.confidence[:, :, body_wrist_index]

new_wrist_data = ma.where(wrist.data == 0, body_wrist, wrist)
new_wrist_conf = ma.where(wrist_conf == 0, body_wrist_conf, wrist_conf)

pose.body.data[:, :, body_wrist_index] = ma.masked_equal(new_wrist_data, 0)
pose.body.confidence[:, :, body_wrist_index] = new_wrist_conf
return pose


def correct_wrists(pose: Pose) -> Pose:
pose = correct_wrist(pose, 'LEFT')
pose = correct_wrist(pose, 'RIGHT')
return pose


def reduce_holistic(pose: Pose) -> Pose:
if pose.header.components[0].name != "POSE_LANDMARKS":
return pose

import mediapipe as mp
points_set = set([p for p_tup in list(mp.solutions.holistic.FACEMESH_CONTOURS) for p in p_tup])
face_contours = [str(p) for p in sorted(points_set)]

ignore_names = [
"EAR", "NOSE", "MOUTH", "EYE", # Face
"THUMB", "PINKY", "INDEX", # Hands
"KNEE", "ANKLE", "HEEL", "FOOT_INDEX" # Feet
]

body_component = [c for c in pose.header.components if c.name == 'POSE_LANDMARKS'][0]
body_no_face_no_hands = [p for p in body_component.points if all([i not in p for i in ignore_names])]

components = [c.name for c in pose.header.components if c.name != 'POSE_WORLD_LANDMARKS']
return pose.get_components(components, {
"FACE_LANDMARKS": face_contours,
"POSE_LANDMARKS": body_no_face_no_hands
})
88 changes: 88 additions & 0 deletions sign_language_segmentation/src/tfds_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import importlib
from typing import Dict, List, TypedDict, Union

import tensorflow_datasets as tfds
from pose_format import Pose
from pose_format.numpy.pose_body import NumPyPoseBody
from pose_format.pose_header import PoseHeader
from pose_format.utils.reader import BufferReader
from sign_language_datasets.datasets.config import SignDatasetConfig
from sign_language_datasets.datasets.dgs_corpus import DgsCorpusConfig
from tqdm import tqdm
import mediapipe as mp

from .pose_utils import pose_hide_legs, pose_normalization_info


mp_holistic = mp.solutions.holistic
FACEMESH_CONTOURS_POINTS = [str(p) for p in sorted(set([p for p_tup in list(mp_holistic.FACEMESH_CONTOURS) for p in p_tup]))]

class ProcessedPoseDatum(TypedDict):
id: str
pose: Union[Pose, Dict[str, Pose]]
tf_datum: dict


def get_tfds_dataset(name,
poses="holistic",
fps=25,
split="train",
components: List[str] = None,
reduce_face=False,
data_dir=None,
version="1.0.0",
filter_func=None):
dataset_module = importlib.import_module("sign_language_datasets.datasets." + name + "." + name)

config_kwargs = dict(
name=poses + "-" + str(fps),
version=version, # Specific version
include_video=False, # Download and load dataset videos
fps=fps, # Load videos at constant fps
include_pose=poses)

# Loading a dataset with custom configuration
if name == "dgs_corpus":
config = DgsCorpusConfig(**config_kwargs, split="3.0.0-uzh-document")
else:
config = SignDatasetConfig(**config_kwargs)

tfds_dataset = tfds.load(name=name, builder_kwargs=dict(config=config), split=split, data_dir=data_dir)

# pylint: disable=protected-access
with open(dataset_module._POSE_HEADERS[poses], "rb") as buffer:
pose_header = PoseHeader.read(BufferReader(buffer.read()))

normalization_info = pose_normalization_info(pose_header)
return [process_datum(datum, pose_header, normalization_info, components, reduce_face)
for datum in tqdm(tfds_dataset, desc="Loading dataset")
if filter_func is None or filter_func(datum)]


def process_datum(datum,
pose_header: PoseHeader,
normalization_info,
components: List[str] = None,
reduce_face=False) -> ProcessedPoseDatum:
tf_poses = {"": datum["pose"]} if "pose" in datum else datum["poses"]
poses = {}
for key, tf_pose in tf_poses.items():
fps = int(tf_pose["fps"].numpy())
pose_body = NumPyPoseBody(fps, tf_pose["data"].numpy(), tf_pose["conf"].numpy())
pose = Pose(pose_header, pose_body)

# Get subset of components if needed
if reduce_face:
pose = pose.get_components(components, {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS})
elif components and len(components) != len(pose_header.components):
pose = pose.get_components(components)

pose = pose.normalize(normalization_info)
pose_hide_legs(pose)
poses[key] = pose

return {
"id": datum["id"].numpy().decode('utf-8'),
"pose": poses[""] if "pose" in datum else poses,
"tf_datum": datum
}

0 comments on commit e3a020b

Please sign in to comment.