-
Notifications
You must be signed in to change notification settings - Fork 11
Unified solver interface for compatibility with JAXopt and Optimistix #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
Unified solver interface for compatibility with JAXopt and Optimistix #365
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big picture questions:
- Do we need JaxOpt still? it looks like the transition would be completed already with this PR
- How does this change how the user interacts with solvers? It looks like it doesn't. Is there a reason why the solver registry is public if one can only pass strings to specify solvers?
- Does this interface allow arbitrary optimistix and optax solvers? i.e. can we eventually allow passing an instance of a solver directly
- For this PR we need some developer guide note explaining the interface? stuff like, do we expect any solver implementing the API to be compatible? do you have any note on how to write an interface to a new solver library? other than matching the api
- Can you implement the proximal gradient in a separate PR?
- If we are not dropping jaxopt yet, then we could split the optax/optimistix in two separate PRs and add jaxopt to the registry instead of what we have now. But if you think that we can drop jaxopt here, then let's continue on this one.
Next week Billy is joining and we continue discuss this
pass | ||
|
||
@abc.abstractmethod | ||
def run(self, init_params: Params, *args) -> StepResult: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe have an abstract run_iterator, taking in a generator; for solvers that are not stochastic raise a ValueError
or a more specific one if it exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's great! I forgot about that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it to AbstractSolver
.
How do you want to control which solver should have it and which shouldn't?
How about adding a base implementation copying jaxopt.StochasticSolver
's and only allowing if the class has a class variable _stochastic=True
?
Or a StochasticSolverMixin
with the implementation and inheriting that in every class that should have it?
src/nemos/solvers/_jaxopt_solvers.py
Outdated
**solver_init_kwargs, | ||
) | ||
|
||
def _extend_args(self, args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we avoid the if statement by storing the extended args? or initializing as (reg_stregth,) or empty tuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you want to avoid it? args are passed to the methods by the code using the solver, they can't be stored on construction.
The prepending could be done it other ways, I like that this is explicit.
tags=self.config.tags, | ||
) | ||
|
||
self.stats.update(solution.stats) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have a self.stats and the other adapter doesn't. can we replicate the stats in the jaxopt adapter too? you can parse the jaxopt state and extract as much info as we can
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only because Optimistix saves the number of steps taken at the end of a minimise
run in the stats instead of the state itself like JAXopt, and I wanted a quick way to access those for the tests.
I haven't given this too much thought, so we could think about if we want to standardize what we keep at the end of the runs.
Thanks for taking a look!
I think it's good to keep JAXopt at least as a backup option for a while. Perhaps providing a way like the environment variable mentioned for switching the backend.
It doesn't. I aimed for no changes required from the user's side as to not break existing analysis code.
Do you mean that it's not called
Do you mean passing an
Yes, that makes sense. I would expect any solver implementing the API to be compatible. The only hurdle I see is the compatibility check in the regularizers that works based on name.
Sure!
Yes, that also makes sense. I think it's useful to keep JAXopt for now and only drop it once the switch is done and you are satisfied with it. Edit: quote formatting was off |
Defines the interface. Will add type annotations later.
Instead of trying to find solvers in packages, have an explicit list.
According to the new interface __init__ will receive everything, and will raise an error at that point.
Simplify how the solver object is accessed in test_solvers and test_glm.
Cleanup for previous commit
This will go into a separate issue + PR. Updated the developer notes.
I removed things related to stochastic optimization, this will be done later in #376 |
Currently using PyTree from jaxtyping. Might want to just alias to Any. Moved solver interface types to typing. Use StepResult and SolverState in GLM and BaseRegressor.
Also a docstring
For now PyTree is imported as Pytree to not introduce too many changes. This PyTree is parametrizable, which is used in the solvers module. ArrayLike is the same as jax.typing.ArrayLike. Importing from jaxtyping for consistency.
I will open an issue about potentially using jaxtyping.
There is a scratch part at the end that will have to be removed based on which version we think is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have @billbrod taking a look but after you address my latest comments, this LGTM
|
||
## Background | ||
|
||
In the beginning NeMoS relied on [JAXopt](https://jaxopt.github.io/stable/) as its optimization backend. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the beginning NeMoS relied on [JAXopt](https://jaxopt.github.io/stable/) as its optimization backend. | |
In the earlier versions, NeMoS relied on [JAXopt](https://jaxopt.github.io/stable/) as its optimization backend. |
│ └─ Concrete Subclass WrappedProxSVRG | ||
``` | ||
|
||
`OptaxOptimistixSolver` is for using Optax solvers, utilizing `optimistix.OptaxMinimiser` to run the full optimization loop. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
`OptaxOptimistixSolver` is for using Optax solvers, utilizing `optimistix.OptaxMinimiser` to run the full optimization loop. | |
`OptaxOptimistixSolver` is an adapter for Optax solvers, relying on `optimistix.OptaxMinimiser` to run the full optimization loop. |
``` | ||
|
||
`OptaxOptimistixSolver` is for using Optax solvers, utilizing `optimistix.OptaxMinimiser` to run the full optimization loop. | ||
Optimistix does not have implementations of Nesterov acceleration, so gradient descent is implemented by wrapping `optax.sgd` which does support it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a reference to the Nesterov acceleration?
To run stochastic (~mini-batch) optimization, JAXopt used a `run_iterator` method. | ||
Instead of the full input data `run_iterator` accepts a generator / iterator that provides batches of data. | ||
|
||
For solvers defined in `nemos` that can be used this way, we will likely provide `StochasticMixin` which borrows the implementation from JAXopt. (Or some version of it. See below.). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For solvers defined in `nemos` that can be used this way, we will likely provide `StochasticMixin` which borrows the implementation from JAXopt. (Or some version of it. See below.). | |
For solvers defined in `nemos` that can be used this way, we will likely provide `StochasticMixin` which borrows the implementation from JAXopt (Or some version of it, see below). |
We will likely define an interface or protocol for this, allowing custom (user-defined) solvers to also implement their own version. | ||
We will also have to decide on how this will be exposed to users on the level of `BaseRegressor` and `GLM`. | ||
|
||
Note that (Prox-)SVRG is especially well-suited for running stochastic optimization, however it currently requires the optimization loop to be implemented separately as it is a bit more involved than what is done by `run_iterator`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrap this in a admonition on SVRG, to make it more visible
|
||
Currently, the solver registry defines which implementation to use for each algorithm, so that has to be overwritten in order to tell NeMoS to use a custom class. | ||
|
||
This is hacky and not an intended use-case for now, but in the future we are [planning to support passing any solver to `BaseRegressor`](https://github.com/flatironinstitute/nemos/issues/378). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if it makes sense to have a guideline for something we do not recommend. It might be better to copy paste this note in the issue #378
split_indices = np.cumsum(sizes)[:-1] | ||
flat_params = np.concat([x.flatten() for x in params]) | ||
def unpacker(_flat_params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be generalized and included in nemos as utility function that we expose. It's useful in general
import jax
import jax.numpy as jnp
def get_packer_unpacker(parameter_tree):
flat, struct = jax.tree_util.tree_flatten(parameter_tree)
shapes = [x.shape for x in flat]
sizes = jnp.array([x.size for x in flat], dtype=int)
split_indices = jnp.cumsum(sizes[:-1])
def packer(parameter_tree):
flat = jax.tree_util.tree_leaves(parameter_tree)
return jnp.concatenate([x.flatten() for x in flat])
def unpacker(flat_params):
split_params = jnp.split(flat_params, split_indices)
split_params = [x.reshape(s) for x, s in zip(split_params, shapes)]
return jax.tree_util.tree_unflatten(struct, split_params)
return packer, unpacker
# NOTE not testing these anymore since non-JAXopt solvers' state is not necessarily a namedtuple | ||
# assert ( | ||
# hasattr(state, "_fields") | ||
# and hasattr(state, "_field_defaults") | ||
# and hasattr(state, "_asdict") | ||
# ) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove note
# NOTE not testing these anymore since non-JAXopt solvers' state is not necessarily a namedtuple | ||
# assert ( | ||
# hasattr(state, "_fields") | ||
# and hasattr(state, "_field_defaults") | ||
# and hasattr(state, "_asdict") | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove note
setenv = | ||
NEMOS_SOLVER_BACKEND = jaxopt | ||
commands = | ||
pytest -n auto tests/test_solvers.py tests/test_glm.py tests/test_convergence.py tests/test_regularizer.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it is a pain, but can you group all glm tests that calls fit or that calls init solver or any method interacting with solver explicitly in a class and only run that class only?
pytest -n auto tests/test_solvers.py tests/test_glm.py::TestSolverRelated tests/test_convergence.py tests/test_regularizer.py
Because JAXopt is no longer maintained, NeMoS is migrating its optimization backed to Optimistix and Optax.
As these are not a full replacement for JAXopt yet, at least in the beginning, solvers from both backends will be supported.
This is achieved through a unified interface for solvers which defines the interaction between the optimization backend's solvers and NeMoS models.
The solver interface each solver must adhere to in order to be compatible with
BaseRegressor
is defined inAbstractSolver
, and mostly follows the previous interface of JAXopt solvers.Compatibility with existing JAXopt-, Optimistix-, and the (Prox-)SVRG solvers is provided by adapter classes. A base class for these is defined in
SolverAdapter
.Instead of looking up solver classes based on their name in
nemos.solvers
andjaxopt
, the class used to implement each algorithm is explicitly defined in the solver registry.Currently solvers are created based on the algorithm's name, but this could be extended in the future to allow user-defined solvers to be passed to
BaseRegressor
. As long as a solver implements theAbstractSolver
interface, it should be compatible with NeMoS. (In this case the compatibility check with the regularizer could be disabled or rewritten.)As Optimistix doesn't implement
ProximalGradient
, this is currently set to use a custom implementation inOptaxOptimistixProximalGradient
based on Optax's SGD with linesearch followed by the proximal operator. This seems to work in practice, but is not as theoretically sound asjaxopt.ProximalGradient
.We contributed an implementation of
LBFGS
to Optimistix (PR for L-BFGS update, PR for zoom linesearch), and a wrapper class for it is already included here (just commented out). Until these PRs are merged and released, JAXopt's or Optax's implementation could be used as default.As in this interface every solver parameter must be passed to the solver's constructor,
BaseRegressor.instantiate_solver
is simplified. Each solver can expose its accepted arguments explicitly throughget_accepted_arguments
. If this is not defined,BaseRegressor
tries to infer it.There are basic differences between the parameters accepted by JAXopt and Optimistix. These differences are handled on solver instantiation, but a warning is raised:
maxiter
, Optimistix usesmax_steps
tol
value, while Optimistix usesatol
andrtol
Currently the
NEMOS_SOLVER_BACKEND
environment variable can be used to run tests with a specific backend. I addedtox
environments that run solver-dependent tests with both backends, and added these to the CI.This switch could be moved to the main code, so that users have the option to choose the backend they want.