We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 99686c8 commit d5704b3Copy full SHA for d5704b3
cleanrl/tqc_td3_jax.py
@@ -229,7 +229,7 @@ def main():
229
apply_fn=actor.apply,
230
params=actor.init(actor_key, obs),
231
target_params=actor.init(actor_key, obs),
232
- tx=optax.adam(learning_rate=args.learning_rate),
+ tx=optax.adan(learning_rate=args.learning_rate),
233
)
234
235
agent = Agent(actor, actor_state)
@@ -249,7 +249,7 @@ def main():
249
obs,
250
jnp.array([envs.action_space.sample()]),
251
),
252
253
254
qf2_state = RLTrainState.create(
255
apply_fn=qf.apply,
@@ -263,7 +263,7 @@ def main():
263
264
265
266
267
268
actor.apply = jax.jit(actor.apply)
269
qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm"))
0 commit comments