From 4dc9f0b5ffada8b983b031013579d8f2a6c6fa74 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 3 Jan 2025 16:17:16 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- src/mrpro/data/KData.py | 300 ++++++++++++++++++- src/mrpro/data/_kdata/KDataProtocol.py | 41 --- src/mrpro/data/_kdata/KDataRearrangeMixin.py | 41 --- src/mrpro/data/_kdata/KDataRemoveOsMixin.py | 75 ----- src/mrpro/data/_kdata/KDataSelectMixin.py | 65 ---- src/mrpro/data/_kdata/KDataSplitMixin.py | 161 ---------- 6 files changed, 293 insertions(+), 390 deletions(-) delete mode 100644 src/mrpro/data/_kdata/KDataProtocol.py delete mode 100644 src/mrpro/data/_kdata/KDataRearrangeMixin.py delete mode 100644 src/mrpro/data/_kdata/KDataRemoveOsMixin.py delete mode 100644 src/mrpro/data/_kdata/KDataSelectMixin.py delete mode 100644 src/mrpro/data/_kdata/KDataSplitMixin.py diff --git a/src/mrpro/data/KData.py b/src/mrpro/data/KData.py index 4b5df6250..47ddc27d2 100644 --- a/src/mrpro/data/KData.py +++ b/src/mrpro/data/KData.py @@ -1,23 +1,21 @@ """MR raw data / k-space data class.""" +import copy import dataclasses import datetime import warnings from collections.abc import Callable, Sequence from pathlib import Path from types import EllipsisType +from typing import Literal, cast import h5py import ismrmrd import numpy as np import torch -from einops import rearrange -from typing_extensions import Self +from einops import rearrange, repeat +from typing_extensions import Self, TypeVar -from mrpro.data._kdata.KDataRearrangeMixin import KDataRearrangeMixin -from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin -from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin -from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin from mrpro.data.acq_filters import has_n_coils, is_image_acquisition from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits @@ -29,6 +27,8 @@ from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd +RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) + KDIM_SORT_LABELS = ( 'k1', 'k2', @@ -63,7 +63,9 @@ @dataclasses.dataclass(slots=True, frozen=True) -class KData(KDataSplitMixin, KDataRearrangeMixin, KDataSelectMixin, KDataRemoveOsMixin, MoveDataMixin): +class KData( + MoveDataMixin, +): """MR raw data / k-space data class.""" header: KHeader @@ -366,3 +368,287 @@ def compress_coils( ).permute(*np.argsort(permute_order)) return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone()) + + def rearrange_k2_k1_into_k1(self: Self) -> Self: + """Rearrange kdata from (... k2 k1 ...) to (... 1 (k2 k1) ...). + + Parameters + ---------- + kdata + K-space data (other coils k2 k1 k0) + + Returns + ------- + K-space data (other coils 1 (k2 k1) k0) + """ + # Rearrange data + kdat = rearrange(self.data, '... coils k2 k1 k0->... coils 1 (k2 k1) k0') + + # Rearrange trajectory + ktraj = rearrange(self.traj.as_tensor(), 'dim ... k2 k1 k0-> dim ... 1 (k2 k1) k0') + + # Create new header with correct shape + kheader = copy.deepcopy(self.header) + + # Update shape of acquisition info index + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') + ) + + return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) + + def remove_readout_os(self: Self) -> Self: + """Remove any oversampling along the readout (k0) direction [GAD]_. + + Returns a copy of the data. + + Parameters + ---------- + kdata + K-space data + + Returns + ------- + Copy of K-space data with oversampling removed. + + Raises + ------ + ValueError + If the recon matrix along x is larger than the encoding matrix along x. + + References + ---------- + .. [GAD] Gadgetron https://github.com/gadgetron/gadgetron-python + """ + from mrpro.operators.FastFourierOp import FastFourierOp + + # Ratio of k0/x between encoded and recon space + x_ratio = self.header.recon_matrix.x / self.header.encoding_matrix.x + if x_ratio == 1: + # If the encoded and recon space is the same we don't have to do anything + return self + elif x_ratio > 1: + raise ValueError('Recon matrix along x should be equal or larger than encoding matrix along x.') + + # Starting and end point of image after removing oversampling + start_cropped_readout = (self.header.encoding_matrix.x - self.header.recon_matrix.x) // 2 + end_cropped_readout = start_cropped_readout + self.header.recon_matrix.x + + def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor: + # returns a cropped copy + return data_to_crop[..., start_cropped_readout:end_cropped_readout].clone() + + # Transform to image space along readout, crop to reconstruction matrix size and transform back + fourier_k0_op = FastFourierOp(dim=(-1,)) + (cropped_data,) = fourier_k0_op(crop_readout(*fourier_k0_op.H(self.data))) + + # Adapt trajectory + ks = [self.traj.kz, self.traj.ky, self.traj.kx] + # only cropped ks that are not broadcasted/singleton along k0 + cropped_ks = [crop_readout(k) if k.shape[-1] > 1 else k.clone() for k in ks] + cropped_traj = KTrajectory(cropped_ks[0], cropped_ks[1], cropped_ks[2]) + + # Adapt header parameters + header = copy.deepcopy(self.header) + header.acq_info.center_sample -= start_cropped_readout + header.acq_info.number_of_samples[:] = cropped_data.shape[-1] + header.encoding_matrix.x = cropped_data.shape[-1] + + header.acq_info.discard_post = (header.acq_info.discard_post * x_ratio).to(torch.int32) + header.acq_info.discard_pre = (header.acq_info.discard_pre * x_ratio).to(torch.int32) + + return type(self)(header, cropped_data, cropped_traj) + + def select_other_subset( + self: Self, + subset_idx: torch.Tensor, + subset_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + ) -> Self: + """Select a subset from the other dimension of KData. + + Parameters + ---------- + kdata + K-space data (other coils k2 k1 k0) + subset_idx + Index which elements of the other subset to use, e.g. phase 0,1,2 and 5 + subset_label + Name of the other label, e.g. phase + + Returns + ------- + K-space data (other_subset coils k2 k1 k0) + + Raises + ------ + ValueError + If the subset indices are not available in the data + """ + # Make a copy such that the original kdata.header remains the same + kheader = copy.deepcopy(self.header) + ktraj = self.traj.as_tensor() + + # Verify that the subset_idx is available + label_idx = getattr(kheader.acq_info.idx, subset_label) + if not all(el in torch.unique(label_idx) for el in subset_idx): + raise ValueError('Subset indices are outside of the available index range') + + # Find subset index in acq_info index + other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) + + # Adapt header + kheader.acq_info.apply_( + lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field + ) + + # Select data + kdat = self.data[other_idx, ...] + + # Select ktraj + if ktraj.shape[1] > 1: + ktraj = ktraj[:, other_idx, ...] + + return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) + + def _split_k2_or_k1_into_other( + self, + split_idx: torch.Tensor, + other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + split_dir: Literal['k2', 'k1'], + ) -> Self: + """Based on an index tensor, split the data in e.g. phases. + + Parameters + ---------- + split_idx + 2D index describing the k2 or k1 points in each block to be moved to the other dimension + (other_split, k1_per_split) or (other_split, k2_per_split) + other_label + Label of other dimension, e.g. repetition, phase + split_dir + Dimension to split, either 'k1' or 'k2' + + Returns + ------- + K-space data with new shape + ((other other_split) coils k2 k1_per_split k0) or ((other other_split) coils k2_per_split k1 k0) + + Raises + ------ + ValueError + Already existing "other_label" can only be of length 1 + """ + # Number of other + n_other = split_idx.shape[0] + + # Verify that the specified label of the other dimension is unused + if getattr(self.header.encoding_limits, other_label).length > 1: + raise ValueError(f'{other_label} is already used to encode different parts of the scan.') + + # Set-up splitting + if split_dir == 'k1': + # Split along k1 dimensions + def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: + return dat_traj[:, :, :, split_idx, :] + + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + # cast due to https://github.com/python/mypy/issues/10817 + return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) + + # Rearrange other_split and k1 dimension + rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' + rearrange_pattern_traj = 'dim other k2 other_split k1 k0->dim (other other_split) k2 k1 k0' + rearrange_pattern_acq_info = 'other k2 other_split k1 ... -> (other other_split) k2 k1 ...' + + elif split_dir == 'k2': + # Split along k2 dimensions + def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: + return dat_traj[:, :, split_idx, :, :] + + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + return cast(RotationOrTensor, acq_info[:, split_idx, ...]) + + # Rearrange other_split and k1 dimension + rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' + rearrange_pattern_traj = 'dim other other_split k2 k1 k0->dim (other other_split) k2 k1 k0' + rearrange_pattern_acq_info = 'other other_split k2 k1 ... -> (other other_split) k2 k1 ...' + + else: + raise ValueError('split_dir has to be "k1" or "k2"') + + # Split data + kdat = rearrange(split_data_traj(self.data), rearrange_pattern_data) + + # First we need to make sure the other dimension is the same as data then we can split the trajectory + ktraj = self.traj.as_tensor() + # Verify that other dimension of trajectory is 1 or matches data + if ktraj.shape[1] > 1 and ktraj.shape[1] != self.data.shape[0]: + raise ValueError(f'other dimension of trajectory has to be 1 or match data ({self.data.shape[0]})') + elif ktraj.shape[1] == 1 and self.data.shape[0] > 1: + ktraj = repeat(ktraj, 'dim other k2 k1 k0->dim (other_data other) k2 k1 k0', other_data=self.data.shape[0]) + ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) + + # Create new header with correct shape + kheader = self.header.clone() + + # Update shape of acquisition info index + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) + if isinstance(field, Rotation | torch.Tensor) + else field + ) + + # Update other label limits and acquisition info + setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) + + # acq_info for new other dimensions + acq_info_other_split = repeat( + torch.linspace(0, n_other - 1, n_other), 'other-> other k2 k1', k2=kdat.shape[-3], k1=kdat.shape[-2] + ) + setattr(kheader.acq_info.idx, other_label, acq_info_other_split) + + return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) + + def split_k1_into_other( + self: Self, + split_idx: torch.Tensor, + other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + ) -> Self: + """Based on an index tensor, split the data in e.g. phases. + + Parameters + ---------- + kdata + K-space data (other coils k2 k1 k0) + split_idx + 2D index describing the k1 points in each block to be moved to other dimension (other_split, k1_per_split) + other_label + Label of other dimension, e.g. repetition, phase + + Returns + ------- + K-space data with new shape ((other other_split) coils k2 k1_per_split k0) + """ + return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k1') + + def split_k2_into_other( + self: Self, + split_idx: torch.Tensor, + other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + ) -> Self: + """Based on an index tensor, split the data in e.g. phases. + + Parameters + ---------- + kdata + K-space data (other coils k2 k1 k0) + split_idx + 2D index describing the k2 points in each block to be moved to other dimension (other_split, k2_per_split) + other_label + Label of other dimension, e.g. repetition, phase + + Returns + ------- + K-space data with new shape ((other other_split) coils k2_per_split k1 k0) + """ + return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k2') diff --git a/src/mrpro/data/_kdata/KDataProtocol.py b/src/mrpro/data/_kdata/KDataProtocol.py deleted file mode 100644 index 485a8fc4d..000000000 --- a/src/mrpro/data/_kdata/KDataProtocol.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Protocol for KData.""" - -from typing import Literal - -import torch -from typing_extensions import Protocol, Self - -from mrpro.data.KHeader import KHeader -from mrpro.data.KTrajectory import KTrajectory - - -class _KDataProtocol(Protocol): - """Protocol for KData used for type hinting in KData mixins. - - Note that the actual KData class can have more properties and methods than those defined here. - - If you want to use a property or method of KData in a new KDataMixin class, - you must add it to this Protocol to make sure that the type hinting works [PRO]_. - - References - ---------- - .. [PRO] Protocols https://typing.readthedocs.io/en/latest/spec/protocol.html#protocols - """ - - @property - def header(self) -> KHeader: ... - - @property - def data(self) -> torch.Tensor: ... - - @property - def traj(self) -> KTrajectory: ... - - def __init__(self, header: KHeader, data: torch.Tensor, traj: KTrajectory): ... - - def _split_k2_or_k1_into_other( - self, - split_idx: torch.Tensor, - other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - split_dir: Literal['k1', 'k2'], - ) -> Self: ... diff --git a/src/mrpro/data/_kdata/KDataRearrangeMixin.py b/src/mrpro/data/_kdata/KDataRearrangeMixin.py deleted file mode 100644 index 23a58dea6..000000000 --- a/src/mrpro/data/_kdata/KDataRearrangeMixin.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Rearrange KData.""" - -import copy - -from einops import rearrange -from typing_extensions import Self - -from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import rearrange_acq_info_fields - - -class KDataRearrangeMixin(_KDataProtocol): - """Rearrange KData.""" - - def rearrange_k2_k1_into_k1(self: Self) -> Self: - """Rearrange kdata from (... k2 k1 ...) to (... 1 (k2 k1) ...). - - Parameters - ---------- - kdata - K-space data (other coils k2 k1 k0) - - Returns - ------- - K-space data (other coils 1 (k2 k1) k0) - """ - # Rearrange data - kdat = rearrange(self.data, '... coils k2 k1 k0->... coils 1 (k2 k1) k0') - - # Rearrange trajectory - ktraj = rearrange(self.traj.as_tensor(), 'dim ... k2 k1 k0-> dim ... 1 (k2 k1) k0') - - # Create new header with correct shape - kheader = copy.deepcopy(self.header) - - # Update shape of acquisition info index - kheader.acq_info.apply_( - lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') - ) - - return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py b/src/mrpro/data/_kdata/KDataRemoveOsMixin.py deleted file mode 100644 index 555f56a39..000000000 --- a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Remove oversampling along readout dimension.""" - -from copy import deepcopy - -import torch -from typing_extensions import Self - -from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.KTrajectory import KTrajectory - - -class KDataRemoveOsMixin(_KDataProtocol): - """Remove oversampling along readout dimension.""" - - def remove_readout_os(self: Self) -> Self: - """Remove any oversampling along the readout (k0) direction [GAD]_. - - Returns a copy of the data. - - Parameters - ---------- - kdata - K-space data - - Returns - ------- - Copy of K-space data with oversampling removed. - - Raises - ------ - ValueError - If the recon matrix along x is larger than the encoding matrix along x. - - References - ---------- - .. [GAD] Gadgetron https://github.com/gadgetron/gadgetron-python - """ - from mrpro.operators.FastFourierOp import FastFourierOp - - # Ratio of k0/x between encoded and recon space - x_ratio = self.header.recon_matrix.x / self.header.encoding_matrix.x - if x_ratio == 1: - # If the encoded and recon space is the same we don't have to do anything - return self - elif x_ratio > 1: - raise ValueError('Recon matrix along x should be equal or larger than encoding matrix along x.') - - # Starting and end point of image after removing oversampling - start_cropped_readout = (self.header.encoding_matrix.x - self.header.recon_matrix.x) // 2 - end_cropped_readout = start_cropped_readout + self.header.recon_matrix.x - - def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor: - # returns a cropped copy - return data_to_crop[..., start_cropped_readout:end_cropped_readout].clone() - - # Transform to image space along readout, crop to reconstruction matrix size and transform back - fourier_k0_op = FastFourierOp(dim=(-1,)) - (cropped_data,) = fourier_k0_op(crop_readout(*fourier_k0_op.H(self.data))) - - # Adapt trajectory - ks = [self.traj.kz, self.traj.ky, self.traj.kx] - # only cropped ks that are not broadcasted/singleton along k0 - cropped_ks = [crop_readout(k) if k.shape[-1] > 1 else k.clone() for k in ks] - cropped_traj = KTrajectory(cropped_ks[0], cropped_ks[1], cropped_ks[2]) - - # Adapt header parameters - header = deepcopy(self.header) - header.acq_info.center_sample -= start_cropped_readout - header.acq_info.number_of_samples[:] = cropped_data.shape[-1] - header.encoding_matrix.x = cropped_data.shape[-1] - - header.acq_info.discard_post = (header.acq_info.discard_post * x_ratio).to(torch.int32) - header.acq_info.discard_pre = (header.acq_info.discard_pre * x_ratio).to(torch.int32) - - return type(self)(header, cropped_data, cropped_traj) diff --git a/src/mrpro/data/_kdata/KDataSelectMixin.py b/src/mrpro/data/_kdata/KDataSelectMixin.py deleted file mode 100644 index 8f8a452cf..000000000 --- a/src/mrpro/data/_kdata/KDataSelectMixin.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Select subset along other dimensions of KData.""" - -import copy -from typing import Literal - -import torch -from typing_extensions import Self - -from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.Rotation import Rotation - - -class KDataSelectMixin(_KDataProtocol): - """Select subset of KData.""" - - def select_other_subset( - self: Self, - subset_idx: torch.Tensor, - subset_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> Self: - """Select a subset from the other dimension of KData. - - Parameters - ---------- - kdata - K-space data (other coils k2 k1 k0) - subset_idx - Index which elements of the other subset to use, e.g. phase 0,1,2 and 5 - subset_label - Name of the other label, e.g. phase - - Returns - ------- - K-space data (other_subset coils k2 k1 k0) - - Raises - ------ - ValueError - If the subset indices are not available in the data - """ - # Make a copy such that the original kdata.header remains the same - kheader = copy.deepcopy(self.header) - ktraj = self.traj.as_tensor() - - # Verify that the subset_idx is available - label_idx = getattr(kheader.acq_info.idx, subset_label) - if not all(el in torch.unique(label_idx) for el in subset_idx): - raise ValueError('Subset indices are outside of the available index range') - - # Find subset index in acq_info index - other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) - - # Adapt header - kheader.acq_info.apply_( - lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field - ) - - # Select data - kdat = self.data[other_idx, ...] - - # Select ktraj - if ktraj.shape[1] > 1: - ktraj = ktraj[:, other_idx, ...] - - return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataSplitMixin.py b/src/mrpro/data/_kdata/KDataSplitMixin.py deleted file mode 100644 index c28004af4..000000000 --- a/src/mrpro/data/_kdata/KDataSplitMixin.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Mixin class to split KData into other subsets.""" - -from typing import Literal, TypeVar, cast - -import torch -from einops import rearrange, repeat -from typing_extensions import Self - -from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import rearrange_acq_info_fields -from mrpro.data.EncodingLimits import Limits -from mrpro.data.Rotation import Rotation - -RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) - - -class KDataSplitMixin(_KDataProtocol): - """Split KData into other subsets.""" - - def _split_k2_or_k1_into_other( - self, - split_idx: torch.Tensor, - other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - split_dir: Literal['k2', 'k1'], - ) -> Self: - """Based on an index tensor, split the data in e.g. phases. - - Parameters - ---------- - split_idx - 2D index describing the k2 or k1 points in each block to be moved to the other dimension - (other_split, k1_per_split) or (other_split, k2_per_split) - other_label - Label of other dimension, e.g. repetition, phase - split_dir - Dimension to split, either 'k1' or 'k2' - - Returns - ------- - K-space data with new shape - ((other other_split) coils k2 k1_per_split k0) or ((other other_split) coils k2_per_split k1 k0) - - Raises - ------ - ValueError - Already existing "other_label" can only be of length 1 - """ - # Number of other - n_other = split_idx.shape[0] - - # Verify that the specified label of the other dimension is unused - if getattr(self.header.encoding_limits, other_label).length > 1: - raise ValueError(f'{other_label} is already used to encode different parts of the scan.') - - # Set-up splitting - if split_dir == 'k1': - # Split along k1 dimensions - def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: - return dat_traj[:, :, :, split_idx, :] - - def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: - # cast due to https://github.com/python/mypy/issues/10817 - return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) - - # Rearrange other_split and k1 dimension - rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' - rearrange_pattern_traj = 'dim other k2 other_split k1 k0->dim (other other_split) k2 k1 k0' - rearrange_pattern_acq_info = 'other k2 other_split k1 ... -> (other other_split) k2 k1 ...' - - elif split_dir == 'k2': - # Split along k2 dimensions - def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: - return dat_traj[:, :, split_idx, :, :] - - def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: - return cast(RotationOrTensor, acq_info[:, split_idx, ...]) - - # Rearrange other_split and k1 dimension - rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' - rearrange_pattern_traj = 'dim other other_split k2 k1 k0->dim (other other_split) k2 k1 k0' - rearrange_pattern_acq_info = 'other other_split k2 k1 ... -> (other other_split) k2 k1 ...' - - else: - raise ValueError('split_dir has to be "k1" or "k2"') - - # Split data - kdat = rearrange(split_data_traj(self.data), rearrange_pattern_data) - - # First we need to make sure the other dimension is the same as data then we can split the trajectory - ktraj = self.traj.as_tensor() - # Verify that other dimension of trajectory is 1 or matches data - if ktraj.shape[1] > 1 and ktraj.shape[1] != self.data.shape[0]: - raise ValueError(f'other dimension of trajectory has to be 1 or match data ({self.data.shape[0]})') - elif ktraj.shape[1] == 1 and self.data.shape[0] > 1: - ktraj = repeat(ktraj, 'dim other k2 k1 k0->dim (other_data other) k2 k1 k0', other_data=self.data.shape[0]) - ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) - - # Create new header with correct shape - kheader = self.header.clone() - - # Update shape of acquisition info index - kheader.acq_info.apply_( - lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) - if isinstance(field, Rotation | torch.Tensor) - else field - ) - - # Update other label limits and acquisition info - setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) - - # acq_info for new other dimensions - acq_info_other_split = repeat( - torch.linspace(0, n_other - 1, n_other), 'other-> other k2 k1', k2=kdat.shape[-3], k1=kdat.shape[-2] - ) - setattr(kheader.acq_info.idx, other_label, acq_info_other_split) - - return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) - - def split_k1_into_other( - self: Self, - split_idx: torch.Tensor, - other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> Self: - """Based on an index tensor, split the data in e.g. phases. - - Parameters - ---------- - kdata - K-space data (other coils k2 k1 k0) - split_idx - 2D index describing the k1 points in each block to be moved to other dimension (other_split, k1_per_split) - other_label - Label of other dimension, e.g. repetition, phase - - Returns - ------- - K-space data with new shape ((other other_split) coils k2 k1_per_split k0) - """ - return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k1') - - def split_k2_into_other( - self: Self, - split_idx: torch.Tensor, - other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> Self: - """Based on an index tensor, split the data in e.g. phases. - - Parameters - ---------- - kdata - K-space data (other coils k2 k1 k0) - split_idx - 2D index describing the k2 points in each block to be moved to other dimension (other_split, k2_per_split) - other_label - Label of other dimension, e.g. repetition, phase - - Returns - ------- - K-space data with new shape ((other other_split) coils k2_per_split k1 k0) - """ - return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k2')