Skip to content

Commit 6942c97

Browse files
committed
jit action selection and linear_schedule
1 parent 89dcbb4 commit 6942c97

File tree

2 files changed

+67
-27
lines changed

2 files changed

+67
-27
lines changed

cleanrl/dqn_atari_jax.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
"XLA_PYTHON_CLIENT_MEM_FRACTION"
1010
] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
1111

12+
import functools
13+
1214
import flax
1315
import flax.linen as nn
1416
import gym
@@ -127,9 +129,23 @@ class TrainState(TrainState):
127129
target_params: flax.core.FrozenDict
128130

129131

130-
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
131-
slope = (end_e - start_e) / duration
132-
return max(slope * t + start_e, end_e)
132+
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 5, 6, 7))
133+
def select_action(rng_seed, start_e, end_e, duration, t, num_actions, num_envs, network_def, online_params, obs):
134+
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
135+
slope = (end_e - start_e) / duration
136+
return jnp.maximum(slope * t + start_e, end_e)
137+
138+
epsilon = linear_schedule(start_e, end_e, duration, t)
139+
rng, rng_1, rng_2 = jax.random.split(rng_seed, 3)
140+
return (
141+
rng,
142+
jnp.where(
143+
jax.random.uniform(rng_1) < epsilon,
144+
jax.random.randint(rng_2, (num_envs,), 0, num_actions),
145+
network_def.apply(online_params, obs).argmax(axis=-1),
146+
),
147+
epsilon,
148+
)
133149

134150

135151
if __name__ == "__main__":
@@ -157,7 +173,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
157173
random.seed(args.seed)
158174
np.random.seed(args.seed)
159175
key = jax.random.PRNGKey(args.seed)
160-
key, q_key = jax.random.split(key, 2)
176+
key, q_key, act_key = jax.random.split(key, 3)
161177

162178
# env setup
163179
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
@@ -174,7 +190,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
174190
tx=optax.adam(learning_rate=args.learning_rate),
175191
)
176192

177-
q_network.apply = jax.jit(q_network.apply)
178193
# This step is not necessary as init called on same observation and key will always lead to same initializations
179194
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))
180195

@@ -208,14 +223,19 @@ def mse_loss(params):
208223
obs = envs.reset()
209224
for global_step in range(args.total_timesteps):
210225
# ALGO LOGIC: put action logic here
211-
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
212-
if random.random() < epsilon:
213-
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
214-
else:
215-
# obs = jax.device_put(obs)
216-
logits = q_network.apply(q_state.params, obs)
217-
actions = logits.argmax(axis=-1)
218-
actions = jax.device_get(actions)
226+
act_key, actions, epsilon = select_action(
227+
act_key,
228+
args.start_e,
229+
args.end_e,
230+
args.exploration_fraction * args.total_timesteps,
231+
global_step,
232+
envs.single_action_space.n,
233+
envs.num_envs,
234+
q_network,
235+
q_state.params,
236+
obs,
237+
)
238+
actions = jax.device_get(actions)
219239

220240
# TRY NOT TO MODIFY: execute the game and log data.
221241
next_obs, rewards, dones, infos = envs.step(actions)
@@ -226,7 +246,7 @@ def mse_loss(params):
226246
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
227247
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
228248
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
229-
writer.add_scalar("charts/epsilon", epsilon, global_step)
249+
writer.add_scalar("charts/epsilon", jax.device_get(epsilon), global_step)
230250
break
231251

232252
# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`

cleanrl/dqn_jax.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy
22
import argparse
3+
import functools
34
import os
45
import random
56
import time
@@ -100,9 +101,23 @@ class TrainState(TrainState):
100101
target_params: flax.core.FrozenDict
101102

102103

103-
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
104-
slope = (end_e - start_e) / duration
105-
return max(slope * t + start_e, end_e)
104+
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 5, 6, 7))
105+
def select_action(rng_seed, start_e, end_e, duration, t, num_actions, num_envs, network_def, online_params, obs):
106+
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
107+
slope = (end_e - start_e) / duration
108+
return jnp.maximum(slope * t + start_e, end_e)
109+
110+
epsilon = linear_schedule(start_e, end_e, duration, t)
111+
rng, rng_1, rng_2 = jax.random.split(rng_seed, 3)
112+
return (
113+
rng,
114+
jnp.where(
115+
jax.random.uniform(rng_1) < epsilon,
116+
jax.random.randint(rng_2, (num_envs,), 0, num_actions),
117+
network_def.apply(online_params, obs).argmax(axis=-1),
118+
),
119+
epsilon,
120+
)
106121

107122

108123
if __name__ == "__main__":
@@ -130,7 +145,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
130145
random.seed(args.seed)
131146
np.random.seed(args.seed)
132147
key = jax.random.PRNGKey(args.seed)
133-
key, q_key = jax.random.split(key, 2)
148+
key, q_key, act_key = jax.random.split(key, 3)
134149

135150
# env setup
136151
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
@@ -147,7 +162,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
147162
tx=optax.adam(learning_rate=args.learning_rate),
148163
)
149164

150-
q_network.apply = jax.jit(q_network.apply)
151165
# This step is not necessary as init called on same observation and key will always lead to same initializations
152166
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))
153167

@@ -180,13 +194,19 @@ def mse_loss(params):
180194
obs = envs.reset()
181195
for global_step in range(args.total_timesteps):
182196
# ALGO LOGIC: put action logic here
183-
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
184-
if random.random() < epsilon:
185-
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
186-
else:
187-
logits = q_network.apply(q_state.params, obs)
188-
actions = logits.argmax(axis=-1)
189-
actions = jax.device_get(actions)
197+
act_key, actions, epsilon = select_action(
198+
act_key,
199+
args.start_e,
200+
args.end_e,
201+
args.exploration_fraction * args.total_timesteps,
202+
global_step,
203+
envs.single_action_space.n,
204+
envs.num_envs,
205+
q_network,
206+
q_state.params,
207+
obs,
208+
)
209+
actions = jax.device_get(actions)
190210

191211
# TRY NOT TO MODIFY: execute the game and log data.
192212
next_obs, rewards, dones, infos = envs.step(actions)
@@ -197,7 +217,7 @@ def mse_loss(params):
197217
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
198218
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
199219
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
200-
writer.add_scalar("charts/epsilon", epsilon, global_step)
220+
writer.add_scalar("charts/epsilon", jax.device_get(epsilon), global_step)
201221
break
202222

203223
# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`

0 commit comments

Comments
 (0)