Skip to content

Commit 15e959b

Browse files
Merge pull request #122 from bagibence/optax_linesearch_fix
Support Optax solvers that include a linesearch
2 parents 5ba7b32 + 7bf75da commit 15e959b

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

optimistix/_solver/optax.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,13 @@ def step(
9797
("loss" in self.verbose, "Loss", f),
9898
("y" in self.verbose, "y", y),
9999
)
100-
updates, new_opt_state = self.optim.update(grads, state.opt_state, y)
100+
101+
# fix args and discard aux
102+
_fn_for_optax = lambda y: fn(y, args)[0]
103+
104+
updates, new_opt_state = self.optim.update(
105+
grads, state.opt_state, y, value=f, grad=grads, value_fn=_fn_for_optax
106+
)
101107
new_y = eqx.apply_updates(y, updates)
102108
terminate = cauchy_termination(
103109
self.rtol,

tests/helpers.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,20 @@ class DFPClassicalTrustRegionHessian(optx.AbstractQuasiNewton):
243243

244244

245245
atol = rtol = 1e-8
246-
minimisers = (
246+
_general_minimisers = (
247247
optx.NelderMead(rtol, atol),
248248
optx.BFGS(rtol, atol, use_inverse=False),
249249
optx.BFGS(rtol, atol, use_inverse=True),
250250
BFGSDampedNewton(rtol, atol),
251251
BFGSIndirectDampedNewton(rtol, atol),
252252
# Tighter tolerance needed to have BFGSDogleg pass the JVP test.
253253
BFGSDogleg(1e-10, 1e-10),
254+
optx.OptaxMinimiser(optax.adam(learning_rate=3e-3), rtol=rtol, atol=atol),
255+
# optax.lbfgs includes their linesearch by default
256+
optx.OptaxMinimiser(optax.lbfgs(), rtol=rtol, atol=atol),
257+
)
258+
259+
_minim_only = (
254260
BFGSClassicalTrustRegionHessian(rtol, atol),
255261
BFGSLinearTrustRegionHessian(rtol, atol),
256262
BFGSLinearTrustRegion(rtol, atol),
@@ -264,18 +270,32 @@ class DFPClassicalTrustRegionHessian(optx.AbstractQuasiNewton):
264270
optx.GradientDescent(1.5e-2, rtol, atol),
265271
# Tighter tolerance needed to have NonlinearCG pass the JVP test.
266272
optx.NonlinearCG(1e-10, 1e-10),
267-
optx.OptaxMinimiser(optax.adam(learning_rate=3e-3), rtol=rtol, atol=atol),
273+
# explicitly including a linesearch
274+
optx.OptaxMinimiser(
275+
optax.chain(
276+
optax.sgd(learning_rate=1.0),
277+
optax.scale_by_zoom_linesearch(15, curv_rtol=jnp.inf),
278+
),
279+
rtol=rtol,
280+
atol=atol,
281+
),
282+
optx.OptaxMinimiser(
283+
optax.chain(
284+
optax.sgd(learning_rate=1.0),
285+
optax.scale_by_backtracking_linesearch(15),
286+
),
287+
rtol=rtol,
288+
atol=atol,
289+
),
268290
)
269291

292+
minimisers = _general_minimisers + _minim_only
293+
270294
# the minimisers can handle least squares problems, but the least squares
271295
# solvers cannot handle general minimisation problems.
272-
least_squares_optimisers = _lsqr_only + minimisers
273-
# Remove ones that work, but are just pretty bad!
274-
least_squares_optimisers = [
275-
x
276-
for x in least_squares_optimisers
277-
if not isinstance(x, (optx.GradientDescent, optx.NonlinearCG))
278-
]
296+
# without the ones that work, but are just pretty bad!
297+
least_squares_optimisers = _lsqr_only + _general_minimisers
298+
279299

280300
#
281301
# MINIMISATION PROBLEMS

0 commit comments

Comments
 (0)