Skip to content

Conversation

BalzaniEdoardo
Copy link

@BalzaniEdoardo BalzaniEdoardo commented May 6, 2025

L-BFGS Implementation for quasi_newton.py

Hello, and first of all, thank you for the great package!

In this PR, I’m working on implementing the L-BFGS algorithm within the quasi_newton.py module, targeting part of #116 . Before consolidating the code with tests and full integration, I’d appreciate guidance on design decisions.


Implementation Overview

  • The descent direction is computed via _lim_mem_hess_inv_operator_fn, which implements the two-loop recursion using the history of parameter and gradient residuals.
  • _lim_mem_hess_inv_operator acts as an operator factory: it partially evaluates _lim_mem_hess_inv_operator_fn with the current residuals and returns a lineax.FunctionLinearOperator.
  • Currently, the operator is materialized before returning, which likely defeats the purpose of using an implicit representation.
  • The buffers with the residuals are stored in a pytree of arrays with the same structure as the parameters, but with an additional dimension (of length "buffer size").
  • Residuals are stored in a dictionary within _QuasiNewtonState.
  • The Hessian update returns an additional dictionary carrying the updated residual state.

Questions

  1. Materialization vs. Tree Equality

    To satisfy this assertion in _iterate.py:

    assert eqx.tree_equal(static_state, new_static_state) is True

    I had to materialize the FunctionLinearOperator. Is there a way to avoid this and retain the implicit operator while still passing this check?

  2. The memory buffer size is currently fixed at 10 iterations. What’s the best way to expose this as a user-settable parameter without altering the broader solver API in quasi_newton.py?

  3. I did not systematically tested it yet, but for low-dimensional parameters, JIT-compiling the function _lim_mem_hess_inv_operator_fn and calling it directly seems significantly faster than the construction of a FunctionLinearOperator. Below is line_profiler output for a minimise call:

Timer unit: 1e-06 s

Total time: 0.052245 s
File: /Users/ebalzani/Code/optimistix/optimistix/_solver/quasi_newton.py
Function: _lim_mem_hess_inv_operator at line 120

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  120                                           @line_profiler.profile
  121                                           def _lim_mem_hess_inv_operator(
  122                                                   residual_par: PyTree[Array],
  123                                                   residual_grad: PyTree[Array],
  124                                                   rho: Array,
  125                                                   index_start: Array,
  126                                                   input_shape: Optional[PyTree[jax.ShapeDtypeStruct]] = None,
  127                                           ):
  128                                               """Define a `lineax` linear operator implementing the L-BFGS inverse Hessian approximation.
  129                                           
  130                                               This operator computes the action of the approximate inverse Hessian on a vector `pytree`
  131                                               using the limited-memory BFGS (L-BFGS) two-loop recursion. It does not materialize the matrix
  132                                               explicitly but returns a `lineax.FunctionLinearOperator`.
  133                                           
  134                                               - `residual_par`: History of parameter updates `s_k = x_{k+1} - x_k`
  135                                               - `residual_grad`: History of gradient updates `y_k = g_{k+1} - g_k`
  136                                               - `rho`: Reciprocal dot products `rho_k = 1 / ⟨s_k, y_k⟩`
  137                                               - `index_start`: Index of the most recent update in the circular buffer
  138                                           
  139                                               Returns a `lineax.FunctionLinearOperator` with input and output shape matching a single element
  140                                               of `residual_par`.
  141                                           
  142                                               """
  143         4          2.0      0.5      0.0      operator_func = partial(
  144         2          0.0      0.0      0.0          _lim_mem_hess_inv_operator_fn,
  145         2          0.0      0.0      0.0          residual_par=residual_par,
  146         2          0.0      0.0      0.0          residual_grad=residual_grad,
  147         2          0.0      0.0      0.0          rho=rho,
  148         2          0.0      0.0      0.0          index_start=index_start
  149                                               )
  150         2          0.0      0.0      0.0      input_shape = (
  151         2        942.0    471.0      1.8          jax.eval_shape(lambda: jtu.tree_map(lambda x: x[0], residual_par))
  152         2          0.0      0.0      0.0          if input_shape is None
  153                                                   else input_shape
  154                                               )
  155         4      28219.0   7054.8     54.0      op = lx.FunctionLinearOperator(
  156         2          0.0      0.0      0.0          operator_func,
  157         2          0.0      0.0      0.0          input_shape,
  158         2          1.0      0.5      0.0          tags=lx.positive_semidefinite_tag,
  159                                               )
  160         2      23081.0  11540.5     44.2      return lx.materialise(op)

Let me know how to move on from here and thanks in advance for the insights!

PS tagging my collaborator here too: @bagibence

@johannahaffner
Copy link
Collaborator

Awesome! And that was impressively quick.
I can go through this by Thursday :)

Copy link
Collaborator

@johannahaffner johannahaffner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very reasonable. I've left a first round of comments. Let me know what you think!

)
input_shape = (
jax.eval_shape(lambda: jtu.tree_map(lambda x: x[0], residual_par))
if input_shape is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.eval_shape(lambda: y) should always be the input shape, so I think we can use that directly here and avoid making it an optional argument. I can't think of a case where we would want to use any other shape, since the Hessian is always with respect to y and always symmetric.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is no case in which that's not true! The idea was to provided it directly as a static arg but once the function is jit-compiled that won't matter I guess

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, this is inside the jitted region and the shape of y is static.

input_shape,
tags=lx.positive_semidefinite_tag,
)
return lx.materialise(op)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the shape complaint you got when just returning the function linear operator is due to the jaxpr in it, we solve this in a few places by just returning the dynamic portion of the equinox module, e.g. here. (And here for some background.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally I've been meaning to track down what's going on here - this indicates that we are retracing jaxprs, which is not ideal from a compile-time perspective. In a more perfect world we can just directly use the one we already have!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you think that originates?

Copy link
Author

@BalzaniEdoardo BalzaniEdoardo May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using eqx.filter makes sense here. I believe the issue comes from the fact that the JAXPRs creates a unique id for the traced variables:

(Pdb) elem.invars
[Var(id=4544898944):float64[1]]
(Pdb) elem_.invars
[Var(id=4775038400):float64[1]]

here elem are the JAXPRs vars in the scope of eqx.tree_equal

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it concerns the jaxpr, I followed the suggestion of filtering the tree.
I am implementing the filtering during the __call__ method of the LBFGSUpdate:

  • Get the static part of the tree before calling update/no_update
  • In update/no_update get the operator and filter for the dynamic sub-tree.
  • After the update call combine the static with the dynamic .

Does this make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filtering as you've done here makes perfect sense!

@johannahaffner
Copy link
Collaborator

To address this question (others in the comments):

  1. JIT-compiling the function _lim_mem_hess_inv_operator_fn and calling it directly seems significantly faster than the construction of a FunctionLinearOperator.

Maybe I don't see it, did you include the timing results for the jit compilation of just the function that creates the operator?

@BalzaniEdoardo
Copy link
Author

Thanks for the feedback!

I gave it another pass, let me know if this is good enough for me to start writing tests.

@BalzaniEdoardo
Copy link
Author

To address this question (others in the comments):

  1. JIT-compiling the function _lim_mem_hess_inv_operator_fn and calling it directly seems significantly faster than the construction of a FunctionLinearOperator.

Maybe I don't see it, did you include the timing results for the jit compilation of just the function that creates the operator?

I am not sure when this could be a bottleneck, likely only for small problems for which the evaluation of the loss and its gradient is very fast. I am attaching a quick bench marking script

from time import perf_counter

import jax
import jax.numpy as jnp
from jax import random
from optimistix._solver.quasi_newton import _make_lbfgs_operator, _lbfgs_operator_fn

jax.config.update("jax_enable_x64", True)

key = random.PRNGKey(123)
par_shape = (1,)
memory_size = 10

key, subkey1, subkey2, subkey3 = random.split(key, 4)

X = random.normal(subkey1, (30, *par_shape))
true_pars = random.normal(subkey2, par_shape)
noise = 0.8 * random.normal(subkey3, (30,))
y = jnp.dot(X, true_pars) + noise

init_par = jnp.zeros(par_shape)

grad_diff = jnp.zeros((memory_size, *par_shape))
param_diff = jnp.zeros((memory_size, *par_shape))
inner_products = jnp.zeros((memory_size, *par_shape))
index_start = jnp.array(0)

# ---------------------------
# Without JIT
# ---------------------------
print("=== Without JIT ===")

with jax.disable_jit(True):
    t0 = perf_counter()
    _lbfgs_operator_fn(init_par, param_diff, grad_diff, inner_products, index_start)
    print("Function call (tracing JAXPR):", perf_counter() - t0)

    t0 = perf_counter()
    _lbfgs_operator_fn(init_par, param_diff, grad_diff, inner_products, index_start)
    print("Function call (interpreted, no trace):", perf_counter() - t0)

    t0 = perf_counter()
    op = _make_lbfgs_operator(
        y_diff_history=param_diff,
        grad_diff_history=grad_diff,
        inner_history=inner_products,
        index_start=index_start,
    )
    print("LinearOperator construction (traces internal matvec):", perf_counter() - t0)

    t0 = perf_counter()
    op = _make_lbfgs_operator(
        y_diff_history=param_diff,
        grad_diff_history=grad_diff,
        inner_history=inner_products,
        index_start=index_start,
    )
    op.mv(init_par)
    print("LinearOperator construction + mv (new trace):", perf_counter() - t0)


# ---------------------------
# With JIT
# ---------------------------
print("\n=== With JIT ===")

t0 = perf_counter()
_lbfgs_operator_fn(init_par, param_diff, grad_diff, inner_products, index_start)
print("Function call (JIT trace + compile + run):", perf_counter() - t0)

t0 = perf_counter()
_lbfgs_operator_fn(init_par, param_diff, grad_diff, inner_products, index_start)
print("Function call (JIT execution only):", perf_counter() - t0)

t0 = perf_counter()
op_jit = _make_lbfgs_operator(
    y_diff_history=param_diff,
    grad_diff_history=grad_diff,
    inner_history=inner_products,
    index_start=index_start,
)
op_jit.mv(init_par)
print("LinearOperator JIT construction + mv (compile + run):", perf_counter() - t0)

t0 = perf_counter()
op_jit = _make_lbfgs_operator(
    y_diff_history=param_diff,
    grad_diff_history=grad_diff,
    inner_history=inner_products,
    index_start=index_start,
)
op_jit.mv(init_par)
print("LinearOperator JIT reconstruction + mv (new trace + run):", perf_counter() - t0)

t0 = perf_counter()
op_jit.mv(init_par)
print("LinearOperator mv call (JIT execution only):", perf_counter() - t0)

which outputs:

=== Without JIT ===
Function call (tracing JAXPR): 1.0509237500373274
Function call (interpreted, no trace): 0.030687999911606312
LinearOperator construction (traces internal matvec): 0.1505412079859525
LinearOperator construction + mv (new trace): 0.9351932499557734
=== With JIT ===
Function call (JIT trace + compile + run): 0.05465291708242148
Function call (JIT execution only): 2.062495332211256e-05
LinearOperator JIT construction + mv (compile + run): 0.07817629189230502
LinearOperator JIT reconstruction + mv (new trace + run): 0.002992084017023444
LinearOperator mv call (JIT execution only): 0.0002868750598281622

@johannahaffner
Copy link
Collaborator

Thank you for the update! I can review tomorrow night.

Copy link
Collaborator

@johannahaffner johannahaffner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I've left a few small comments and I can follow along much more easily now, thanks for the descriptive variable names!

For tests you can just add it to the list of minimisers defined in tests/helpers.py.



class LBFGSUpdate(AbstractQuasiNewtonUpdate, strict=True):
"""Private intermediate class for LBFGS updates."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a public class, I think. We want to expose the update classes to users who may wish to build custom solvers, in which case they are required. If you have a reference for this specific implementation, it would be great to add it here!

Small thing: can we move the definition of the update class up, so that it is grouped with the other update classes, we then define the abstract solver and then have the concrete ones at the bottom of the file? 🙈

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely 😀 I totally understand the need for linting the code!

recursion. It does not materialize the matrix explicitly but returns a
`lineax.FunctionLinearOperator`.

- `y_diff_history`: History of parameter updates `s_k = x_{k+1} - x_k`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a small comment that s_k and y_k are the typical variable names used in the literature? Since our y is something else.

self.descent = NewtonDescent()
self.search = search
self.hessian_update = LBFGSUpdate(
use_inverse=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only supports use_inverse = True right now, correct? Do you know what this would look like for the approximation of the Hessian itself, not its inverse?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked this up and I believe that an approx it the hessian directly may be possible starting from a representation like the ine of chapter 7.2 of this:
https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf

when the author talks about the compact representation of the update. Let me know if my intuition is correct and I'll dig more into it!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ps is the approx of the hessian something you would want?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, apologies for taking so long to get around to this!

I've not checked the details of the algorithms precisely, but other than that I think this basically all looks reasonable to me. I've left lots of nitty comments below on edge cases and code tidiness and such.

Copy link
Collaborator

@johannahaffner johannahaffner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed some small fixes and moved the function to create an identity operator back to quasi Newton, without the tag shenanigans :)

@johannahaffner
Copy link
Collaborator

Alright, tweaked some more things!

Of note:

  • I suggest we do the handling of self.use_inverse in a separate PR that turn moves the Hessian update machinery into solver methods @patrick-kidger, and leave this as-is here (see comments above).
  • likewise happy to refine the jaxpr handling in the course of improving our static handling of these (to make everything compatible with jax.vmap)
  • @BalzaniEdoardo can you do the jnp.where safeguard for the no-jit + debug + vmap edge case? It seems like you know what would need to change :)
  • finally, I'm questioning if failing to unpack the shape of the y_diff history isn't actually an informative way to fail in the case in which y is an empty pytree.

@BalzaniEdoardo
Copy link
Author

  • @BalzaniEdoardo can you do the jnp.where safeguard for the no-jit + debug + vmap edge case? It seems like you know what would need to change :)

Sounds good, I'll add that tomorrow.

@johannahaffner
Copy link
Collaborator

I think you're missing a jnp.asarray around the inner in the set :)

@johannahaffner
Copy link
Collaborator

History length is now a state attribute, and will become a solver-level attribute when the update methods become solver methods :)

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I think this LGTM! I have only nitty remarks/questions, see below.

@johannahaffner happy to merge this into dev :)

history_length: int
y_diff_history: PyTree[Y]
grad_diff_history: PyTree[Y]
y_diff_grad_diff_cross_inner: Float[Array, " history_length history_length"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, this is intentionally only used when use_inverse=False?

Copy link
Collaborator

@johannahaffner johannahaffner Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have different update states for the approximation of the Hessian and its inverse, and they do use different inner / outer products.

Comment on lines +384 to +389
y_diff_grad_diff_cross_inner = state.y_diff_grad_diff_cross_inner.at[
state.index_start % self.history_length
].set(v_tree_dot(state.grad_diff_history, y_diff))
y_diff_grad_diff_cross_inner = y_diff_grad_diff_cross_inner.at[
:, state.index_start % self.history_length
].set(0)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is a trick to gradually fill in the lower triangular part of matrix (across multiple calls to _update), with zero diagonal, and the upper triangular part will be filled with either zeros or nonsense data depending where we are? And the upper triangular part is fine because it's not read by the triangular solve calls we use, so we're happy with it containing meaningless data?

This is subtle enough that I think it deserves a comment 😄

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, added an explanation!

# We know that gamma > 0 because we only update the Hessian approximation if the
# inner product is positive, to maintain positive definiteness of the Hessian
# approximation, and thus this operator is only ever called in that case.
latest_y_diff, latest_grad_diff = jtu.tree_map(
Copy link
Collaborator

@johannahaffner johannahaffner Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the oldest pair, or the most current one @BalzaniEdoardo? Wikipedia indexes this with k-m.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think that means that it is the oldest?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the most recent in the code, because the index is incremented by one before _update is called, so subtracting one selects the most recent update. This matches Byrd et al. (1994) Eq. (3.12) - and the optax implementation, where they use the most recent (s_{k-1}, y_{k-1}) pair (see attached figure).

I also noticed this discrepancy from the Wikipedia page, which suggests using (s_{k−m}, y_{k−m}) instead. I’m not aware of an alternative that is provably better, so I went with what seemed to be the most common practice.

Screenshot 2025-07-14 at 9 37 59 AM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thank you!

@johannahaffner johannahaffner merged commit 8dbc8e5 into patrick-kidger:dev Jul 14, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants