-
Notifications
You must be signed in to change notification settings - Fork 271
Equinox Integration #2005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equinox Integration #2005
Conversation
|
I'm running into issues with stateful operations. Equinox manages states functionally, ie class Model(eqx.Module):
norm: eqx.nn.BatchNorm
linear: eqx.nn.Linear
def __init__(self, key):
self.norm = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
self.linear = eqx.nn.Linear(in_features=32, out_features=32, key=key)
def __call__(self, x, state):
x, state = self.norm(x, state)
x = self.linear1(x)
return x, state
model, state = eqx.nn.make_with_state(Model)(key=rng_key)
...
for _ in range(steps):
# Full-batch gradient descent in this simple example.
model, state, opt_state = make_step(model, state, opt_state, xs, ys)This is a bit different then the other jax neural net libraries. I think there are two options:
feedback appreciated! |
| x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) | ||
|
|
||
| batched_nn = jax.vmap( | ||
| nn, in_axes=(0, None), out_axes=(0, None), axis_name="batch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately theres some overhead for the end user when static args like state are included.
As far as I'm aware, static kwargs arent possible with vmap. otherwise this would be much easier and I could simply do something like
def eqx_module(...):
...
return jax.vmap(partial(nn, state=state), ...)I'm guessing given the complexity I should probably add an official example to walk through stateful/non-stateful examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think any vmap logic should happen at the equinox level, not at numpyro eqx_module level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make sure I have it correct, are you saying that vmap should not be used and assumed inside eqx_module (and therefore the current implementation is better than what I mention in my previous comment)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant that vmap should be used neither in the implementation of eqx_module nor in the numpyro model vmap(eqx_module(...)). We should use eqx_module(name, vmapped_module) or something like that, where vmapped_module is an equinox module (but it is vmapped through a non-numpyro mechanism)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I see what youre saying.
I just realized I can use eqx.filter_vmap ahead of time to solve for this outside of the numpyro model, do you think this is sufficient?
linear = eqx.nn.Linear(in_feature=3, out_features=1, key=rng_key)
batched_linear = eqx.filter_vmap(linear)
def model(x):
# Use the pre-initialized, vmapped module
nn = eqx_module("nn", batched_linear)
y = nn(x)edit: given the typical behavior of equinox users is to use jax.vmap, I should probably avoid this pattern and use jax.vmap instead of eqx.filter_vmap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I meant so but I'm not sure if it works. It looks to me that f is an equinox module, hence the compatible solely relies on equinox.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as far as I can tell this pattern is working! it registers parameters properly, and recovers original parameters - all of the tests I added are passing (I mimicked the tests from the nnx_module)
its just unfortunate it requires vmap inside the numpyro model, and stopping the gradient for the state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usage of vmap here looks clean to me. Regarding gradient, it is more like a numpyro issue. I think we can add stop gradient somewhere in numpyro to deal with it. How about adding stop gradient for mutable state at this line
Line 58 in fc3a7b1
| return result["loss"], result["mutable_state"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sadly that didnt work, autodelta still breaks during when svi.init is called inside of svi.run - I committed here though for reference.
in the traceback I'm adding below, it seems like during svi.init, autodelta attempty to get self._init_locs and fails when running the postprocess_fn.
in particular, the trace fails within numpyro.infer.util.constrain_fn which gets called in that postprocess_fn mentioned above
traceback
# test svi - trace error with AutoDelta
guide = AutoDelta(model)
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
> svi.run(random.PRNGKey(100), 10)
test/contrib/test_module.py:574:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
numpyro/infer/svi.py:402: in run
svi_state = self.init(rng_key, *args, init_params=init_params, **kwargs)
numpyro/infer/svi.py:185: in init
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
numpyro/handlers.py:191: in get_trace
self(*args, **kwargs)
numpyro/primitives.py:121: in __call__
return self.fn(*args, **kwargs)
numpyro/handlers.py:846: in __call__
return cloned_seeded_fn.__call__(*args, **kwargs)
numpyro/handlers.py:847: in __call__
return super().__call__(*args, **kwargs)
numpyro/primitives.py:121: in __call__
return self.fn(*args, **kwargs)
numpyro/infer/autoguide.py:556: in __call__
self._setup_prototype(*args, **kwargs)
numpyro/infer/autoguide.py:534: in _setup_prototype
for k, v in self._postprocess_fn(self._init_locs).items()
numpyro/handlers.py:846: in __call__
return cloned_seeded_fn.__call__(*args, **kwargs)
numpyro/handlers.py:847: in __call__
return super().__call__(*args, **kwargs)
numpyro/primitives.py:121: in __call__
return self.fn(*args, **kwargs)
numpyro/infer/util.py:225: in constrain_fn
model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs)
numpyro/handlers.py:191: in get_trace
self(*args, **kwargs)
numpyro/primitives.py:121: in __call__
return self.fn(*args, **kwargs)
numpyro/primitives.py:121: in __call__
return self.fn(*args, **kwargs)
numpyro/primitives.py:121: in __call__
return self.fn(*args, **kwargs)
numpyro/primitives.py:121: in __call__
return self.fn(*args, **kwargs)
test/contrib/test_module.py:560: in model
y, state = batched_nn(x, mutable_holder["state"])
test/contrib/test_module.py:543: in __call__
x, state = self.bn(x, state)
../../.pyenv/versions/3.11.7/lib/python3.11/contextlib.py:81: in inner
return func(*args, **kwds)
../../.pyenv/versions/3.11.7/envs/numpyro_dev/lib/python3.11/site-packages/equinox/nn/_batch_norm.py:163: in __call__
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
../../.pyenv/versions/3.11.7/envs/numpyro_dev/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:1060: in op
return getattr(self.aval, f"_{name}")(self, *args)
../../.pyenv/versions/3.11.7/envs/numpyro_dev/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:579: in deferring_binary_op
return binary_op(*args)
I did something crazy in my local environment and stopped the gradient in numpyro.mutable's return output and it does fix this issue, but changing a primitive seems bad and it does break 6 tests
is there anything more obvious that I could adjust inside of AutoDelta to resolve this? as a reminder, it seems like this issue is only relevant to AutoDelta so far (for instance the test passes without stopping the gradient for the state when using autonormal )
as an alternative if we go back to just stopping the gradient manually inside the model, all of the tests pass itd just be nice to avoid that if possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alright I just pushed a working version - it turns out that AutoDelta does work for a typical workflow, but in that specific unit test it breaks if x is sampled within the model (still not sure why, but svi.init fails in that scenario).
I decided to undo all of the code that stops gradients and just sample x outside of the model block in that unit test and everything passes now
this probably isnt ideal because it does seem like theres some sort of bug that requires stopping the gradient in the model when using AutoDelta + equinox + a latent variable
let me know what you think!
|
cc @patrick-kidger if you have any feedback cc @danielward27 as well since you've worked with both numpyro and equinox before |
|
Ah, awesome! Taking a quick scan over some of the questions I've seen raised above:
It's not super obvious but actually Equinox has an init/apply style like so: init_fn = eqx.nn.Foo
apply_fn = eqx.nn.Foo.__call__
params = init_fn(hyperparams)
output = apply_fn(params, input)
This is an intentional omission to encourage folks to use just one kind of nested-object-abstraction! But if it's useful to you here for compatibility, then something that does this can be set up like so: from typing import Any
import equinox as eqx
import jax
from jaxtyping import PyTree
def to_dict(tree: PyTree) -> dict[str, Any]:
out = {}
def to_dict_impl(path, leaf):
out[jax.tree_util.keystr(path)] = leaf
jax.tree.map_with_path(to_dict_impl, tree)
return out
def from_dict(data: dict, tree: PyTree) -> PyTree:
def from_dict_impl(path, _):
return data[jax.tree_util.keystr(path)]
return jax.tree.map_with_path(from_dict_impl, tree)
mlp = eqx.nn.MLP(2, 2, 2, 2, key=jax.random.key(0))
data = to_dict(mlp)
mlp2 = from_dict(data, mlp)
print(data)I hope that helps! |
this is very helpful thank you so much! |
fehiepsi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kylejcaron! Could you add new entries in docs for mutable and eqx_module? https://github.com/pyro-ppl/numpyro/tree/master/docs/source
| x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) | ||
|
|
||
| batched_nn = jax.vmap( | ||
| nn, in_axes=(0, None), out_axes=(0, None), axis_name="batch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usage of vmap here looks clean to me. Regarding gradient, it is more like a numpyro issue. I think we can add stop gradient somewhere in numpyro to deal with it. How about adding stop gradient for mutable state at this line
Line 58 in fc3a7b1
| return result["loss"], result["mutable_state"] |
fehiepsi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Thanks @kylejcaron!
|
This pr is still in draft mode. Let me know if it is ready to merge. |
It’s Ready to go and thanks for all of the help! |
Closes #1709.
This adds Equinox support to numpyro's contrib module