-
Notifications
You must be signed in to change notification settings - Fork 263
Closed
Description
Putting the NN parameters as a non-first argument in loss fn results in a weird error. PaddlePaddle/Paddle#10333 might be related
See the following script as a minimal reproduction example.
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(1)(x)
return x
x = jnp.zeros((1024, 64))
y = jnp.zeros((1024,))
key = jax.random.PRNGKey(0)
example_input = np.zeros((1, 64))
model = Model()
critic_parameters = model.init(key, x)
model.apply = jax.jit(model.apply)
model_optimizer = optax.adam(learning_rate=0.04, eps=1e-5)
model_optimizer_state = model_optimizer.init(critic_parameters)
@jax.jit
def update_works(
x, y,
critic_parameters,
model_optimizer_state, key
):
def model_loss(critic_parameters, x, y):
newvalue = model.apply(critic_parameters, x)
v_loss = 0.5 * ((newvalue.squeeze() - y.squeeze()) ** 2).mean()
return v_loss
loss, grads = jax.value_and_grad(model_loss)(critic_parameters, x, y)
updates, model_optimizer_state = model_optimizer.update(grads, model_optimizer_state)
critic_parameters = optax.apply_updates(critic_parameters, updates)
return loss, critic_parameters, model_optimizer_state
update_works(x, y, critic_parameters, model_optimizer_state, key)
@jax.jit
def update_not_works(
x, y,
critic_parameters,
model_optimizer_state, key
):
def model_loss(x, y, critic_parameters):
newvalue = model.apply(critic_parameters, x)
v_loss = 0.5 * ((newvalue.squeeze() - y.squeeze()) ** 2).mean()
return v_loss
loss, grads = jax.value_and_grad(model_loss)(x, y, critic_parameters)
updates, model_optimizer_state = model_optimizer.update(grads, model_optimizer_state)
critic_parameters = optax.apply_updates(critic_parameters, updates)
return loss, critic_parameters, model_optimizer_state
update_not_works(x, y, critic_parameters, model_optimizer_state, key)
Running update_works
works but update_not_works
results in the following error
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/costa/Documents/go/src/github.com/cleanrl/cleanrl/minimal_bug.py", line 59, in <module>
update_not_works(x, y, critic_parameters, model_optimizer_state, key)
File "/home/costa/Documents/go/src/github.com/cleanrl/cleanrl/minimal_bug.py", line 56, in update_not_works
updates, model_optimizer_state = model_optimizer.update(grads, model_optimizer_state)
File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/optax/_src/combine.py", line 54, in update_fn
updates, new_s = fn(updates, s, params)
File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/optax/_src/transform.py", line 326, in update_fn
mu = _update_moment(updates, state.mu, b1, 1)
File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/optax/_src/transform.py", line 82, in _update_moment
return jax.tree_map(
File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/optax/_src/transform.py", line 83, in <lambda>
lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
TypeError: unsupported operand type(s) for *: 'float' and 'FrozenDict'
Note that update_works
and update_not_works
only have the following line of code difference
-def model_loss(x, y, critic_parameters):
+def model_loss(critic_parameters, x, y):
newvalue = model.apply(critic_parameters, x)
v_loss = 0.5 * ((newvalue.squeeze() - y.squeeze()) ** 2).mean()
return v_loss
-loss, grads = jax.value_and_grad(model_loss)(x, y, critic_parameters)
+loss, grads = jax.value_and_grad(model_loss)(critic_parameters, x, y)
>>> jax.__version__
'0.3.12'
>>> import optax
>>> optax.__version__
'0.1.2'
>>>
Metadata
Metadata
Assignees
Labels
No labels