Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Jacobian of Operator #452

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
111 changes: 111 additions & 0 deletions src/mrpro/operators/Jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Jacobian."""

from collections.abc import Callable

import torch

from mrpro.operators.LinearOperator import LinearOperator
from mrpro.operators.Operator import Operator


class Jacobian(LinearOperator):
"""Jacobian of an Operator.

This operator computes the Jacobian of an operator at a given point x0, i.e. a linearization of the operator.
"""

def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: torch.Tensor):
"""Initialize the Jacobian operator.

Parameters
----------
operator
operator to linearize
x0
point at which to linearize the operator
"""
super().__init__()
self._vjp: Callable[[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]] | None = None
self._x0: tuple[torch.Tensor, ...] = x0
self._operator = operator
self._f_x0: tuple[torch.Tensor, ...] | None = None

def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override]
"""Apply the adjoint operator.

Parameters
----------
x
input tensor

Returns
-------
output tensor
"""
if self._vjp is None:
self._f_x0, self._vjp = torch.func.vjp(self._operator, *self._x0)
assert self._vjp is not None # noqa: S101 (hint for mypy)
return (self._vjp(x)[0],)

def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override]
"""Apply the operator.

Parameters
----------
x
input tensor

Returns
-------
output tensor
"""
self._f_x0, jvp = torch.func.jvp(self._operator, self._x0, x)
return jvp

@property
def value_at_x0(self) -> tuple[torch.Tensor, ...]:
"""Value of the operator at x0."""
if self._f_x0 is None:
self._f_x0 = self._operator(*self._x0)
assert self._f_x0 is not None # noqa: S101 (hint for mypy)
return self._f_x0

def taylor(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""Taylor approximation of the operator.

Approximate the operator at x by a first order Taylor expansion around x0.

This is not faster than the forward method of the operator itself, as the calculation of the
jacobian-vector-product requires the forward pass of the operator to be computed.

Parameters
----------
x
input tensor

Returns
-------
Value of the Taylor approximation at x
"""
delta = tuple(ix - ix0 for ix, ix0 in zip(x, self._x0, strict=False))
self._f_x0, jvp = torch.func.jvp(self._operator, self._x0, delta)
assert self._f_x0 is not None # noqa: S101 (hint for mypy)
f_x = tuple(ifx + ijvp for ifx, ijvp in zip(self._f_x0, jvp, strict=False))
return f_x

def gauss_newton(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""Calculate the Gauss-Newton approximation of the Hessian of the operator.

Returns J^T J x, where J is the Jacobian of the operator at x0.
Uses backward and forward automatic differentiation of the operator.

Parameters
----------
x
input tensor

Returns
-------
Gauss-Newton approximation of the Hessian applied to x
"""
return self.adjoint(*self(*x))
2 changes: 2 additions & 0 deletions src/mrpro/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mrpro.operators.FourierOp import FourierOp
from mrpro.operators.GridSamplingOp import GridSamplingOp
from mrpro.operators.IdentityOp import IdentityOp
from mrpro.operators.Jacobian import Jacobian
from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix
from mrpro.operators.MagnitudeOp import MagnitudeOp
from mrpro.operators.MultiIdentityOp import MultiIdentityOp
Expand All @@ -37,6 +38,7 @@
"Functional",
"GridSamplingOp",
"IdentityOp",
"Jacobian",
"LinearOperator",
"LinearOperatorMatrix",
"MagnitudeOp",
Expand Down
51 changes: 51 additions & 0 deletions tests/operators/test_jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from mrpro.operators import Jacobian
from mrpro.operators.functionals import L2NormSquared

from tests import RandomGenerator
from tests.helper import dotproduct_adjointness_test


def test_jacobian_adjointness():
"""Test adjointness of Jacobian operator."""
rng = RandomGenerator(123)
x = rng.float32_tensor(3)
y = rng.float32_tensor(())
x0 = rng.float32_tensor(3)
op = L2NormSquared()
jacobian = Jacobian(op, x0)
dotproduct_adjointness_test(jacobian, x, y)


def test_jacobian_taylor():
"""Test Taylor expansion"""
rng = RandomGenerator(123)
x0 = rng.float32_tensor(3)
x = x0 + 1e-2 * rng.float32_tensor(3)
op = L2NormSquared()
jacobian = Jacobian(op, x0)
fx = jacobian.taylor(x)
torch.testing.assert_close(fx, op(x), rtol=1e-3, atol=1e-3)


def test_jacobian_gaussnewton():
"""Test Gauss Newton approximation of the Hessian"""
rng = RandomGenerator(123)
x0 = rng.float32_tensor(3)
x = x0 + 1e-2 * rng.float32_tensor(3)
op = L2NormSquared()
jacobian = Jacobian(op, x0)
(actual,) = jacobian.gauss_newton(x)
expected = torch.vdot(x, x0) * 4 * x0 # analytical solution for L2NormSquared
torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3)


def test_jacobian_valueatx0():
"""Test value at x0"""
rng = RandomGenerator(123)
x0 = rng.float32_tensor(3)
op = L2NormSquared()
jacobian = Jacobian(op, x0)
(actual,) = jacobian.value_at_x0
(expected,) = op(x0)
torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3)
Loading