-
Notifications
You must be signed in to change notification settings - Fork 29
LBFGS Hess approx #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LBFGS Hess approx #135
Conversation
Awesome! And that was impressively quick. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very reasonable. I've left a first round of comments. Let me know what you think!
optimistix/_solver/quasi_newton.py
Outdated
) | ||
input_shape = ( | ||
jax.eval_shape(lambda: jtu.tree_map(lambda x: x[0], residual_par)) | ||
if input_shape is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.eval_shape(lambda: y)
should always be the input shape, so I think we can use that directly here and avoid making it an optional argument. I can't think of a case where we would want to use any other shape, since the Hessian is always with respect to y
and always symmetric.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is no case in which that's not true! The idea was to provided it directly as a static arg but once the function is jit-compiled that won't matter I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly, this is inside the jitted region and the shape of y
is static.
optimistix/_solver/quasi_newton.py
Outdated
input_shape, | ||
tags=lx.positive_semidefinite_tag, | ||
) | ||
return lx.materialise(op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidentally I've been meaning to track down what's going on here - this indicates that we are retracing jaxprs, which is not ideal from a compile-time perspective. In a more perfect world we can just directly use the one we already have!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do you think that originates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using eqx.filter
makes sense here. I believe the issue comes from the fact that the JAXPRs creates a unique id for the traced variables:
(Pdb) elem.invars
[Var(id=4544898944):float64[1]]
(Pdb) elem_.invars
[Var(id=4775038400):float64[1]]
here elem
are the JAXPRs vars in the scope of eqx.tree_equal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For what it concerns the jaxpr, I followed the suggestion of filtering the tree.
I am implementing the filtering during the __call__
method of the LBFGSUpdate
:
- Get the static part of the tree before calling
update
/no_update
- In
update
/no_update
get the operator and filter for the dynamic sub-tree. - After the update call combine the static with the dynamic .
Does this make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filtering as you've done here makes perfect sense!
To address this question (others in the comments):
Maybe I don't see it, did you include the timing results for the jit compilation of just the function that creates the operator? |
Thanks for the feedback! I gave it another pass, let me know if this is good enough for me to start writing tests. |
I am not sure when this could be a bottleneck, likely only for small problems for which the evaluation of the loss and its gradient is very fast. I am attaching a quick bench marking script from time import perf_counter
import jax
import jax.numpy as jnp
from jax import random
from optimistix._solver.quasi_newton import _make_lbfgs_operator, _lbfgs_operator_fn
jax.config.update("jax_enable_x64", True)
key = random.PRNGKey(123)
par_shape = (1,)
memory_size = 10
key, subkey1, subkey2, subkey3 = random.split(key, 4)
X = random.normal(subkey1, (30, *par_shape))
true_pars = random.normal(subkey2, par_shape)
noise = 0.8 * random.normal(subkey3, (30,))
y = jnp.dot(X, true_pars) + noise
init_par = jnp.zeros(par_shape)
grad_diff = jnp.zeros((memory_size, *par_shape))
param_diff = jnp.zeros((memory_size, *par_shape))
inner_products = jnp.zeros((memory_size, *par_shape))
index_start = jnp.array(0)
# ---------------------------
# Without JIT
# ---------------------------
print("=== Without JIT ===")
with jax.disable_jit(True):
t0 = perf_counter()
_lbfgs_operator_fn(init_par, param_diff, grad_diff, inner_products, index_start)
print("Function call (tracing JAXPR):", perf_counter() - t0)
t0 = perf_counter()
_lbfgs_operator_fn(init_par, param_diff, grad_diff, inner_products, index_start)
print("Function call (interpreted, no trace):", perf_counter() - t0)
t0 = perf_counter()
op = _make_lbfgs_operator(
y_diff_history=param_diff,
grad_diff_history=grad_diff,
inner_history=inner_products,
index_start=index_start,
)
print("LinearOperator construction (traces internal matvec):", perf_counter() - t0)
t0 = perf_counter()
op = _make_lbfgs_operator(
y_diff_history=param_diff,
grad_diff_history=grad_diff,
inner_history=inner_products,
index_start=index_start,
)
op.mv(init_par)
print("LinearOperator construction + mv (new trace):", perf_counter() - t0)
# ---------------------------
# With JIT
# ---------------------------
print("\n=== With JIT ===")
t0 = perf_counter()
_lbfgs_operator_fn(init_par, param_diff, grad_diff, inner_products, index_start)
print("Function call (JIT trace + compile + run):", perf_counter() - t0)
t0 = perf_counter()
_lbfgs_operator_fn(init_par, param_diff, grad_diff, inner_products, index_start)
print("Function call (JIT execution only):", perf_counter() - t0)
t0 = perf_counter()
op_jit = _make_lbfgs_operator(
y_diff_history=param_diff,
grad_diff_history=grad_diff,
inner_history=inner_products,
index_start=index_start,
)
op_jit.mv(init_par)
print("LinearOperator JIT construction + mv (compile + run):", perf_counter() - t0)
t0 = perf_counter()
op_jit = _make_lbfgs_operator(
y_diff_history=param_diff,
grad_diff_history=grad_diff,
inner_history=inner_products,
index_start=index_start,
)
op_jit.mv(init_par)
print("LinearOperator JIT reconstruction + mv (new trace + run):", perf_counter() - t0)
t0 = perf_counter()
op_jit.mv(init_par)
print("LinearOperator mv call (JIT execution only):", perf_counter() - t0) which outputs:
|
Thank you for the update! I can review tomorrow night. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I've left a few small comments and I can follow along much more easily now, thanks for the descriptive variable names!
For tests you can just add it to the list of minimisers defined in tests/helpers.py
.
optimistix/_solver/quasi_newton.py
Outdated
|
||
|
||
class LBFGSUpdate(AbstractQuasiNewtonUpdate, strict=True): | ||
"""Private intermediate class for LBFGS updates.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a public class, I think. We want to expose the update classes to users who may wish to build custom solvers, in which case they are required. If you have a reference for this specific implementation, it would be great to add it here!
Small thing: can we move the definition of the update class up, so that it is grouped with the other update classes, we then define the abstract solver and then have the concrete ones at the bottom of the file? 🙈
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely 😀 I totally understand the need for linting the code!
optimistix/_solver/quasi_newton.py
Outdated
recursion. It does not materialize the matrix explicitly but returns a | ||
`lineax.FunctionLinearOperator`. | ||
|
||
- `y_diff_history`: History of parameter updates `s_k = x_{k+1} - x_k` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a small comment that s_k
and y_k
are the typical variable names used in the literature? Since our y
is something else.
optimistix/_solver/quasi_newton.py
Outdated
self.descent = NewtonDescent() | ||
self.search = search | ||
self.hessian_update = LBFGSUpdate( | ||
use_inverse=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only supports use_inverse = True
right now, correct? Do you know what this would look like for the approximation of the Hessian itself, not its inverse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked this up and I believe that an approx it the hessian directly may be possible starting from a representation like the ine of chapter 7.2 of this:
https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf
when the author talks about the compact representation of the update. Let me know if my intuition is correct and I'll dig more into it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ps is the approx of the hessian something you would want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, apologies for taking so long to get around to this!
I've not checked the details of the algorithms precisely, but other than that I think this basically all looks reasonable to me. I've left lots of nitty comments below on edge cases and code tidiness and such.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed some small fixes and moved the function to create an identity operator back to quasi Newton, without the tag shenanigans :)
Alright, tweaked some more things! Of note:
|
Sounds good, I'll add that tomorrow. |
I think you're missing a |
History length is now a state attribute, and will become a solver-level attribute when the update methods become solver methods :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I think this LGTM! I have only nitty remarks/questions, see below.
@johannahaffner happy to merge this into dev :)
history_length: int | ||
y_diff_history: PyTree[Y] | ||
grad_diff_history: PyTree[Y] | ||
y_diff_grad_diff_cross_inner: Float[Array, " history_length history_length"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, this is intentionally only used when use_inverse=False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we have different update states for the approximation of the Hessian and its inverse, and they do use different inner / outer products.
y_diff_grad_diff_cross_inner = state.y_diff_grad_diff_cross_inner.at[ | ||
state.index_start % self.history_length | ||
].set(v_tree_dot(state.grad_diff_history, y_diff)) | ||
y_diff_grad_diff_cross_inner = y_diff_grad_diff_cross_inner.at[ | ||
:, state.index_start % self.history_length | ||
].set(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC this is a trick to gradually fill in the lower triangular part of matrix (across multiple calls to _update
), with zero diagonal, and the upper triangular part will be filled with either zeros or nonsense data depending where we are? And the upper triangular part is fine because it's not read by the triangular solve calls we use, so we're happy with it containing meaningless data?
This is subtle enough that I think it deserves a comment 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, added an explanation!
# We know that gamma > 0 because we only update the Hessian approximation if the | ||
# inner product is positive, to maintain positive definiteness of the Hessian | ||
# approximation, and thus this operator is only ever called in that case. | ||
latest_y_diff, latest_grad_diff = jtu.tree_map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the oldest pair, or the most current one @BalzaniEdoardo? Wikipedia indexes this with k-m.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I think that means that it is the oldest?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is the most recent in the code, because the index is incremented by one before _update
is called, so subtracting one selects the most recent update. This matches Byrd et al. (1994) Eq. (3.12) - and the optax
implementation, where they use the most recent (s_{k-1}, y_{k-1})
pair (see attached figure).
I also noticed this discrepancy from the Wikipedia page, which suggests using (s_{k−m}, y_{k−m})
instead. I’m not aware of an alternative that is provably better, so I went with what seemed to be the most common practice.

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thank you!
L-BFGS Implementation for
quasi_newton.py
Hello, and first of all, thank you for the great package!
In this PR, I’m working on implementing the L-BFGS algorithm within the
quasi_newton.py
module, targeting part of #116 . Before consolidating the code with tests and full integration, I’d appreciate guidance on design decisions.Implementation Overview
_lim_mem_hess_inv_operator_fn
, which implements the two-loop recursion using the history of parameter and gradient residuals._lim_mem_hess_inv_operator
acts as an operator factory: it partially evaluates_lim_mem_hess_inv_operator_fn
with the current residuals and returns alineax.FunctionLinearOperator
._QuasiNewtonState
.Questions
Materialization vs. Tree Equality
To satisfy this assertion in
_iterate.py
:I had to materialize the FunctionLinearOperator. Is there a way to avoid this and retain the implicit operator while still passing this check?
The memory buffer size is currently fixed at 10 iterations. What’s the best way to expose this as a user-settable parameter without altering the broader solver API in
quasi_newton.py
?I did not systematically tested it yet, but for low-dimensional parameters, JIT-compiling the function
_lim_mem_hess_inv_operator_fn
and calling it directly seems significantly faster than the construction of aFunctionLinearOperator
. Below is line_profiler output for a minimise call:Let me know how to move on from here and thanks in advance for the insights!
PS tagging my collaborator here too: @bagibence