Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Jan 7, 2025
1 parent a146b84 commit 1776900
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 27 deletions.
47 changes: 29 additions & 18 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.

import ast
import dataclasses
import inspect
import os
import sys

from sphinx_pyproject import SphinxConfig
from urllib.parse import quote

from mrpro import __version__ as project_version
config = SphinxConfig('../../pyproject.toml', globalns=globals(), config_overrides = {'version': project_version})
from mrpro import __version__ as project_version

config = SphinxConfig('../../pyproject.toml', globalns=globals(), config_overrides={'version': project_version})
sys.path.insert(0, os.path.abspath('../../src')) # Source code dir relative to this file

# -- Project information -----------------------------------------------------
Expand All @@ -33,15 +37,13 @@
'sphinx.ext.doctest',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
# 'sphinx.ext.linkcode',
'sphinx_github_style',
'sphinx_github_style',
'sphinx.ext.napoleon',
'myst_nb',
'sphinx.ext.mathjax',
'sphinx-mathjax-offline',
'sphinx.ext.intersphinx',
# 'sphinx_autodoc_typehints',

'sphinx_autodoc_typehints',
]
intersphinx_mapping = {
'torch': ('https://pytorch.org/docs/stable/', None),
Expand All @@ -58,17 +60,20 @@
}

napoleon_use_param = True
napoleon_use_rtype = False
typehints_defaults = 'comma'
typehints_use_signature = True
typehints_use_signature_return = True
typehints_use_rtype = False
autosummary_generate = True
autosummary_imported_members = False
autosummary_ignore_module_all = False
autodoc_member_order = 'groupwise'
autodoc_preserve_defaults = True
autodoc_class_signature = 'separated'
templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
source_suffix = {'.rst': 'restructuredtext', '.txt': 'restructuredtext', '.md': 'markdown'}
autodoc_class_signature = 'separated'
myst_enable_extensions = [
'amsmath',
'dollarmath',
Expand Down Expand Up @@ -96,24 +101,30 @@
'github_version': 'main',
}
linkcode_blob = html_context['github_version']
autodoc_preserve_defaults=True


import inspect, ast,dataclasses
def get_lambda_source(obj):
"""Convert lambda to source code."""
source = inspect.getsource(obj)
for node in ast.walk(ast.parse(source.strip())):
if isinstance(node, ast.Lambda):
return ast.unparse(node.body)
class DefaultValue():


class DefaultValue:
"""Used to store default values of dataclass fields with default factory."""

def __init__(self, value):
self.value = value

def __repr__(self):
"""This is called by sphinx when rendering the default value."""
return self.value


def rewrite_dataclass_init_default_factories(app, obj, bound_method) -> None:
if (not "init" in str(obj)
if (
not 'init' in str(obj)
or not getattr(obj, '__defaults__', None)
or not any(isinstance(d, dataclasses._HAS_DEFAULT_FACTORY_CLASS) for d in obj.__defaults__)
):
Expand All @@ -127,16 +138,16 @@ def rewrite_dataclass_init_default_factories(app, obj, bound_method) -> None:
if field.default_factory is not dataclasses.MISSING:
if not field.name in parameters:
continue
if field.default_factory.__name__ == "<lambda>":
defaults[field.name] = DefaultValue(get_lambda_source(field.default_factory))
if field.default_factory.__name__ == '<lambda>':
defaults[field.name] = DefaultValue(get_lambda_source(field.default_factory))
else:
defaults[field.name] = DefaultValue(field.default_factory.__name__ + "()")
new_defaults = tuple(defaults.get(name,param.default) for name, param in parameters.items() if param.default!=inspect._empty)
defaults[field.name] = DefaultValue(field.default_factory.__name__ + '()')
new_defaults = tuple(defaults.get(name, param.default) for name, param in parameters.items() if param.default != inspect._empty)
obj.__defaults__ = new_defaults


def setup(app):
# forces mathjax on all pages
app.set_html_assets_policy('always')
# rewrite dataclass init signature
app.connect("autodoc-before-process-signature", rewrite_dataclass_init_default_factories)

app.connect('autodoc-before-process-signature', rewrite_dataclass_init_default_factories)
4 changes: 2 additions & 2 deletions src/mrpro/data/KTrajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class KTrajectory(MoveDataMixin):
"""Trajectory in z direction / phase encoding direction k2 if Cartesian. Shape `(*other,k2,k1,k0)`"""

ky: torch.Tensor
"""Trajectory in y direction / phase encoding direction k1 if Cartesian. Shape (*other,k2,k1,k0)"""
"""Trajectory in y direction / phase encoding direction k1 if Cartesian. Shape `(*other,k2,k1,k0)`"""

kx: torch.Tensor
"""Trajectory in x direction / phase encoding direction k0 if Cartesian. Shape (*other,k2,k1,k0)"""
"""Trajectory in x direction / phase encoding direction k0 if Cartesian. Shape `(*other,k2,k1,k0)`"""

grid_detection_tolerance: float = 1e-3
"""tolerance of how close trajectory positions have to be to integer grid points."""
Expand Down
4 changes: 2 additions & 2 deletions src/mrpro/data/Rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def as_matrix(self) -> torch.Tensor:
"""Represent as rotation matrix.
3D rotations can be represented using rotation matrices, which
are 3 x 3 real orthogonal matrices with determinant equal to +1 [ROTa]_
are 3 x 3 real orthogonal matrices with determinant equal to +1 [ROT]_
for proper rotations and -1 for improper rotations.
Returns
Expand All @@ -916,7 +916,7 @@ def as_matrix(self) -> torch.Tensor:
References
----------
.. [ROTa] Rotation matrix https://en.wikipedia.org/wiki/Rotation_matrix#In_three_dimensions
.. [ROT] Rotation matrix https://en.wikipedia.org/wiki/Rotation_matrix#In_three_dimensions
"""
quaternions = self._quaternions
matrix = _quaternion_to_matrix(quaternions)
Expand Down
10 changes: 5 additions & 5 deletions src/mrpro/operators/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ def operator_norm(
the function throws an `ValueError`.
dim
The dimensions of the tensors on which the operator operates. The choice of `dim` determines how
the operator norm is inperpreted.
For example, for a matrix-vector multiplication with a batched matrix tensor of shape
`(batch1, batch2, row, column)` and a batched input tensor of shape `(batch1, batch2, row)`:
the operator norm is inperpreted. For example, for a matrix-vector multiplication with a batched matrix
tensor of shape `(batch1, batch2, row, column)` and a batched input tensor of shape `(batch1, batch2, row)`:
- If `dim=None`, the operator is considered as a block diagonal matrix with batch1*batch2 blocks
and the result is a tensor containing a single norm value (shape `(1, 1, 1)`).
and the result is a tensor containing a single norm value (shape `(1, 1, 1)`).
- If `dim=(-1)`, `batch1*batch2` matrices are considered, and for each a separate operator norm is computed.
- If `dim=(-1,-2)`, `batch1` matrices with `batch2` blocks are considered, and for each matrix a
separate operator norm is computed.
separate operator norm is computed.
Thus, the choice of `dim` determines implicitly determines the domain of the operator.
max_iterations
maximum number of iterations
Expand Down

0 comments on commit 1776900

Please sign in to comment.