@@ -243,14 +243,20 @@ class DFPClassicalTrustRegionHessian(optx.AbstractQuasiNewton):
243
243
244
244
245
245
atol = rtol = 1e-8
246
- minimisers = (
246
+ _general_minimisers = (
247
247
optx .NelderMead (rtol , atol ),
248
248
optx .BFGS (rtol , atol , use_inverse = False ),
249
249
optx .BFGS (rtol , atol , use_inverse = True ),
250
250
BFGSDampedNewton (rtol , atol ),
251
251
BFGSIndirectDampedNewton (rtol , atol ),
252
252
# Tighter tolerance needed to have BFGSDogleg pass the JVP test.
253
253
BFGSDogleg (1e-10 , 1e-10 ),
254
+ optx .OptaxMinimiser (optax .adam (learning_rate = 3e-3 ), rtol = rtol , atol = atol ),
255
+ # optax.lbfgs includes their linesearch by default
256
+ optx .OptaxMinimiser (optax .lbfgs (), rtol = rtol , atol = atol ),
257
+ )
258
+
259
+ _minim_only = (
254
260
BFGSClassicalTrustRegionHessian (rtol , atol ),
255
261
BFGSLinearTrustRegionHessian (rtol , atol ),
256
262
BFGSLinearTrustRegion (rtol , atol ),
@@ -264,18 +270,32 @@ class DFPClassicalTrustRegionHessian(optx.AbstractQuasiNewton):
264
270
optx .GradientDescent (1.5e-2 , rtol , atol ),
265
271
# Tighter tolerance needed to have NonlinearCG pass the JVP test.
266
272
optx .NonlinearCG (1e-10 , 1e-10 ),
267
- optx .OptaxMinimiser (optax .adam (learning_rate = 3e-3 ), rtol = rtol , atol = atol ),
273
+ # explicitly including a linesearch
274
+ optx .OptaxMinimiser (
275
+ optax .chain (
276
+ optax .sgd (learning_rate = 1.0 ),
277
+ optax .scale_by_zoom_linesearch (15 , curv_rtol = jnp .inf ),
278
+ ),
279
+ rtol = rtol ,
280
+ atol = atol ,
281
+ ),
282
+ optx .OptaxMinimiser (
283
+ optax .chain (
284
+ optax .sgd (learning_rate = 1.0 ),
285
+ optax .scale_by_backtracking_linesearch (15 ),
286
+ ),
287
+ rtol = rtol ,
288
+ atol = atol ,
289
+ ),
268
290
)
269
291
292
+ minimisers = _general_minimisers + _minim_only
293
+
270
294
# the minimisers can handle least squares problems, but the least squares
271
295
# solvers cannot handle general minimisation problems.
272
- least_squares_optimisers = _lsqr_only + minimisers
273
- # Remove ones that work, but are just pretty bad!
274
- least_squares_optimisers = [
275
- x
276
- for x in least_squares_optimisers
277
- if not isinstance (x , (optx .GradientDescent , optx .NonlinearCG ))
278
- ]
296
+ # without the ones that work, but are just pretty bad!
297
+ least_squares_optimisers = _lsqr_only + _general_minimisers
298
+
279
299
280
300
#
281
301
# MINIMISATION PROBLEMS
0 commit comments