Skip to content

Conversation

bagibence
Copy link
Collaborator

Because JAXopt is no longer maintained, NeMoS is migrating its optimization backed to Optimistix and Optax.
As these are not a full replacement for JAXopt yet, at least in the beginning, solvers from both backends will be supported.
This is achieved through a unified interface for solvers which defines the interaction between the optimization backend's solvers and NeMoS models.

The solver interface each solver must adhere to in order to be compatible with BaseRegressor is defined in AbstractSolver, and mostly follows the previous interface of JAXopt solvers.
Compatibility with existing JAXopt-, Optimistix-, and the (Prox-)SVRG solvers is provided by adapter classes. A base class for these is defined in SolverAdapter.

Instead of looking up solver classes based on their name in nemos.solvers and jaxopt, the class used to implement each algorithm is explicitly defined in the solver registry.
Currently solvers are created based on the algorithm's name, but this could be extended in the future to allow user-defined solvers to be passed to BaseRegressor. As long as a solver implements the AbstractSolver interface, it should be compatible with NeMoS. (In this case the compatibility check with the regularizer could be disabled or rewritten.)

As Optimistix doesn't implement ProximalGradient, this is currently set to use a custom implementation in OptaxOptimistixProximalGradient based on Optax's SGD with linesearch followed by the proximal operator. This seems to work in practice, but is not as theoretically sound as jaxopt.ProximalGradient.
We contributed an implementation of LBFGS to Optimistix (PR for L-BFGS update, PR for zoom linesearch), and a wrapper class for it is already included here (just commented out). Until these PRs are merged and released, JAXopt's or Optax's implementation could be used as default.

As in this interface every solver parameter must be passed to the solver's constructor, BaseRegressor.instantiate_solver is simplified. Each solver can expose its accepted arguments explicitly through get_accepted_arguments. If this is not defined, BaseRegressor tries to infer it.

There are basic differences between the parameters accepted by JAXopt and Optimistix. These differences are handled on solver instantiation, but a warning is raised:

  • JAXopt uses maxiter, Optimistix uses max_steps
  • JAXopt uses a single tol value, while Optimistix uses atol and rtol

Currently the NEMOS_SOLVER_BACKEND environment variable can be used to run tests with a specific backend. I added tox environments that run solver-dependent tests with both backends, and added these to the CI.
This switch could be moved to the main code, so that users have the option to choose the backend they want.

Copy link
Collaborator

@BalzaniEdoardo BalzaniEdoardo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big picture questions:

  1. Do we need JaxOpt still? it looks like the transition would be completed already with this PR
  2. How does this change how the user interacts with solvers? It looks like it doesn't. Is there a reason why the solver registry is public if one can only pass strings to specify solvers?
  3. Does this interface allow arbitrary optimistix and optax solvers? i.e. can we eventually allow passing an instance of a solver directly
  4. For this PR we need some developer guide note explaining the interface? stuff like, do we expect any solver implementing the API to be compatible? do you have any note on how to write an interface to a new solver library? other than matching the api
  5. Can you implement the proximal gradient in a separate PR?
  6. If we are not dropping jaxopt yet, then we could split the optax/optimistix in two separate PRs and add jaxopt to the registry instead of what we have now. But if you think that we can drop jaxopt here, then let's continue on this one.

Next week Billy is joining and we continue discuss this

pass

@abc.abstractmethod
def run(self, init_params: Params, *args) -> StepResult:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe have an abstract run_iterator, taking in a generator; for solvers that are not stochastic raise a ValueError or a more specific one if it exists.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's great! I forgot about that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it to AbstractSolver.
How do you want to control which solver should have it and which shouldn't?
How about adding a base implementation copying jaxopt.StochasticSolver's and only allowing if the class has a class variable _stochastic=True?
Or a StochasticSolverMixin with the implementation and inheriting that in every class that should have it?

**solver_init_kwargs,
)

def _extend_args(self, args):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid the if statement by storing the extended args? or initializing as (reg_stregth,) or empty tuple.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to avoid it? args are passed to the methods by the code using the solver, they can't be stored on construction.
The prepending could be done it other ways, I like that this is explicit.

tags=self.config.tags,
)

self.stats.update(solution.stats)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have a self.stats and the other adapter doesn't. can we replicate the stats in the jaxopt adapter too? you can parse the jaxopt state and extract as much info as we can

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only because Optimistix saves the number of steps taken at the end of a minimise run in the stats instead of the state itself like JAXopt, and I wanted a quick way to access those for the tests.

I haven't given this too much thought, so we could think about if we want to standardize what we keep at the end of the runs.

@bagibence
Copy link
Collaborator Author

bagibence commented Jul 14, 2025

Thanks for taking a look!

Do we need JaxOpt still? it looks like the transition would be completed already with this PR.

I think it's good to keep JAXopt at least as a backup option for a while. Perhaps providing a way like the environment variable mentioned for switching the backend.

How does this change how the user interacts with solvers? It looks like it doesn't.

It doesn't. I aimed for no changes required from the user's side as to not break existing analysis code.

Is there a reason why the solver registry is public if one can only pass strings to specify solvers?

Do you mean that it's not called _solver_registry?
Actually, overwriting the registry dict already gives users a way to plug in their own classes. I'm not sure if that's a feature or a bug.

Does this interface allow arbitrary optimistix and optax solvers? i.e. can we eventually allow passing an instance of a solver directly

Do you mean passing an optimistix.GradientDescent directly? It doesn't allow that immediately, but I think it could be easily be modified (in BaseRegressor) to allow passing instances of anything that implements the solver interface. So for example GLM(solver=OptimistixGradientDescent(learning_rate=1e-3)). Also, OptimistixWrapper could be modified to, instead of constructing _solver, accept an already instantiated solver object. Or just have a function that wraps an existing solver instance.

For this PR we need some developer guide note explaining the interface? stuff like, do we expect any solver implementing the API to be compatible? do you have any note on how to write an interface to a new solver library? other than matching the api

Yes, that makes sense. I would expect any solver implementing the API to be compatible. The only hurdle I see is the compatibility check in the regularizers that works based on name.
I think for a new library subclassing SolverAdapter would be the most sane approach. So e.g. for Optax an alternative instead of relying Optimistix's OptaxMinimiser wrapper could be to have an OptaxWrapper whose run method implements the while loop like done here.

Can you implement the proximal gradient in a separate PR?

Sure!

If we are not dropping jaxopt yet, then we could split the optax/optimistix in two separate PRs and add jaxopt to the registry instead of what we have now. But if you think that we can drop jaxopt here, then let's continue on this one.

Yes, that also makes sense. I think it's useful to keep JAXopt for now and only drop it once the switch is done and you are satisfied with it.

Edit: quote formatting was off

bagibence added 25 commits July 29, 2025 14:36
Defines the interface.
Will add type annotations later.
Instead of trying to find solvers in packages, have an explicit list.
According to the new interface __init__ will
receive everything, and will raise an error at that point.
Simplify how the solver object is accessed in test_solvers and test_glm.
Cleanup for previous commit
@bagibence
Copy link
Collaborator Author

I removed things related to stochastic optimization, this will be done later in #376

Bence Bagi and others added 10 commits August 7, 2025 12:55
Currently using PyTree from jaxtyping. Might want to just alias to Any.
Moved solver interface types to typing.
Use StepResult and SolverState in GLM and BaseRegressor.
For now PyTree is imported as Pytree to not introduce too many changes.
This PyTree is parametrizable, which is used in the solvers module.
ArrayLike is the same as jax.typing.ArrayLike. Importing from jaxtyping
for consistency.
I will open an issue about potentially using jaxtyping.
@bagibence bagibence mentioned this pull request Aug 14, 2025
5 tasks
Copy link
Collaborator

@BalzaniEdoardo BalzaniEdoardo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have @billbrod taking a look but after you address my latest comments, this LGTM


## Background

In the beginning NeMoS relied on [JAXopt](https://jaxopt.github.io/stable/) as its optimization backend.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
In the beginning NeMoS relied on [JAXopt](https://jaxopt.github.io/stable/) as its optimization backend.
In the earlier versions, NeMoS relied on [JAXopt](https://jaxopt.github.io/stable/) as its optimization backend.

│ └─ Concrete Subclass WrappedProxSVRG
```

`OptaxOptimistixSolver` is for using Optax solvers, utilizing `optimistix.OptaxMinimiser` to run the full optimization loop.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`OptaxOptimistixSolver` is for using Optax solvers, utilizing `optimistix.OptaxMinimiser` to run the full optimization loop.
`OptaxOptimistixSolver` is an adapter for Optax solvers, relying on `optimistix.OptaxMinimiser` to run the full optimization loop.

```

`OptaxOptimistixSolver` is for using Optax solvers, utilizing `optimistix.OptaxMinimiser` to run the full optimization loop.
Optimistix does not have implementations of Nesterov acceleration, so gradient descent is implemented by wrapping `optax.sgd` which does support it.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a reference to the Nesterov acceleration?

To run stochastic (~mini-batch) optimization, JAXopt used a `run_iterator` method.
Instead of the full input data `run_iterator` accepts a generator / iterator that provides batches of data.

For solvers defined in `nemos` that can be used this way, we will likely provide `StochasticMixin` which borrows the implementation from JAXopt. (Or some version of it. See below.).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For solvers defined in `nemos` that can be used this way, we will likely provide `StochasticMixin` which borrows the implementation from JAXopt. (Or some version of it. See below.).
For solvers defined in `nemos` that can be used this way, we will likely provide `StochasticMixin` which borrows the implementation from JAXopt (Or some version of it, see below).

We will likely define an interface or protocol for this, allowing custom (user-defined) solvers to also implement their own version.
We will also have to decide on how this will be exposed to users on the level of `BaseRegressor` and `GLM`.

Note that (Prox-)SVRG is especially well-suited for running stochastic optimization, however it currently requires the optimization loop to be implemented separately as it is a bit more involved than what is done by `run_iterator`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap this in a admonition on SVRG, to make it more visible


Currently, the solver registry defines which implementation to use for each algorithm, so that has to be overwritten in order to tell NeMoS to use a custom class.

This is hacky and not an intended use-case for now, but in the future we are [planning to support passing any solver to `BaseRegressor`](https://github.com/flatironinstitute/nemos/issues/378).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if it makes sense to have a guideline for something we do not recommend. It might be better to copy paste this note in the issue #378

split_indices = np.cumsum(sizes)[:-1]
flat_params = np.concat([x.flatten() for x in params])
def unpacker(_flat_params):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be generalized and included in nemos as utility function that we expose. It's useful in general

import jax
import jax.numpy as jnp

def get_packer_unpacker(parameter_tree):
    flat, struct = jax.tree_util.tree_flatten(parameter_tree)
    shapes =  [x.shape for x in flat]
    sizes = jnp.array([x.size for x in flat], dtype=int)
    split_indices = jnp.cumsum(sizes[:-1])

    def packer(parameter_tree):
        flat = jax.tree_util.tree_leaves(parameter_tree)
        return jnp.concatenate([x.flatten() for x in flat])

    def unpacker(flat_params):
        split_params = jnp.split(flat_params, split_indices)
        split_params = [x.reshape(s) for x, s in zip(split_params, shapes)]
        return jax.tree_util.tree_unflatten(struct, split_params)

    return packer, unpacker

Comment on lines +1688 to +1694
# NOTE not testing these anymore since non-JAXopt solvers' state is not necessarily a namedtuple
# assert (
# hasattr(state, "_fields")
# and hasattr(state, "_field_defaults")
# and hasattr(state, "_asdict")
# )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove note

Comment on lines +1652 to +1657
# NOTE not testing these anymore since non-JAXopt solvers' state is not necessarily a namedtuple
# assert (
# hasattr(state, "_fields")
# and hasattr(state, "_field_defaults")
# and hasattr(state, "_asdict")
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove note

setenv =
NEMOS_SOLVER_BACKEND = jaxopt
commands =
pytest -n auto tests/test_solvers.py tests/test_glm.py tests/test_convergence.py tests/test_regularizer.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it is a pain, but can you group all glm tests that calls fit or that calls init solver or any method interacting with solver explicitly in a class and only run that class only?

    pytest -n auto tests/test_solvers.py tests/test_glm.py::TestSolverRelated tests/test_convergence.py tests/test_regularizer.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants