Skip to content

Commit f27b6d7

Browse files
kinalmehtavwxyzjn
andauthored
prototype jax with dqn (#222)
* Prototype JAX + DQN * formatting changes * bug fix: predicted q value in mse * Prototype JAX + DQN + Atari * formatting changes * Fix `UNKNOWN: CUDNN_STATUS_EXECUTION` * update mse loss calculation to be (target-pred) instead of (pred-target) * Fix image format and Conv padding * Adapting to the TrainState API * Add assets * Add my benchmark script * fix benchmark script embed it was pointing to c51, fixed it to point to dqn * docs: add DQN + JAX documentation * jit action selection and linear_schedule * docs fix * update docs * change documentation addr * add test cases * update ci * Add warning on installing jax on windows * fix pre-commit * revert back changes * update benchmark scripts * Add docs * update docs Co-authored-by: Costa Huang <[email protected]>
1 parent 5bfdd45 commit f27b6d7

22 files changed

+694
-8
lines changed

.github/workflows/tests.yaml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ jobs:
3636
run: poetry run pip install setuptools==59.5.0
3737
- name: Run core tests
3838
run: poetry run pytest tests/test_classic_control.py
39+
- name: Install jax
40+
if: runner.os == 'Linux' || runner.os == 'macOS'
41+
run: poetry install -E jax
42+
- name: Run core tests with jax
43+
if: runner.os == 'Linux' || runner.os == 'macOS'
44+
run: poetry run pytest tests/test_classic_control_jax.py
3945

4046
test-atari-envs:
4147
strategy:
@@ -62,6 +68,12 @@ jobs:
6268
run: poetry run pip install setuptools==59.5.0
6369
- name: Run atari tests
6470
run: poetry run pytest tests/test_atari.py
71+
- name: Install jax
72+
if: runner.os == 'Linux' || runner.os == 'macOS'
73+
run: poetry install -E jax
74+
- name: Run core tests with jax
75+
if: runner.os == 'Linux' || runner.os == 'macOS'
76+
run: poetry run pytest tests/test_atari_jax.py
6577

6678
test-pybullet-envs:
6779
strategy:
@@ -136,7 +148,7 @@ jobs:
136148
with:
137149
poetry-version: ${{ matrix.poetry-version }}
138150

139-
# pybullet tests
151+
# mujoco tests
140152
- name: Install core dependencies
141153
run: poetry install -E pytest
142154
- name: Install pybullet dependencies

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ You may also use a prebuilt development environment hosted in Gitpod:
124124
| | [`ppo_continuous_action_isaacgym.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_action_isaacgympy)
125125
|[Deep Q-Learning (DQN)](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf) | [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy) |
126126
| | [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy) |
127+
| | [`dqn_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py), [docs](/rl-algorithms/dqn/#dqn_jaxpy) |
128+
| | [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), [docs](/rl-algorithms/dqn/#dqn_atari_jaxpy) |
127129
|[Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) | [`c51.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51py) |
128130
| | [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51_ataripy) |
129131
|[Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) | [`sac_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy) |

benchmark/dqn.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,19 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
1111
--command "poetry run python cleanrl/dqn_atari.py --track --capture-video" \
1212
--num-seeds 3 \
1313
--workers 1
14+
15+
poetry install -E "jax"
16+
poetry run pip install --upgrade "jax[cuda]==0.3.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
17+
xvfb-run -a python -m cleanrl_utils.benchmark \
18+
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
19+
--command "poetry run python cleanrl/dqn_jax.py --track --capture-video" \
20+
--num-seeds 3 \
21+
--workers 1
22+
23+
poetry install -E "atari jax"
24+
poetry run pip install --upgrade "jax[cuda]==0.3.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
25+
xvfb-run -a python -m cleanrl_utils.benchmark \
26+
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
27+
--command "poetry run python cleanrl/dqn_atari_jax.py --track --capture-video" \
28+
--num-seeds 3 \
29+
--workers 1

cleanrl/dqn_atari_jax.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_atari_jaxpy
2+
import argparse
3+
import os
4+
import random
5+
import time
6+
from distutils.util import strtobool
7+
8+
os.environ[
9+
"XLA_PYTHON_CLIENT_MEM_FRACTION"
10+
] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
11+
12+
import flax
13+
import flax.linen as nn
14+
import gym
15+
import jax
16+
import jax.numpy as jnp
17+
import numpy as np
18+
import optax
19+
from flax.training.train_state import TrainState
20+
from stable_baselines3.common.atari_wrappers import (
21+
ClipRewardEnv,
22+
EpisodicLifeEnv,
23+
FireResetEnv,
24+
MaxAndSkipEnv,
25+
NoopResetEnv,
26+
)
27+
from stable_baselines3.common.buffers import ReplayBuffer
28+
from torch.utils.tensorboard import SummaryWriter
29+
30+
31+
def parse_args():
32+
# fmt: off
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
35+
help="the name of this experiment")
36+
parser.add_argument("--seed", type=int, default=1,
37+
help="seed of the experiment")
38+
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
39+
help="if toggled, `torch.backends.cudnn.deterministic=False`")
40+
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
41+
help="if toggled, cuda will be enabled by default")
42+
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
43+
help="if toggled, this experiment will be tracked with Weights and Biases")
44+
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
45+
help="the wandb's project name")
46+
parser.add_argument("--wandb-entity", type=str, default=None,
47+
help="the entity (team) of wandb's project")
48+
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
49+
help="weather to capture videos of the agent performances (check out `videos` folder)")
50+
51+
# Algorithm specific arguments
52+
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
53+
help="the id of the environment")
54+
parser.add_argument("--total-timesteps", type=int, default=10000000,
55+
help="total timesteps of the experiments")
56+
parser.add_argument("--learning-rate", type=float, default=1e-4,
57+
help="the learning rate of the optimizer")
58+
parser.add_argument("--buffer-size", type=int, default=1000000,
59+
help="the replay memory buffer size")
60+
parser.add_argument("--gamma", type=float, default=0.99,
61+
help="the discount factor gamma")
62+
parser.add_argument("--target-network-frequency", type=int, default=1000,
63+
help="the timesteps it takes to update the target network")
64+
parser.add_argument("--batch-size", type=int, default=32,
65+
help="the batch size of sample from the reply memory")
66+
parser.add_argument("--start-e", type=float, default=1,
67+
help="the starting epsilon for exploration")
68+
parser.add_argument("--end-e", type=float, default=0.01,
69+
help="the ending epsilon for exploration")
70+
parser.add_argument("--exploration-fraction", type=float, default=0.10,
71+
help="the fraction of `total-timesteps` it takes from start-e to go end-e")
72+
parser.add_argument("--learning-starts", type=int, default=80000,
73+
help="timestep to start learning")
74+
parser.add_argument("--train-frequency", type=int, default=4,
75+
help="the frequency of training")
76+
args = parser.parse_args()
77+
# fmt: on
78+
return args
79+
80+
81+
def make_env(env_id, seed, idx, capture_video, run_name):
82+
def thunk():
83+
env = gym.make(env_id)
84+
env = gym.wrappers.RecordEpisodeStatistics(env)
85+
if capture_video:
86+
if idx == 0:
87+
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
88+
env = NoopResetEnv(env, noop_max=30)
89+
env = MaxAndSkipEnv(env, skip=4)
90+
env = EpisodicLifeEnv(env)
91+
if "FIRE" in env.unwrapped.get_action_meanings():
92+
env = FireResetEnv(env)
93+
env = ClipRewardEnv(env)
94+
env = gym.wrappers.ResizeObservation(env, (84, 84))
95+
env = gym.wrappers.GrayScaleObservation(env)
96+
env = gym.wrappers.FrameStack(env, 4)
97+
env.seed(seed)
98+
env.action_space.seed(seed)
99+
env.observation_space.seed(seed)
100+
return env
101+
102+
return thunk
103+
104+
105+
# ALGO LOGIC: initialize agent here:
106+
class QNetwork(nn.Module):
107+
action_dim: int
108+
109+
@nn.compact
110+
def __call__(self, x):
111+
x = jnp.transpose(x, (0, 2, 3, 1))
112+
x = x / (255.0)
113+
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
114+
x = nn.relu(x)
115+
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
116+
x = nn.relu(x)
117+
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
118+
x = nn.relu(x)
119+
x = x.reshape((x.shape[0], -1))
120+
x = nn.Dense(512)(x)
121+
x = nn.relu(x)
122+
x = nn.Dense(self.action_dim)(x)
123+
return x
124+
125+
126+
class TrainState(TrainState):
127+
target_params: flax.core.FrozenDict
128+
129+
130+
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
131+
slope = (end_e - start_e) / duration
132+
return max(slope * t + start_e, end_e)
133+
134+
135+
if __name__ == "__main__":
136+
args = parse_args()
137+
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
138+
if args.track:
139+
import wandb
140+
141+
wandb.init(
142+
project=args.wandb_project_name,
143+
entity=args.wandb_entity,
144+
sync_tensorboard=True,
145+
config=vars(args),
146+
name=run_name,
147+
monitor_gym=True,
148+
save_code=True,
149+
)
150+
writer = SummaryWriter(f"runs/{run_name}")
151+
writer.add_text(
152+
"hyperparameters",
153+
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
154+
)
155+
156+
# TRY NOT TO MODIFY: seeding
157+
random.seed(args.seed)
158+
np.random.seed(args.seed)
159+
key = jax.random.PRNGKey(args.seed)
160+
key, q_key = jax.random.split(key, 2)
161+
162+
# env setup
163+
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
164+
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
165+
166+
obs = envs.reset()
167+
168+
q_network = QNetwork(action_dim=envs.single_action_space.n)
169+
170+
q_state = TrainState.create(
171+
apply_fn=q_network.apply,
172+
params=q_network.init(q_key, obs),
173+
target_params=q_network.init(q_key, obs),
174+
tx=optax.adam(learning_rate=args.learning_rate),
175+
)
176+
177+
q_network.apply = jax.jit(q_network.apply)
178+
# This step is not necessary as init called on same observation and key will always lead to same initializations
179+
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))
180+
181+
rb = ReplayBuffer(
182+
args.buffer_size,
183+
envs.single_observation_space,
184+
envs.single_action_space,
185+
"cpu",
186+
optimize_memory_usage=True,
187+
handle_timeout_termination=True,
188+
)
189+
190+
@jax.jit
191+
def update(q_state, observations, actions, next_observations, rewards, dones):
192+
q_next_target = q_network.apply(q_state.target_params, next_observations) # (batch_size, num_actions)
193+
q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,)
194+
next_q_value = rewards + (1 - dones) * args.gamma * q_next_target
195+
196+
def mse_loss(params):
197+
q_pred = q_network.apply(params, observations) # (batch_size, num_actions)
198+
q_pred = q_pred[np.arange(q_pred.shape[0]), actions.squeeze()] # (batch_size,)
199+
return ((q_pred - next_q_value) ** 2).mean(), q_pred
200+
201+
(loss_value, q_pred), grads = jax.value_and_grad(mse_loss, has_aux=True)(q_state.params)
202+
q_state = q_state.apply_gradients(grads=grads)
203+
return loss_value, q_pred, q_state
204+
205+
start_time = time.time()
206+
207+
# TRY NOT TO MODIFY: start the game
208+
obs = envs.reset()
209+
for global_step in range(args.total_timesteps):
210+
# ALGO LOGIC: put action logic here
211+
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
212+
if random.random() < epsilon:
213+
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
214+
else:
215+
# obs = jax.device_put(obs)
216+
logits = q_network.apply(q_state.params, obs)
217+
actions = logits.argmax(axis=-1)
218+
actions = jax.device_get(actions)
219+
220+
# TRY NOT TO MODIFY: execute the game and log data.
221+
next_obs, rewards, dones, infos = envs.step(actions)
222+
223+
# TRY NOT TO MODIFY: record rewards for plotting purposes
224+
for info in infos:
225+
if "episode" in info.keys():
226+
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
227+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
228+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
229+
writer.add_scalar("charts/epsilon", epsilon, global_step)
230+
break
231+
232+
# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
233+
real_next_obs = next_obs.copy()
234+
for idx, d in enumerate(dones):
235+
if d:
236+
real_next_obs[idx] = infos[idx]["terminal_observation"]
237+
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
238+
239+
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
240+
obs = next_obs
241+
242+
# ALGO LOGIC: training.
243+
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
244+
data = rb.sample(args.batch_size)
245+
# perform a gradient-descent step
246+
loss, old_val, q_state = update(
247+
q_state,
248+
data.observations.numpy(),
249+
data.actions.numpy(),
250+
data.next_observations.numpy(),
251+
data.rewards.flatten().numpy(),
252+
data.dones.flatten().numpy(),
253+
)
254+
255+
if global_step % 100 == 0:
256+
writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step)
257+
writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step)
258+
print("SPS:", int(global_step / (time.time() - start_time)))
259+
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
260+
261+
# update the target network
262+
if global_step % args.target_network_frequency == 0:
263+
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))
264+
265+
envs.close()
266+
writer.close()

0 commit comments

Comments
 (0)