Skip to content

Add zoom linesearch #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: tidy-dev
Choose a base branch
from

Conversation

bagibence
Copy link

As we talked about, I implemented zoom linesearch that satisfies the strong Wolfe conditions, which is particularly useful for quasi-Newton methods.

After some iterations, the current version ended up similar to the ones in Optax and JAXopt, but compatible with Optimistix's flattened loop. It works but definitely has some rough edges, so please let me know what changes you'd like me to make.

The main complication is that to check the curvature condition, we need the slope at y_eval, and for that we need the gradient at y_eval. I couldn't do that with the current interface of AbstractSearch.step, so I modified it to accept lin_fn and options, which are then used to calculate the gradient inside the linesearch.
This might not be the optimal solution. For example, as lin_fn doesn't exist in AbstractGaussNewton.step, I just passed the identity function there. Let me know if you know a better way to handle this.

To avoid recalculating a gradient in the accepted branch of AbstractQuasiNewton.step, I added a check to read it from the search state if it's using zoom linesearch, which might not be the most elegant solution.

I haven't added any new tests yet, but I updated BFGS and DFP to use zoom instead of backtracking, and all tests pass for me.
I've also tested it with the L-BFGS implementation from #135. After that's merged, LBFGS should also be adopted to use Zoom.

I added, then removed a StrongWolfeBacktracking class that does backtracking and checks the strong Wolfe conditions. I can add it again or open a separate PR if that's something you want to have.

@johannahaffner
Copy link
Collaborator

johannahaffner commented Jun 13, 2025

Nice! I'm just going through the L-BFGS PR in #135.

Since this is a larger change, shall we schedule a Zoom discussion for next week? I think the backtracking strong Wolfe would be nice to take a look at. If you still have it around, just put it in a separate draft PR as-is and I'll take a look :)

@bagibence
Copy link
Author

Sure, let's do that! I'll write you an email.

Copy link
Collaborator

@johannahaffner johannahaffner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two very minor comments, not very important in the grand scheme of things. I took a look through, a bit more high-level for now.

Let's walk through this tomorrow, starting from the math and really looking at what is required and where. I'm also going to take a look at the optax and JAXopt implementations from which you are adapting things.

In particular we will need solutions for:

  • communicating the first direction (y - y_eval after an accepted step) to the search;
  • getting gradients at the trial iterate in a way that a) does not compromise the performance of existing solvers which depends on skipping this, and b) does not introduce hard-coded coupling to specific search flavours in the solvers (if it can be avoided)
  • how to avoid passing a lin_fn to the search, since this is a solver-level implementation detail (as you have seen the least squares solvers do not provide this)
  • figuring out if this needs to be a two-loop setup / if this is a good fit for the existing "flattened" loop architecture

We could very well be looking at a break to the existing logic here, so this will need some thinking :)

@@ -556,7 +578,7 @@ class DFP(AbstractQuasiNewton[Y, Aux, _Hessian]):
atol: float
norm: Callable[[PyTree], Scalar]
descent: NewtonDescent
search: BacktrackingArmijo
search: AbstractSearch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We intentionally provide concrete classes in the concrete solvers :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Do you want me to revert this back or switch to Zoom? Also, should I remove search from __init__?

@johannahaffner
Copy link
Collaborator

As discussed, some suggestions for the implementation:

  • restrict the type of the f_eval_info argument to just FunctionInfo.EvalGrad and raise a ValueError for all other cases. (Since AbstractGaussNewton cannot provide this type, the changes in this solver can be walked back.)
  • add a boolean attribute grad_at_y_eval to AbstractQuasiNewton and AbstractGradientDescent that defaults to False. Pseudocode below
  • add an at attribute to the FunctionInfo class, so that we can at any time access f_info.at or f_eval_info.at (the type of this attribute is Y).
class AbstractSolver(...):
    ...
    def step(...):
        f_eval, lin_fn = ...
        if self.grad_at_y_eval: 
            grad = lin_to_grad(...)
            f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
        else:
            f_eval_info = FunctionInfo.Eval(f_eval)
        # ...call the search as normal, now passing an explicit f_eval_info
        def accepted(descent_state):
            # Reverse logic: here we need to compute the gradient if we have not already done that
            # If we have already computed the gradient, then the accepted branch closes over grad

@bagibence
Copy link
Author

bagibence commented Jun 25, 2025

restrict the type of the f_eval_info argument to just FunctionInfo.EvalGrad and raise a ValueError for all other cases

I got rid of the _FnEvalInfo altogether and switched to FunctionInfo.EvalGrad everywhere. Do you want to keep it this way or the previous _FnEvalInfo: TypeAlias = FunctionInfo.EvalGrad?
I added a TypeError in Zoom.step. Would you prefer ValueError instead?

add a boolean attribute grad_at_y_eval to AbstractQuasiNewton and AbstractGradientDescent that defaults to False. Pseudocode below

Thanks for the pseudocode, I did it like that. However, since it's ultimately determined by the search used if it needs the gradient at y_eval or not, I added a _needs_grad_at_y_eval class attribute to the search classes instead. Otherwise, when creating a new subclass of AbstractQuasiNewton or AbstractGradientDescent with a concrete search, one would have to keep track of if that particular search needs the gradient and set the flag appropriately by hand. This way the search tells the caller if it needs it or not. How do you like this solution?

add an at attribute to the FunctionInfo class, so that we can at any time access f_info.at or f_eval_info.at (the type of this attribute is Y).

I agree that it would be nice to have that, and I made the PointEval and PointEvalGrad classes for that convenience. But wouldn't FunctionInfo.at make passing around both y and f_info (and y_eval and f_eval_info) redundant in many places throughout the codebase? I assumed there was a reason for keeping them decoupled. Perhaps to avoid passing around potentially large parameter pytrees when only needing a scalar function value?

If you want to get rid of PointEval and PointEvalGrad, we can just pass around / store the location and the function info separately. I just did it here bagibence#1

I'll try and see how big of a change it would be add FunctionInfo.at. I suspect that it might be a bit larger scope than just zoom, so it might be better to break it up into a separate PR.

Update: I just added .at in some FunctionInfo subclasses, and it ended up being smaller than I expected. The tests pass for me, but it still feels redundant without removing y. bagibence#2

Update 2: As an alternative, I modified PointEval and PointEvalGrad to wrap FunctionInfo.Eval and FunctionInfo.EvalGrad in bagibence#3

@johannahaffner
Copy link
Collaborator

I got rid of the _FnEvalInfo altogether and switched to FunctionInfo.EvalGrad everywhere.

I think that is clear!

Would you prefer ValueError instead?

Yes, like it is done here:

However, since it's ultimately determined by the search used if it needs the gradient at y_eval or not, I added a _needs_grad_at_y_eval class attribute to the search classes instead.

This is a very good point and I do like this solution!

I agree that it would be nice to have that, and I made the PointEval and PointEvalGrad classes for that convenience. But wouldn't FunctionInfo.at make passing around both y and f_info (and y_eval and f_eval_info) redundant in many places throughout the codebase?
[Plus updates 1, 2, 3]

And these are excellent points too! Then let's hold off on mainstreaming this change - I think the wrapper class approach makes sense, and could potentially even be extended to include other point-specific information, such as the slope?

@bagibence
Copy link
Author

Yes, like it is done here

Done, raising ValueError if f_info or f_eval_info is the wrong type.

I think the wrapper class approach makes sense, and could potentially even be extended to include other point-specific information, such as the slope?

I included the slope in PointEvalGrad here. Let me know if you want to use it like this bagibence#4

I also switched to a frozenset for verbose and added a few toggles. What should be printed and when requires a bit more thought.
Some things are printed based on a runtime condition (e.g. interval_too_short), which verbose_print doesn't handle I think, so for those I kept _cond_print for now.

@johannahaffner
Copy link
Collaborator

Just rebased on latest dev, will take a long look at it tomorrow! Sorry for being so slow on this, will get the ball rolling now :)

Copy link
Collaborator

@johannahaffner johannahaffner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few initial comments - most of them are questions at this point :)

)
_step_fn = ft.partial(self._step, y, y_eval, f_info, f_eval_info)

accept, state = jax.lax.cond(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for later - we can use filter_cond from the misc module in case we want to put a non-array into the ZoomState.

)

# write these into the state
state = eqx.tree_at(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, then we modify the state in various places using a tree_at, which makes it hard to follow what happens where. Could that be changed? It probably requires breaking up this large state into several sub-states, e.g. for the interpolation? I think it would be very desirable to update the state in one place only.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see what you mean.
Now I reduced it to a single tree_at. I'll think about how we could get rid of that. Or is that fine?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing wrong with a tree_at in principle! It makes it a little harder to see what does not get updated, though. Do we have elements of the search state that are unchanged here? Intuitively I would assume that the search state is updated at every search.step, even if that means intentionally not updating certain elements of it. (Pretty much like solver.state is also updated in every solver.step, even though some elements may emerge unchanged when we take the rejected branch.)

)

return ZoomState(
ls_iter_num=jnp.array(0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care about the number of iterations in the increase/zoom stage or do we only care about the outermost of these conditions, namely whether we are about to start a new linesearch, having successfully completed the last one? In the latter case, does this contain the same information as done and failed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep track of the number of iterations for checking if we reached the max number of steps allowed (maxls).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Can we call it line_search_max_steps?

current_slope=_slope_init,
#
interval_found=jnp.array(False),
done=jnp.array(False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, do interval_found, done and failed encode completely different elements of the state? As in, they do not have a logical relationship that would allow us to express one of these through the others?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are different.

interval_found is separate, it's for switching between the stages of the algorithm: looking for an interval and "zooming into it".

done and failed are related, but not as simply as done = ~failed, they kind of work together:

not done, not failed - continue.
done, not failed - accept the currently tested stepsize.
not done, failed - try the safe stepsize on the next iteration.

If needed, we might be able to eliminate state.done if _search_interval and _zoom_into_interval return accept, state instead of just the state.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation! Let's keep it for now, I will try to organise my thoughts with pen and paper.

@bagibence
Copy link
Author

I added a BFGS and L-BFGS using zoom to the tests.
Also proposed a fix for one of the todos in quasi_newton without which the tests fail.

@@ -221,7 +221,7 @@ class Zoom(AbstractSearch[Y, _FnInfo, FunctionInfo.EvalGrad, ZoomState]):
# TODO decide on defaults
c1: float = 1e-4
c2: float = 0.9
c3: float = 1e-6
c3: float | None = 1e-6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document why c3 is optional and what the search does if it is not specified?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added Zoom.__init__.__doc__ with an explanation for this and comments in Zoom.decrease_condition_with_approx.
Let me know if you want to add a more exhausting documentation somewhere.

The Optax docs are really good, although a bit confusing because if I understand things correctly, they don't necessarily switch to the approximate condition but allow either of Armijo or the approximate one in case the function doesn't change much.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been a source of confusion for me too. I'm also wondering if we should actually offer this feature in this search, as opposed to creating an abstract base class for Zoom line searches that then have slightly different features.

We're doing something similar with the linear and classical trust region searches, for example. The concrete classes / searches that subclass the base class don't have to do a whole lot differently, but we can separate the functionality a little more cleanly, and document for users what is happening.

Does that make sense to you? I think there are a few extra features here - same for the initial step length selection - and quite a bit of control flow already, making it tricky to understand what is supposed to happen if these different options interact.

@johannahaffner
Copy link
Collaborator

I added a BFGS and L-BFGS using zoom to the tests. Also proposed a fix for one of the todos in quasi_newton without which the tests fail.

Thanks!

@johannahaffner
Copy link
Collaborator

I'm making dents in it, I promise 🙈 can I rebase this on a tidy development branch?

@bagibence
Copy link
Author

Of course, go ahead! I'll also add the documentation for c3.

@johannahaffner johannahaffner changed the base branch from dev to tidy-dev August 14, 2025 12:16
@johannahaffner
Copy link
Collaborator

johannahaffner commented Aug 17, 2025

Ok, done! I also squashed your work on the Zoom line search into a single commit (attributed to you, of course)! The tests now pass and we're up to date with the latest version of tidy-dev. You'll need to pull from this, or perhaps reset your local copy on this version.

Sorry for the confusion! This is not really part of the PR experience I want to offer. Now back to the math :)

Copy link
Collaborator

@johannahaffner johannahaffner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is another round of comments :) Some of them are still quite basic. Do you think that we could factor out the core common functionalities and then provide the extra features in separate searches (e.g. the Hager-Zhang condition)? Extra concrete classes could be quite lean, and this may be preferable to having several features that enable extra functionality. Does that make sense to you?

c: FloatScalar,
value_c: FloatScalar,
) -> FloatScalar:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: text in docstrings starts in same line, e.g.

"""A docstring.
With extra text.
"""


Args:
a: scalar
value_a: value of a function f at a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a, b, c and the associated values are step lengths, right? And f is not the function to be minimised, but the step length selection function, where the step lengths are what is usually called $\alpha$ in the literature and f would usually be called $\phi$.

Can we clarify that here? I think one way to do it would be to change the names here. Is there some ordering of the values a, b and c, e.g. a < b < c? Perhaps we can also call f something else, e.g. steplength_fn.

xmin: point at which p'(xmin) = 0
"""
C = slope_a
db = b - a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the names of a and b don't make it obvious, call this distance_{...} or distance_to_{...}, since are in relation to a.

value_cubic_ref: FloatScalar,
) -> FloatScalar:
"""
Find a stepsize by minimizing the cubic or quadratic curve fitted to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: docstring, as above.

delta = jnp.abs(hi - lo)
left = jnp.minimum(hi, lo)
right = jnp.maximum(hi, lo)
cubic_chk = 0.2 * delta # cubic guess has to be at least this far from the sides
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about cubic_threshold and quadratic_threshold?

set_cubic_to_hi = set_hi_to_middle | set_hi_to_lo

# do the updates
new_stepsize_hi, new_point_hi = tree_where(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be slightly more clear to update the step sizes first, and then use new_point_high = (new_step_size_high * state.point_high**ω).ω, instead of tree-mapping a tuple that we then unpack again immediately.

# evaluate the slope along the descent direction for the new stepsize
new_stepsize = state.y_eval_stepsize
new_point = PointEvalGrad(y_eval, f_eval_info)
slope_at_new_point = new_point.compute_grad_dot(state.descent_direction)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, this is now the slope of the step size selection function $\phi$?

# If the interval is found, this will zoom into the correct interval.
# If not, it still sets lo and hi, but that's okay because it will not be used
# by _zoom_into_interval, and we will just return here in the next iteration.
new_stepsize_lo, new_point_lo, new_slope_lo = tree_where(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above, I think it would be preferable to update these separately instead of wrapping a tuple and then unpacking them again :)

descent_direction=state.descent_direction,
)

def fake_first_step(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be a separate method?

accept = jnp.array(True)
return accept, state

def _safe_step(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, maybe just delineate that in step with a comment (what is now in the docstring)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants