Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ndx-pose for Lightning pose #1170

Open
wants to merge 6 commits into
base: use_latest_version_of_ndx_pose
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ video = [
"opencv-python-headless>=4.8.1.78",
]
lightningpose = [
"ndx-pose>=0.1.1",
"ndx-pose>=0.2",
"neuroconv[video]",
]
medpc = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def get_metadata_schema(self) -> dict:
description=dict(type="string"),
scorer=dict(type="string"),
source_software=dict(type="string", default="LightningPose"),
camera_name=dict(type="string", default="CameraPoseEstimation"),
),
patternProperties={
"^(?!(name|description|scorer|source_software)$)[a-zA-Z0-9_]+$": dict(
"^(?!(name|description|scorer|source_software|camera_name)$)[a-zA-Z0-9_]+$": dict(
title="PoseEstimationSeries",
type="object",
properties=dict(name=dict(type="string"), description=dict(type="string")),
Expand Down Expand Up @@ -80,22 +81,15 @@ def __init__(
verbose : bool, default: True
controls verbosity. ``True`` by default.
"""
from importlib.metadata import version

# This import is to assure that the ndx_pose is in the global namespace when an pynwb.io object is created
# For more detail, see https://github.com/rly/ndx-pose/issues/36
import ndx_pose # noqa: F401
from packaging import version as version_parse

from neuroconv.datainterfaces.behavior.video.video_utils import (
VideoCaptureContext,
)

ndx_pose_version = version("ndx-pose")

if version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2.0"):
raise ImportError("The ndx-pose version must be less than 0.2.0.")

self._vc = VideoCaptureContext

self.file_path = Path(file_path)
Expand Down Expand Up @@ -170,6 +164,7 @@ def get_metadata(self) -> DeepDict:
description="Contains the pose estimation series for each keypoint.",
scorer=self.scorer_name,
source_software="LightningPose",
camera_name="CameraPoseEstimation",
)
for keypoint_name in self.keypoint_names:
keypoint_name_without_spaces = keypoint_name.replace(" ", "")
Expand Down Expand Up @@ -206,7 +201,7 @@ def add_to_nwbfile(
The description of how the confidence was computed, e.g., 'Softmax output of the deep neural network'.
stub_test : bool, default: False
"""
from ndx_pose import PoseEstimation, PoseEstimationSeries
from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons

metadata_copy = deepcopy(metadata)

Expand All @@ -223,15 +218,14 @@ def add_to_nwbfile(
original_video_name = str(self.original_video_file_path)
else:
original_video_name = metadata_copy["Behavior"]["Videos"][0]["name"]

pose_estimation_kwargs = dict(
name=pose_estimation_metadata["name"],
description=pose_estimation_metadata["description"],
source_software=pose_estimation_metadata["source_software"],
scorer=pose_estimation_metadata["scorer"],
original_videos=[original_video_name],
dimensions=[self.dimension],
)
camera_name = pose_estimation_metadata["camera_name"]
if camera_name in nwbfile.devices:
camera = nwbfile.devices[camera_name]
else:
camera = nwbfile.create_device(
name=camera_name,
description="Camera used for behavioral recording and pose estimation.",
)

pose_estimation_data = self.pose_estimation_data if not stub_test else self.pose_estimation_data.head(n=10)
timestamps = self.get_timestamps(stub_test=stub_test)
Expand Down Expand Up @@ -263,8 +257,28 @@ def add_to_nwbfile(

pose_estimation_series.append(PoseEstimationSeries(**pose_estimation_series_kwargs))

pose_estimation_kwargs.update(
# Add Skeleton(s)
nodes = [keypoint_name.replace(" ", "") for keypoint_name in self.keypoint_names]
subject = nwbfile.subject if nwbfile.subject is not None else None
name = f"Skeleton{pose_estimation_name}"
skeleton = Skeleton(name=name, nodes=nodes, subject=subject)
if "Skeletons" in behavior.data_interfaces:
skeletons = behavior.data_interfaces["Skeletons"]
skeletons.add_skeletons(skeleton)
else:
skeletons = Skeletons(skeletons=[skeleton])
behavior.add(skeletons)

pose_estimation_kwargs = dict(
name=pose_estimation_metadata["name"],
description=pose_estimation_metadata["description"],
source_software=pose_estimation_metadata["source_software"],
scorer=pose_estimation_metadata["scorer"],
original_videos=[original_video_name],
dimensions=[self.dimension],
pose_estimation_series=pose_estimation_series,
devices=[camera],
skeleton=skeleton,
)

if self.source_data["labeled_video_file_path"]:
Expand Down
9 changes: 1 addition & 8 deletions tests/test_on_data/behavior/test_behavior_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,9 @@

from importlib.metadata import version

from packaging import version as version_parse

ndx_pose_version = version("ndx-pose")


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseDataInterface(DataInterfaceTestMixin, TemporalAlignmentMixin):
data_interface_cls = LightningPoseDataInterface
interface_kwargs = dict(
Expand Down Expand Up @@ -94,6 +89,7 @@ def setup_metadata(self, request):
description="Contains the pose estimation series for each keypoint.",
scorer="heatmap_tracker",
source_software="LightningPose",
camera_name="CameraPoseEstimation",
)
)
cls.expected_metadata[cls.pose_estimation_name].update(
Expand Down Expand Up @@ -165,9 +161,6 @@ def check_read_nwb(self, nwbfile_path: str):
assert_array_equal(pose_estimation_series.data[:], test_data[["x", "y"]].values)


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseDataInterfaceWithStubTest(DataInterfaceTestMixin, TemporalAlignmentMixin):
data_interface_cls = LightningPoseDataInterface
interface_kwargs = dict(
Expand Down
10 changes: 1 addition & 9 deletions tests/test_on_data/behavior/test_lightningpose_converter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import shutil
import tempfile
from datetime import datetime
from importlib.metadata import version
from pathlib import Path
from warnings import warn

import pytest
from hdmf.testing import TestCase
from packaging import version
from packaging import version as version_parse
from pynwb import NWBHDF5IO
from pynwb.image import ImageSeries

Expand All @@ -19,12 +15,7 @@

from ..setup_paths import BEHAVIOR_DATA_PATH

ndx_pose_version = version("ndx-pose")


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseConverter(TestCase):
@classmethod
def setUpClass(cls) -> None:
Expand Down Expand Up @@ -73,6 +64,7 @@ def setUpClass(cls) -> None:
description="Contains the pose estimation series for each keypoint.",
scorer="heatmap_tracker",
source_software="LightningPose",
camera_name="CameraPoseEstimation",
)

cls.pose_estimation_metadata.update(
Expand Down
Loading