Skip to content

Conversation

@kylejcaron
Copy link
Contributor

Closes #1709.

This adds Equinox support to numpyro's contrib module

@kylejcaron
Copy link
Contributor Author

I'm running into issues with stateful operations.

Equinox manages states functionally, ie

class Model(eqx.Module):
    norm: eqx.nn.BatchNorm
    linear: eqx.nn.Linear

    def __init__(self, key):
        self.norm = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
        self.linear = eqx.nn.Linear(in_features=32, out_features=32, key=key)

    def __call__(self, x, state):
        x, state = self.norm(x, state)
        x = self.linear1(x)
        return x, state

model, state = eqx.nn.make_with_state(Model)(key=rng_key)  
...

for _ in range(steps):
    # Full-batch gradient descent in this simple example.
    model, state, opt_state = make_step(model, state, opt_state, xs, ys)

This is a bit different then the other jax neural net libraries. I think there are two options:

  1. Provide eqx_module() with an uninitialized model class and manage the state under the hood (current approach in this PR). I'm having a bit of trouble getting this to work
    • This also might be a bit awkward to equinox users who are used to managing states
  2. Trust users to use Eager initialization outside the model, with the state when applicable. They'd also have to know to pass the state through the call function. I'm not sure how registering the state with numpyro_mutable works with this

feedback appreciated!

@fehiepsi fehiepsi added the WIP label Mar 13, 2025
x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))

batched_nn = jax.vmap(
nn, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
Copy link
Contributor Author

@kylejcaron kylejcaron Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately theres some overhead for the end user when static args like state are included.

As far as I'm aware, static kwargs arent possible with vmap. otherwise this would be much easier and I could simply do something like

def eqx_module(...):
   ...
    return jax.vmap(partial(nn, state=state), ...)

I'm guessing given the complexity I should probably add an official example to walk through stateful/non-stateful examples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think any vmap logic should happen at the equinox level, not at numpyro eqx_module level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure I have it correct, are you saying that vmap should not be used and assumed inside eqx_module (and therefore the current implementation is better than what I mention in my previous comment)?

Copy link
Member

@fehiepsi fehiepsi Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that vmap should be used neither in the implementation of eqx_module nor in the numpyro model vmap(eqx_module(...)). We should use eqx_module(name, vmapped_module) or something like that, where vmapped_module is an equinox module (but it is vmapped through a non-numpyro mechanism)

Copy link
Contributor Author

@kylejcaron kylejcaron Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I see what youre saying.

I just realized I can use eqx.filter_vmap ahead of time to solve for this outside of the numpyro model, do you think this is sufficient?

linear = eqx.nn.Linear(in_feature=3, out_features=1, key=rng_key)
batched_linear = eqx.filter_vmap(linear)

def model(x):
        # Use the pre-initialized, vmapped module
         nn = eqx_module("nn", batched_linear)
         y = nn(x)

edit: given the typical behavior of equinox users is to use jax.vmap, I should probably avoid this pattern and use jax.vmap instead of eqx.filter_vmap

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I meant so but I'm not sure if it works. It looks to me that f is an equinox module, hence the compatible solely relies on equinox.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I can tell this pattern is working! it registers parameters properly, and recovers original parameters - all of the tests I added are passing (I mimicked the tests from the nnx_module)

its just unfortunate it requires vmap inside the numpyro model, and stopping the gradient for the state

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of vmap here looks clean to me. Regarding gradient, it is more like a numpyro issue. I think we can add stop gradient somewhere in numpyro to deal with it. How about adding stop gradient for mutable state at this line

return result["loss"], result["mutable_state"]

Copy link
Contributor Author

@kylejcaron kylejcaron Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sadly that didnt work, autodelta still breaks during when svi.init is called inside of svi.run - I committed here though for reference.

in the traceback I'm adding below, it seems like during svi.init, autodelta attempty to get self._init_locs and fails when running the postprocess_fn.

in particular, the trace fails within numpyro.infer.util.constrain_fn which gets called in that postprocess_fn mentioned above

traceback
        # test svi - trace error with AutoDelta
        guide = AutoDelta(model)
        svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
>       svi.run(random.PRNGKey(100), 10)

test/contrib/test_module.py:574: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
numpyro/infer/svi.py:402: in run
    svi_state = self.init(rng_key, *args, init_params=init_params, **kwargs)
numpyro/infer/svi.py:185: in init
    guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
numpyro/handlers.py:191: in get_trace
    self(*args, **kwargs)
numpyro/primitives.py:121: in __call__
    return self.fn(*args, **kwargs)
numpyro/handlers.py:846: in __call__
    return cloned_seeded_fn.__call__(*args, **kwargs)
numpyro/handlers.py:847: in __call__
    return super().__call__(*args, **kwargs)
numpyro/primitives.py:121: in __call__
    return self.fn(*args, **kwargs)
numpyro/infer/autoguide.py:556: in __call__
    self._setup_prototype(*args, **kwargs)
numpyro/infer/autoguide.py:534: in _setup_prototype
    for k, v in self._postprocess_fn(self._init_locs).items()
numpyro/handlers.py:846: in __call__
    return cloned_seeded_fn.__call__(*args, **kwargs)
numpyro/handlers.py:847: in __call__
    return super().__call__(*args, **kwargs)
numpyro/primitives.py:121: in __call__
    return self.fn(*args, **kwargs)
numpyro/infer/util.py:225: in constrain_fn
    model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs)
numpyro/handlers.py:191: in get_trace
    self(*args, **kwargs)
numpyro/primitives.py:121: in __call__
    return self.fn(*args, **kwargs)
numpyro/primitives.py:121: in __call__
    return self.fn(*args, **kwargs)
numpyro/primitives.py:121: in __call__
    return self.fn(*args, **kwargs)
numpyro/primitives.py:121: in __call__
    return self.fn(*args, **kwargs)
test/contrib/test_module.py:560: in model
    y, state = batched_nn(x, mutable_holder["state"])
test/contrib/test_module.py:543: in __call__
    x, state = self.bn(x, state)
../../.pyenv/versions/3.11.7/lib/python3.11/contextlib.py:81: in inner
    return func(*args, **kwds)
../../.pyenv/versions/3.11.7/envs/numpyro_dev/lib/python3.11/site-packages/equinox/nn/_batch_norm.py:163: in __call__
    running_mean = (1 - momentum) * batch_mean + momentum * running_mean
../../.pyenv/versions/3.11.7/envs/numpyro_dev/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:1060: in op
    return getattr(self.aval, f"_{name}")(self, *args)
../../.pyenv/versions/3.11.7/envs/numpyro_dev/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:579: in deferring_binary_op
    return binary_op(*args)

I did something crazy in my local environment and stopped the gradient in numpyro.mutable's return output and it does fix this issue, but changing a primitive seems bad and it does break 6 tests

Screenshot 2025-03-27 at 10 11 23 AM

is there anything more obvious that I could adjust inside of AutoDelta to resolve this? as a reminder, it seems like this issue is only relevant to AutoDelta so far (for instance the test passes without stopping the gradient for the state when using autonormal )

as an alternative if we go back to just stopping the gradient manually inside the model, all of the tests pass itd just be nice to avoid that if possible

Copy link
Contributor Author

@kylejcaron kylejcaron Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright I just pushed a working version - it turns out that AutoDelta does work for a typical workflow, but in that specific unit test it breaks if x is sampled within the model (still not sure why, but svi.init fails in that scenario).

I decided to undo all of the code that stops gradients and just sample x outside of the model block in that unit test and everything passes now

this probably isnt ideal because it does seem like theres some sort of bug that requires stopping the gradient in the model when using AutoDelta + equinox + a latent variable

let me know what you think!

@kylejcaron kylejcaron requested a review from fehiepsi March 17, 2025 17:23
@kylejcaron
Copy link
Contributor Author

kylejcaron commented Mar 26, 2025

cc @patrick-kidger if you have any feedback

cc @danielward27 as well since you've worked with both numpyro and equinox before

@patrick-kidger
Copy link

patrick-kidger commented Mar 26, 2025

Ah, awesome! Taking a quick scan over some of the questions I've seen raised above:

I would suggest to let users eagerly initialize the nn_module because no init_fn, apply_fn pattern happens here.

It's not super obvious but actually Equinox has an init/apply style like so:

init_fn = eqx.nn.Foo
apply_fn = eqx.nn.Foo.__call__

params = init_fn(hyperparams)
output = apply_fn(params, input)

equinox doesnt have tools to go back and forth between dictionaries and pytrees

This is an intentional omission to encourage folks to use just one kind of nested-object-abstraction! But if it's useful to you here for compatibility, then something that does this can be set up like so:

from typing import Any

import equinox as eqx
import jax
from jaxtyping import PyTree

def to_dict(tree: PyTree) -> dict[str, Any]:
    out = {}

    def to_dict_impl(path, leaf):
        out[jax.tree_util.keystr(path)] = leaf

    jax.tree.map_with_path(to_dict_impl, tree)
    return out

def from_dict(data: dict, tree: PyTree) -> PyTree:
    def from_dict_impl(path, _):
        return data[jax.tree_util.keystr(path)]

    return jax.tree.map_with_path(from_dict_impl, tree)
    
mlp = eqx.nn.MLP(2, 2, 2, 2, key=jax.random.key(0))
data = to_dict(mlp)
mlp2 = from_dict(data, mlp)
print(data)

I hope that helps!

@kylejcaron
Copy link
Contributor Author

kylejcaron commented Mar 26, 2025

I hope that helps!

this is very helpful thank you so much!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kylejcaron! Could you add new entries in docs for mutable and eqx_module? https://github.com/pyro-ppl/numpyro/tree/master/docs/source

x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))

batched_nn = jax.vmap(
nn, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of vmap here looks clean to me. Regarding gradient, it is more like a numpyro issue. I think we can add stop gradient somewhere in numpyro to deal with it. How about adding stop gradient for mutable state at this line

return result["loss"], result["mutable_state"]

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Thanks @kylejcaron!

@fehiepsi
Copy link
Member

This pr is still in draft mode. Let me know if it is ready to merge.

@kylejcaron kylejcaron marked this pull request as ready for review March 28, 2025 02:36
@kylejcaron
Copy link
Contributor Author

This pr is still in draft mode. Let me know if it is ready to merge.

It’s Ready to go and thanks for all of the help!

@fehiepsi fehiepsi merged commit ab1f0dc into pyro-ppl:master Mar 28, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Equinox models integration

3 participants