Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use latest version of ndx pose for DeepLabCutInterface #1128

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
## Bug Fixes

## Features
* Use the latest version of ndx-pose for `DeepLabCutInterface` [PR #1128](https://github.com/catalystneuro/neuroconv/pull/1128)

## Improvements


# v0.6.9 (Upcoming)
Small fixes should be here.

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ sleap = [
"sleap-io>=0.0.2; python_version>='3.9'",
]
deeplabcut = [
"ndx-pose==0.1.1",
"ndx-pose>=0.2",
"tables; platform_system != 'Darwin'",
"tables>=3.10.1; platform_system == 'Darwin' and python_version >= '3.10'",
]
Expand All @@ -128,7 +128,7 @@ video = [
"opencv-python-headless>=4.8.1.78",
]
lightningpose = [
"ndx-pose==0.1.1",
"ndx-pose>=0.1.1",
"neuroconv[video]",
]
medpc = [
Expand Down
78 changes: 56 additions & 22 deletions src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import pickle
import warnings
from pathlib import Path
Expand Down Expand Up @@ -93,7 +92,7 @@ def _get_cv2_timestamps(file_path: Union[Path, str]):
return timestamps


def _get_movie_timestamps(movie_file, VARIABILITYBOUND=1000, infer_timestamps=True):
def _get_video_timestamps(movie_file, VARIABILITYBOUND=1000, infer_timestamps=True):
"""
Return numpy array of the timestamps for a video.

Expand Down Expand Up @@ -263,13 +262,52 @@ def _write_pes_to_nwbfile(
exclude_nans,
pose_estimation_container_kwargs: Optional[dict] = None,
):

from ndx_pose import PoseEstimation, PoseEstimationSeries
"""
Updated version of _write_pes_to_nwbfile to work with ndx-pose v0.2.0+
"""
from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons
from pynwb.file import Subject

pose_estimation_container_kwargs = pose_estimation_container_kwargs or dict()
pose_estimation_name = pose_estimation_container_kwargs.get("name", "PoseEstimationDeepLabCut")

# Create a subject if it doesn't exist
if nwbfile.subject is None:
subject = Subject(subject_id=animal)
nwbfile.subject = subject
else:
subject = nwbfile.subject

# Create skeleton from the keypoints
keypoints = df_animal.columns.get_level_values("bodyparts").unique()
animal = animal if animal else ""
subject = subject if animal == subject.subject_id else None
skeleton_name = f"Skeleton{pose_estimation_name}_{animal.capitalize()}"
skeleton = Skeleton(
name=skeleton_name,
nodes=list(keypoints),
edges=np.array(paf_graph) if paf_graph else None, # Convert paf_graph to numpy array
subject=subject,
)

# Create Skeletons container
if "behavior" not in nwbfile.processing:
behavior_processing_module = nwbfile.create_processing_module(
name="behavior", description="processed behavioral data"
pauladkisson marked this conversation as resolved.
Show resolved Hide resolved
)
skeletons = Skeletons(skeletons=[skeleton])
behavior_processing_module.add(skeletons)
else:
behavior_processing_module = nwbfile.processing["behavior"]
if "Skeletons" not in behavior_processing_module.data_interfaces:
skeletons = Skeletons(skeletons=[skeleton])
behavior_processing_module.add(skeletons)
else:
skeletons = behavior_processing_module["Skeletons"]
skeletons.add_skeletons(skeleton)

pose_estimation_series = []
for keypoint in df_animal.columns.get_level_values("bodyparts").unique():
for keypoint in keypoints:
data = df_animal.xs(keypoint, level="bodyparts", axis=1).to_numpy()

if exclude_nans:
Expand All @@ -292,35 +330,31 @@ def _write_pes_to_nwbfile(
)
pose_estimation_series.append(pes)

deeplabcut_version = None
is_deeplabcut_installed = importlib.util.find_spec(name="deeplabcut") is not None
if is_deeplabcut_installed:
deeplabcut_version = importlib.metadata.version(distribution_name="deeplabcut")
camera_name = pose_estimation_name
if camera_name not in nwbfile.devices:
camera = nwbfile.create_device(
name=camera_name,
description="Camera used for behavioral recording and pose estimation.",
)
else:
camera = nwbfile.devices[camera_name]

# TODO, taken from the original implementation, improve it if the video is passed
# Create PoseEstimation container with updated arguments
dimensions = [list(map(int, image_shape.split(",")))[1::2]]
dimensions = np.array(dimensions, dtype="uint32")
pose_estimation_default_kwargs = dict(
pose_estimation_series=pose_estimation_series,
description="2D keypoint coordinates estimated using DeepLabCut.",
original_videos=[video_file_path],
original_videos=[video_file_path] if video_file_path else None,
dimensions=dimensions,
devices=[camera],
scorer=scorer,
source_software="DeepLabCut",
source_software_version=deeplabcut_version,
nodes=[pes.name for pes in pose_estimation_series],
edges=paf_graph if paf_graph else None,
**pose_estimation_container_kwargs,
skeleton=skeleton,
)
pose_estimation_default_kwargs.update(pose_estimation_container_kwargs)
pose_estimation_container = PoseEstimation(**pose_estimation_default_kwargs)

if "behavior" in nwbfile.processing: # TODO: replace with get_module
behavior_processing_module = nwbfile.processing["behavior"]
else:
behavior_processing_module = nwbfile.create_processing_module(
name="behavior", description="processed behavioral data"
)
behavior_processing_module.add(pose_estimation_container)

return nwbfile
Expand Down Expand Up @@ -387,7 +421,7 @@ def _add_subject_to_nwbfile(
if video_file_path is None:
timestamps = df.index.tolist() # setting timestamps to dummy
else:
timestamps = _get_movie_timestamps(video_file_path, infer_timestamps=True)
timestamps = _get_video_timestamps(video_file_path, infer_timestamps=True)

# Fetch the corresponding metadata pickle file, we extract the edges graph from here
# TODO: This is the original implementation way to extract the file name but looks very brittle. Improve it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class DeepLabCutInterface(BaseTemporalAlignmentInterface):
"""Data interface for DeepLabCut datasets."""

display_name = "DeepLabCut"
keywords = ("DLC",)
keywords = ("DLC", "DeepLabCut", "pose estimation", "behavior")
associated_suffixes = (".h5", ".csv")
info = "Interface for handling data from DeepLabCut."

Expand Down Expand Up @@ -62,6 +62,8 @@ def __init__(
self.config_dict = _read_config(config_file_path=config_file_path)
self.subject_name = subject_name
self.verbose = verbose
self.pose_estimation_container_kwargs = dict()

super().__init__(file_path=file_path, config_file_path=config_file_path)

def get_metadata(self):
Expand Down Expand Up @@ -101,7 +103,7 @@ def add_to_nwbfile(
self,
nwbfile: NWBFile,
metadata: Optional[dict] = None,
container_name: str = "PoseEstimation",
container_name: str = "PoseEstimationDeepLabCut",
):
"""
Conversion from DLC output files to nwb. Derived from dlc2nwb library.
Expand All @@ -112,8 +114,9 @@ def add_to_nwbfile(
nwb file to which the recording information is to be added
metadata: dict
metadata info for constructing the nwb file (optional).
container_name: str, default: "PoseEstimation"
Name of the container to store the pose estimation.
container_name: str, default: "PoseEstimationDeepLabCut"
name of the PoseEstimation container in the nwb

"""
from ._dlc_utils import _add_subject_to_nwbfile

Expand All @@ -123,5 +126,5 @@ def add_to_nwbfile(
individual_name=self.subject_name,
config_file=self.source_data["config_file_path"],
timestamps=self._timestamps,
pose_estimation_container_kwargs=dict(name=container_name),
pose_estimation_container_kwargs=self.pose_estimation_container_kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,22 @@ def __init__(
verbose : bool, default: True
controls verbosity. ``True`` by default.
"""
from importlib.metadata import version

# This import is to assure that the ndx_pose is in the global namespace when an pynwb.io object is created
# For more detail, see https://github.com/rly/ndx-pose/issues/36
import ndx_pose # noqa: F401
from packaging import version as version_parse

from neuroconv.datainterfaces.behavior.video.video_utils import (
VideoCaptureContext,
)

ndx_pose_version = version("ndx-pose")

if version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2.0"):
raise ImportError("The ndx-pose version must be less than 0.2.0.")

self._vc = VideoCaptureContext

self.file_path = Path(file_path)
Expand Down
32 changes: 25 additions & 7 deletions tests/test_on_data/behavior/test_behavior_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,16 @@
except ImportError:
from setup_paths import BEHAVIOR_DATA_PATH, OUTPUT_PATH

from importlib.metadata import version

from packaging import version as version_parse

ndx_pose_version = version("ndx-pose")


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseDataInterface(DataInterfaceTestMixin, TemporalAlignmentMixin):
data_interface_cls = LightningPoseDataInterface
interface_kwargs = dict(
Expand Down Expand Up @@ -156,6 +165,9 @@ def check_read_nwb(self, nwbfile_path: str):
assert_array_equal(pose_estimation_series.data[:], test_data[["x", "y"]].values)


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseDataInterfaceWithStubTest(DataInterfaceTestMixin, TemporalAlignmentMixin):
data_interface_cls = LightningPoseDataInterface
interface_kwargs = dict(
Expand Down Expand Up @@ -363,17 +375,19 @@ def check_renaming_instance(self, nwbfile_path: str):
with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io:
nwbfile = io.read()
assert "behavior" in nwbfile.processing
assert "PoseEstimation" not in nwbfile.processing["behavior"].data_interfaces
assert "PoseEstimationDeepLabCut" not in nwbfile.processing["behavior"].data_interfaces
assert custom_container_name in nwbfile.processing["behavior"].data_interfaces

def check_read_nwb(self, nwbfile_path: str):
with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io:
nwbfile = io.read()
assert "behavior" in nwbfile.processing
processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces
assert "PoseEstimation" in processing_module_interfaces
assert "PoseEstimationDeepLabCut" in processing_module_interfaces

pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series
pose_estimation_series_in_nwb = processing_module_interfaces[
"PoseEstimationDeepLabCut"
].pose_estimation_series
expected_pose_estimation_series = ["ind1_leftear", "ind1_rightear", "ind1_snout", "ind1_tailbase"]

expected_pose_estimation_series_are_in_nwb_file = [
Expand Down Expand Up @@ -449,9 +463,11 @@ def check_read_nwb(self, nwbfile_path: str):
nwbfile = io.read()
assert "behavior" in nwbfile.processing
processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces
assert "PoseEstimation" in processing_module_interfaces
assert "PoseEstimationDeepLabCut" in processing_module_interfaces

pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series
pose_estimation_series_in_nwb = processing_module_interfaces[
"PoseEstimationDeepLabCut"
].pose_estimation_series
expected_pose_estimation_series = ["ind1_leftear", "ind1_rightear", "ind1_snout", "ind1_tailbase"]

expected_pose_estimation_series_are_in_nwb_file = [
Expand Down Expand Up @@ -500,9 +516,11 @@ def check_custom_timestamps(self, nwbfile_path: str):
nwbfile = io.read()
assert "behavior" in nwbfile.processing
processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces
assert "PoseEstimation" in processing_module_interfaces
assert "PoseEstimationDeepLabCut" in processing_module_interfaces

pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series
pose_estimation_series_in_nwb = processing_module_interfaces[
"PoseEstimationDeepLabCut"
].pose_estimation_series

for pose_estimation in pose_estimation_series_in_nwb.values():
pose_timestamps = pose_estimation.timestamps
Expand Down
9 changes: 9 additions & 0 deletions tests/test_on_data/behavior/test_lightningpose_converter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import shutil
import tempfile
from datetime import datetime
from importlib.metadata import version
from pathlib import Path
from warnings import warn

import pytest
from hdmf.testing import TestCase
from packaging import version
from packaging import version as version_parse
from pynwb import NWBHDF5IO
from pynwb.image import ImageSeries

Expand All @@ -15,7 +19,12 @@

from ..setup_paths import BEHAVIOR_DATA_PATH

ndx_pose_version = version("ndx-pose")


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseConverter(TestCase):
@classmethod
def setUpClass(cls) -> None:
Expand Down
Loading