From d782167fe8303135dbe6a73341cce5a47305ba93 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 31 Oct 2024 13:51:02 +0100 Subject: [PATCH 1/8] _apply ghstack-source-id: ece2c18bcaf513233da227f1cd8640fdae3b1b29 ghstack-comment-id: 2441645633 Pull Request resolved: https://github.com/PTB-MR/mrpro/pull/489 --- src/mrpro/data/AcqInfo.py | 14 +++---- src/mrpro/data/MoveDataMixin.py | 66 +++++++++++++++++++++++------- src/mrpro/data/SpatialDimension.py | 15 ------- 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 83f752a57..a66224de1 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -1,6 +1,6 @@ """Acquisition information dataclass.""" -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass import ismrmrd @@ -206,17 +206,13 @@ def tensor_2d(data: np.ndarray) -> torch.Tensor: data_tensor = data_tensor[None, None] return data_tensor - def spatialdimension_2d( - data: np.ndarray, conversion: Callable[[torch.Tensor], torch.Tensor] | None = None - ) -> SpatialDimension[torch.Tensor]: + def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: # Ensure spatial dimension is (k1*k2*other, 1, 3) if data.ndim != 2: raise ValueError('Spatial dimension is expected to be of shape (N,3)') data = data[:, None, :] # all spatial dimensions are float32 - return ( - SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))).apply_(conversion) - ) + return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))) acq_idx = AcqIdx( k1=tensor(idx['kspace_encode_step_1']), @@ -251,10 +247,10 @@ def spatialdimension_2d( flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), - patient_table_position=spatialdimension_2d(headers['patient_table_position'], mm_to_m), + patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), phase_dir=spatialdimension_2d(headers['phase_dir']), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), - position=spatialdimension_2d(headers['position'], mm_to_m), + position=spatialdimension_2d(headers['position']).apply_(mm_to_m), read_dir=spatialdimension_2d(headers['read_dir']), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index f3f147260..49b410ff4 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -1,13 +1,13 @@ """MoveDataMixin.""" import dataclasses -from collections.abc import Iterator +from collections.abc import Callable, Iterator from copy import copy as shallowcopy from copy import deepcopy -from typing import ClassVar, TypeAlias +from typing import ClassVar, TypeAlias, cast import torch -from typing_extensions import Any, Protocol, Self, overload, runtime_checkable +from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable class InconsistentDeviceError(ValueError): # noqa: D101 @@ -22,6 +22,9 @@ class DataclassInstance(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +T = TypeVar('T') + + class MoveDataMixin: """Move dataclass fields to cpu/gpu and convert dtypes.""" @@ -151,7 +154,6 @@ def _to( copy: bool = False, memo: dict | None = None, ) -> Self: - new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self """Move data to device and convert dtype if necessary. This method is called by .to(), .cuda(), .cpu(), .double(), and so on. @@ -179,6 +181,8 @@ def _to( memo A dictionary to keep track of already converted objects to avoid multiple conversions. """ + new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + if memo is None: memo = {} @@ -219,26 +223,60 @@ def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin: memo=memo, ) - converted: Any - for name, data in new._items(): - if id(data) in memo: - object.__setattr__(new, name, memo[id(data)]) - continue + def _convert(data: T) -> T: + converted: Any # https://github.com/python/mypy/issues/10817 if isinstance(data, torch.Tensor): converted = _tensor_to(data) elif isinstance(data, MoveDataMixin): converted = _mixin_to(data) elif isinstance(data, torch.nn.Module): converted = _module_to(data) - elif copy: - converted = deepcopy(data) else: converted = data - memo[id(data)] = converted - # this works even if new is frozen - object.__setattr__(new, name, converted) + return cast(T, converted) + + # manual recursion allows us to do the copy only once + new.apply_(_convert, memo=memo, recurse=False) return new + def apply_( + self: Self, function: Callable[[T], T] | None = None, memo: dict[int, Any] | None = None, recurse: bool = True + ) -> Self: + """Apply a function to all children in-place. + + Parameters + ---------- + function + The function to apply to all tensors. None is interpreted as a no-op. + + memo + A dictionary to keep track of objects that the function has already been applied to, + to avoid multiple applications. This is useful if the object has a circular reference. + + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + applied: Any + + if memo is None: + memo = {} + + if function is None: + return self + + for name, data in self._items(): + if id(data) in memo: + # this works even if self is frozen + object.__setattr__(self, name, memo[id(data)]) + continue + if recurse and isinstance(data, MoveDataMixin): + applied = data.apply_(function, memo=memo) + else: + applied = function(data) + memo[id(data)] = applied + object.__setattr__(self, name, applied) + return self + def cuda( self, device: torch.device | str | int | None = None, diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index b5f3dfd27..17ebf6375 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -134,21 +134,6 @@ def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexe self.y[idx] = other.y self.x[idx] = other.x - def apply_(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z in-place. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - if func is not None: - self.z = func(self.z) - self.y = func(self.y) - self.x = func(self.x) - return self - def apply(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: """Apply function to each of x,y,z. From 4d006e005255de4a2c88e095cb49b17b2457fd44 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 10 Nov 2024 02:52:41 +0100 Subject: [PATCH 2/8] add apply test --- tests/data/test_movedatamixin.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 06f55a4dc..3feb091de 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -23,6 +23,7 @@ class A(MoveDataMixin): """Test class A.""" floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0)) + floattensor2: torch.Tensor = field(default_factory=lambda: torch.tensor(-1.0)) complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) @@ -204,3 +205,21 @@ def testchild(attribute, expected_dtype): assert original is not new, 'original and new should not be the same object' assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' + + +def test_movedatamixin_apply(): + """Tests apply_ method of MoveDataMixin.""" + data = B() + # make one of the parameters shared to test memo behavior + data.child.floattensor2 = data.child.floattensor + original = data.clone() + + def multiply_by_2(obj): + if isinstance(obj, torch.Tensor): + return obj * 2 + return obj + + data.apply_(multiply_by_2) + torch.testing.assert_close(data.floattensor, original.floattensor * 2) + torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2) + assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' From 8b3095d2fdd339779001d83dd5dcd3d337db79ae Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 10 Nov 2024 13:15:11 +0100 Subject: [PATCH 3/8] docstring --- src/mrpro/data/MoveDataMixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index 49b410ff4..610d91b63 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -247,7 +247,7 @@ def apply_( Parameters ---------- function - The function to apply to all tensors. None is interpreted as a no-op. + The function to apply to all fields. None is interpreted as a no-op. memo A dictionary to keep track of objects that the function has already been applied to, From bb6518ddab49269ec88599c1c198ce77071764ff Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 10 Nov 2024 15:18:08 +0100 Subject: [PATCH 4/8] change apply --- src/mrpro/data/MoveDataMixin.py | 6 +++++- src/mrpro/data/SpatialDimension.py | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index 610d91b63..ce0a9a0aa 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -240,7 +240,11 @@ def _convert(data: T) -> T: return new def apply_( - self: Self, function: Callable[[T], T] | None = None, memo: dict[int, Any] | None = None, recurse: bool = True + self: Self, + function: Callable[[Any], Any] | None = None, + *, + memo: dict[int, Any] | None = None, + recurse: bool = True, ) -> Self: """Apply a function to all children in-place. diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index 17ebf6375..8c3c98960 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -3,14 +3,13 @@ from __future__ import annotations from collections.abc import Callable -from copy import deepcopy from dataclasses import dataclass from typing import Generic, get_args import numpy as np import torch from numpy.typing import ArrayLike -from typing_extensions import Any, Protocol, TypeVar, overload +from typing_extensions import Protocol, Self, TypeVar, overload import mrpro.utils.typing as type_utils from mrpro.data.MoveDataMixin import MoveDataMixin @@ -109,6 +108,10 @@ def from_array_zyx( return SpatialDimension(z, y, x) + def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: + """Apply a function to the fields of the dataclass.""" + return super().apply_(function) + @property def zyx(self) -> tuple[T_co, T_co, T_co]: """Return a z,y,x tuple.""" From c9ea2e8f8eab278e8110e8f5bb127de68215f235 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 10 Nov 2024 15:47:55 +0100 Subject: [PATCH 5/8] fix --- src/mrpro/data/SpatialDimension.py | 29 +--------------------------- tests/data/test_spatial_dimension.py | 23 ---------------------- 2 files changed, 1 insertion(+), 51 deletions(-) diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index 8c3c98960..0bc768c7b 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -110,7 +110,7 @@ def from_array_zyx( def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: """Apply a function to the fields of the dataclass.""" - return super().apply_(function) + return super(SpatialDimension, self).apply_(function) @property def zyx(self) -> tuple[T_co, T_co, T_co]: @@ -137,33 +137,6 @@ def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexe self.y[idx] = other.y self.x[idx] = other.x - def apply(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - - def func_(x: Any) -> T_co: # noqa: ANN401 - if isinstance(x, torch.Tensor): - # use clone for autograd - x = x.clone() - else: - x = deepcopy(x) - if func is None: - return x - else: - return func(x) - - return self.__class__(func_(self.z), func_(self.y), func_(self.x)) - - def clone(self: SpatialDimension[T_co]) -> SpatialDimension[T_co]: - """Return a deep copy of the SpatialDimension.""" - return self.apply() - @overload def __mul__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... diff --git a/tests/data/test_spatial_dimension.py b/tests/data/test_spatial_dimension.py index cd46854b9..afafece04 100644 --- a/tests/data/test_spatial_dimension.py +++ b/tests/data/test_spatial_dimension.py @@ -115,29 +115,6 @@ def conversion(x: torch.Tensor) -> torch.Tensor: assert torch.equal(spatial_dimension_inplace.z, z) -def test_spatial_dimension_apply(): - """Test apply (out of place)""" - - def conversion(x: torch.Tensor) -> torch.Tensor: - assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' - return x.swapaxes(0, 1).square() - - xyz = RandomGenerator(0).float32_tensor((1, 2, 3)) - spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy()) - spatial_dimension_outofplace = spatial_dimension.apply().apply(conversion) - - assert spatial_dimension_outofplace is not spatial_dimension - - assert isinstance(spatial_dimension_outofplace.x, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.y, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.z, torch.Tensor) - - x, y, z = conversion(xyz).unbind(-1) - assert torch.equal(spatial_dimension_outofplace.x, x) - assert torch.equal(spatial_dimension_outofplace.y, y) - assert torch.equal(spatial_dimension_outofplace.z, z) - - def test_spatial_dimension_zyx(): """Test the zyx tuple property""" z, y, x = (2, 3, 4) From 3df31c33b96505bdf003932a79c6e1ec620570d1 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 10 Nov 2024 15:50:29 +0100 Subject: [PATCH 6/8] docstring --- src/mrpro/data/SpatialDimension.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index 0bc768c7b..46b1db89a 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -109,7 +109,13 @@ def from_array_zyx( return SpatialDimension(z, y, x) def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: - """Apply a function to the fields of the dataclass.""" + """Apply a function to each z, y, x (in-place). + + Parameters + ---------- + function + function to apply + """ return super(SpatialDimension, self).apply_(function) @property From f7cc9129f1b6ec95eecae4f08a45a830c94d255e Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 11 Nov 2024 01:53:49 +0100 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Christoph Kolbitsch --- src/mrpro/data/MoveDataMixin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index ce0a9a0aa..ad33ea51b 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -252,9 +252,8 @@ def apply_( ---------- function The function to apply to all fields. None is interpreted as a no-op. - memo - A dictionary to keep track of objects that the function has already been applied to, + A dictionary to keep track of objects that the function has already been applied to, to avoid multiple applications. This is useful if the object has a circular reference. recurse From 81cd6f8cdb7d1a32812b84e8fd7b590ec9430b57 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 11 Nov 2024 11:23:03 +0100 Subject: [PATCH 8/8] Update src/mrpro/data/MoveDataMixin.py Co-authored-by: Christoph Kolbitsch --- src/mrpro/data/MoveDataMixin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index ad33ea51b..8d977d0a6 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -255,7 +255,6 @@ def apply_( memo A dictionary to keep track of objects that the function has already been applied to, to avoid multiple applications. This is useful if the object has a circular reference. - recurse If True, the function will be applied to all children that are MoveDataMixin instances. """