Skip to content

Commit 032782c

Browse files
bagibenceJohanna Haffner
authored andcommitted
Add implementation of the Zoom linesearch.
1 parent cbdaa4d commit 032782c

File tree

11 files changed

+1133
-18
lines changed

11 files changed

+1133
-18
lines changed

optimistix/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
OptaxMinimiser as OptaxMinimiser,
7777
polak_ribiere as polak_ribiere,
7878
SteepestDescent as SteepestDescent,
79+
Zoom as Zoom,
7980
)
8081

8182

optimistix/_search.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ class AbstractSearch(eqx.Module, Generic[Y, _FnInfo, _FnEvalInfo, SearchState]):
320320
See [this documentation](./introduction.md) for more information.
321321
"""
322322

323+
_needs_grad_at_y_eval: ClassVar[bool]
324+
323325
@abc.abstractmethod
324326
def init(self, y: Y, f_info_struct: _FnInfo) -> SearchState:
325327
"""Is called just once, at the very start of the entire optimisation problem.
@@ -363,7 +365,7 @@ def step(
363365
- `f_info`: An [`optimistix.FunctionInfo`][] describing information about `f`
364366
evaluated at `y`, the gradient of `f` at `y`, etc.
365367
- `f_eval_info`: An [`optimistix.FunctionInfo`][] describing information about
366-
`f` evaluated at `y`, the gradient of `f` at `y`, etc.
368+
`f` evaluated at `y_eval`, the gradient of `f` at `y_eval`, etc.
367369
- `state`: the evolving state of the repeated searches.
368370
369371
**Returns:**

optimistix/_solver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@
4848
ClassicalTrustRegion as ClassicalTrustRegion,
4949
LinearTrustRegion as LinearTrustRegion,
5050
)
51+
from .zoom import Zoom as Zoom

optimistix/_solver/backtracking.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import cast, TypeAlias
1+
from typing import cast, ClassVar, TypeAlias
22

33
import equinox as eqx
44
import jax.numpy as jnp
@@ -29,6 +29,7 @@ class BacktrackingArmijo(AbstractSearch[Y, _FnInfo, _FnEvalInfo, _BacktrackingSt
2929
decrease_factor: ScalarLike = 0.5
3030
slope: ScalarLike = 0.1
3131
step_init: ScalarLike = 1.0
32+
_needs_grad_at_y_eval: ClassVar[bool] = False
3233

3334
def __post_init__(self):
3435
self.decrease_factor = eqx.error_if(

optimistix/_solver/gradient_methods.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable
2-
from typing import Any, Generic, TypeAlias
2+
from typing import Any, cast, Generic, TypeAlias
33

44
import equinox as eqx
55
import jax
@@ -131,7 +131,9 @@ class AbstractGradientDescent(AbstractMinimiser[Y, Aux, _GradientDescentState]):
131131
norm: AbstractVar[Callable[[PyTree], Scalar]]
132132
descent: AbstractVar[AbstractDescent[Y, FunctionInfo.EvalGrad, Any]]
133133
search: AbstractVar[
134-
AbstractSearch[Y, FunctionInfo.EvalGrad, FunctionInfo.Eval, Any]
134+
AbstractSearch[
135+
Y, FunctionInfo.EvalGrad, FunctionInfo.Eval | FunctionInfo.EvalGrad, Any
136+
]
135137
]
136138

137139
def init(
@@ -170,19 +172,31 @@ def step(
170172
f_eval, lin_fn, aux_eval = jax.linearize(
171173
lambda _y: fn(_y, args), state.y_eval, has_aux=True
172174
)
175+
176+
if self.search._needs_grad_at_y_eval:
177+
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode)
178+
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
179+
else:
180+
f_eval_info = FunctionInfo.Eval(f_eval)
181+
173182
step_size, accept, search_result, search_state = self.search.step(
174183
state.first_step,
175184
y,
176185
state.y_eval,
177186
state.f_info,
178-
FunctionInfo.Eval(f_eval),
187+
f_eval_info,
179188
state.search_state,
180189
)
181190

182191
def accepted(descent_state):
183-
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)
192+
nonlocal f_eval_info
193+
194+
if not self.search._needs_grad_at_y_eval:
195+
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)
196+
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
197+
198+
f_eval_info = cast(FunctionInfo.EvalGrad, f_eval_info)
184199

185-
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
186200
descent_state = self.descent.query(state.y_eval, f_eval_info, descent_state)
187201
y_diff = (state.y_eval**ω - y**ω).ω
188202
f_diff = (f_eval**ω - state.f_info.f**ω).ω

optimistix/_solver/learning_rate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import cast
1+
from typing import cast, ClassVar
22

33
import equinox as eqx
44
import jax.numpy as jnp
@@ -16,6 +16,7 @@ def _typed_asarray(x: ScalarLike) -> Array:
1616
class LearningRate(AbstractSearch[Y, FunctionInfo, FunctionInfo, None]):
1717
"""Move downhill by taking a step of the fixed size `learning_rate`."""
1818

19+
_needs_grad_at_y_eval: ClassVar[bool] = False
1920
learning_rate: ScalarLike = eqx.field(converter=_typed_asarray)
2021

2122
def init(self, y: Y, f_info_struct: FunctionInfo) -> None:

optimistix/_solver/quasi_newton.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
from collections.abc import Callable
3-
from typing import Any, Generic, TypeVar
3+
from typing import Any, cast, Generic, TypeVar
44

55
import equinox as eqx
66
import jax
@@ -30,6 +30,7 @@
3030
from .._solution import RESULTS
3131
from .backtracking import BacktrackingArmijo
3232
from .gauss_newton import NewtonDescent
33+
from .zoom import Zoom
3334

3435

3536
_Hessian = TypeVar(
@@ -118,10 +119,6 @@ class AbstractQuasiNewton(
118119
structure and the Hessian update state, while the latter is called to compute an
119120
update to the approximation of the Hessian or the inverse Hessian.
120121
121-
Already supported schemes to form inverse Hessian and Hessian approximations are
122-
implemented in `optimistix.AbstractBFGS`, `optimistix.AbstractDFP` and
123-
`optimistix.AbstractLBFGS`.
124-
125122
Supports the following `options`:
126123
127124
- `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to
@@ -210,29 +207,40 @@ def step(
210207
f_eval, lin_fn, aux_eval = jax.linearize(
211208
lambda _y: fn(_y, args), state.y_eval, has_aux=True
212209
)
210+
211+
if self.search._needs_grad_at_y_eval:
212+
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode)
213+
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
214+
else:
215+
f_eval_info = FunctionInfo.Eval(f_eval)
216+
213217
step_size, accept, search_result, search_state = self.search.step(
214218
state.first_step,
215219
y,
216220
state.y_eval,
217221
state.f_info,
218-
FunctionInfo.Eval(f_eval),
222+
f_eval_info, # pyright: ignore # TODO Fix (jhaffner)
219223
state.search_state,
220224
)
221225

222226
def accepted(descent_state):
223-
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)
227+
nonlocal f_eval_info
228+
229+
if not self.search._needs_grad_at_y_eval:
230+
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)
231+
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
224232

225233
f_eval_info, hessian_update_state = self.update_hessian(
226234
y,
227235
state.y_eval,
228236
state.f_info,
229-
FunctionInfo.EvalGrad(f_eval, grad),
237+
cast(FunctionInfo.EvalGrad, f_eval_info),
230238
state.hessian_update_state,
231239
)
232240

233241
descent_state = self.descent.query(
234242
state.y_eval,
235-
f_eval_info, # pyright: ignore
243+
f_eval_info,
236244
descent_state,
237245
)
238246
y_diff = (state.y_eval**ω - y**ω).ω
@@ -447,6 +455,7 @@ def __init__(
447455
norm: Callable[[PyTree], Scalar] = max_norm,
448456
use_inverse: bool = True,
449457
verbose: frozenset[str] = frozenset(),
458+
search: AbstractSearch = Zoom(initial_guess_strategy="one"),
450459
):
451460
self.rtol = rtol
452461
self.atol = atol
@@ -606,6 +615,7 @@ def __init__(
606615
norm: Callable[[PyTree], Scalar] = max_norm,
607616
use_inverse: bool = True,
608617
verbose: frozenset[str] = frozenset(),
618+
search: AbstractSearch = Zoom(initial_guess_strategy="one"),
609619
):
610620
self.rtol = rtol
611621
self.atol = atol

optimistix/_solver/trust_region.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import TypeAlias, TypeVar
2+
from typing import ClassVar, TypeAlias, TypeVar
33

44
import equinox as eqx
55
import jax.numpy as jnp
@@ -46,6 +46,7 @@ class _AbstractTrustRegion(AbstractSearch[Y, _FnInfo, _FnEvalInfo, _TrustRegionS
4646
low_cutoff: AbstractVar[ScalarLike]
4747
high_constant: AbstractVar[ScalarLike]
4848
low_constant: AbstractVar[ScalarLike]
49+
_needs_grad_at_y_eval: ClassVar[bool] = False
4950

5051
def __post_init__(self):
5152
# You would not expect `self.low_cutoff` or `self.high_cutoff` to

0 commit comments

Comments
 (0)