Skip to content

Commit

Permalink
_apply
Browse files Browse the repository at this point in the history
ghstack-source-id: ece2c18bcaf513233da227f1cd8640fdae3b1b29
ghstack-comment-id: 2441645633
Pull Request resolved: #489
  • Loading branch information
fzimmermann89 committed Nov 10, 2024
1 parent c268ad2 commit d782167
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 38 deletions.
14 changes: 5 additions & 9 deletions src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Acquisition information dataclass."""

from collections.abc import Callable, Sequence
from collections.abc import Sequence
from dataclasses import dataclass

import ismrmrd
Expand Down Expand Up @@ -206,17 +206,13 @@ def tensor_2d(data: np.ndarray) -> torch.Tensor:
data_tensor = data_tensor[None, None]
return data_tensor

def spatialdimension_2d(
data: np.ndarray, conversion: Callable[[torch.Tensor], torch.Tensor] | None = None
) -> SpatialDimension[torch.Tensor]:
def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
# Ensure spatial dimension is (k1*k2*other, 1, 3)
if data.ndim != 2:
raise ValueError('Spatial dimension is expected to be of shape (N,3)')
data = data[:, None, :]
# all spatial dimensions are float32
return (
SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))).apply_(conversion)
)
return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32)))

acq_idx = AcqIdx(
k1=tensor(idx['kspace_encode_step_1']),
Expand Down Expand Up @@ -251,10 +247,10 @@ def spatialdimension_2d(
flags=tensor_2d(headers['flags']),
measurement_uid=tensor_2d(headers['measurement_uid']),
number_of_samples=tensor_2d(headers['number_of_samples']),
patient_table_position=spatialdimension_2d(headers['patient_table_position'], mm_to_m),
patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m),
phase_dir=spatialdimension_2d(headers['phase_dir']),
physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']),
position=spatialdimension_2d(headers['position'], mm_to_m),
position=spatialdimension_2d(headers['position']).apply_(mm_to_m),
read_dir=spatialdimension_2d(headers['read_dir']),
sample_time_us=tensor_2d(headers['sample_time_us']),
scan_counter=tensor_2d(headers['scan_counter']),
Expand Down
66 changes: 52 additions & 14 deletions src/mrpro/data/MoveDataMixin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""MoveDataMixin."""

import dataclasses
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from copy import copy as shallowcopy
from copy import deepcopy
from typing import ClassVar, TypeAlias
from typing import ClassVar, TypeAlias, cast

import torch
from typing_extensions import Any, Protocol, Self, overload, runtime_checkable
from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable


class InconsistentDeviceError(ValueError): # noqa: D101
Expand All @@ -22,6 +22,9 @@ class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


T = TypeVar('T')


class MoveDataMixin:
"""Move dataclass fields to cpu/gpu and convert dtypes."""

Expand Down Expand Up @@ -151,7 +154,6 @@ def _to(
copy: bool = False,
memo: dict | None = None,
) -> Self:
new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self
"""Move data to device and convert dtype if necessary.
This method is called by .to(), .cuda(), .cpu(), .double(), and so on.
Expand Down Expand Up @@ -179,6 +181,8 @@ def _to(
memo
A dictionary to keep track of already converted objects to avoid multiple conversions.
"""
new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self

if memo is None:
memo = {}

Expand Down Expand Up @@ -219,26 +223,60 @@ def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin:
memo=memo,
)

converted: Any
for name, data in new._items():
if id(data) in memo:
object.__setattr__(new, name, memo[id(data)])
continue
def _convert(data: T) -> T:
converted: Any # https://github.com/python/mypy/issues/10817
if isinstance(data, torch.Tensor):
converted = _tensor_to(data)
elif isinstance(data, MoveDataMixin):
converted = _mixin_to(data)
elif isinstance(data, torch.nn.Module):
converted = _module_to(data)
elif copy:
converted = deepcopy(data)
else:
converted = data
memo[id(data)] = converted
# this works even if new is frozen
object.__setattr__(new, name, converted)
return cast(T, converted)

# manual recursion allows us to do the copy only once
new.apply_(_convert, memo=memo, recurse=False)
return new

def apply_(
self: Self, function: Callable[[T], T] | None = None, memo: dict[int, Any] | None = None, recurse: bool = True
) -> Self:
"""Apply a function to all children in-place.
Parameters
----------
function
The function to apply to all tensors. None is interpreted as a no-op.
memo
A dictionary to keep track of objects that the function has already been applied to,
to avoid multiple applications. This is useful if the object has a circular reference.
recurse
If True, the function will be applied to all children that are MoveDataMixin instances.
"""
applied: Any

if memo is None:
memo = {}

if function is None:
return self

for name, data in self._items():
if id(data) in memo:
# this works even if self is frozen
object.__setattr__(self, name, memo[id(data)])
continue
if recurse and isinstance(data, MoveDataMixin):
applied = data.apply_(function, memo=memo)
else:
applied = function(data)
memo[id(data)] = applied
object.__setattr__(self, name, applied)
return self

def cuda(
self,
device: torch.device | str | int | None = None,
Expand Down
15 changes: 0 additions & 15 deletions src/mrpro/data/SpatialDimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,6 @@ def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexe
self.y[idx] = other.y
self.x[idx] = other.x

def apply_(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]:
"""Apply function to each of x,y,z in-place.
Parameters
----------
func
function to apply to each of x,y,z
None is interpreted as the identity function.
"""
if func is not None:
self.z = func(self.z)
self.y = func(self.y)
self.x = func(self.x)
return self

def apply(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]:
"""Apply function to each of x,y,z.
Expand Down

0 comments on commit d782167

Please sign in to comment.