-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* NUFFT * created NonLinearOperator template * first lbfgs with test * changed input of lbfgs to list of params * lbfgs raises error for complex-valued tensors * fixed lbfgs; added adam with sat-recov example for ellipse phantom * fixed tests; removed warning * addressed review points * fixed docstring in mse class * test * fertig * addressed final reviews * fixed mypy complain * removed pytestwarning comment * deleted phantom and commented pytest warning * fixed mse type hint; updated adam and lbfgs docstrings * add test for invalid bounds and address review commnts * rename contraintop test file * fix error introduced by me... * fix typo in comment * allow neginf/posinf as bounds --------- Co-authored-by: koflera <[email protected]> Co-authored-by: koflera <[email protected]> Co-authored-by: Felix <[email protected]>
- Loading branch information
1 parent
9bf426e
commit 88b59d2
Showing
13 changed files
with
751 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from mrpro.algorithms._prewhiten_kspace import prewhiten_kspace | ||
from mrpro.algorithms._remove_readout_os import remove_readout_os | ||
from mrpro.algorithms._lbfgs import lbfgs | ||
from mrpro.algorithms._adam import adam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""ADAM for solving non-linear minimization problems.""" | ||
|
||
# Copyright 2024 Physikalisch-Technische Bundesanstalt | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
from torch.optim import Adam | ||
|
||
from mrpro.operators import Operator | ||
|
||
|
||
def adam( | ||
f: Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor]], | ||
params: list, | ||
max_iter: int, | ||
lr: float = 1e-3, | ||
betas: tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-8, | ||
weight_decay: float = 0, | ||
amsgrad: bool = False, | ||
foreach: bool | None = None, | ||
maximize: bool = False, | ||
capturable: bool = False, | ||
differentiable: bool = False, | ||
fused: bool | None = None, | ||
) -> list[torch.Tensor]: | ||
"""Adam for non-linear minimization problems. | ||
Parameters | ||
---------- | ||
f | ||
scalar-valued function to be optimized | ||
params | ||
list of parameters to be optimized. | ||
Note that these parameters will not be changed. Instead, we create a copy and | ||
leave the initial values untouched. | ||
lr, optional | ||
learning rate, by default 1e-3 | ||
betas, optional | ||
coefficients used for computing running averages of gradient and its square, | ||
by default (0.9, 0.999) | ||
eps, optional | ||
term added to the denominator to improve numerical stability, by default 1e-8 | ||
weight_decay, optional | ||
weight decay (L2 penalty), by default 0 | ||
amsgrad, optional | ||
whether to use the AMSGrad variant of this algorithm from the paper | ||
`On the Convergence of Adam and Beyond`, by default False | ||
foreach, optional | ||
whether `foreach` implementation of optimizer is used, by default None | ||
maximize, optional | ||
maximize the objective with respect to the params, instead of minimizing, by default False | ||
capturable, optional | ||
whether this instance is safe to capture in a CUDA graph. Passing True can impair ungraphed | ||
performance, so if you don’t intend to graph capture this instance, leave it False, by default False | ||
differentiable, optional | ||
whether autograd should occur through the optimizer step in training. Otherwise, the step() function | ||
runs in a torch.no_grad() context. Setting to True can impair performance, so leave it False if you | ||
don’t intend to run autograd through this instance, by default False | ||
fused, optional | ||
whether the fused implementation (CUDA only) is used. Currently, torch.float64, torch.float32, | ||
torch.float16, and torch.bfloat16 are supported., by default None | ||
Returns | ||
------- | ||
list of optimized parameters | ||
""" | ||
|
||
# define Adam routine | ||
optim = Adam( | ||
params=[p.detach().clone().requires_grad_(True) for p in params], | ||
lr=lr, | ||
betas=betas, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
amsgrad=amsgrad, | ||
foreach=foreach, | ||
maximize=maximize, | ||
capturable=capturable, | ||
differentiable=differentiable, | ||
fused=fused, | ||
) | ||
|
||
def closure(): | ||
optim.zero_grad() | ||
(objective,) = f(*params) | ||
objective.backward() | ||
return objective | ||
|
||
# run adam | ||
for _ in range(max_iter): | ||
optim.step(closure) | ||
|
||
return params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""LBFGS for solving non-linear minimization problems.""" | ||
|
||
# Copyright 2024 Physikalisch-Technische Bundesanstalt | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import torch | ||
from torch.optim import LBFGS | ||
|
||
from mrpro.operators import Operator | ||
|
||
|
||
def lbfgs( | ||
f: Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor]], | ||
params: list, | ||
lr: float = 1.0, | ||
max_iter: int = 100, | ||
max_eval: int | None = 100, | ||
tolerance_grad: float = 1e-07, | ||
tolerance_change: float = 1e-09, | ||
history_size: int = 10, | ||
line_search_fn: str | None = 'strong_wolfe', | ||
) -> list[torch.Tensor]: | ||
"""LBFGS for non-linear minimization problems. | ||
Parameters | ||
---------- | ||
f | ||
scalar function to be minimized | ||
params | ||
list with parameters to be optimized. | ||
Note that these parameters will not be changed. Instead, we create a copy and | ||
leave the initial values untouched. | ||
lr, optional | ||
learning rate | ||
max_iter, optional | ||
maximal number of iterations, by default 100 | ||
max_eval, optional | ||
maximal number of evaluations of f per optimization step, | ||
by default 100 | ||
tolerance_grad, optional | ||
termination tolerance on first order optimality, | ||
by default 1e-07 | ||
tolerance_change, optional | ||
termination tolerance on function value/parameter changes, by default 1e-09 | ||
history_size, optional | ||
update history size, by default 10 | ||
line_search_fn, optional | ||
line search algorithm, either ‘strong_wolfe’ or None, | ||
by default "strong_wolfe" | ||
Returns | ||
------- | ||
list of optimized parameters | ||
""" | ||
|
||
# TODO: remove after new pytorch release; | ||
if torch.tensor([torch.is_complex(p) for p in params]).any(): | ||
raise ValueError( | ||
"at least one tensor in 'params' is complex-valued; \ | ||
\ncomplex-valued tensors will be allowed for lbfgs in future torch versions" | ||
) | ||
|
||
# define lbfgs routine | ||
optim = LBFGS( | ||
params=[p.detach().clone().requires_grad_(True) for p in params], | ||
lr=lr, | ||
history_size=history_size, | ||
max_iter=max_iter, | ||
max_eval=max_eval, | ||
tolerance_grad=tolerance_grad, | ||
tolerance_change=tolerance_change, | ||
line_search_fn=line_search_fn, | ||
) | ||
|
||
def closure(): | ||
optim.zero_grad() | ||
(objective,) = f(*params) | ||
objective.backward() | ||
return objective | ||
|
||
# run lbfgs | ||
optim.step(closure) | ||
|
||
return params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
"""Operator enforcing constraints by variable transformations.""" | ||
|
||
# Copyright 2024 Physikalisch-Technische Bundesanstalt | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from mrpro.operators import Operator | ||
|
||
|
||
class ConstraintsOp(Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]): | ||
"""Transformation to map real-valued tensors to certain ranges.""" | ||
|
||
def __init__( | ||
self, | ||
bounds: tuple[tuple[float | None, float | None], ...], | ||
beta_sigmoid: float = 1.0, | ||
beta_softplus: float = 1.0, | ||
) -> None: | ||
super().__init__() | ||
|
||
if beta_sigmoid <= 0: | ||
raise ValueError(f'parameter beta_sigmoid must be greater than zero; given {beta_sigmoid}') | ||
if beta_softplus <= 0: | ||
raise ValueError(f'parameter beta_softplus must be greater than zero; given {beta_softplus}') | ||
|
||
self.beta_sigmoid = beta_sigmoid | ||
self.beta_softplus = beta_softplus | ||
|
||
self.lower_bounds = [bound[0] for bound in bounds] | ||
self.upper_bounds = [bound[1] for bound in bounds] | ||
|
||
for lb, ub in bounds: | ||
if lb is not None and ub is not None: | ||
if torch.isnan(torch.tensor(lb)) or torch.isnan(torch.tensor(ub)): | ||
raise ValueError(' "nan" is not a valid lower or upper bound;' f'\nbound tuple {lb, ub} is invalid') | ||
|
||
if lb >= ub: | ||
raise ValueError( | ||
'bounds should be ( (a1,b1), (a2,b2), ...) with ai < bi if neither ai or bi is None;' | ||
f'\nbound tuple {lb, ub} is invalid' | ||
) | ||
|
||
@staticmethod | ||
def sigmoid(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor: | ||
"""Constraint x to be in the range given by 'bounds'.""" | ||
|
||
return F.sigmoid(beta * x) | ||
|
||
@staticmethod | ||
def sigmoid_inverse(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor: | ||
"""Constraint x to be in the range given by 'bounds'.""" | ||
|
||
return torch.logit(x) / beta | ||
|
||
@staticmethod | ||
def softplus(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor: | ||
"""Constrain x to be in (bound,infty).""" | ||
|
||
return -(1 / beta) * torch.nn.functional.logsigmoid(-beta * x) | ||
|
||
@staticmethod | ||
def softplus_inverse(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor: | ||
"""Inverse of 'softplus_transformation.""" | ||
|
||
return beta * x + torch.log(-torch.expm1(-beta * x)) | ||
|
||
def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: | ||
"""Transform tensors to chosen range. | ||
Parameters | ||
---------- | ||
x | ||
tensors to be transformed | ||
Returns | ||
------- | ||
tensors transformed to the range defined by the chosen bounds | ||
""" | ||
# iterate over the tensors and constrain them if necessary according to the | ||
# chosen bounds | ||
xc = [] | ||
for i in range(len(self.lower_bounds)): | ||
lb, ub = self.lower_bounds[i], self.upper_bounds[i] | ||
|
||
# distiguish cases | ||
if (lb is not None and not torch.isneginf(torch.tensor(lb))) and ( | ||
ub is not None and not torch.isposinf(torch.tensor(ub)) | ||
): | ||
# case (a,b) with a<b and a,b \in R | ||
xc.append(lb + (ub - lb) * self.sigmoid(x[i], beta=self.beta_sigmoid)) | ||
|
||
elif lb is not None and (ub is None or torch.isposinf(torch.tensor(ub))): | ||
# case (a,None); corresponds to (a, \infty) | ||
xc.append(lb + self.softplus(x[i], beta=self.beta_softplus)) | ||
|
||
elif (lb is None or torch.isneginf(torch.tensor(lb))) and ub is not None: | ||
# case (None,b); corresponds to (-\infty, b) | ||
xc.append(ub - self.softplus(-x[i], beta=self.beta_softplus)) | ||
elif (lb is None or torch.isneginf(torch.tensor(lb))) and (ub is None or torch.isposinf(torch.tensor(ub))): | ||
# case (None,None); corresponds to (-\infty, \infty), i.e. no transformation | ||
xc.append(x[i]) | ||
|
||
return tuple(xc) | ||
|
||
def inverse(self, *xc: torch.Tensor) -> tuple[torch.Tensor, ...]: | ||
"""Reverses the variable transformation. | ||
Parameters | ||
---------- | ||
xc | ||
transformed tensors with values in the range defined by the bounds | ||
Returns | ||
------- | ||
tensors in the domain with no bounds | ||
""" | ||
# iterate over the tensors and constrain them if necessary according to the | ||
# chosen bounds | ||
x = [] | ||
for i in range(len(self.lower_bounds)): | ||
lb, ub = self.lower_bounds[i], self.upper_bounds[i] | ||
|
||
# distiguish cases | ||
if (lb is not None and not torch.isneginf(torch.tensor(lb))) and ( | ||
ub is not None and not torch.isposinf(torch.tensor(ub)) | ||
): | ||
|
||
# case (a,b) with a<b and a,b \in R | ||
x.append(self.sigmoid_inverse((xc[i] - lb) / (ub - lb), beta=self.beta_sigmoid)) | ||
|
||
elif lb is not None and (ub is None or torch.isposinf(torch.tensor(ub))): | ||
# case (a,None); corresponds to (a, \infty) | ||
x.append(self.softplus_inverse(xc[i] - lb, beta=self.beta_softplus)) | ||
|
||
elif (lb is None or torch.isneginf(torch.tensor(lb))) and ub is not None: | ||
# case (None,b); corresponds to (-\infty, b) | ||
x.append(-self.softplus_inverse(-(xc[i] - ub), beta=self.beta_softplus)) | ||
elif (lb is None or torch.isneginf(torch.tensor(lb))) and (ub is None or torch.isposinf(torch.tensor(ub))): | ||
# case (None,None); corresponds to (-\infty, \infty), i.e. no transformation | ||
x.append(xc[i]) | ||
|
||
return tuple(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from mrpro.operators.functionals._mse_data_discrepancy import mse_data_discrepancy |
Oops, something went wrong.