Skip to content

Commit d5704b3

Browse files
committed
Try ADAN
1 parent 99686c8 commit d5704b3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

cleanrl/tqc_td3_jax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def main():
229229
apply_fn=actor.apply,
230230
params=actor.init(actor_key, obs),
231231
target_params=actor.init(actor_key, obs),
232-
tx=optax.adam(learning_rate=args.learning_rate),
232+
tx=optax.adan(learning_rate=args.learning_rate),
233233
)
234234

235235
agent = Agent(actor, actor_state)
@@ -249,7 +249,7 @@ def main():
249249
obs,
250250
jnp.array([envs.action_space.sample()]),
251251
),
252-
tx=optax.adam(learning_rate=args.learning_rate),
252+
tx=optax.adan(learning_rate=args.learning_rate),
253253
)
254254
qf2_state = RLTrainState.create(
255255
apply_fn=qf.apply,
@@ -263,7 +263,7 @@ def main():
263263
obs,
264264
jnp.array([envs.action_space.sample()]),
265265
),
266-
tx=optax.adam(learning_rate=args.learning_rate),
266+
tx=optax.adan(learning_rate=args.learning_rate),
267267
)
268268
actor.apply = jax.jit(actor.apply)
269269
qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm"))

0 commit comments

Comments
 (0)