-
Notifications
You must be signed in to change notification settings - Fork 30
LBFGS Hess approx #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LBFGS Hess approx #135
Changes from 7 commits
a8b3ffb
956329e
9380261
36a9d3c
30669b4
c0be4bc
45a8f69
925a73c
f5e8bc5
f2aa7a6
3ccac0f
a5d16ef
d29374e
5822ad5
3dbd061
689c876
b5ba14d
d81fe10
090df71
85fe537
c0a09b9
b83db06
080afd2
60c9dee
3a8524b
5956e33
d4e3624
0660afc
a1d5f17
af7be91
5c375e8
1c8f620
5543c15
13b2052
ab3a5d9
da41125
43fdd7f
677b84c
db3d0a9
ed555bf
2e2d56c
b05fa4c
96d6c89
be43422
3ffedaf
6a3add4
013b029
424bd51
33d35b0
9eaae6e
9082f5e
f2c078e
442bbbd
4e967f5
de6c15b
47bd5f6
84ad3f2
b1f801e
3a64a5b
7b74a86
8eb96c0
2788d2d
ab816f1
1f4e217
1307644
2937fd6
a6b5bc1
4896842
644b9ed
b6a9a02
0a50627
0239c68
6caf975
32c5d8b
92648ac
16e72ae
263effa
8ae60c7
cdc9748
751a2ec
f5d9ebe
26154fe
f44e79d
2db0b5a
8ccf14e
ba84546
6a753a8
0b1f018
25b6e4d
a8e33ad
353c85b
405df6c
87c2118
fe8922d
a80823f
e746f3c
46e14b5
921f5a3
0721680
3d7fe9d
cf551b4
3a413c1
89b6f73
0d6e298
7a0342d
7cc524c
1dac261
9764d54
a845bb4
8a712f4
99d4a7e
b5e77f1
552b74d
8eaf147
189b7aa
1ba895a
5cde2c7
c6e2aca
ea10532
fab3658
adcdbcf
0058b73
c769268
093ce74
8617111
c5b8f66
43320a0
ef6f8b0
647ce42
411206c
ad8d2f5
be6853d
36ae69f
9d3e953
367ce4f
4c29f4a
eb17414
7a65e2e
7da438d
aabde0f
836e48e
2e6cac0
2a8df02
c6c46be
6e76c03
e466ced
6d334b7
523cfed
422f839
329bac5
77dba7a
6cf0725
d724a4d
02a69b7
a85bca4
e7a55b3
89bc63d
bd09b85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import abc | ||
from collections.abc import Callable | ||
from typing import Any, Generic, TypeVar, Union | ||
from functools import partial | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from typing import Any, Generic, TypeVar, Union, Optional | ||
|
||
import equinox as eqx | ||
import jax | ||
|
@@ -66,6 +67,97 @@ def _identity_pytree(pytree: PyTree[Array]) -> lx.PyTreeLinearOperator: | |
) | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@jax.jit | ||
BalzaniEdoardo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _lim_mem_hess_inv_operator_fn( | ||
pytree: PyTree[Array], | ||
residual_par: PyTree[Array], | ||
residual_grad: PyTree[Array], | ||
rho: Array, | ||
index_start: Array | ||
): | ||
""" | ||
LBFGS descent linear operator. | ||
|
||
""" | ||
history_len = rho.shape[0] | ||
circ_index = (jnp.arange(history_len) + index_start) % history_len | ||
|
||
# First loop: iterate backwards and compute alpha coefficients | ||
def backward_iter(q, indx): | ||
dy, dg, r = jtu.tree_map(lambda x: x[indx], (residual_par, residual_grad, rho)) | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
alpha = r * tree_dot(dy, q) | ||
q = (q ** ω - alpha * dg ** ω).ω | ||
return q, alpha | ||
|
||
# Second loop: iterate forwards and apply correction using stored alpha | ||
def forward_iter(args, indx): | ||
q, alpha = args | ||
ai = alpha[indx] | ||
r = rho[circ_index[indx]] | ||
dy, dg = jtu.tree_map( | ||
lambda x: x[circ_index[indx]], | ||
(residual_par, residual_grad) | ||
) | ||
bi = r * tree_dot(dg, q) | ||
q = (q ** ω + (dy ** ω * (ai - bi))).ω | ||
return (q, alpha), None | ||
|
||
|
||
q, alpha = jax.lax.scan(backward_iter, pytree, circ_index, reverse=True) | ||
dym, dgm = jtu.tree_map( | ||
lambda x: x[index_start % history_len], | ||
(residual_par, residual_grad) | ||
) | ||
dyg = tree_dot(dym, dgm) | ||
dgg = tree_dot(dgm, dgm) | ||
gamma_k = jnp.where(dgg > 1e-10, dyg / dgg, 1.0) | ||
q = (gamma_k * q ** ω).ω | ||
(q, _), _ = jax.lax.scan(forward_iter, (q, alpha), jnp.arange(history_len), reverse=False) | ||
return q | ||
|
||
|
||
def _lim_mem_hess_inv_operator( | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
residual_par: PyTree[Array], | ||
residual_grad: PyTree[Array], | ||
rho: Array, | ||
index_start: Array, | ||
input_shape: Optional[PyTree[jax.ShapeDtypeStruct]] = None, | ||
): | ||
"""Define a `lineax` linear operator implementing the L-BFGS inverse Hessian approximation. | ||
|
||
This operator computes the action of the approximate inverse Hessian on a vector `pytree` | ||
using the limited-memory BFGS (L-BFGS) two-loop recursion. It does not materialize the matrix | ||
explicitly but returns a `lineax.FunctionLinearOperator`. | ||
|
||
- `residual_par`: History of parameter updates `s_k = x_{k+1} - x_k` | ||
- `residual_grad`: History of gradient updates `y_k = g_{k+1} - g_k` | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- `rho`: Reciprocal dot products `rho_k = 1 / ⟨s_k, y_k⟩` | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- `index_start`: Index of the most recent update in the circular buffer | ||
|
||
Returns a `lineax.FunctionLinearOperator` with input and output shape matching a single element | ||
of `residual_par`. | ||
|
||
""" | ||
operator_func = partial( | ||
_lim_mem_hess_inv_operator_fn, | ||
residual_par=residual_par, | ||
residual_grad=residual_grad, | ||
rho=rho, | ||
index_start=index_start | ||
) | ||
input_shape = ( | ||
jax.eval_shape(lambda: jtu.tree_map(lambda x: x[0], residual_par)) | ||
if input_shape is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there is no case in which that's not true! The idea was to provided it directly as a static arg but once the function is jit-compiled that won't matter I guess There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly, this is inside the jitted region and the shape of |
||
else input_shape | ||
) | ||
op = lx.FunctionLinearOperator( | ||
operator_func, | ||
input_shape, | ||
tags=lx.positive_semidefinite_tag, | ||
) | ||
return lx.materialise(op) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Incidentally I've been meaning to track down what's going on here - this indicates that we are retracing jaxprs, which is not ideal from a compile-time perspective. In a more perfect world we can just directly use the one we already have! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where do you think that originates? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using (Pdb) elem.invars
[Var(id=4544898944):float64[1]]
(Pdb) elem_.invars
[Var(id=4775038400):float64[1]] here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For what it concerns the jaxpr, I followed the suggestion of filtering the tree.
Does this make sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filtering as you've done here makes perfect sense! |
||
|
||
|
||
def _outer(tree1, tree2): | ||
def leaf_fn(x): | ||
return jtu.tree_map(lambda leaf: jnp.tensordot(x, leaf, axes=0), tree2) | ||
|
@@ -74,7 +166,7 @@ def leaf_fn(x): | |
|
||
|
||
_Hessian = TypeVar( | ||
"_Hessian", FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv | ||
"_Hessian", FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv, | ||
) | ||
|
||
|
||
|
@@ -96,7 +188,8 @@ def __call__( | |
y_eval: Y, | ||
f_info: Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv], | ||
f_eval_info: FunctionInfo.EvalGrad, | ||
) -> Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv]: | ||
hess_update_state: dict, | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> tuple[Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv], dict]: | ||
"""Called whenever we want to update the Hessian approximation. This is usually | ||
in the `accepted` branch of the `step` method of an | ||
[`optimistix.AbstractQuasiNewton`][] minimiser. | ||
|
@@ -143,7 +236,8 @@ def __call__( | |
y_eval: Y, | ||
f_info: Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv], | ||
f_eval_info: FunctionInfo.EvalGrad, | ||
) -> Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv]: | ||
hess_update_state: dict, | ||
) -> tuple[Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv], dict]: | ||
f_eval = f_eval_info.f | ||
grad = f_eval_info.grad | ||
y_diff = (y_eval**ω - y**ω).ω | ||
|
@@ -167,9 +261,9 @@ def __call__( | |
) | ||
if self.use_inverse: | ||
# in this case `hessian` is the new inverse hessian | ||
return FunctionInfo.EvalGradHessianInv(f_eval, grad, hessian) | ||
return FunctionInfo.EvalGradHessianInv(f_eval, grad, hessian), {} | ||
else: | ||
return FunctionInfo.EvalGradHessian(f_eval, grad, hessian) | ||
return FunctionInfo.EvalGradHessian(f_eval, grad, hessian), {} | ||
|
||
|
||
class DFPUpdate(_AbstractBFGSDFPUpdate, strict=True): | ||
|
@@ -292,6 +386,9 @@ class _QuasiNewtonState( | |
result: RESULTS | ||
# Used in compat.py | ||
num_accepted_steps: Int[Array, ""] | ||
# update state | ||
hess_update_state: dict | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
|
||
class AbstractQuasiNewton( | ||
|
@@ -335,13 +432,31 @@ def init( | |
) -> _QuasiNewtonState: | ||
f = tree_full_like(f_struct, 0) | ||
grad = tree_full_like(y, 0) | ||
hess_update_state = dict() | ||
if self.hessian_update.use_inverse: | ||
hessian_inv = _identity_pytree(y) | ||
if isinstance(self, LBFGS): | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hess_update_state = dict( | ||
history_len=10, | ||
start_index=jnp.array(0), | ||
residual_y=jtu.tree_map(lambda y: jnp.zeros((10, *y.shape)), y), | ||
residual_grad=jtu.tree_map(lambda y: jnp.zeros((10, *y.shape)), y), | ||
rho=jnp.zeros(10) | ||
) | ||
hessian_inv = _lim_mem_hess_inv_operator( | ||
hess_update_state["residual_y"], | ||
hess_update_state["residual_grad"], | ||
hess_update_state["rho"], | ||
hess_update_state["start_index"] | ||
) | ||
else: | ||
hessian_inv = _identity_pytree(y) | ||
f_info = FunctionInfo.EvalGradHessianInv(f, grad, hessian_inv) | ||
else: | ||
hessian = _identity_pytree(y) | ||
f_info = FunctionInfo.EvalGradHessian(f, grad, hessian) | ||
f_info_struct = eqx.filter_eval_shape(lambda: f_info) | ||
|
||
|
||
return _QuasiNewtonState( | ||
first_step=jnp.array(True), | ||
y_eval=y, | ||
|
@@ -352,6 +467,7 @@ def init( | |
terminate=jnp.array(False), | ||
result=RESULTS.successful, | ||
num_accepted_steps=jnp.array(0), | ||
hess_update_state=hess_update_state | ||
) | ||
|
||
def step( | ||
|
@@ -379,11 +495,12 @@ def step( | |
def accepted(descent_state): | ||
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode) | ||
|
||
f_eval_info = self.hessian_update( | ||
f_eval_info, hess_update_state = self.hessian_update( | ||
y, | ||
state.y_eval, | ||
state.f_info, | ||
FunctionInfo.EvalGrad(f_eval, grad), | ||
state.hess_update_state | ||
) | ||
|
||
descent_state = self.descent.query( | ||
|
@@ -399,12 +516,12 @@ def accepted(descent_state): | |
terminate = jnp.where( | ||
state.first_step, jnp.array(False), terminate | ||
) # Skip termination on first step | ||
return state.y_eval, f_eval_info, aux_eval, descent_state, terminate | ||
return state.y_eval, f_eval_info, aux_eval, descent_state, terminate, hess_update_state | ||
|
||
def rejected(descent_state): | ||
return y, state.f_info, state.aux, descent_state, jnp.array(False) | ||
return y, state.f_info, state.aux, descent_state, jnp.array(False), state.hess_update_state | ||
|
||
y, f_info, aux, descent_state, terminate = filter_cond( | ||
y, f_info, aux, descent_state, terminate, hess_update_state = filter_cond( | ||
accept, accepted, rejected, state.descent_state | ||
) | ||
|
||
|
@@ -438,6 +555,7 @@ def rejected(descent_state): | |
terminate=terminate, | ||
result=result, | ||
num_accepted_steps=state.num_accepted_steps + jnp.where(accept, 1, 0), | ||
hess_update_state=hess_update_state, | ||
) | ||
return y, state, aux | ||
|
||
|
@@ -603,3 +721,114 @@ def __init__( | |
Valid entries are `step_size`, `loss`, `y`. For example | ||
`verbose=frozenset({"step_size", "loss"})`. | ||
""" | ||
|
||
|
||
class LBFGSUpdate(AbstractQuasiNewtonUpdate, strict=True): | ||
"""Private intermediate class for LBFGS updates.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be a public class, I think. We want to expose the update classes to users who may wish to build custom solvers, in which case they are required. If you have a reference for this specific implementation, it would be great to add it here! Small thing: can we move the definition of the update class up, so that it is grouped with the other update classes, we then define the abstract solver and then have the concrete ones at the bottom of the file? 🙈 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Absolutely 😀 I totally understand the need for linting the code! |
||
|
||
use_inverse = True | ||
|
||
def no_update(self, inner, grad_diff, y_diff, f_info, start_index): | ||
return f_info.hessian_inv, jnp.array(0) | ||
|
||
def update(self, inner, grad_diff, y_diff, f_info, start_index): | ||
assert isinstance(f_info, FunctionInfo.EvalGradHessianInv) | ||
# update the start index | ||
return _lim_mem_hess_inv_operator( | ||
y_diff, | ||
grad_diff, | ||
inner, | ||
start_index, | ||
), jnp.array(1) | ||
|
||
def __call__( | ||
self, | ||
y: Y, | ||
y_eval: Y, | ||
f_info: Union[FunctionInfo.EvalGradHessian, FunctionInfo.EvalGradHessianInv], | ||
f_eval_info: FunctionInfo.EvalGrad, | ||
hess_update_state: dict, | ||
) -> tuple[FunctionInfo.EvalGradHessianInv, dict]: | ||
f_eval = f_eval_info.f | ||
grad = f_eval_info.grad | ||
y_diff = (y_eval**ω - y**ω).ω | ||
grad_diff = (grad**ω - f_info.grad**ω).ω | ||
|
||
history_len = hess_update_state["history_len"] | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
residual_y = hess_update_state["residual_y"] | ||
residual_grad = hess_update_state["residual_grad"] | ||
start_index = hess_update_state["start_index"] | ||
rho = hess_update_state["rho"] | ||
|
||
# update states | ||
residual_y = jtu.tree_map(lambda x, z: x.at[start_index].set(z), residual_y, y_diff) | ||
residual_grad = jtu.tree_map(lambda x, z: x.at[start_index].set(z), residual_grad, grad_diff) | ||
rho = rho.at[start_index].set(1. / tree_dot(y_diff, grad_diff)) | ||
rho = jnp.where(jnp.isinf(rho), 0, rho) | ||
|
||
hessian, update = filter_cond( | ||
rho[start_index] != 0, | ||
self.update, | ||
self.no_update, | ||
rho, | ||
residual_grad, | ||
residual_y, | ||
f_info, | ||
start_index, | ||
) | ||
# increment circular index | ||
start_index = (start_index + update) % history_len | ||
|
||
update_state = dict( | ||
history_len = history_len, | ||
start_index = jnp.array(start_index), | ||
residual_y = residual_y, | ||
residual_grad = residual_grad, | ||
rho= rho, | ||
) | ||
|
||
return FunctionInfo.EvalGradHessianInv(f_eval, grad, hessian), update_state | ||
|
||
class LBFGS(AbstractQuasiNewton[Y, Aux, _Hessian], strict=True): | ||
"""L-BFGS (Limited-memory Broyden–Fletcher–Goldfarb–Shanno) minimisation algorithm. | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
This is a quasi-Newton optimisation algorithm that approximates the inverse Hessian | ||
using a limited history of gradient and parameter updates. Unlike full BFGS, which stores | ||
a dense matrix, L-BFGS maintains a memory-efficient representation suitable for large-scale | ||
problems. | ||
|
||
Supports the following `options`: | ||
|
||
- `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to | ||
compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, | ||
which is usually more efficient. Changing this can be useful when the target | ||
function does not support reverse-mode automatic differentiation. | ||
""" | ||
|
||
rtol: float | ||
atol: float | ||
norm: Callable[[PyTree], Scalar] | ||
descent: NewtonDescent | ||
search: AbstractSearch | ||
hessian_update: AbstractQuasiNewtonUpdate | ||
use_inverse: bool | ||
verbose: frozenset[str] | ||
|
||
def __init__( | ||
self, | ||
rtol: float, | ||
atol: float, | ||
norm: Callable[[PyTree], Scalar] = max_norm, | ||
verbose: frozenset[str] = frozenset(), | ||
search: AbstractSearch = BacktrackingArmijo(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not make the search an argument here, and just specify it in the body of (And mini-thing: make verbose come last, again for consistency.) |
||
|
||
): | ||
self.rtol = rtol | ||
self.atol = atol | ||
self.norm = norm | ||
self.use_inverse = True | ||
johannahaffner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.descent = NewtonDescent(linear_solver=lx.Cholesky()) | ||
# TODO(raderj): switch out `BacktrackingArmijo` with a better line search. | ||
self.search = search | ||
self.hessian_update = LBFGSUpdate() | ||
self.verbose = verbose |
Uh oh!
There was an error while loading. Please reload this page.