Skip to content

Putting the NN parameters as a non-first argument in loss fn results in a weird error: TypeError: unsupported operand type(s) for *: 'float' and 'FrozenDict' #366

@vwxyzjn

Description

@vwxyzjn

Putting the NN parameters as a non-first argument in loss fn results in a weird error. PaddlePaddle/Paddle#10333 might be related

See the following script as a minimal reproduction example.

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np

class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(64)(x)
        x = nn.tanh(x)
        x = nn.Dense(64)(x)
        x = nn.tanh(x)
        x = nn.Dense(1)(x)
        return x

x = jnp.zeros((1024, 64))
y = jnp.zeros((1024,))

key = jax.random.PRNGKey(0)
example_input = np.zeros((1, 64))
model = Model()
critic_parameters = model.init(key, x)
model.apply = jax.jit(model.apply)
model_optimizer = optax.adam(learning_rate=0.04, eps=1e-5)
model_optimizer_state = model_optimizer.init(critic_parameters)


@jax.jit
def update_works(
    x, y, 
    critic_parameters,
    model_optimizer_state, key
):
    def model_loss(critic_parameters, x, y):
        newvalue = model.apply(critic_parameters, x)
        v_loss = 0.5 * ((newvalue.squeeze() - y.squeeze()) ** 2).mean()
        return v_loss
    loss, grads = jax.value_and_grad(model_loss)(critic_parameters, x, y)
    updates, model_optimizer_state = model_optimizer.update(grads, model_optimizer_state)
    critic_parameters = optax.apply_updates(critic_parameters, updates)
    return loss, critic_parameters, model_optimizer_state
update_works(x, y, critic_parameters, model_optimizer_state, key)

@jax.jit
def update_not_works(
    x, y, 
    critic_parameters,
    model_optimizer_state, key
):
    def model_loss(x, y, critic_parameters):
        newvalue = model.apply(critic_parameters, x)
        v_loss = 0.5 * ((newvalue.squeeze() - y.squeeze()) ** 2).mean()
        return v_loss
    loss, grads = jax.value_and_grad(model_loss)(x, y, critic_parameters)
    updates, model_optimizer_state = model_optimizer.update(grads, model_optimizer_state)
    critic_parameters = optax.apply_updates(critic_parameters, updates)
    return loss, critic_parameters, model_optimizer_state
update_not_works(x, y, critic_parameters, model_optimizer_state, key)

Running update_works works but update_not_works results in the following error

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/costa/Documents/go/src/github.com/cleanrl/cleanrl/minimal_bug.py", line 59, in <module>
    update_not_works(x, y, critic_parameters, model_optimizer_state, key)
  File "/home/costa/Documents/go/src/github.com/cleanrl/cleanrl/minimal_bug.py", line 56, in update_not_works
    updates, model_optimizer_state = model_optimizer.update(grads, model_optimizer_state)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/optax/_src/combine.py", line 54, in update_fn
    updates, new_s = fn(updates, s, params)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/optax/_src/transform.py", line 326, in update_fn
    mu = _update_moment(updates, state.mu, b1, 1)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/optax/_src/transform.py", line 82, in _update_moment
    return jax.tree_map(
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/optax/_src/transform.py", line 83, in <lambda>
    lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
TypeError: unsupported operand type(s) for *: 'float' and 'FrozenDict'

Note that update_works and update_not_works only have the following line of code difference

-def model_loss(x, y, critic_parameters):
+def model_loss(critic_parameters, x, y):
     newvalue = model.apply(critic_parameters, x)
     v_loss = 0.5 * ((newvalue.squeeze() - y.squeeze()) ** 2).mean()
     return v_loss
-loss, grads = jax.value_and_grad(model_loss)(x, y, critic_parameters)
+loss, grads = jax.value_and_grad(model_loss)(critic_parameters, x, y)
>>> jax.__version__
'0.3.12'
>>> import optax
>>> optax.__version__
'0.1.2'
>>> 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions