-
Notifications
You must be signed in to change notification settings - Fork 11
Unified solver interface for compatibility with JAXopt and Optimistix #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
Changes from 128 commits
acf56b8
412d4c6
02b1f59
7fb12a7
5079e40
37481e4
3c2387f
42f77e5
4f3c0bd
fc692da
4ab4f4a
c33c8f3
a8491b5
439ace7
ea86249
3c6d715
cde1c63
ef55bf5
2f837df
f6c7219
097d52b
768109d
cc31c56
da8d28d
565622a
2e8fab4
f9ef6c1
ac7aff1
e4b618d
f5a3df1
779f0bb
7da37bf
90d1504
6a2c909
80085fb
44dae64
bd10725
38ffdfd
be4183b
fe1a264
7b91499
ac2ee09
6e7bb2d
7cf3893
a53fd90
9303933
6e39ec2
4fa8070
477eceb
9d22ee1
61fda2e
11668ce
15f9264
d87b4b1
42db029
4569ece
34fdfce
9927b0d
1df7027
edd733d
797bece
f9b6a27
e9c0ef0
4e8e993
158bcdc
9c38ef7
8eb9a12
78837e0
0a00752
4be5c2f
2bd27b7
2affa66
cc5943f
94b2030
8f3c220
654c8bb
88c14df
5ec5dae
cbd733e
4263650
39403a9
31e9dbb
d1cc5eb
f08e942
75c7017
aaafcf0
1f87b90
298d620
feab8ba
9a0ddd9
81d507a
f4c05d8
bda7b3f
436cac2
05917b4
227fd5f
e159d95
173198e
1d01aae
0bf1cf8
dffd2d2
3026936
a2989bb
5b4e70f
2c2db68
bf188aa
d623ff2
be1b725
15dd13f
9ea9189
274c59a
f2af14d
65941ca
31db049
8013fec
3008f81
17bb6a3
07b333e
5299543
90a1401
033f00a
681535d
ac527a8
4f2d7a9
532e81d
e51af2d
e5d74c8
2a5e458
e0c4f2c
9e24d0a
1bfa8ee
009a055
fbd91d3
ce00d22
7fa7dca
f671e32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# The `solvers` Module | ||
|
||
## Background | ||
|
||
In the beginning NeMoS relied on [JAXopt](https://jaxopt.github.io/stable/) as its optimization backend. | ||
As JAXopt is no longer maintained, we added support for alternative optimization backends. | ||
|
||
Some of JAXopt's funtionality was ported to [Optax](https://optax.readthedocs.io/en/latest/) by Google, and [Optimistix](Optimistix) was started by the community to fill the gaps after JAXopt's deprecation. | ||
|
||
To support flexibility and long-term maintenance, NeMoS now has a backend-agnostic solver interface, allowing the use of solvers from different backend libraries with different interfaces. | ||
|
||
## `AbstractSolver` interface | ||
This interface is defined by [`AbstractSolver`](nemos.solvers.AbstractSolver) and mostly follows the JAXopt API. | ||
All solvers implemented in NeMoS are subclasses of `AbstractSolver`, however subclassing is not required for implementing solvers that can be used with NeMoS. (See [custom solvers](#custom_solvers)) | ||
|
||
The `AbstractSolver` interface requires implementing the following methods: | ||
- `__init__`: all solver parameters and settings should go here. The other methods only take the solver state, current or initial solution (model parameters), and the input data for the objective function. | ||
- `init_state`: Initialize the solver state. | ||
- `update`: Take one step of the optimization algorithm. | ||
- `run`: Run a full optimization. | ||
- `get_accepted_arguments`: Set of argument names that can be passed to `__init__`. These will be the parameters users can change by passing `solver_kwargs` to `BaseRegressor` / `GLM`. | ||
- `get_optim_info`: Collect diagnostic information about the optimization run into an `OptimizationInfo` namedtuple. | ||
|
||
This is a generic class parametrized by `SolverState` and `StepResult`. | ||
`SolverState` in concrete subclasses should be the type of the solver state. | ||
`StepResult` is the type of what is returned by each step of the solver. Typically this is a tuple of the parameters and the solver state. | ||
|
||
### Optimization info | ||
Because different libraries store info about the optimization run in different places, we decided to standardize some common diagnostics. | ||
Optimistix saves some things in the stats dict, Optax and Jaxopt store things in their state. | ||
These are saved in `solver.optimization_info` which is of type `OptimizationInfo`. | ||
|
||
`OptimizationInfo` holds the following fields: | ||
- `function_val`: The final value of the objective function. As not all solvers store this by default, and it's potentially expensive to evaluate, this field is optional. | ||
- `num_steps`: The number of steps taken by the solver. | ||
- `converged`: Whether the optimization converged according to the solver's criteria. | ||
- `reached_max_steps`: Whether the solver reached the maximum number of steps allowed. | ||
|
||
## Adapters | ||
Support for existing solvers from external libraries and the custom implementation of (Prox-)SVRG is done through adapters that "translate" between the interfaces of these external solvers and the `AbstractSolver` interface. | ||
|
||
Creating adapters for existing solvers can be done in multiple ways. | ||
In our experience wrapping solver objects through adapters provides a clean way of doing that, and recommend adapters for new optimization libraries to follow this pattern. | ||
|
||
[`SolverAdapter`](nemos.solvers.SolverAdapter) provides methods for wrapping existing solvers. | ||
Each subclass of `SolverAdapter` has to define the methods of `AbstractInterface`, as well as a `_solver_cls` class variable signaling the type of solver wrapped by it. | ||
During construction it has to set a `_solver` attribute that is a concrete instance of `_solver_cls`. | ||
|
||
Default method implementations: | ||
- A default implementation of `get_accepted_arguments` is provided, returning the arguments to `__init__`, `_solver_cls`, and `_solver_cls.__init__`, and discarding the ones required by `AbstractSolver.__init__`. | ||
- `__getattr__` dispatches every attribute call to the wrapped `_solver`. | ||
- `__init_subclass__` generates a docstring for the adapter including accepted arguments and the wrapped solver's documentation. | ||
|
||
Currently we provide adapters for two optimization backends: | ||
- [`OptimistixAdapter`](nemos.solvers.OptimistixAdapter) wraps Optimistix solvers. | ||
- [`JaxoptAdapter`](nemos.solvers.JaxoptAdapter) wraps JAXopt solvers. As `SVRG` and `ProxSVRG` follow the JAXopt interface, these are also wrapped with `JaxoptAdapter`. | ||
|
||
|
||
## List of available solvers | ||
|
||
``` | ||
Abstract Class AbstractSolver | ||
│ | ||
├─ Abstract Subclass SolverAdapter | ||
│ │ | ||
│ ├─ Abstract Subclass OptimistixAdapter | ||
│ │ │ | ||
│ │ ├─ Concrete Subclass OptimistixBFGS | ||
│ │ ├─ Concrete Subclass OptimistixLBFGS | ||
│ │ ├─ Concrete Subclass OptimistixNonlinearCG | ||
│ │ └─ Concrete Subclass OptaxOptimistixSolver | ||
│ │ │ | ||
│ │ ├─ Concrete Subclass OptaxOptimistixLBFGS | ||
│ │ ├─ Concrete Subclass OptaxOptimistixGradientDescent | ||
│ │ └─ Concrete Subclass OptaxOptimistixProximalGradient | ||
│ │ | ||
│ └─ Abstract Subclass JaxoptAdapter | ||
│ │ | ||
│ ├─ Concrete Subclass JaxoptLBFGS | ||
│ ├─ Concrete Subclass JaxoptGradientDescent | ||
│ ├─ Concrete Subclass JaxoptProximalGradient | ||
│ ├─ Concrete Subclass JaxoptBFGS | ||
│ ├─ Concrete Subclass JaxoptNonlinearCG | ||
│ │ | ||
│ ├─ Concrete Subclass WrappedSVRG | ||
│ └─ Concrete Subclass WrappedProxSVRG | ||
``` | ||
|
||
`OptaxOptimistixSolver` is for using Optax solvers, utilizing `optimistix.OptaxMinimiser` to run the full optimization loop. | ||
bagibence marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Optimistix does not have implementations of Nesterov acceleration, so gradient descent is implemented by wrapping `optax.sgd` which does support it. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a reference to the Nesterov acceleration? |
||
Note that `OptaxOptimistixSolver` allows using any solver from Optax (e.g., Adam). See `OptaxOptimistixGradientDescent` for a template of how to wrap new Optax solvers. | ||
|
||
## Custom solvers | ||
If you want to use your own solver in `nemos`, you just have to write a solver that adheres to the `AbstractSolver` interface, and it should be straightforward to plug in. | ||
While it is not necessary, a way to ensure adherence to the interface is subclassing `AbstractSolver`. | ||
|
||
Currently, the solver registry defines which implementation to use for each algorithm, so that has to be overwritten in order to tell `nemos` to use this custom class, but in the future we are [planning to support passing any solver to `BaseRegressor`](https://github.com/flatironinstitute/nemos/issues/378). | ||
|
||
We might also define something like an `ImplementsSolverInterface` protocol as well to easily check if user-supplied solvers define the methods required for the interface. | ||
|
||
## Stochastic optimization | ||
To run stochastic (~mini-batch) optimization, JAXopt used a `run_iterator` method. | ||
Instead of the full input data `run_iterator` accepts a generator / iterator that provides batches of data. | ||
|
||
For solvers defined in `nemos` that can be used this way, we will likely provide `StochasticMixin` which borrows the implementation from JAXopt. (Or some version of it. See below.). | ||
bagibence marked this conversation as resolved.
Show resolved
Hide resolved
|
||
We will likely define an interface or protocol for this, allowing custom (user-defined) solvers to also implement their own version. | ||
We will also have to decide on how this will be exposed to users on the level of `BaseRegressor` and `GLM`. | ||
|
||
Note that (Prox-)SVRG is especially well-suited for running stochastic optimization, however it currently requires the optimization loop to be implemented separately as it is a bit more involved than what is done by `run_iterator`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrap this in a admonition on SVRG, to make it more visible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. It's an info now. Is that good or should it be a warning instead? |
||
A potential solution to this would be to provide a separate method that accepts the full data and takes care of the batching. That might be a more convenient alternative to the current `run_iterator` as well. | ||
|
||
## Note on line searches vs. fixed stepsize in Optimistix | ||
By default Optimistix doesn't expose the search attribute of concrete solvers but we might want to flexibly switch between linesearches and constant learning rates depending on whether `stepsize` is passed to the solver. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this idea |
||
A solution to this would be to create short redefinitions of the required solvers with the `search` as an argument to `__init__`, and in the adapter dealing with `stepsize` with something like: | ||
```python | ||
class BFGS(AbstractBFGS[Y, Aux, _Hessian]): | ||
rtol: float | ||
atol: float | ||
norm: Callable[[PyTree], Scalar] | ||
use_inverse: bool | ||
descent: NewtonDescent | ||
search: AbstractSearch | ||
verbose: frozenset[str] | ||
|
||
def __init__( | ||
self, | ||
rtol: float, | ||
atol: float, | ||
norm: Callable[[PyTree], Scalar] = max_norm, | ||
use_inverse: bool = True, | ||
verbose: frozenset[str] = frozenset(), | ||
search: AbstractSearch = Zoom(initial_guess_strategy="one"), | ||
): | ||
self.rtol = rtol | ||
self.atol = atol | ||
self.norm = norm | ||
self.use_inverse = use_inverse | ||
self.descent = NewtonDescent(linear_solver=lx.Cholesky()) | ||
self.search = search | ||
self.verbose = verbose | ||
``` | ||
|
||
and | ||
|
||
```python | ||
if "stepsize" in solver_init_kwargs: | ||
assert "search" not in solver_init_kwargs, "Specify either search or stepsize" | ||
solver_init_kwargs["search"] = optx.LearningRate( | ||
solver_init_kwargs.pop("stepsize") | ||
) | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
04-basis_module.md | ||
05-observation_models.md | ||
06-regularizer.md | ||
07-solvers.md | ||
``` | ||
|
||
## Introduction | ||
|
Uh oh!
There was an error while loading. Please reload this page.