From 845836f2c6b9a1e33d1e5749d46e6477128375ad Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 20 Oct 2024 02:46:05 +0200 Subject: [PATCH 01/10] jacobian as operator --- src/mrpro/operators/Jacobian.py | 109 +++++++++++++++++++++++++++++++ src/mrpro/operators/__init__.py | 2 + tests/operators/test_jacobian.py | 26 ++++++++ 3 files changed, 137 insertions(+) create mode 100644 src/mrpro/operators/Jacobian.py create mode 100644 tests/operators/test_jacobian.py diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py new file mode 100644 index 000000000..9c7178776 --- /dev/null +++ b/src/mrpro/operators/Jacobian.py @@ -0,0 +1,109 @@ +"""Jacobian.""" + +import torch + +from mrpro.operators.LinearOperator import LinearOperator +from mrpro.operators.Operator import Operator + + +class Jacobian(LinearOperator): + """Jacobian of an Operator. + + This operator computes the Jacobian of an operator at a given point x0, i.e. a linearization of the operator. + + + + """ + + def __init__(self, operator: Operator[torch.tensor, tuple[torch.Tensor]], *x0: torch.Tensor): + """Initialize the Jacobian operator. + + Parameters + ---------- + operator + operator to linearize + x0 + point at which to linearize the operator + """ + super().__init__() + self._vjp = None + self._x0 = x0 + self._operator = operator + self._f_x0 = None + + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint operator. + + Parameters + ---------- + x + input tensor + + Returns + ------- + output tensor + """ + if self._vjp is None: + self._f_x0, self._vjp = torch.func.vjp(self._operator, *self._x0) + return (self._vjp(x)[0],) + + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the operator. + + Parameters + ---------- + x + input tensor + + Returns + ------- + output tensor + """ + self._f_x0, jvp = torch.func.jvp(self._operator, self._x0, x) + return jvp + + @property + def value_at_x0(self) -> torch.Tensor: + """Value of the operator at x0.""" + if self._f_x0 is None: + self._f_x0 = self._operator(self._x0)[0] + return self._f_x0 + + def taylor(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Taylor approximation of the operator. + + Approximate the operator at x by a first order Taylor expansion around x0. + + This is not faster than the forward method of the operator itself, as the calculation of the + jacobian-vector-product requires the forward pass of the operator to be computed. + + Parameters + ---------- + x + input tensor + + Returns + ------- + Value of the Taylor approximation at x + """ + delta = tuple(ix - ix0 for ix, ix0 in zip(x, self._x0, strict=False)) + self._f_x0, jvp = torch.func.jvp(self._operator, self._x0, delta) + f_x = tuple(ifx + ijvp for ifx, ijvp in zip(self._f_x0, jvp, strict=False)) + return f_x + + def gauss_newton(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Calculate the Gauss-Newton approximation of the Hessian of the operator. + + Returns J^T J x, where J is the Jacobian of the operator at x0. + Uses backward and forward automatic differentiation of the operator. + + Parameters + ---------- + x + input tensor + + Returns + ------- + Gauss-Newton approximation of the Hessian applied to x + """ + return self.adjoint(*self(x)) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 4fe58f1e3..c682e9460 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -11,6 +11,7 @@ from mrpro.operators.FourierOp import FourierOp from mrpro.operators.GridSamplingOp import GridSamplingOp from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.Jacobian import Jacobian from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp from mrpro.operators.PhaseOp import PhaseOp @@ -38,6 +39,7 @@ "GridSamplingOp", "IdentityOp", "LinearOperator", + "Jacobian", "MagnitudeOp", "Operator", "PhaseOp", diff --git a/tests/operators/test_jacobian.py b/tests/operators/test_jacobian.py new file mode 100644 index 000000000..6cbff1a1e --- /dev/null +++ b/tests/operators/test_jacobian.py @@ -0,0 +1,26 @@ +import torch +from mrpro.operators import Jacobian +from mrpro.operators.functionals import L2NormSquared + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +def test_jacobian_adjointness(): + rng = RandomGenerator(123) + x = rng.float32_tensor(3) + y = rng.float32_tensor(()) + x0 = rng.float32_tensor(3) + op = L2NormSquared() + jacobian = Jacobian(op, x0) + dotproduct_adjointness_test(jacobian, x, y) + + +def test_jacobian_taylor(): + rng = RandomGenerator(123) + x0 = rng.float32_tensor(3) + x = x0 + 1e-2 * rng.float32_tensor(3) + op = L2NormSquared() + jacobian = Jacobian(op, x0) + fx = jacobian.taylor(x) + torch.testing.assert_allclose(fx, op(x), rtol=1e-3, atol=1e-3) From c20c997ac981b65eb9cc8e88550899a1dc4e4aaf Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sun, 20 Oct 2024 17:22:02 +0200 Subject: [PATCH 02/10] some type fixes --- src/mrpro/operators/Jacobian.py | 19 +++++++++---------- tests/operators/test_jacobian.py | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index 9c7178776..2507a0d1a 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -10,12 +10,9 @@ class Jacobian(LinearOperator): """Jacobian of an Operator. This operator computes the Jacobian of an operator at a given point x0, i.e. a linearization of the operator. - - - """ - def __init__(self, operator: Operator[torch.tensor, tuple[torch.Tensor]], *x0: torch.Tensor): + def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: torch.Tensor): """Initialize the Jacobian operator. Parameters @@ -26,12 +23,12 @@ def __init__(self, operator: Operator[torch.tensor, tuple[torch.Tensor]], *x0: t point at which to linearize the operator """ super().__init__() - self._vjp = None - self._x0 = x0 + self._vjp:Callable[*tuple[torch.Tensor,...],:tuple[torch.Tensor,...]]|None = None + self._x0:tuple[torch.Tensor,...] = x0 self._operator = operator - self._f_x0 = None + self._f_x0:tuple[torch.Tensor,...]|None = None - def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor,]: + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor,...]: """Apply the adjoint operator. Parameters @@ -45,6 +42,7 @@ def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor,]: """ if self._vjp is None: self._f_x0, self._vjp = torch.func.vjp(self._operator, *self._x0) + assert self._vjp is not None return (self._vjp(x)[0],) def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -63,10 +61,11 @@ def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: return jvp @property - def value_at_x0(self) -> torch.Tensor: + def value_at_x0(self) -> tuple[torch.Tensor,...]: """Value of the operator at x0.""" if self._f_x0 is None: - self._f_x0 = self._operator(self._x0)[0] + self._f_x0 = self._operator(self._x0) + assert self._f_x0 is not None return self._f_x0 def taylor(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: diff --git a/tests/operators/test_jacobian.py b/tests/operators/test_jacobian.py index 6cbff1a1e..657eca6b2 100644 --- a/tests/operators/test_jacobian.py +++ b/tests/operators/test_jacobian.py @@ -23,4 +23,4 @@ def test_jacobian_taylor(): op = L2NormSquared() jacobian = Jacobian(op, x0) fx = jacobian.taylor(x) - torch.testing.assert_allclose(fx, op(x), rtol=1e-3, atol=1e-3) + torch.testing.assert_close(fx, op(x), rtol=1e-3, atol=1e-3) From e54d9efad0c57a416284bb697e123f4dd7976a10 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sun, 20 Oct 2024 17:24:16 +0200 Subject: [PATCH 03/10] cont --- src/mrpro/operators/Jacobian.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index 2507a0d1a..22c94a321 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -1,4 +1,5 @@ """Jacobian.""" +from typing import Callable import torch @@ -42,7 +43,7 @@ def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor,...]: """ if self._vjp is None: self._f_x0, self._vjp = torch.func.vjp(self._operator, *self._x0) - assert self._vjp is not None + assert self._vjp is not None #noqa: S101 (hint for mypy) return (self._vjp(x)[0],) def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -65,7 +66,7 @@ def value_at_x0(self) -> tuple[torch.Tensor,...]: """Value of the operator at x0.""" if self._f_x0 is None: self._f_x0 = self._operator(self._x0) - assert self._f_x0 is not None + assert self._f_x0 is not None #noqa: S101 (hint for mypy) return self._f_x0 def taylor(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: From 024a5d29e4b456aac84922aa848909a49cb400e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Oct 2024 15:24:58 +0000 Subject: [PATCH 04/10] [pre-commit] auto fixes from pre-commit hooks --- src/mrpro/operators/Jacobian.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index 22c94a321..b07fbcd55 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -1,5 +1,6 @@ """Jacobian.""" -from typing import Callable + +from collections.abc import Callable import torch @@ -24,12 +25,12 @@ def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: t point at which to linearize the operator """ super().__init__() - self._vjp:Callable[*tuple[torch.Tensor,...],:tuple[torch.Tensor,...]]|None = None - self._x0:tuple[torch.Tensor,...] = x0 + self._vjp: Callable[*tuple[torch.Tensor, ...], : tuple[torch.Tensor, ...]] | None = None + self._x0: tuple[torch.Tensor, ...] = x0 self._operator = operator - self._f_x0:tuple[torch.Tensor,...]|None = None + self._f_x0: tuple[torch.Tensor, ...] | None = None - def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor,...]: + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: """Apply the adjoint operator. Parameters @@ -43,7 +44,7 @@ def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor,...]: """ if self._vjp is None: self._f_x0, self._vjp = torch.func.vjp(self._operator, *self._x0) - assert self._vjp is not None #noqa: S101 (hint for mypy) + assert self._vjp is not None # noqa: S101 (hint for mypy) return (self._vjp(x)[0],) def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -62,11 +63,11 @@ def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: return jvp @property - def value_at_x0(self) -> tuple[torch.Tensor,...]: + def value_at_x0(self) -> tuple[torch.Tensor, ...]: """Value of the operator at x0.""" if self._f_x0 is None: self._f_x0 = self._operator(self._x0) - assert self._f_x0 is not None #noqa: S101 (hint for mypy) + assert self._f_x0 is not None # noqa: S101 (hint for mypy) return self._f_x0 def taylor(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: From 05ee19f25224a4119d87a3277531d291cf625aca Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sun, 20 Oct 2024 18:06:02 +0200 Subject: [PATCH 05/10] cont --- src/mrpro/operators/Jacobian.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index b07fbcd55..d657f5434 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -30,7 +30,7 @@ def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: t self._operator = operator self._f_x0: tuple[torch.Tensor, ...] | None = None - def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override """Apply the adjoint operator. Parameters @@ -47,7 +47,7 @@ def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: assert self._vjp is not None # noqa: S101 (hint for mypy) return (self._vjp(x)[0],) - def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override] """Apply the operator. Parameters @@ -66,7 +66,7 @@ def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: def value_at_x0(self) -> tuple[torch.Tensor, ...]: """Value of the operator at x0.""" if self._f_x0 is None: - self._f_x0 = self._operator(self._x0) + self._f_x0 = self._operator(*self._x0) assert self._f_x0 is not None # noqa: S101 (hint for mypy) return self._f_x0 @@ -89,6 +89,7 @@ def taylor(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: """ delta = tuple(ix - ix0 for ix, ix0 in zip(x, self._x0, strict=False)) self._f_x0, jvp = torch.func.jvp(self._operator, self._x0, delta) + assert self._f_x0 is not None # noqa: S101 (hint for mypy) f_x = tuple(ifx + ijvp for ifx, ijvp in zip(self._f_x0, jvp, strict=False)) return f_x @@ -107,4 +108,4 @@ def gauss_newton(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: ------- Gauss-Newton approximation of the Hessian applied to x """ - return self.adjoint(*self(x)) + return self.adjoint(*self(*x)) From 23d301ef3f802926cb68496a4d18bd230257ac67 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Oct 2024 16:07:56 +0000 Subject: [PATCH 06/10] [pre-commit] auto fixes from pre-commit hooks --- src/mrpro/operators/Jacobian.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index d657f5434..679bf7798 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -30,7 +30,7 @@ def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: t self._operator = operator self._f_x0: tuple[torch.Tensor, ...] | None = None - def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override """Apply the adjoint operator. Parameters @@ -47,7 +47,7 @@ def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[o assert self._vjp is not None # noqa: S101 (hint for mypy) return (self._vjp(x)[0],) - def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override] + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override] """Apply the operator. Parameters From 9df26edb13813a9237419d7a425f4611927aa3ff Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 21 Oct 2024 12:33:05 +0200 Subject: [PATCH 07/10] typos --- src/mrpro/operators/Jacobian.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index 679bf7798..56ba57f38 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -25,12 +25,12 @@ def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: t point at which to linearize the operator """ super().__init__() - self._vjp: Callable[*tuple[torch.Tensor, ...], : tuple[torch.Tensor, ...]] | None = None + self._vjp: Callable[*tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]] | None = None self._x0: tuple[torch.Tensor, ...] = x0 self._operator = operator self._f_x0: tuple[torch.Tensor, ...] | None = None - def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: # type:ignore[override] """Apply the adjoint operator. Parameters From e77fdc78231b8c5dbb0e562c1265ae666be2b487 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Thu, 14 Nov 2024 01:32:47 +0100 Subject: [PATCH 08/10] fix tpye hint --- src/mrpro/operators/Jacobian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index 56ba57f38..da851305b 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -25,7 +25,7 @@ def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: t point at which to linearize the operator """ super().__init__() - self._vjp: Callable[*tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]] | None = None + self._vjp: Callable[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]] | None = None self._x0: tuple[torch.Tensor, ...] = x0 self._operator = operator self._f_x0: tuple[torch.Tensor, ...] | None = None From 6688633308c612749cf92b03495f69ec9c15dfe8 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 14 Nov 2024 01:55:39 +0100 Subject: [PATCH 09/10] add tests --- src/mrpro/operators/Jacobian.py | 3 ++- tests/operators/test_jacobian.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index da851305b..3a54780ba 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -1,6 +1,7 @@ """Jacobian.""" from collections.abc import Callable +from typing import Unpack import torch @@ -25,7 +26,7 @@ def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: t point at which to linearize the operator """ super().__init__() - self._vjp: Callable[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]] | None = None + self._vjp: Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor, ...]] | None = None self._x0: tuple[torch.Tensor, ...] = x0 self._operator = operator self._f_x0: tuple[torch.Tensor, ...] | None = None diff --git a/tests/operators/test_jacobian.py b/tests/operators/test_jacobian.py index 657eca6b2..15e649cb5 100644 --- a/tests/operators/test_jacobian.py +++ b/tests/operators/test_jacobian.py @@ -7,6 +7,7 @@ def test_jacobian_adjointness(): + """Test adjointness of Jacobian operator.""" rng = RandomGenerator(123) x = rng.float32_tensor(3) y = rng.float32_tensor(()) @@ -17,6 +18,7 @@ def test_jacobian_adjointness(): def test_jacobian_taylor(): + """Test Taylor expansion""" rng = RandomGenerator(123) x0 = rng.float32_tensor(3) x = x0 + 1e-2 * rng.float32_tensor(3) @@ -24,3 +26,26 @@ def test_jacobian_taylor(): jacobian = Jacobian(op, x0) fx = jacobian.taylor(x) torch.testing.assert_close(fx, op(x), rtol=1e-3, atol=1e-3) + + +def test_jacobian_gaussnewton(): + """Test Gauss Newton approximation of the Hessian""" + rng = RandomGenerator(123) + x0 = rng.float32_tensor(3) + x = x0 + 1e-2 * rng.float32_tensor(3) + op = L2NormSquared() + jacobian = Jacobian(op, x0) + (actual,) = jacobian.gauss_newton(x) + expected = torch.vdot(x, x0) * 4 * x0 # analytical solution for L2NormSquared + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3) + + +def test_jacobian_valueatx0(): + """Test value at x0""" + rng = RandomGenerator(123) + x0 = rng.float32_tensor(3) + op = L2NormSquared() + jacobian = Jacobian(op, x0) + (actual,) = jacobian.value_at_x0 + (expected,) = op(x0) + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3) From 135a455d6faa65410618c056b9956233691bc647 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 25 Nov 2024 16:53:28 +0100 Subject: [PATCH 10/10] mypy --- src/mrpro/operators/Jacobian.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mrpro/operators/Jacobian.py b/src/mrpro/operators/Jacobian.py index 3a54780ba..4e3e2500f 100644 --- a/src/mrpro/operators/Jacobian.py +++ b/src/mrpro/operators/Jacobian.py @@ -1,7 +1,6 @@ """Jacobian.""" from collections.abc import Callable -from typing import Unpack import torch @@ -26,7 +25,7 @@ def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: t point at which to linearize the operator """ super().__init__() - self._vjp: Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor, ...]] | None = None + self._vjp: Callable[[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]] | None = None self._x0: tuple[torch.Tensor, ...] = x0 self._operator = operator self._f_x0: tuple[torch.Tensor, ...] | None = None