Skip to content

Add support for optimizer checkpointing #579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
pip install mypy
- name: Run mypy
run: |
mypy --follow-imports=skip --ignore-missing-imports --exclude "(numpy|test)" scico/
mypy --follow-imports=skip --ignore-missing-imports --exclude "(numpy|test)" scico/ scico/numpy/util.py
4 changes: 3 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ Version 0.0.7 (unreleased)
• New module ``scico.trace`` for tracing function/method calls.
• New generic functional ``functional.ComposedFunctional`` representing
a functional composed with an orthogonal linear operator.
• Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.5.0.
• New optimizer methods ``save_state`` and ``load_state`` supporting
algorithm state checkpointing.
• Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.5.3.
• Support ``flax`` versions 0.8.0 to 0.10.2.


Expand Down
11 changes: 10 additions & 1 deletion docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,15 @@ @Article {chambolle-2010-firstorder
pages = {120--145}
}

@Misc {chandler-2024-closedform,
author = {Edward P. Chandler and Shirin Shoushtari and Brendt
Wohlberg and Ulugbek S. Kamilov},
title = {Closed-Form Approximation of the Total Variation
Proximal Operator},
year = 2024,
eprint = {2412.07718}
}

@Article {clinthorne-1993-preconditioning,
author = {Clinthorne, Neal H. and Pan, Tin-Su and Chiao,
Ping-Chun and Rogers, W. Leslie and Stamos, John A.},
Expand Down Expand Up @@ -764,6 +773,7 @@ @InProceedings {yu-2013-better
year = 2013
}


@Article {zhang-2017-dncnn,
author = {Kai Zhang and Wangmeng Zuo and Yunjin Chen and Deyu
Meng and Lei Zhang},
Expand Down Expand Up @@ -793,7 +803,6 @@ @Article {zhang-2021-plug
pages = {6360--6376}
}


@Article {zhou-2006-adaptive,
author = {Bin Zhou and Li Gao and Yu-Hong Dai},
title = {Gradient Methods with Adaptive Step-Sizes},
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ scipy>=1.11.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.13,<=0.5.0
jax>=0.4.13,<=0.5.0
jaxlib>=0.4.13,<=0.5.3
jax>=0.4.13,<=0.5.3
orbax-checkpoint>=0.5.0
flax>=0.8.0,<=0.10.2
pyabel>=0.9.0
8 changes: 4 additions & 4 deletions scico/diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -10,7 +10,7 @@
import re
import warnings
from collections import OrderedDict, namedtuple
from typing import List, Optional, Tuple, Union
from typing import List, NamedTuple, Optional, Tuple, Union


class IterationStats:
Expand Down Expand Up @@ -190,13 +190,13 @@ def end(self):
):
print()

def history(self, transpose: bool = False):
def history(self, transpose: bool = False) -> Union[List[NamedTuple], Tuple[List]]:
"""Retrieve record of all inserted iterations.

Args:
transpose: Flag indicating whether results should be returned
in "transposed" form, i.e. as a namedtuple of lists
rather than a list of namedtuples. Default: False.
rather than a list of namedtuples.

Returns:
list of namedtuple or namedtuple of lists: Record of all
Expand Down
8 changes: 4 additions & 4 deletions scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -48,7 +48,7 @@ def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore
Args:
x: Input image.
lam: Noise parameter.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Expand Down Expand Up @@ -85,7 +85,7 @@ def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore
Args:
x: Input image.
lam: Noise parameter.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Expand Down Expand Up @@ -134,7 +134,7 @@ def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore
Args:
x: Input array.
lam: Noise parameter (ignored).
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Expand Down
6 changes: 3 additions & 3 deletions scico/functional/_dist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -73,7 +73,7 @@ def prox(
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Expand Down Expand Up @@ -144,7 +144,7 @@ def prox(
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Expand Down
35 changes: 30 additions & 5 deletions scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __call__(self, x: Union[Array, BlockArray]) -> float:
Args:
x: Point at which to evaluate this functional.

Returns:
Result of evaluating the functional at `x`.
"""
# Functionals that can be evaluated should override this method.
raise NotImplementedError(f"Functional {type(self)} cannot be evaluated.")
Expand All @@ -85,9 +87,12 @@ def prox(
Args:
v: Point at which to evaluate prox function.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes. These include `x0`, an initial guess for the
minimizer in the definition of :math:`\prox`.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
# Functionals that have a prox should override this method.
raise NotImplementedError(f"Functional {type(self)} does not have a prox.")
Expand All @@ -112,8 +117,11 @@ def conj_prox(
Args:
v: Point at which to evaluate prox function.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional keyword args, passed directly to
**kwargs: Additional keyword args, passed directly to
`self.prox`.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return v - lam * self.prox(v / lam, 1.0 / lam, **kwargs)

Expand All @@ -122,6 +130,9 @@ def grad(self, x: Union[Array, BlockArray]):

Args:
x: Point at which to evaluate gradient.

Returns:
The gradient at `x`.
"""
return self._grad(x)

Expand Down Expand Up @@ -169,6 +180,16 @@ def prox(
\prox_{\alpha (\beta f)}(\mb{v}) =
\prox_{(\alpha \beta) f}(\mb{v}) \;.


Args:
v: Point at which to evaluate prox function.
lam: Proximal parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes. These include `x0`, an initial guess for the
minimizer in the definition of :math:`\prox`.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return self.functional.prox(v, lam * self.scale, **kwargs)

Expand Down Expand Up @@ -234,9 +255,11 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray:
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
if len(v.shape) == len(self.functional_list):
return snp.blockarray(
Expand All @@ -260,7 +283,7 @@ class ComposedFunctional(Functional):
where :math:`f` is the composed functional, :math:`g` is the
functional from which it is composed, and :math:`A` is an orthogonal
linear operator. Note that the resulting :class:`Functional` can only
be applied (either via evaluation or :method:`prox` calls) to inputs
be applied (either via evaluation or :meth:`prox` calls) to inputs
of shape and dtype corresponding to the input specification of the
linear operator.
"""
Expand Down Expand Up @@ -312,9 +335,11 @@ def prox(
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return self.linop.T(self.functional.prox(self.linop(v), lam=lam, **kwargs))

Expand Down
16 changes: 14 additions & 2 deletions scico/functional/_indicator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -59,8 +59,11 @@ def prox(
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda` (has no effect).
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return snp.maximum(v, 0)

Expand Down Expand Up @@ -108,5 +111,14 @@ def prox(

.. math::
\mathrm{prox}_{\lambda I}(\mb{v}) = r \frac{\mb{v}}{\norm{\mb{v}}_2}\;.

Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda` (has no effect).
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return self.radius * v / norm(v)
43 changes: 33 additions & 10 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -50,8 +50,11 @@ def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array
Args:
v: Input array :math:`\mb{v}`.
lam: Thresholding parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return snp.where(snp.abs(v) >= lam, v, 0)

Expand Down Expand Up @@ -92,8 +95,11 @@ def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Array:
Args:
v: Input array :math:`\mb{v}`.
lam: Thresholding parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
tmp = snp.abs(v) - lam
tmp = 0.5 * (tmp + snp.abs(tmp))
Expand Down Expand Up @@ -135,8 +141,11 @@ def prox(
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return v / (1.0 + 2.0 * lam)

Expand Down Expand Up @@ -176,8 +185,11 @@ def prox(
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
norm_v = norm(v)
return snp.where(norm_v == 0, 0 * v, snp.maximum(1 - lam / norm_v, 0) * v)
Expand Down Expand Up @@ -249,8 +261,11 @@ def prox(
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
if isinstance(v, snp.BlockArray) and self.l2_axis is not None:
raise ValueError("Initializer parameter l2_axis must be None for BlockArray input.")
Expand Down Expand Up @@ -332,8 +347,11 @@ def prox(
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
alpha = lam
beta = self.beta
Expand Down Expand Up @@ -454,12 +472,14 @@ def prox(

in the separable case.


Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return self._prox(v, lam=lam, **kwargs)

Expand Down Expand Up @@ -494,8 +514,11 @@ def prox(
Args:
v: Input array :math:`\mb{v}`. Required to be two-dimensional.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
**kwargs: Additional arguments that may be used by derived
classes.

Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
if v.ndim != 2:
raise ValueError("Input array must be two dimensional.")
Expand Down
Loading
Loading