Skip to content
Draft
Show file tree
Hide file tree
Changes from 128 commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
acf56b8
Add very basic AbstractSolver
bagibence Jun 19, 2025
412d4c6
Add wrapper for Jaxopt solvers
bagibence Jun 19, 2025
02b1f59
Use JaxoptWrapper for (Prox)SVRG
bagibence Jun 19, 2025
7fb12a7
Export the wrappers
bagibence Jun 19, 2025
5079e40
Create solver_registry
bagibence Jun 19, 2025
37481e4
Simplify BaseRegressor.instantiate_solver
bagibence Jun 19, 2025
3c2387f
Remove BaseRegressor.solver_kwargs setter checks
bagibence Jun 19, 2025
42f77e5
Update tests
bagibence Jun 19, 2025
4f3c0bd
Add generic types to AbstractSolver
bagibence Jun 26, 2025
fc692da
Types in JaxoptWrapper
bagibence Jun 26, 2025
4ab4f4a
Add wrapper for Optimistix solvers
bagibence Jun 26, 2025
c33c8f3
Add proximal gradient using Optax and Optimistix
bagibence Jun 26, 2025
a8491b5
Update solver registry
bagibence Jun 26, 2025
439ace7
Make a test handle the new proximal gradient
bagibence Jun 26, 2025
ea86249
Add JAXopt and Optimistix NonlinearCG
bagibence Jun 27, 2025
3c6d715
Fix RecursionError when deepcopying JaxoptWrapper
bagibence Jun 27, 2025
cde1c63
Fix a test: solver state is not necessarily a NamedTuple
bagibence Jun 27, 2025
ef55bf5
Remove duplicate max_steps in OptimistixConfig
bagibence Jun 27, 2025
2f837df
Pass verbose and norm to OptaxMinimiser in custom ProximalGradient
bagibence Jun 27, 2025
f6c7219
Add OptaxOptimistixGradientDescent
bagibence Jun 27, 2025
097d52b
Rename _prox_grad to optax_optimistix_solvers
bagibence Jun 27, 2025
768109d
Delete unused BaseRegressor._inspect_solver_kwargs
bagibence Jun 27, 2025
cc31c56
Introduce AbstractSolver.get_accepted_arguments
bagibence Jun 27, 2025
da8d28d
Hndle jaxopt-optimistix parameter name mismatch in one place
bagibence Jun 27, 2025
565622a
Remove _replace_maxiter
bagibence Jun 27, 2025
2e8fab4
build user_args dict in a more straightforward way
bagibence Jun 27, 2025
f9ef6c1
Remove unused code
bagibence Jun 27, 2025
ac7aff1
Remove some unnecessary asserts in tests
bagibence Jun 27, 2025
e4b618d
Pass things by keyword and reorder methods
bagibence Jun 27, 2025
f5a3df1
whitespaces
bagibence Jun 27, 2025
779f0bb
Relax a test
bagibence Jul 3, 2025
7da37bf
Clean up handling norm and verbose for Optimistix
bagibence Jul 3, 2025
90d1504
Expose `nesterov` parameter in Optax-based GDs
bagibence Jul 3, 2025
6a2c909
Keep accepted arguments as set to avoid unnecessary conversions
bagibence Jul 3, 2025
80085fb
Add SolverAdapter and some docstrings
bagibence Jul 3, 2025
44dae64
Fix import paths
bagibence Jul 3, 2025
bd10725
Generate docstring for solver adapters.
bagibence Jul 3, 2025
38ffdfd
Add __init__ to AbstractSolver and docstrings
bagibence Jul 3, 2025
be4183b
Add docstrings and remove unused imports
bagibence Jul 3, 2025
fe1a264
Inline a small method
bagibence Jul 3, 2025
7b91499
typo
bagibence Jul 3, 2025
ac2ee09
Run isort
bagibence Jul 29, 2025
6e7bb2d
typo
bagibence Jul 4, 2025
7cf3893
Update solver registry
bagibence Jul 4, 2025
a53fd90
fix import
bagibence Jul 4, 2025
9303933
Add optax, optimistix deps
bagibence Jul 4, 2025
6e39ec2
Add option to control optimization backend for tests with env var
bagibence Jul 4, 2025
4fa8070
Add backend-specific envs to tox and mention in CONTRIBUTING.md
bagibence Jul 4, 2025
477eceb
Add backend-specific tests to ci.yml
bagibence Jul 4, 2025
9d22ee1
Make solver implementation files start with underscore
bagibence Jul 4, 2025
61fda2e
pytest -n auto for the backend-specific envs in tox
bagibence Jul 4, 2025
11668ce
Run tox -e fix
bagibence Jul 4, 2025
15f9264
Type annotations
bagibence Jul 7, 2025
d87b4b1
Constrain Optax version
bagibence Jul 7, 2025
42db029
lint
bagibence Jul 8, 2025
4569ece
Switch Cauchy convergence to use max_norm
bagibence Jul 8, 2025
34fdfce
Set f_struct dtype dynamically
bagibence Jul 8, 2025
9927b0d
Put back an assert: solver.fun is callable
bagibence Jul 8, 2025
1df7027
Disable curvature test in OptaxOptimistixProximalGradient's linesearch
bagibence Jul 8, 2025
edd733d
Removed unused import
bagibence Jul 8, 2025
797bece
Docstring, imports
bagibence Jul 8, 2025
f9b6a27
Comment out currently unused code
bagibence Jul 8, 2025
e9c0ef0
Docstrings and comments
bagibence Jul 8, 2025
4e8e993
Add OptaxOptimistixLBFGS
bagibence Jul 8, 2025
158bcdc
lint
bagibence Jul 8, 2025
9c38ef7
Try adding a purely optimistix-based proximal gradient
bagibence Jul 7, 2025
8eb9a12
Enable OptimistixProximalGradient
bagibence Jul 8, 2025
78837e0
Move the purely optimistix-based ProximalGradient inside nemos
bagibence Jul 8, 2025
0a00752
Increase stepsize in ProxBacktrackingArmijo if accepted
bagibence Jul 11, 2025
4be5c2f
First version of DirectProximalGradient
bagibence Jul 11, 2025
2bd27b7
Move fista_line_search into ProximalGradient and use Cauchy
bagibence Jul 11, 2025
2affa66
remove todo
bagibence Jul 11, 2025
cc5943f
Rename nesterov parameter to acceleration in Optax-based GD solvers
bagibence Jul 11, 2025
94b2030
Type annotations
bagibence Jul 11, 2025
8f3c220
Fix ProximalGradient logic
bagibence Jul 11, 2025
654c8bb
Pass fun instead of fun_with_aux to init solvers in OptimistixAdapter
bagibence Jul 11, 2025
88c14df
Add clarification, rename variables, store fun value in state
bagibence Jul 11, 2025
5ec5dae
Remove unnecessary solver_init_kwargs.pop calls
bagibence Jul 14, 2025
cbd733e
Add missing type annotation
bagibence Jul 14, 2025
4263650
Use super() in OptimistixWrapper.get_accepted_arguments
bagibence Jul 14, 2025
39403a9
Remove OptimistixLBFGS until optimistix.LBFGS is released
bagibence Jul 14, 2025
31e9dbb
Add AbstractSolver.run_iterator
bagibence Jul 14, 2025
d1cc5eb
Make AbstractSolver.run_iterator optional to overwrite
bagibence Jul 16, 2025
f08e942
Fix bug in OptaxOptimistixProximalGradient
bagibence Jul 16, 2025
75c7017
Remove DirectProximalGradient
bagibence Jul 16, 2025
aaafcf0
Extend SolverAdapter.get_accepted_arguments
bagibence Jul 16, 2025
1f87b90
Remove OptimistixProximalGradient
bagibence Jul 17, 2025
298d620
Use super().__init__ in OptaxOptimistixProximalGradient
bagibence Jul 17, 2025
feab8ba
Move maxiter property to OptimistixAdapter
bagibence Jul 17, 2025
9a0ddd9
Fix bug in OptimistixAdapter
bagibence Jul 17, 2025
81d507a
Add StochasticSolver, StochasticMixin, use it, test it
bagibence Jul 17, 2025
f4c05d8
Type annotate *args: Any
bagibence Jul 17, 2025
bda7b3f
Replace JaxoptWrapper._extend_args with hyperparams_prox tuple
bagibence Jul 22, 2025
436cac2
Set JAXopt solvers as default for now
bagibence Jul 22, 2025
05917b4
OptimistixConfig docstring
bagibence Jul 22, 2025
227fd5f
Consistent JaxoptStepResult type
bagibence Jul 29, 2025
e159d95
OptimizationInfo and get_optim_info
bagibence Jul 29, 2025
173198e
Clear up the run_iterator, maxiter inheritance incompatibility
bagibence Jul 29, 2025
1d01aae
Rename JaxoptWrapper to JaxoptAdapter
bagibence Jul 29, 2025
0bf1cf8
Fix tolerance and use get_optim_info in a test
bagibence Jul 29, 2025
dffd2d2
Keep JAXopt's LBFGS as default for now
bagibence Jul 29, 2025
3026936
Add a rough draft about solvers to developer notes
bagibence Jul 29, 2025
a2989bb
Merge branch 'development' into solver_adapters
bagibence Jul 31, 2025
5b4e70f
Make SolverAdapter explicitly abstract
bagibence Jul 31, 2025
2c2db68
Documentation
bagibence Jul 31, 2025
bf188aa
Use more explicit _f_struct_factory
Aug 7, 2025
d623ff2
Remove things related to stochastic optimization
Aug 7, 2025
be1b725
Remove unused arg in BaseRegressor.instantiate_solver, update typing
Aug 7, 2025
15dd13f
Rename Optax-based solvers
bagibence Aug 7, 2025
9ea9189
Update mentions of jaxopt in the developer notes
bagibence Aug 7, 2025
274c59a
Update mentions of jaxopt in the docs
bagibence Aug 7, 2025
f2af14d
Update how SolverAdapter creates the docstring
bagibence Aug 7, 2025
65941ca
Make OptimistixOptaxSolver abstract, move it, improve its docs
bagibence Aug 7, 2025
31db049
Remove prox from accepted arguments by proximal jaxopt solvers
bagibence Aug 7, 2025
8013fec
Add get_solver_documentation helper function.
bagibence Aug 7, 2025
3008f81
Add jaxtyping dep, import ArrayLike and PyTree from there.
bagibence Aug 8, 2025
17bb6a3
Undo previous changes and just use Pytree
bagibence Aug 8, 2025
07b333e
TODO -> TODO:, NOTE -> NOTE: for highlighting
bagibence Aug 8, 2025
5299543
Type annotate solver_registry
bagibence Aug 8, 2025
90a1401
Improve get_solver_documentation and add it to the API reference
bagibence Aug 8, 2025
033f00a
Types in OptimizationInfo were off
bagibence Aug 8, 2025
681535d
Update docs
bagibence Aug 14, 2025
ac527a8
Only check and save solver.fun if it exists
bagibence Aug 14, 2025
4f2d7a9
Update documentation
bagibence Aug 14, 2025
532e81d
Remove comment. Developer notes about solvers are a bit nicer now
bagibence Aug 14, 2025
e51af2d
Small docs update
bagibence Aug 14, 2025
e5d74c8
Add how-to guide on creating custom solvers.
bagibence Aug 14, 2025
2a5e458
Delete scratch cell from docs/how_to_guide/custom_solvers.md
bagibence Aug 14, 2025
e0c4f2c
Apply suggestions from review in developer guide
bagibence Aug 31, 2025
9e24d0a
Typing in example
bagibence Aug 15, 2025
1bfa8ee
Any->Params in solver methods in nemos.typing
bagibence Aug 15, 2025
009a055
Require StepResult to be a tuple
bagibence Aug 15, 2025
fbd91d3
Remove todo comment
bagibence Aug 15, 2025
ce00d22
Remove note comments
bagibence Aug 31, 2025
7fa7dca
Admonition for SVRG and run_iterator
bagibence Aug 31, 2025
f671e32
Add get_flattener_unflattener and use it in how-to guide
bagibence Aug 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,56 @@ jobs:
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

tox_backend_jaxopt:
if: ${{ !github.event.pull_request.draft }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.10', '3.11', '3.12']
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4 # Use v4 for compatibility with pyproject.toml
with:
python-version: ${{ matrix.python-version }}
cache: pip

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox

- name: Run solver-dependent tests with JAXopt backend
run: tox -e backend-jaxopt

tox_backend_optimistix:
if: ${{ !github.event.pull_request.draft }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.10', '3.11', '3.12']
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4 # Use v4 for compatibility with pyproject.toml
with:
python-version: ${{ matrix.python-version }}
cache: pip

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox

- name: Run solver-dependent tests with Optimistix backend
run: tox -e backend-optimistix

tox_check:
if: ${{ !github.event.pull_request.draft }}
runs-on: ubuntu-latest
Expand Down Expand Up @@ -110,6 +160,8 @@ jobs:
- prevent_docs_absolute_links
- tox_check
- check-relative-links
# - tox_backend_jaxopt
# - tox_backend_optimistix
runs-on: ubuntu-latest
steps:
- name: Decide whether all tests and notebooks succeeded
Expand Down
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ There are several options for how to run a subset of tests:
- Run a specific test within a specific module: `pytests tests/test_glm.py::test_func`
- Another example specifying a test method via the command line: `pytest tests/test_glm.py::GLMClass::test_func`

To run tests with solvers implemented with either the `jaxopt` or the `optimistix` backend, set the `NEMOS_SOLVER_BACKEND` environment variable. E.g. setting for a single test run:
`NEMOS_SOLVER_BACKEND=jaxopt pytest`

There are also dedicated tox environments that do this automatically and run a subset of tests that depend on the solvers:
`tox -e backend-jaxopt,backend-optimistix` will run solver-dependent tests with both backends separately.

#### Adding tests

New tests can be added in any of the existing `tests/test_*.py` scripts. Tests should be functions, contained within classes. The class contains a bunch of related tests
Expand Down
13 changes: 13 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,19 @@ Utility functions for running convolution over the sample axis.
create_convolutional_predictor


The ``nemos.solvers`` module
----------------------------
JAX-based optimizers used for parameter fitting.

.. currentmodule:: nemos.solvers

.. autosummary::
:toctree: generated/solvers
:nosignatures:

get_solver_documentation


The ``nemos.identifiability_constraints`` module
------------------------------------------------
Functions to apply identifiability constraints to rank-deficient feature matrices, ensuring the uniqueness of model
Expand Down
2 changes: 1 addition & 1 deletion docs/developers_notes/01-base_class.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The `base_class` module introduces the `Base` class and abstract classes definin

The `Base` class is envisioned as the foundational component for any object type (e.g., basis, regression, dimensionality reduction, clustering, observation models, regularizers etc.). In contrast, abstract classes derived from `Base` define overarching object categories (e.g., `base_regressor.BaseRegressor` is building block for GLMs, GAMS, etc. while [`observation_models.Observations`](nemos.observation_models.Observations) is the building block for the Poisson observations, Gamma observations, ... etc.).

Designed to be compatible with the `scikit-learn` API, the class inherits directly from [`sklearn.BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html#sklearn.base.BaseEstimator). The class facilitate access to `scikit-learn`'s robust pipeline and cross-validation modules, while customizing the `set_param` method for working with NeMoS basis objects. This is achieved while leveraging the accelerated computational capabilities of `jax` and `jaxopt` in the backend, which is essential for analyzing extensive neural recordings and fitting large models.
Designed to be compatible with the `scikit-learn` API, the class inherits directly from [`sklearn.BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html#sklearn.base.BaseEstimator). The class facilitate access to `scikit-learn`'s robust pipeline and cross-validation modules, while customizing the `set_param` method for working with NeMoS basis objects. This is achieved while leveraging the accelerated computational capabilities of `jax` in the backend, which is essential for analyzing extensive neural recordings and fitting large models.

Below a scheme of how we envision the architecture of the NeMoS models.

Expand Down
5 changes: 2 additions & 3 deletions docs/developers_notes/02-base_regressor.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Public attributes are stored as properties:

- `regularizer`: An instance of the [`nemos.regularizer.Regularizer`](nemos.regularizer.Regularizer) class. The setter for this property accepts either the instance directly or a string that is used to instantiate the appropriate regularizer.
- `regularizer_strength`: A float quantifying the amount of regularization.
- `solver_name`: One of the `jaxopt` solver supported solvers, currently "GradientDescent", "BFGS", "LBFGS", "ProximalGradient" and, "NonlinearCG".
- `solver_name`: One of the supported solvers in the [solver registry](nemos.solvers._solver_registry.solver_registry), currently "GradientDescent", "BFGS", "LBFGS", "ProximalGradient", "SVRG", and "NonlinearCG".
- `solver_kwargs`: Extra keyword arguments to be passed at solver initialization.
- `solver_init_state`, `solver_update`, `solver_run`: Read-only property with a partially evaluated `solver.init_state`, `solver.update` and, `solver.run` methods. The partial evaluation guarantees a consistent API for all solvers.

Expand All @@ -45,8 +45,7 @@ Typically, in `YourRegressor` you will call `self.solver_init_state` at the para
:::{admonition} Solvers
:class: note

Solvers are typically optimizers from the `jaxopt` package, but in principle they could be custom optimization routines as long as they respect the `jaxopt` api (i.e., have a `run`, `init_state`, and [`update`](nemos.glm.GLM.update) method with the appropriate input/output types).
We rely on `jaxopt` because it provides a comprehensive set of robust, GPU accelerated, batchable and differentiable optimizers in JAX, that are highly customizable. In the future we may provide a number of custom solvers optimized for convex stochastic optimization.
We implement a set of standard solvers in NeMoS, relying on various backends. In the future we are planning to add support for user-defined solvers, because in principle any object that adheres to the [`AbstractSolver`](nemos.solvers._abstract_solver.AbstractSolver) interface should be compatible with NeMoS. For more information about the solver interface and solvers, see the [developer notes about solvers](07-solvers.md).
:::

## Contributor Guidelines
Expand Down
4 changes: 2 additions & 2 deletions docs/developers_notes/03-glm.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Our design aligns with the `scikit-learn` API, facilitating seamless integration

The classes provided here are modular by design offering a standard foundation for any GLM variant.

Instantiating a specific GLM simply requires providing an observation model (Gamma, Poisson, etc.), a regularization strategies (Ridge, Lasso, etc.) and an optimization scheme during initialization. This is done using the [`nemos.observation_models.Observations`](nemos.observation_models.Observations), [`nemos.regularizer.Regularizer`](nemos.regularizer.Regularizer) objects as well as the compatible `jaxopt` solvers, respectively.
Instantiating a specific GLM simply requires providing an observation model (Gamma, Poisson, etc.), a regularization strategies (Ridge, Lasso, etc.) and an optimization scheme during initialization. This is done using the [`nemos.observation_models.Observations`](nemos.observation_models.Observations), [`nemos.regularizer.Regularizer`](nemos.regularizer.Regularizer) objects as well as the compatible solvers, respectively.


<figure markdown>
Expand All @@ -41,7 +41,7 @@ The [`GLM`](nemos.glm.GLM) class provides a direct implementation of the GLM mod
- **`intercept_`**: Stores the bias terms' solutions as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation.
- **`dof_resid_`**: The degrees of freedom of the model's residual. this quantity is used to estimate the scale parameter, see below, and compute frequentist confidence intervals.
- **`scale_`**: The scale parameter of the observation distribution, which together with the rate, uniquely specifies a distribution of the exponential family. Example: a 1D Gaussian is specified by the mean which is the rate, and the standard deviation, which is the scale.
- **`solver_state_`**: Indicates the solver's state. For specific solver states, refer to the [`jaxopt` documentation](https://jaxopt.github.io/stable/index.html#).
- **`solver_state_`**: Indicates the solver's state. For specific solver states, refer to the [solver implementations](nemos.solvers).

Additionally, the [`GLM`](nemos.glm.GLM) class inherits the attributes of `BaseRegressor`, see the [relative note](02-base_regressor.md) for more information.

Expand Down
2 changes: 1 addition & 1 deletion docs/developers_notes/06-regularizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Abstract Class Regularizer
```

:::{note}
If we need advanced adaptive solvers (e.g., Adam, LAMB etc.) in the future, we should consider adding [`Optax`](https://optax.readthedocs.io/en/latest/) as a dependency, which is compatible with `jaxopt`, see [here](https://jaxopt.github.io/stable/_autosummary/jaxopt.OptaxSolver.html#jaxopt.OptaxSolver).
If we need advanced adaptive solvers (e.g., Adam, LAMB etc.) in the future, we can use [`Optax`](https://optax.readthedocs.io/en/latest/) solvers through [`OptimistixOptaxSolver`](nemos.solvers._optimistix_solvers.OptimistixOptaxSolver). See [`OptimistixOptaxLBFGS`](nemos.solvers._optax_optimistix_solvers.OptimistixOptaxLBFGS) for an example.
:::

(the-abstract-class-regularizer)=
Expand Down
151 changes: 151 additions & 0 deletions docs/developers_notes/07-solvers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# The `solvers` Module

## Background

In the beginning NeMoS relied on [JAXopt](https://jaxopt.github.io/stable/) as its optimization backend.
As JAXopt is no longer maintained, we added support for alternative optimization backends.

Some of JAXopt's funtionality was ported to [Optax](https://optax.readthedocs.io/en/latest/) by Google, and [Optimistix](Optimistix) was started by the community to fill the gaps after JAXopt's deprecation.

To support flexibility and long-term maintenance, NeMoS now has a backend-agnostic solver interface, allowing the use of solvers from different backend libraries with different interfaces.

## `AbstractSolver` interface
This interface is defined by [`AbstractSolver`](nemos.solvers.AbstractSolver) and mostly follows the JAXopt API.
All solvers implemented in NeMoS are subclasses of `AbstractSolver`, however subclassing is not required for implementing solvers that can be used with NeMoS. (See [custom solvers](#custom_solvers))

The `AbstractSolver` interface requires implementing the following methods:
- `__init__`: all solver parameters and settings should go here. The other methods only take the solver state, current or initial solution (model parameters), and the input data for the objective function.
- `init_state`: Initialize the solver state.
- `update`: Take one step of the optimization algorithm.
- `run`: Run a full optimization.
- `get_accepted_arguments`: Set of argument names that can be passed to `__init__`. These will be the parameters users can change by passing `solver_kwargs` to `BaseRegressor` / `GLM`.
- `get_optim_info`: Collect diagnostic information about the optimization run into an `OptimizationInfo` namedtuple.

This is a generic class parametrized by `SolverState` and `StepResult`.
`SolverState` in concrete subclasses should be the type of the solver state.
`StepResult` is the type of what is returned by each step of the solver. Typically this is a tuple of the parameters and the solver state.

### Optimization info
Because different libraries store info about the optimization run in different places, we decided to standardize some common diagnostics.
Optimistix saves some things in the stats dict, Optax and Jaxopt store things in their state.
These are saved in `solver.optimization_info` which is of type `OptimizationInfo`.

`OptimizationInfo` holds the following fields:
- `function_val`: The final value of the objective function. As not all solvers store this by default, and it's potentially expensive to evaluate, this field is optional.
- `num_steps`: The number of steps taken by the solver.
- `converged`: Whether the optimization converged according to the solver's criteria.
- `reached_max_steps`: Whether the solver reached the maximum number of steps allowed.

## Adapters
Support for existing solvers from external libraries and the custom implementation of (Prox-)SVRG is done through adapters that "translate" between the interfaces of these external solvers and the `AbstractSolver` interface.

Creating adapters for existing solvers can be done in multiple ways.
In our experience wrapping solver objects through adapters provides a clean way of doing that, and recommend adapters for new optimization libraries to follow this pattern.

[`SolverAdapter`](nemos.solvers.SolverAdapter) provides methods for wrapping existing solvers.
Each subclass of `SolverAdapter` has to define the methods of `AbstractInterface`, as well as a `_solver_cls` class variable signaling the type of solver wrapped by it.
During construction it has to set a `_solver` attribute that is a concrete instance of `_solver_cls`.

Default method implementations:
- A default implementation of `get_accepted_arguments` is provided, returning the arguments to `__init__`, `_solver_cls`, and `_solver_cls.__init__`, and discarding the ones required by `AbstractSolver.__init__`.
- `__getattr__` dispatches every attribute call to the wrapped `_solver`.
- `__init_subclass__` generates a docstring for the adapter including accepted arguments and the wrapped solver's documentation.

Currently we provide adapters for two optimization backends:
- [`OptimistixAdapter`](nemos.solvers.OptimistixAdapter) wraps Optimistix solvers.
- [`JaxoptAdapter`](nemos.solvers.JaxoptAdapter) wraps JAXopt solvers. As `SVRG` and `ProxSVRG` follow the JAXopt interface, these are also wrapped with `JaxoptAdapter`.


## List of available solvers

```
Abstract Class AbstractSolver
├─ Abstract Subclass SolverAdapter
│ │
│ ├─ Abstract Subclass OptimistixAdapter
│ │ │
│ │ ├─ Concrete Subclass OptimistixBFGS
│ │ ├─ Concrete Subclass OptimistixLBFGS
│ │ ├─ Concrete Subclass OptimistixNonlinearCG
│ │ └─ Concrete Subclass OptaxOptimistixSolver
│ │ │
│ │ ├─ Concrete Subclass OptaxOptimistixLBFGS
│ │ ├─ Concrete Subclass OptaxOptimistixGradientDescent
│ │ └─ Concrete Subclass OptaxOptimistixProximalGradient
│ │
│ └─ Abstract Subclass JaxoptAdapter
│ │
│ ├─ Concrete Subclass JaxoptLBFGS
│ ├─ Concrete Subclass JaxoptGradientDescent
│ ├─ Concrete Subclass JaxoptProximalGradient
│ ├─ Concrete Subclass JaxoptBFGS
│ ├─ Concrete Subclass JaxoptNonlinearCG
│ │
│ ├─ Concrete Subclass WrappedSVRG
│ └─ Concrete Subclass WrappedProxSVRG
```

`OptaxOptimistixSolver` is for using Optax solvers, utilizing `optimistix.OptaxMinimiser` to run the full optimization loop.
Optimistix does not have implementations of Nesterov acceleration, so gradient descent is implemented by wrapping `optax.sgd` which does support it.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a reference to the Nesterov acceleration?

Note that `OptaxOptimistixSolver` allows using any solver from Optax (e.g., Adam). See `OptaxOptimistixGradientDescent` for a template of how to wrap new Optax solvers.

## Custom solvers
If you want to use your own solver in `nemos`, you just have to write a solver that adheres to the `AbstractSolver` interface, and it should be straightforward to plug in.
While it is not necessary, a way to ensure adherence to the interface is subclassing `AbstractSolver`.

Currently, the solver registry defines which implementation to use for each algorithm, so that has to be overwritten in order to tell `nemos` to use this custom class, but in the future we are [planning to support passing any solver to `BaseRegressor`](https://github.com/flatironinstitute/nemos/issues/378).

We might also define something like an `ImplementsSolverInterface` protocol as well to easily check if user-supplied solvers define the methods required for the interface.

## Stochastic optimization
To run stochastic (~mini-batch) optimization, JAXopt used a `run_iterator` method.
Instead of the full input data `run_iterator` accepts a generator / iterator that provides batches of data.

For solvers defined in `nemos` that can be used this way, we will likely provide `StochasticMixin` which borrows the implementation from JAXopt. (Or some version of it. See below.).
We will likely define an interface or protocol for this, allowing custom (user-defined) solvers to also implement their own version.
We will also have to decide on how this will be exposed to users on the level of `BaseRegressor` and `GLM`.

Note that (Prox-)SVRG is especially well-suited for running stochastic optimization, however it currently requires the optimization loop to be implemented separately as it is a bit more involved than what is done by `run_iterator`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap this in a admonition on SVRG, to make it more visible

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. It's an info now. Is that good or should it be a warning instead?

A potential solution to this would be to provide a separate method that accepts the full data and takes care of the batching. That might be a more convenient alternative to the current `run_iterator` as well.

## Note on line searches vs. fixed stepsize in Optimistix
By default Optimistix doesn't expose the search attribute of concrete solvers but we might want to flexibly switch between linesearches and constant learning rates depending on whether `stepsize` is passed to the solver.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea

A solution to this would be to create short redefinitions of the required solvers with the `search` as an argument to `__init__`, and in the adapter dealing with `stepsize` with something like:
```python
class BFGS(AbstractBFGS[Y, Aux, _Hessian]):
rtol: float
atol: float
norm: Callable[[PyTree], Scalar]
use_inverse: bool
descent: NewtonDescent
search: AbstractSearch
verbose: frozenset[str]

def __init__(
self,
rtol: float,
atol: float,
norm: Callable[[PyTree], Scalar] = max_norm,
use_inverse: bool = True,
verbose: frozenset[str] = frozenset(),
search: AbstractSearch = Zoom(initial_guess_strategy="one"),
):
self.rtol = rtol
self.atol = atol
self.norm = norm
self.use_inverse = use_inverse
self.descent = NewtonDescent(linear_solver=lx.Cholesky())
self.search = search
self.verbose = verbose
```

and

```python
if "stepsize" in solver_init_kwargs:
assert "search" not in solver_init_kwargs, "Specify either search or stepsize"
solver_init_kwargs["search"] = optx.LearningRate(
solver_init_kwargs.pop("stepsize")
)
```
1 change: 1 addition & 0 deletions docs/developers_notes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
04-basis_module.md
05-observation_models.md
06-regularizer.md
07-solvers.md
```

## Introduction
Expand Down
Loading