-
Notifications
You must be signed in to change notification settings - Fork 28
Add zoom linesearch #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: tidy-dev
Are you sure you want to change the base?
Add zoom linesearch #143
Conversation
Nice! I'm just going through the L-BFGS PR in #135. Since this is a larger change, shall we schedule a Zoom discussion for next week? I think the backtracking strong Wolfe would be nice to take a look at. If you still have it around, just put it in a separate draft PR as-is and I'll take a look :) |
Sure, let's do that! I'll write you an email. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two very minor comments, not very important in the grand scheme of things. I took a look through, a bit more high-level for now.
Let's walk through this tomorrow, starting from the math and really looking at what is required and where. I'm also going to take a look at the optax and JAXopt implementations from which you are adapting things.
In particular we will need solutions for:
- communicating the first direction (
y - y_eval
after an accepted step) to the search; - getting gradients at the trial iterate in a way that a) does not compromise the performance of existing solvers which depends on skipping this, and b) does not introduce hard-coded coupling to specific search flavours in the solvers (if it can be avoided)
- how to avoid passing a
lin_fn
to the search, since this is a solver-level implementation detail (as you have seen the least squares solvers do not provide this) - figuring out if this needs to be a two-loop setup / if this is a good fit for the existing "flattened" loop architecture
We could very well be looking at a break to the existing logic here, so this will need some thinking :)
optimistix/_solver/quasi_newton.py
Outdated
@@ -556,7 +578,7 @@ class DFP(AbstractQuasiNewton[Y, Aux, _Hessian]): | |||
atol: float | |||
norm: Callable[[PyTree], Scalar] | |||
descent: NewtonDescent | |||
search: BacktrackingArmijo | |||
search: AbstractSearch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We intentionally provide concrete classes in the concrete solvers :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. Do you want me to revert this back or switch to Zoom
? Also, should I remove search
from __init__
?
As discussed, some suggestions for the implementation:
class AbstractSolver(...):
...
def step(...):
f_eval, lin_fn = ...
if self.grad_at_y_eval:
grad = lin_to_grad(...)
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
else:
f_eval_info = FunctionInfo.Eval(f_eval)
# ...call the search as normal, now passing an explicit f_eval_info
def accepted(descent_state):
# Reverse logic: here we need to compute the gradient if we have not already done that
# If we have already computed the gradient, then the accepted branch closes over grad |
I got rid of the
Thanks for the pseudocode, I did it like that. However, since it's ultimately determined by the search used if it needs the gradient at
I agree that it would be nice to have that, and I made the If you want to get rid of I'll try and see how big of a change it would be add Update: I just added Update 2: As an alternative, I modified |
I think that is clear!
Yes, like it is done here:
This is a very good point and I do like this solution!
And these are excellent points too! Then let's hold off on mainstreaming this change - I think the wrapper class approach makes sense, and could potentially even be extended to include other point-specific information, such as the slope? |
Done, raising
I included the slope in I also switched to a |
2cfa5d4
to
00e82f4
Compare
Just rebased on latest dev, will take a long look at it tomorrow! Sorry for being so slow on this, will get the ball rolling now :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few initial comments - most of them are questions at this point :)
) | ||
_step_fn = ft.partial(self._step, y, y_eval, f_info, f_eval_info) | ||
|
||
accept, state = jax.lax.cond( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note for later - we can use filter_cond
from the misc module in case we want to put a non-array into the ZoomState
.
) | ||
|
||
# write these into the state | ||
state = eqx.tree_at( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm not mistaken, then we modify the state in various places using a tree_at
, which makes it hard to follow what happens where. Could that be changed? It probably requires breaking up this large state into several sub-states, e.g. for the interpolation? I think it would be very desirable to update the state in one place only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I see what you mean.
Now I reduced it to a single tree_at
. I'll think about how we could get rid of that. Or is that fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing wrong with a tree_at
in principle! It makes it a little harder to see what does not get updated, though. Do we have elements of the search state that are unchanged here? Intuitively I would assume that the search state is updated at every search.step
, even if that means intentionally not updating certain elements of it. (Pretty much like solver.state
is also updated in every solver.step
, even though some elements may emerge unchanged when we take the rejected
branch.)
) | ||
|
||
return ZoomState( | ||
ls_iter_num=jnp.array(0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we care about the number of iterations in the increase/zoom stage or do we only care about the outermost of these conditions, namely whether we are about to start a new linesearch, having successfully completed the last one? In the latter case, does this contain the same information as done
and failed
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to keep track of the number of iterations for checking if we reached the max number of steps allowed (maxls
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense! Can we call it line_search_max_steps
?
current_slope=_slope_init, | ||
# | ||
interval_found=jnp.array(False), | ||
done=jnp.array(False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, do interval_found
, done
and failed
encode completely different elements of the state? As in, they do not have a logical relationship that would allow us to express one of these through the others?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are different.
interval_found
is separate, it's for switching between the stages of the algorithm: looking for an interval and "zooming into it".
done
and failed
are related, but not as simply as done = ~failed
, they kind of work together:
not done, not failed
- continue.
done, not failed
- accept the currently tested stepsize.
not done, failed
- try the safe stepsize on the next iteration.
If needed, we might be able to eliminate state.done
if _search_interval
and _zoom_into_interval
return accept, state
instead of just the state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the explanation! Let's keep it for now, I will try to organise my thoughts with pen and paper.
I added a BFGS and L-BFGS using zoom to the tests. |
@@ -221,7 +221,7 @@ class Zoom(AbstractSearch[Y, _FnInfo, FunctionInfo.EvalGrad, ZoomState]): | |||
# TODO decide on defaults | |||
c1: float = 1e-4 | |||
c2: float = 0.9 | |||
c3: float = 1e-6 | |||
c3: float | None = 1e-6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we document why c3
is optional and what the search does if it is not specified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added Zoom.__init__.__doc__
with an explanation for this and comments in Zoom.decrease_condition_with_approx
.
Let me know if you want to add a more exhausting documentation somewhere.
The Optax docs are really good, although a bit confusing because if I understand things correctly, they don't necessarily switch to the approximate condition but allow either of Armijo or the approximate one in case the function doesn't change much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been a source of confusion for me too. I'm also wondering if we should actually offer this feature in this search, as opposed to creating an abstract base class for Zoom line searches that then have slightly different features.
We're doing something similar with the linear and classical trust region searches, for example. The concrete classes / searches that subclass the base class don't have to do a whole lot differently, but we can separate the functionality a little more cleanly, and document for users what is happening.
Does that make sense to you? I think there are a few extra features here - same for the initial step length selection - and quite a bit of control flow already, making it tricky to understand what is supposed to happen if these different options interact.
Thanks! |
I'm making dents in it, I promise 🙈 can I rebase this on a tidy development branch? |
Of course, go ahead! I'll also add the documentation for |
632e183
to
dd40198
Compare
dd40198
to
032782c
Compare
Ok, done! I also squashed your work on the Zoom line search into a single commit (attributed to you, of course)! The tests now pass and we're up to date with the latest version of Sorry for the confusion! This is not really part of the PR experience I want to offer. Now back to the math :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is another round of comments :) Some of them are still quite basic. Do you think that we could factor out the core common functionalities and then provide the extra features in separate searches (e.g. the Hager-Zhang condition)? Extra concrete classes could be quite lean, and this may be preferable to having several features that enable extra functionality. Does that make sense to you?
c: FloatScalar, | ||
value_c: FloatScalar, | ||
) -> FloatScalar: | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: text in docstrings starts in same line, e.g.
"""A docstring.
With extra text.
"""
|
||
Args: | ||
a: scalar | ||
value_a: value of a function f at a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a
, b
, c
and the associated values are step lengths, right? And f
is not the function to be minimised, but the step length selection function, where the step lengths are what is usually called f
would usually be called
Can we clarify that here? I think one way to do it would be to change the names here. Is there some ordering of the values a
, b
and c
, e.g. a < b < c
? Perhaps we can also call f
something else, e.g. steplength_fn
.
xmin: point at which p'(xmin) = 0 | ||
""" | ||
C = slope_a | ||
db = b - a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the names of a
and b
don't make it obvious, call this distance_{...}
or distance_to_{...}
, since are in relation to a
.
value_cubic_ref: FloatScalar, | ||
) -> FloatScalar: | ||
""" | ||
Find a stepsize by minimizing the cubic or quadratic curve fitted to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: docstring, as above.
delta = jnp.abs(hi - lo) | ||
left = jnp.minimum(hi, lo) | ||
right = jnp.maximum(hi, lo) | ||
cubic_chk = 0.2 * delta # cubic guess has to be at least this far from the sides |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about cubic_threshold
and quadratic_threshold
?
set_cubic_to_hi = set_hi_to_middle | set_hi_to_lo | ||
|
||
# do the updates | ||
new_stepsize_hi, new_point_hi = tree_where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be slightly more clear to update the step sizes first, and then use new_point_high = (new_step_size_high * state.point_high**ω).ω
, instead of tree-mapping a tuple that we then unpack again immediately.
# evaluate the slope along the descent direction for the new stepsize | ||
new_stepsize = state.y_eval_stepsize | ||
new_point = PointEvalGrad(y_eval, f_eval_info) | ||
slope_at_new_point = new_point.compute_grad_dot(state.descent_direction) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, this is now the slope of the step size selection function
# If the interval is found, this will zoom into the correct interval. | ||
# If not, it still sets lo and hi, but that's okay because it will not be used | ||
# by _zoom_into_interval, and we will just return here in the next iteration. | ||
new_stepsize_lo, new_point_lo, new_slope_lo = tree_where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above, I think it would be preferable to update these separately instead of wrapping a tuple and then unpacking them again :)
descent_direction=state.descent_direction, | ||
) | ||
|
||
def fake_first_step( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have to be a separate method?
accept = jnp.array(True) | ||
return accept, state | ||
|
||
def _safe_step( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, maybe just delineate that in step
with a comment (what is now in the docstring)?
As we talked about, I implemented zoom linesearch that satisfies the strong Wolfe conditions, which is particularly useful for quasi-Newton methods.
After some iterations, the current version ended up similar to the ones in Optax and JAXopt, but compatible with Optimistix's flattened loop. It works but definitely has some rough edges, so please let me know what changes you'd like me to make.
The main complication is that to check the curvature condition, we need the slope at
y_eval
, and for that we need the gradient aty_eval
. I couldn't do that with the current interface ofAbstractSearch.step
, so I modified it to acceptlin_fn
andoptions
, which are then used to calculate the gradient inside the linesearch.This might not be the optimal solution. For example, as
lin_fn
doesn't exist inAbstractGaussNewton.step
, I just passed the identity function there. Let me know if you know a better way to handle this.To avoid recalculating a gradient in the
accepted
branch ofAbstractQuasiNewton.step
, I added a check to read it from the search state if it's using zoom linesearch, which might not be the most elegant solution.I haven't added any new tests yet, but I updated
BFGS
andDFP
to use zoom instead of backtracking, and all tests pass for me.I've also tested it with the L-BFGS implementation from #135. After that's merged,
LBFGS
should also be adopted to useZoom
.I added, then removed a
StrongWolfeBacktracking
class that does backtracking and checks the strong Wolfe conditions. I can add it again or open a separate PR if that's something you want to have.