Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
158 commits
Select commit Hold shift + click to select a range
a8b3ffb
add a linear operator computing the descent update
BalzaniEdoardo May 3, 2025
956329e
removed unused cls
BalzaniEdoardo May 3, 2025
9380261
added update hessian state dict
BalzaniEdoardo May 4, 2025
36a9d3c
partial applied to fn
BalzaniEdoardo May 5, 2025
30669b4
allow pytree at set
BalzaniEdoardo May 6, 2025
c0be4bc
added fixes
BalzaniEdoardo May 6, 2025
45a8f69
added import lbfgs
BalzaniEdoardo May 6, 2025
925a73c
linted
BalzaniEdoardo May 12, 2025
f5e8bc5
improved varnames
BalzaniEdoardo May 12, 2025
f2aa7a6
linted
BalzaniEdoardo May 12, 2025
3ccac0f
fixed docstrings
BalzaniEdoardo May 12, 2025
a5d16ef
revert unnecessary change
BalzaniEdoardo May 12, 2025
d29374e
do not use threshold
BalzaniEdoardo May 12, 2025
5822ad5
improved comment
BalzaniEdoardo May 12, 2025
3dbd061
removed todo
BalzaniEdoardo May 12, 2025
689c876
use hist len attr
BalzaniEdoardo May 13, 2025
b5ba14d
renamed variable
BalzaniEdoardo May 14, 2025
d81fe10
remove unused linear solver
BalzaniEdoardo May 14, 2025
090df71
test lbfgs linear op
BalzaniEdoardo May 14, 2025
85fe537
remove jit
BalzaniEdoardo May 14, 2025
c0a09b9
added LBFGS to tests, added a test for the operator
BalzaniEdoardo May 14, 2025
b83db06
test pre-commit
BalzaniEdoardo May 14, 2025
080afd2
test pre-commit 2
BalzaniEdoardo May 14, 2025
60c9dee
lined
BalzaniEdoardo May 14, 2025
3a8524b
started linting
BalzaniEdoardo May 14, 2025
5956e33
linted test
BalzaniEdoardo May 14, 2025
d4e3624
changes for fixing pyright
BalzaniEdoardo May 15, 2025
0660afc
add assertion to pass pyright
BalzaniEdoardo May 15, 2025
a1d5f17
fix typing
BalzaniEdoardo May 15, 2025
af7be91
forced commit adding back typevars
BalzaniEdoardo May 15, 2025
5c375e8
fix output
BalzaniEdoardo May 15, 2025
1c8f620
remove search from init args
BalzaniEdoardo May 15, 2025
5543c15
hist len second to last
BalzaniEdoardo May 15, 2025
13b2052
reflect arg order in docstrings
BalzaniEdoardo May 15, 2025
ab3a5d9
fix convergence
BalzaniEdoardo May 15, 2025
da41125
linted
BalzaniEdoardo May 15, 2025
43fdd7f
fix var names
BalzaniEdoardo May 15, 2025
677b84c
move state update into update
BalzaniEdoardo May 15, 2025
db3d0a9
add hessian update compute
BalzaniEdoardo May 19, 2025
ed555bf
test identity
BalzaniEdoardo May 19, 2025
2e2d56c
test identity
BalzaniEdoardo May 19, 2025
b05fa4c
fix indexing
BalzaniEdoardo May 28, 2025
96d6c89
added compact repr
BalzaniEdoardo May 31, 2025
be43422
fix linalg
BalzaniEdoardo Jun 2, 2025
3ffedaf
removed debug code
BalzaniEdoardo Jun 2, 2025
6a3add4
fix conversion in strict mode
BalzaniEdoardo Jun 2, 2025
013b029
fix tests
BalzaniEdoardo Jun 2, 2025
424bd51
fix roll_and_set comments
BalzaniEdoardo Jun 2, 2025
33d35b0
lint
BalzaniEdoardo Jun 2, 2025
9eaae6e
removed unused func
BalzaniEdoardo Jun 2, 2025
9082f5e
fixed typing
BalzaniEdoardo Jun 2, 2025
f2c078e
removed verbose options
BalzaniEdoardo Jun 2, 2025
442bbbd
removed verbose options
BalzaniEdoardo Jun 2, 2025
4e967f5
improved docstrings
BalzaniEdoardo Jun 2, 2025
de6c15b
add test for warm-up state
BalzaniEdoardo Jun 2, 2025
47bd5f6
warm-up state test fix
BalzaniEdoardo Jun 2, 2025
84ad3f2
add multiple extra hist
BalzaniEdoardo Jun 2, 2025
b1f801e
add an assertion
BalzaniEdoardo Jun 2, 2025
3a64a5b
renamed parameter back to inner_history
BalzaniEdoardo Jun 3, 2025
7b74a86
add a linear operator computing the descent update
BalzaniEdoardo May 3, 2025
8eb96c0
removed unused cls
BalzaniEdoardo May 3, 2025
2788d2d
added update hessian state dict
BalzaniEdoardo May 4, 2025
ab816f1
partial applied to fn
BalzaniEdoardo May 5, 2025
1f4e217
allow pytree at set
BalzaniEdoardo May 6, 2025
1307644
added fixes
BalzaniEdoardo May 6, 2025
2937fd6
added import lbfgs
BalzaniEdoardo May 6, 2025
a6b5bc1
linted
BalzaniEdoardo May 12, 2025
4896842
improved varnames
BalzaniEdoardo May 12, 2025
644b9ed
linted
BalzaniEdoardo May 12, 2025
b6a9a02
fixed docstrings
BalzaniEdoardo May 12, 2025
0a50627
revert unnecessary change
BalzaniEdoardo May 12, 2025
0239c68
do not use threshold
BalzaniEdoardo May 12, 2025
6caf975
improved comment
BalzaniEdoardo May 12, 2025
32c5d8b
removed todo
BalzaniEdoardo May 12, 2025
92648ac
use hist len attr
BalzaniEdoardo May 13, 2025
16e72ae
renamed variable
BalzaniEdoardo May 14, 2025
263effa
remove unused linear solver
BalzaniEdoardo May 14, 2025
8ae60c7
test lbfgs linear op
BalzaniEdoardo May 14, 2025
cdc9748
remove jit
BalzaniEdoardo May 14, 2025
751a2ec
added LBFGS to tests, added a test for the operator
BalzaniEdoardo May 14, 2025
f5d9ebe
test pre-commit
BalzaniEdoardo May 14, 2025
26154fe
test pre-commit 2
BalzaniEdoardo May 14, 2025
f44e79d
lined
BalzaniEdoardo May 14, 2025
2db0b5a
started linting
BalzaniEdoardo May 14, 2025
8ccf14e
linted test
BalzaniEdoardo May 14, 2025
ba84546
changes for fixing pyright
BalzaniEdoardo May 15, 2025
6a753a8
add assertion to pass pyright
BalzaniEdoardo May 15, 2025
0b1f018
fix typing
BalzaniEdoardo May 15, 2025
25b6e4d
forced commit adding back typevars
BalzaniEdoardo May 15, 2025
a8e33ad
fix output
BalzaniEdoardo May 15, 2025
353c85b
remove search from init args
BalzaniEdoardo May 15, 2025
405df6c
hist len second to last
BalzaniEdoardo May 15, 2025
87c2118
reflect arg order in docstrings
BalzaniEdoardo May 15, 2025
fe8922d
fix convergence
BalzaniEdoardo May 15, 2025
a80823f
linted
BalzaniEdoardo May 15, 2025
e746f3c
fix var names
BalzaniEdoardo May 15, 2025
46e14b5
move state update into update
BalzaniEdoardo May 15, 2025
921f5a3
add hessian update compute
BalzaniEdoardo May 19, 2025
0721680
test identity
BalzaniEdoardo May 19, 2025
3d7fe9d
test identity
BalzaniEdoardo May 19, 2025
cf551b4
fix indexing
BalzaniEdoardo May 28, 2025
3a413c1
added compact repr
BalzaniEdoardo May 31, 2025
89b6f73
fix linalg
BalzaniEdoardo Jun 2, 2025
0d6e298
removed debug code
BalzaniEdoardo Jun 2, 2025
7a0342d
fix conversion in strict mode
BalzaniEdoardo Jun 2, 2025
7cc524c
fix tests
BalzaniEdoardo Jun 2, 2025
1dac261
fix roll_and_set comments
BalzaniEdoardo Jun 2, 2025
9764d54
lint
BalzaniEdoardo Jun 2, 2025
a845bb4
removed unused func
BalzaniEdoardo Jun 2, 2025
8a712f4
fixed typing
BalzaniEdoardo Jun 2, 2025
99d4a7e
removed verbose options
BalzaniEdoardo Jun 2, 2025
b5e77f1
removed verbose options
BalzaniEdoardo Jun 2, 2025
552b74d
improved docstrings
BalzaniEdoardo Jun 2, 2025
8eaf147
add test for warm-up state
BalzaniEdoardo Jun 2, 2025
189b7aa
warm-up state test fix
BalzaniEdoardo Jun 2, 2025
1ba895a
add multiple extra hist
BalzaniEdoardo Jun 2, 2025
5cde2c7
add an assertion
BalzaniEdoardo Jun 2, 2025
c6e2aca
renamed parameter back to inner_history
BalzaniEdoardo Jun 3, 2025
ea10532
re-organising quasi-newton, pyright fixes and some simplifications
Jun 14, 2025
fab3658
ruff + pyupgrade after rebase on latest dev
Jun 14, 2025
adcdbcf
ruff + pyupgrade after rebase on latest dev
Jun 14, 2025
0058b73
re-organising quasi-newton, pyright fixes and some simplifications
Jun 14, 2025
c769268
rebase on dev, pyupgrade + fixes
Jun 14, 2025
093ce74
marking a To-Do for Johanna
Jun 14, 2025
8617111
typo fix: spurious
Jun 14, 2025
c5b8f66
small fixes
Jun 15, 2025
43320a0
split module and create separate file for L-BFGS, mark TODOs for PR c…
Jun 15, 2025
ef6f8b0
take operator test out of rotation for now (relies on private imports)
Jun 15, 2025
647ce42
tiny doc tweaks
Jun 15, 2025
411206c
simplify update
BalzaniEdoardo Jun 16, 2025
ad8d2f5
linted
BalzaniEdoardo Jun 16, 2025
be6853d
merged latest commits
BalzaniEdoardo Jun 17, 2025
36ae69f
Merge pull request #1 from BalzaniEdoardo/simplify_update
BalzaniEdoardo Jun 17, 2025
9d3e953
remove unused import
Jun 18, 2025
367ce4f
formatting fix
Jun 18, 2025
4c29f4a
cleaning the To-Dos
Jun 18, 2025
eb17414
double where trick addressed
BalzaniEdoardo Jun 19, 2025
7a65e2e
use concatenate
BalzaniEdoardo Jun 19, 2025
7da438d
document pyright: ignore in Quasi Newton
Jun 22, 2025
aabde0f
address pyright issues
Jun 22, 2025
836e48e
fix names and some doc tweaks
Jun 22, 2025
2e6cac0
addressed remaining issues: naming, better type hints.
Jun 22, 2025
2a8df02
small fixes
Jul 8, 2025
c6c46be
teeny tiny fix
Jul 8, 2025
6e76c03
move a comment
Jul 8, 2025
e466ced
implement .conj method (abstract in Lineax)
Jul 8, 2025
6d334b7
more small fixes
Jul 8, 2025
523cfed
add safe divide for inner and comment
BalzaniEdoardo Jul 9, 2025
422f839
linted
BalzaniEdoardo Jul 9, 2025
329bac5
Merge branch 'hess_approx' of github.com:BalzaniEdoardo/optimistix in…
BalzaniEdoardo Jul 9, 2025
77dba7a
add as array
BalzaniEdoardo Jul 9, 2025
6cf0725
add asarray
BalzaniEdoardo Jul 9, 2025
d724a4d
write the history length into the state
Jul 12, 2025
02a69b7
fix history length in L-BFGS special tst
Jul 12, 2025
a85bca4
reorder and improve documentation of inverse operator function + some…
Jul 13, 2025
e7a55b3
add comments explicating our strategy for Cholesky solves of partly f…
Jul 13, 2025
89bc63d
document hessian update state in quasi-Newton
Jul 13, 2025
bd09b85
document choice for computation of gamma
Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimistix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
hestenes_stiefel as hestenes_stiefel,
IndirectDampedNewtonDescent as IndirectDampedNewtonDescent,
IndirectLevenbergMarquardt as IndirectLevenbergMarquardt,
LBFGS as LBFGS,
LearningRate as LearningRate,
LevenbergMarquardt as LevenbergMarquardt,
LinearTrustRegion as LinearTrustRegion,
Expand Down
1 change: 0 additions & 1 deletion optimistix/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def compute_grad_dot(self, y: Y):
FunctionInfo.Residual = Residual
FunctionInfo.ResidualJac = ResidualJac


Eval.__init__.__doc__ = """**Arguments:**

- `f`: the scalar output of a function evaluation `fn(y)`.
Expand Down
1 change: 1 addition & 0 deletions optimistix/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
BFGSUpdate as BFGSUpdate,
DFP as DFP,
DFPUpdate as DFPUpdate,
LBFGS as LBFGS,
)
from .trust_region import (
ClassicalTrustRegion as ClassicalTrustRegion,
Expand Down
251 changes: 240 additions & 11 deletions optimistix/_solver/quasi_newton.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from collections.abc import Callable
from typing import Any, Generic, TypeVar, Union
from functools import partial
from typing import Any, Generic, TypeVar, Union, Optional

import equinox as eqx
import jax
Expand Down Expand Up @@ -66,6 +67,97 @@ def _identity_pytree(pytree: PyTree[Array]) -> lx.PyTreeLinearOperator:
)


@jax.jit
def _lim_mem_hess_inv_operator_fn(
pytree: PyTree[Array],
residual_par: PyTree[Array],
residual_grad: PyTree[Array],
rho: Array,
index_start: Array
):
"""
LBFGS descent linear operator.

"""
history_len = rho.shape[0]
circ_index = (jnp.arange(history_len) + index_start) % history_len

# First loop: iterate backwards and compute alpha coefficients
def backward_iter(q, indx):
dy, dg, r = jtu.tree_map(lambda x: x[indx], (residual_par, residual_grad, rho))
alpha = r * tree_dot(dy, q)
q = (q ** ω - alpha * dg ** ω).ω
return q, alpha

# Second loop: iterate forwards and apply correction using stored alpha
def forward_iter(args, indx):
q, alpha = args
ai = alpha[indx]
r = rho[circ_index[indx]]
dy, dg = jtu.tree_map(
lambda x: x[circ_index[indx]],
(residual_par, residual_grad)
)
bi = r * tree_dot(dg, q)
q = (q ** ω + (dy ** ω * (ai - bi))).ω
return (q, alpha), None


q, alpha = jax.lax.scan(backward_iter, pytree, circ_index, reverse=True)
dym, dgm = jtu.tree_map(
lambda x: x[index_start % history_len],
(residual_par, residual_grad)
)
dyg = tree_dot(dym, dgm)
dgg = tree_dot(dgm, dgm)
gamma_k = jnp.where(dgg > 1e-10, dyg / dgg, 1.0)
q = (gamma_k * q ** ω).ω
(q, _), _ = jax.lax.scan(forward_iter, (q, alpha), jnp.arange(history_len), reverse=False)
return q


def _lim_mem_hess_inv_operator(
residual_par: PyTree[Array],
residual_grad: PyTree[Array],
rho: Array,
index_start: Array,
input_shape: Optional[PyTree[jax.ShapeDtypeStruct]] = None,
):
"""Define a `lineax` linear operator implementing the L-BFGS inverse Hessian approximation.

This operator computes the action of the approximate inverse Hessian on a vector `pytree`
using the limited-memory BFGS (L-BFGS) two-loop recursion. It does not materialize the matrix
explicitly but returns a `lineax.FunctionLinearOperator`.

- `residual_par`: History of parameter updates `s_k = x_{k+1} - x_k`
- `residual_grad`: History of gradient updates `y_k = g_{k+1} - g_k`
- `rho`: Reciprocal dot products `rho_k = 1 / ⟨s_k, y_k⟩`
- `index_start`: Index of the most recent update in the circular buffer

Returns a `lineax.FunctionLinearOperator` with input and output shape matching a single element
of `residual_par`.

"""
operator_func = partial(
_lim_mem_hess_inv_operator_fn,
residual_par=residual_par,
residual_grad=residual_grad,
rho=rho,
index_start=index_start
)
input_shape = (
jax.eval_shape(lambda: jtu.tree_map(lambda x: x[0], residual_par))
if input_shape is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.eval_shape(lambda: y) should always be the input shape, so I think we can use that directly here and avoid making it an optional argument. I can't think of a case where we would want to use any other shape, since the Hessian is always with respect to y and always symmetric.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is no case in which that's not true! The idea was to provided it directly as a static arg but once the function is jit-compiled that won't matter I guess

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, this is inside the jitted region and the shape of y is static.

else input_shape
)
op = lx.FunctionLinearOperator(
operator_func,
input_shape,
tags=lx.positive_semidefinite_tag,
)
return lx.materialise(op)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the shape complaint you got when just returning the function linear operator is due to the jaxpr in it, we solve this in a few places by just returning the dynamic portion of the equinox module, e.g. here. (And here for some background.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally I've been meaning to track down what's going on here - this indicates that we are retracing jaxprs, which is not ideal from a compile-time perspective. In a more perfect world we can just directly use the one we already have!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you think that originates?

Copy link
Author

@BalzaniEdoardo BalzaniEdoardo May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using eqx.filter makes sense here. I believe the issue comes from the fact that the JAXPRs creates a unique id for the traced variables:

(Pdb) elem.invars
[Var(id=4544898944):float64[1]]
(Pdb) elem_.invars
[Var(id=4775038400):float64[1]]

here elem are the JAXPRs vars in the scope of eqx.tree_equal

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it concerns the jaxpr, I followed the suggestion of filtering the tree.
I am implementing the filtering during the __call__ method of the LBFGSUpdate:

  • Get the static part of the tree before calling update/no_update
  • In update/no_update get the operator and filter for the dynamic sub-tree.
  • After the update call combine the static with the dynamic .

Does this make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filtering as you've done here makes perfect sense!



def _outer(tree1, tree2):
def leaf_fn(x):
return jtu.tree_map(lambda leaf: jnp.tensordot(x, leaf, axes=0), tree2)
Expand All @@ -74,7 +166,7 @@ def leaf_fn(x):


_Hessian = TypeVar(
"_Hessian", FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv
"_Hessian", FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv,
)


Expand All @@ -96,7 +188,8 @@ def __call__(
y_eval: Y,
f_info: Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv],
f_eval_info: FunctionInfo.EvalGrad,
) -> Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv]:
hess_update_state: dict,
) -> tuple[Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv], dict]:
"""Called whenever we want to update the Hessian approximation. This is usually
in the `accepted` branch of the `step` method of an
[`optimistix.AbstractQuasiNewton`][] minimiser.
Expand Down Expand Up @@ -143,7 +236,8 @@ def __call__(
y_eval: Y,
f_info: Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv],
f_eval_info: FunctionInfo.EvalGrad,
) -> Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv]:
hess_update_state: dict,
) -> tuple[Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv], dict]:
f_eval = f_eval_info.f
grad = f_eval_info.grad
y_diff = (y_eval**ω - y**ω).ω
Expand All @@ -167,9 +261,9 @@ def __call__(
)
if self.use_inverse:
# in this case `hessian` is the new inverse hessian
return FunctionInfo.EvalGradHessianInv(f_eval, grad, hessian)
return FunctionInfo.EvalGradHessianInv(f_eval, grad, hessian), {}
else:
return FunctionInfo.EvalGradHessian(f_eval, grad, hessian)
return FunctionInfo.EvalGradHessian(f_eval, grad, hessian), {}


class DFPUpdate(_AbstractBFGSDFPUpdate, strict=True):
Expand Down Expand Up @@ -292,6 +386,9 @@ class _QuasiNewtonState(
result: RESULTS
# Used in compat.py
num_accepted_steps: Int[Array, ""]
# update state
hess_update_state: dict



class AbstractQuasiNewton(
Expand Down Expand Up @@ -335,13 +432,31 @@ def init(
) -> _QuasiNewtonState:
f = tree_full_like(f_struct, 0)
grad = tree_full_like(y, 0)
hess_update_state = dict()
if self.hessian_update.use_inverse:
hessian_inv = _identity_pytree(y)
if isinstance(self, LBFGS):
hess_update_state = dict(
history_len=10,
start_index=jnp.array(0),
residual_y=jtu.tree_map(lambda y: jnp.zeros((10, *y.shape)), y),
residual_grad=jtu.tree_map(lambda y: jnp.zeros((10, *y.shape)), y),
rho=jnp.zeros(10)
)
hessian_inv = _lim_mem_hess_inv_operator(
hess_update_state["residual_y"],
hess_update_state["residual_grad"],
hess_update_state["rho"],
hess_update_state["start_index"]
)
else:
hessian_inv = _identity_pytree(y)
f_info = FunctionInfo.EvalGradHessianInv(f, grad, hessian_inv)
else:
hessian = _identity_pytree(y)
f_info = FunctionInfo.EvalGradHessian(f, grad, hessian)
f_info_struct = eqx.filter_eval_shape(lambda: f_info)


return _QuasiNewtonState(
first_step=jnp.array(True),
y_eval=y,
Expand All @@ -352,6 +467,7 @@ def init(
terminate=jnp.array(False),
result=RESULTS.successful,
num_accepted_steps=jnp.array(0),
hess_update_state=hess_update_state
)

def step(
Expand Down Expand Up @@ -379,11 +495,12 @@ def step(
def accepted(descent_state):
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)

f_eval_info = self.hessian_update(
f_eval_info, hess_update_state = self.hessian_update(
y,
state.y_eval,
state.f_info,
FunctionInfo.EvalGrad(f_eval, grad),
state.hess_update_state
)

descent_state = self.descent.query(
Expand All @@ -399,12 +516,12 @@ def accepted(descent_state):
terminate = jnp.where(
state.first_step, jnp.array(False), terminate
) # Skip termination on first step
return state.y_eval, f_eval_info, aux_eval, descent_state, terminate
return state.y_eval, f_eval_info, aux_eval, descent_state, terminate, hess_update_state

def rejected(descent_state):
return y, state.f_info, state.aux, descent_state, jnp.array(False)
return y, state.f_info, state.aux, descent_state, jnp.array(False), state.hess_update_state

y, f_info, aux, descent_state, terminate = filter_cond(
y, f_info, aux, descent_state, terminate, hess_update_state = filter_cond(
accept, accepted, rejected, state.descent_state
)

Expand Down Expand Up @@ -438,6 +555,7 @@ def rejected(descent_state):
terminate=terminate,
result=result,
num_accepted_steps=state.num_accepted_steps + jnp.where(accept, 1, 0),
hess_update_state=hess_update_state,
)
return y, state, aux

Expand Down Expand Up @@ -603,3 +721,114 @@ def __init__(
Valid entries are `step_size`, `loss`, `y`. For example
`verbose=frozenset({"step_size", "loss"})`.
"""


class LBFGSUpdate(AbstractQuasiNewtonUpdate, strict=True):
"""Private intermediate class for LBFGS updates."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a public class, I think. We want to expose the update classes to users who may wish to build custom solvers, in which case they are required. If you have a reference for this specific implementation, it would be great to add it here!

Small thing: can we move the definition of the update class up, so that it is grouped with the other update classes, we then define the abstract solver and then have the concrete ones at the bottom of the file? 🙈

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely 😀 I totally understand the need for linting the code!


use_inverse = True

def no_update(self, inner, grad_diff, y_diff, f_info, start_index):
return f_info.hessian_inv, jnp.array(0)

def update(self, inner, grad_diff, y_diff, f_info, start_index):
assert isinstance(f_info, FunctionInfo.EvalGradHessianInv)
# update the start index
return _lim_mem_hess_inv_operator(
y_diff,
grad_diff,
inner,
start_index,
), jnp.array(1)

def __call__(
self,
y: Y,
y_eval: Y,
f_info: Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv],
f_eval_info: FunctionInfo.EvalGrad,
hess_update_state: dict,
) -> tuple[FunctionInfo.EvalGradHessianInv, dict]:
f_eval = f_eval_info.f
grad = f_eval_info.grad
y_diff = (y_eval**ω - y**ω).ω
grad_diff = (grad**ω - f_info.grad**ω).ω

history_len = hess_update_state["history_len"]
residual_y = hess_update_state["residual_y"]
residual_grad = hess_update_state["residual_grad"]
start_index = hess_update_state["start_index"]
rho = hess_update_state["rho"]

# update states
residual_y = jtu.tree_map(lambda x, z: x.at[start_index].set(z), residual_y, y_diff)
residual_grad = jtu.tree_map(lambda x, z: x.at[start_index].set(z), residual_grad, grad_diff)
rho = rho.at[start_index].set(1. / tree_dot(y_diff, grad_diff))
rho = jnp.where(jnp.isinf(rho), 0, rho)

hessian, update = filter_cond(
rho[start_index] != 0,
self.update,
self.no_update,
rho,
residual_grad,
residual_y,
f_info,
start_index,
)
# increment circular index
start_index = (start_index + update) % history_len

update_state = dict(
history_len = history_len,
start_index = jnp.array(start_index),
residual_y = residual_y,
residual_grad = residual_grad,
rho= rho,
)

return FunctionInfo.EvalGradHessianInv(f_eval, grad, hessian), update_state

class LBFGS(AbstractQuasiNewton[Y, Aux, _Hessian], strict=True):
"""L-BFGS (Limited-memory Broyden–Fletcher–Goldfarb–Shanno) minimisation algorithm.

This is a quasi-Newton optimisation algorithm that approximates the inverse Hessian
using a limited history of gradient and parameter updates. Unlike full BFGS, which stores
a dense matrix, L-BFGS maintains a memory-efficient representation suitable for large-scale
problems.

Supports the following `options`:

- `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to
compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`,
which is usually more efficient. Changing this can be useful when the target
function does not support reverse-mode automatic differentiation.
"""

rtol: float
atol: float
norm: Callable[[PyTree], Scalar]
descent: NewtonDescent
search: AbstractSearch
hessian_update: AbstractQuasiNewtonUpdate
use_inverse: bool
verbose: frozenset[str]

def __init__(
self,
rtol: float,
atol: float,
norm: Callable[[PyTree], Scalar] = max_norm,
verbose: frozenset[str] = frozenset(),
search: AbstractSearch = BacktrackingArmijo(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not make the search an argument here, and just specify it in the body of __init__. Then we're consistent with the overall pattern where a solver is composed of a specific search and a specific descent, and swapping these out can be done by defining a custom solver.

(And mini-thing: make verbose come last, again for consistency.)


):
self.rtol = rtol
self.atol = atol
self.norm = norm
self.use_inverse = True
self.descent = NewtonDescent(linear_solver=lx.Cholesky())
# TODO(raderj): switch out `BacktrackingArmijo` with a better line search.
self.search = search
self.hessian_update = LBFGSUpdate()
self.verbose = verbose