Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transform method to Detections class and documentation for the transform method #1779

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions docs/detection/detection.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
### `transform`

Transform detections to match the dataset's class names and IDs.

This method performs the following steps:

1. **Remaps class names** using the provided `class_mapping` dictionary.
2. **Filters out predictions** that are not present in the dataset's classes.
3. **Remaps class IDs** to match the dataset's class IDs.

#### Parameters

- **dataset**: The dataset object containing class names and IDs.
- **class_mapping** (`Optional[Dict[str, str]]`): A dictionary to map model class names to dataset class names. If `None`, no remapping is performed.

#### Returns

- **Detections**: A new `Detections` object with transformed class names and IDs.

#### Raises

- **ValueError**: If the dataset does not contain the required class names.

#### Example

```python
# Example dataset with class names
class DatasetMock:
def __init__(self):
self.classes = ["animal", "bird"]

# Example detections
detections = Detections(
xyxy=np.array([[10, 10, 50, 50], [60, 60, 100, 100]]),
confidence=np.array([0.9, 0.8]),
class_id=np.array([0, 1]),
data={"class_name": ["dog", "eagle"]}
)

# Class mapping
class_mapping = {"dog": "animal", "eagle": "bird"}

# Transform detections
transformed_detections = detections.transform(DatasetMock(), class_mapping)

print(transformed_detections.class_id) # Output: [0, 1]
print(transformed_detections.data["class_name"]) # Output: ["animal", "bird"]
```
62 changes: 62 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,3 +1530,65 @@ def validate_fields_both_defined_or_none(
f"Field '{attribute}' should be consistently None or not None in both "
"Detections."
)


def transform(
self, dataset, class_mapping: Optional[Dict[str, str]] = None
) -> Detections:
"""
Transform detections to match the dataset's class names and IDs.

Args:
dataset: The dataset object containing class names and IDs.
class_mapping (Optional[Dict[str, str]]): A dictionary to map model class names
to dataset class names. If None, no remapping is performed.

Returns:
Detections: A new Detections object with transformed class names and IDs.

Raises:
ValueError: If the dataset does not contain the required class names.
"""
if self.is_empty():
return self

if class_mapping is not None:
if self.class_id is None or self.data.get(CLASS_NAME_DATA_FIELD) is None:
raise ValueError(
"Class names must be available in the data field for remapping."
)

current_class_names = self.data[CLASS_NAME_DATA_FIELD]

remapped_class_names = np.array(
[class_mapping.get(name, name) for name in current_class_names]
)

if not all(name in dataset.classes for name in np.unique(remapped_class_names)):
raise ValueError("All mapped class names must be in the dataset's classes.")

self.data[CLASS_NAME_DATA_FIELD] = remapped_class_names

if self.class_id is not None and self.data.get(CLASS_NAME_DATA_FIELD) is not None:
class_names = self.data[CLASS_NAME_DATA_FIELD]

mask = np.isin(class_names, dataset.classes)

self.xyxy = self.xyxy[mask]
self.mask = self.mask[mask] if self.mask is not None else None
self.confidence = self.confidence[mask] if self.confidence is not None else None
self.class_id = self.class_id[mask] if self.class_id is not None else None
self.tracker_id = self.tracker_id[mask] if self.tracker_id is not None else None

for key, value in self.data.items():
if isinstance(value, np.ndarray):
self.data[key] = value[mask]
elif isinstance(value, list):
self.data[key] = [value[i] for i in np.where(mask)[0]]

if self.class_id is not None and self.data.get(CLASS_NAME_DATA_FIELD) is not None:
class_names = self.data[CLASS_NAME_DATA_FIELD]

self.class_id = np.array([dataset.classes.index(name) for name in class_names])

return self
37 changes: 37 additions & 0 deletions test/detection/test_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest

import numpy as np

from supervision.detections.core import Detections


class TestDetectionsTransform(unittest.TestCase):
def test_transform(self):
# Mock dataset
class DatasetMock:
def __init__(self):
self.classes = ["animal", "bird"]

# Example detections
detections = Detections(
xyxy=np.array([[10, 10, 50, 50], [60, 60, 100, 100]]),
confidence=np.array([0.9, 0.8]),
class_id=np.array([0, 1]),
data={"class_name": ["dog", "eagle"]},
)

# Class mapping
class_mapping = {"dog": "animal", "eagle": "bird"}

# Transform detections
transformed_detections = detections.transform(DatasetMock(), class_mapping)

# Verify results
self.assertEqual(transformed_detections.class_id.tolist(), [0, 1])
self.assertEqual(
transformed_detections.data["class_name"].tolist(), ["animal", "bird"]
)


if __name__ == "__main__":
unittest.main()