diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index f3f147260..b3a5f41ed 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -1,13 +1,13 @@ """MoveDataMixin.""" import dataclasses -from collections.abc import Iterator +from collections.abc import Callable, Iterator from copy import copy as shallowcopy from copy import deepcopy from typing import ClassVar, TypeAlias import torch -from typing_extensions import Any, Protocol, Self, overload, runtime_checkable +from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable class InconsistentDeviceError(ValueError): # noqa: D101 @@ -22,6 +22,9 @@ class DataclassInstance(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +T = TypeVar('T') + + class MoveDataMixin: """Move dataclass fields to cpu/gpu and convert dtypes.""" @@ -151,7 +154,6 @@ def _to( copy: bool = False, memo: dict | None = None, ) -> Self: - new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self """Move data to device and convert dtype if necessary. This method is called by .to(), .cuda(), .cpu(), .double(), and so on. @@ -179,6 +181,8 @@ def _to( memo A dictionary to keep track of already converted objects to avoid multiple conversions. """ + new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + if memo is None: memo = {} @@ -208,37 +212,48 @@ def _module_to(data: torch.nn.Module) -> torch.nn.Module: data = deepcopy(data) return data._apply(_tensor_to, recurse=True) - def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin: - return obj._to( - device=device, - dtype=dtype, - non_blocking=non_blocking, - memory_format=memory_format, - shared_memory=shared_memory, - copy=copy, - memo=memo, - ) - - converted: Any - for name, data in new._items(): - if id(data) in memo: - object.__setattr__(new, name, memo[id(data)]) - continue + def _convert(data: T) -> T: + converted: T if isinstance(data, torch.Tensor): converted = _tensor_to(data) - elif isinstance(data, MoveDataMixin): - converted = _mixin_to(data) elif isinstance(data, torch.nn.Module): converted = _module_to(data) - elif copy: - converted = deepcopy(data) else: converted = data - memo[id(data)] = converted - # this works even if new is frozen - object.__setattr__(new, name, converted) + return converted + + new.apply_(_convert, memo=memo) return new + def apply_(self: Self, function: Callable[T, T], memo: dict[int, Any] | None = None, recurse: bool = True) -> Self: + """Apply a function to all children in-place. + + Parameters + ---------- + function + The function to apply to all tensors. + + memo + A dictionary to keep track of objects that the function has already been applied to, + to avoid multiple applications. This is useful if the object has a circular reference. + + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + if memo is None: + memo = {} + for name, data in self._items(): + if id(data) in memo: + # this works even if self is frozen + object.__setattr__(self, name, memo[id(data)]) + continue + if recurse and isinstance(data, MoveDataMixin): + applied = data.apply_(function, memo=memo) + else: + applied = function(data) + memo[id(data)] = applied + object.__setattr__(self, name, applied) + def cuda( self, device: torch.device | str | int | None = None,