Skip to content

Commit

Permalink
_apply
Browse files Browse the repository at this point in the history
ghstack-source-id: 3d530a5c1136c4475a446b28a70b8ee716e8bdf5
ghstack-comment-id: 2441645633
Pull Request resolved: #489
  • Loading branch information
fzimmermann89 committed Oct 28, 2024
1 parent ff9ade1 commit 2637c16
Showing 1 changed file with 41 additions and 26 deletions.
67 changes: 41 additions & 26 deletions src/mrpro/data/MoveDataMixin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""MoveDataMixin."""

import dataclasses
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from copy import copy as shallowcopy
from copy import deepcopy
from typing import ClassVar, TypeAlias

import torch
from typing_extensions import Any, Protocol, Self, overload, runtime_checkable
from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable


class InconsistentDeviceError(ValueError): # noqa: D101
Expand All @@ -22,6 +22,9 @@ class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


T = TypeVar('T')


class MoveDataMixin:
"""Move dataclass fields to cpu/gpu and convert dtypes."""

Expand Down Expand Up @@ -151,7 +154,6 @@ def _to(
copy: bool = False,
memo: dict | None = None,
) -> Self:
new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self
"""Move data to device and convert dtype if necessary.
This method is called by .to(), .cuda(), .cpu(), .double(), and so on.
Expand Down Expand Up @@ -179,6 +181,8 @@ def _to(
memo
A dictionary to keep track of already converted objects to avoid multiple conversions.
"""
new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self

if memo is None:
memo = {}

Expand Down Expand Up @@ -208,37 +212,48 @@ def _module_to(data: torch.nn.Module) -> torch.nn.Module:
data = deepcopy(data)
return data._apply(_tensor_to, recurse=True)

def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin:
return obj._to(
device=device,
dtype=dtype,
non_blocking=non_blocking,
memory_format=memory_format,
shared_memory=shared_memory,
copy=copy,
memo=memo,
)

converted: Any
for name, data in new._items():
if id(data) in memo:
object.__setattr__(new, name, memo[id(data)])
continue
def _convert(data: T) -> T:
converted: T
if isinstance(data, torch.Tensor):
converted = _tensor_to(data)
elif isinstance(data, MoveDataMixin):
converted = _mixin_to(data)
elif isinstance(data, torch.nn.Module):
converted = _module_to(data)
elif copy:
converted = deepcopy(data)
else:
converted = data
memo[id(data)] = converted
# this works even if new is frozen
object.__setattr__(new, name, converted)
return converted

new.apply_(_convert, memo=memo)
return new

def apply_(self: Self, function: Callable[T, T], memo: dict[int, Any] | None = None, recurse: bool = True) -> Self:
"""Apply a function to all children in-place.
Parameters
----------
function
The function to apply to all tensors.
memo
A dictionary to keep track of objects that the function has already been applied to,
to avoid multiple applications. This is useful if the object has a circular reference.
recurse
If True, the function will be applied to all children that are MoveDataMixin instances.
"""
if memo is None:
memo = {}
for name, data in self._items():
if id(data) in memo:
# this works even if self is frozen
object.__setattr__(self, name, memo[id(data)])
continue
if recurse and isinstance(data, MoveDataMixin):
applied = data.apply_(function, memo=memo)
else:
applied = function(data)
memo[id(data)] = applied
object.__setattr__(self, name, applied)

def cuda(
self,
device: torch.device | str | int | None = None,
Expand Down

0 comments on commit 2637c16

Please sign in to comment.