Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add apply_ to dataclasses #505

Merged
merged 9 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Acquisition information dataclass."""

from collections.abc import Callable, Sequence
from collections.abc import Sequence
from dataclasses import dataclass

import ismrmrd
Expand Down Expand Up @@ -206,17 +206,13 @@ def tensor_2d(data: np.ndarray) -> torch.Tensor:
data_tensor = data_tensor[None, None]
return data_tensor

def spatialdimension_2d(
data: np.ndarray, conversion: Callable[[torch.Tensor], torch.Tensor] | None = None
) -> SpatialDimension[torch.Tensor]:
def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
# Ensure spatial dimension is (k1*k2*other, 1, 3)
if data.ndim != 2:
raise ValueError('Spatial dimension is expected to be of shape (N,3)')
data = data[:, None, :]
# all spatial dimensions are float32
return (
SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))).apply_(conversion)
)
return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32)))

acq_idx = AcqIdx(
k1=tensor(idx['kspace_encode_step_1']),
Expand Down Expand Up @@ -251,10 +247,10 @@ def spatialdimension_2d(
flags=tensor_2d(headers['flags']),
measurement_uid=tensor_2d(headers['measurement_uid']),
number_of_samples=tensor_2d(headers['number_of_samples']),
patient_table_position=spatialdimension_2d(headers['patient_table_position'], mm_to_m),
patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m),
phase_dir=spatialdimension_2d(headers['phase_dir']),
physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']),
position=spatialdimension_2d(headers['position'], mm_to_m),
position=spatialdimension_2d(headers['position']).apply_(mm_to_m),
read_dir=spatialdimension_2d(headers['read_dir']),
sample_time_us=tensor_2d(headers['sample_time_us']),
scan_counter=tensor_2d(headers['scan_counter']),
Expand Down
68 changes: 54 additions & 14 deletions src/mrpro/data/MoveDataMixin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""MoveDataMixin."""

import dataclasses
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from copy import copy as shallowcopy
from copy import deepcopy
from typing import ClassVar, TypeAlias
from typing import ClassVar, TypeAlias, cast

import torch
from typing_extensions import Any, Protocol, Self, overload, runtime_checkable
from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable


class InconsistentDeviceError(ValueError): # noqa: D101
Expand All @@ -22,6 +22,9 @@ class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


T = TypeVar('T')


class MoveDataMixin:
"""Move dataclass fields to cpu/gpu and convert dtypes."""

Expand Down Expand Up @@ -151,7 +154,6 @@ def _to(
copy: bool = False,
memo: dict | None = None,
) -> Self:
new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self
"""Move data to device and convert dtype if necessary.

This method is called by .to(), .cuda(), .cpu(), .double(), and so on.
Expand Down Expand Up @@ -179,6 +181,8 @@ def _to(
memo
A dictionary to keep track of already converted objects to avoid multiple conversions.
"""
new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self

if memo is None:
memo = {}

Expand Down Expand Up @@ -219,26 +223,62 @@ def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin:
memo=memo,
)

converted: Any
for name, data in new._items():
if id(data) in memo:
object.__setattr__(new, name, memo[id(data)])
continue
def _convert(data: T) -> T:
converted: Any # https://github.com/python/mypy/issues/10817
if isinstance(data, torch.Tensor):
converted = _tensor_to(data)
elif isinstance(data, MoveDataMixin):
converted = _mixin_to(data)
elif isinstance(data, torch.nn.Module):
converted = _module_to(data)
elif copy:
converted = deepcopy(data)
else:
converted = data
memo[id(data)] = converted
# this works even if new is frozen
object.__setattr__(new, name, converted)
return cast(T, converted)

# manual recursion allows us to do the copy only once
new.apply_(_convert, memo=memo, recurse=False)
return new

def apply_(
self: Self,
function: Callable[[Any], Any] | None = None,
*,
memo: dict[int, Any] | None = None,
recurse: bool = True,
) -> Self:
"""Apply a function to all children in-place.

Parameters
----------
function
The function to apply to all fields. None is interpreted as a no-op.
memo
A dictionary to keep track of objects that the function has already been applied to,
to avoid multiple applications. This is useful if the object has a circular reference.
recurse
If True, the function will be applied to all children that are MoveDataMixin instances.
"""
applied: Any

if memo is None:
memo = {}

if function is None:
return self

for name, data in self._items():
if id(data) in memo:
# this works even if self is frozen
object.__setattr__(self, name, memo[id(data)])
continue
if recurse and isinstance(data, MoveDataMixin):
applied = data.apply_(function, memo=memo)
else:
applied = function(data)
memo[id(data)] = applied
object.__setattr__(self, name, applied)
return self

def cuda(
self,
device: torch.device | str | int | None = None,
Expand Down
55 changes: 11 additions & 44 deletions src/mrpro/data/SpatialDimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from __future__ import annotations

from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass
from typing import Generic, get_args

import numpy as np
import torch
from numpy.typing import ArrayLike
from typing_extensions import Any, Protocol, TypeVar, overload
from typing_extensions import Protocol, Self, TypeVar, overload

import mrpro.utils.typing as type_utils
from mrpro.data.MoveDataMixin import MoveDataMixin
Expand Down Expand Up @@ -109,6 +108,16 @@ def from_array_zyx(

return SpatialDimension(z, y, x)

def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self:
ckolbPTB marked this conversation as resolved.
Show resolved Hide resolved
"""Apply a function to each z, y, x (in-place).

Parameters
----------
function
function to apply
"""
return super(SpatialDimension, self).apply_(function)

@property
def zyx(self) -> tuple[T_co, T_co, T_co]:
"""Return a z,y,x tuple."""
Expand All @@ -134,48 +143,6 @@ def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexe
self.y[idx] = other.y
self.x[idx] = other.x

def apply_(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]:
"""Apply function to each of x,y,z in-place.

Parameters
----------
func
function to apply to each of x,y,z
None is interpreted as the identity function.
"""
if func is not None:
self.z = func(self.z)
self.y = func(self.y)
self.x = func(self.x)
return self

def apply(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]:
"""Apply function to each of x,y,z.

Parameters
----------
func
function to apply to each of x,y,z
None is interpreted as the identity function.
"""

def func_(x: Any) -> T_co: # noqa: ANN401
if isinstance(x, torch.Tensor):
# use clone for autograd
x = x.clone()
else:
x = deepcopy(x)
if func is None:
return x
else:
return func(x)

return self.__class__(func_(self.z), func_(self.y), func_(self.x))

def clone(self: SpatialDimension[T_co]) -> SpatialDimension[T_co]:
"""Return a deep copy of the SpatialDimension."""
return self.apply()

@overload
def __mul__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ...

Expand Down
19 changes: 19 additions & 0 deletions tests/data/test_movedatamixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class A(MoveDataMixin):
"""Test class A."""

floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0))
floattensor2: torch.Tensor = field(default_factory=lambda: torch.tensor(-1.0))
complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64))
inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32))
booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True))
Expand Down Expand Up @@ -204,3 +205,21 @@ def testchild(attribute, expected_dtype):
assert original is not new, 'original and new should not be the same object'

assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared'


def test_movedatamixin_apply():
"""Tests apply_ method of MoveDataMixin."""
data = B()
# make one of the parameters shared to test memo behavior
data.child.floattensor2 = data.child.floattensor
original = data.clone()

def multiply_by_2(obj):
if isinstance(obj, torch.Tensor):
return obj * 2
return obj

data.apply_(multiply_by_2)
torch.testing.assert_close(data.floattensor, original.floattensor * 2)
torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2)
assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared'
23 changes: 0 additions & 23 deletions tests/data/test_spatial_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,6 @@ def conversion(x: torch.Tensor) -> torch.Tensor:
assert torch.equal(spatial_dimension_inplace.z, z)


def test_spatial_dimension_apply():
"""Test apply (out of place)"""

def conversion(x: torch.Tensor) -> torch.Tensor:
assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor'
return x.swapaxes(0, 1).square()

xyz = RandomGenerator(0).float32_tensor((1, 2, 3))
spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy())
spatial_dimension_outofplace = spatial_dimension.apply().apply(conversion)

assert spatial_dimension_outofplace is not spatial_dimension

assert isinstance(spatial_dimension_outofplace.x, torch.Tensor)
assert isinstance(spatial_dimension_outofplace.y, torch.Tensor)
assert isinstance(spatial_dimension_outofplace.z, torch.Tensor)

x, y, z = conversion(xyz).unbind(-1)
assert torch.equal(spatial_dimension_outofplace.x, x)
assert torch.equal(spatial_dimension_outofplace.y, y)
assert torch.equal(spatial_dimension_outofplace.z, z)


def test_spatial_dimension_zyx():
"""Test the zyx tuple property"""
z, y, x = (2, 3, 4)
Expand Down