|
| 1 | +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_atari_jaxpy |
| 2 | +import argparse |
| 3 | +import os |
| 4 | +import random |
| 5 | +import time |
| 6 | +from distutils.util import strtobool |
| 7 | + |
| 8 | +os.environ[ |
| 9 | + "XLA_PYTHON_CLIENT_MEM_FRACTION" |
| 10 | +] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 |
| 11 | + |
| 12 | +import flax |
| 13 | +import flax.linen as nn |
| 14 | +import gym |
| 15 | +import jax |
| 16 | +import jax.numpy as jnp |
| 17 | +import numpy as np |
| 18 | +import optax |
| 19 | +from flax.training.train_state import TrainState |
| 20 | +from stable_baselines3.common.atari_wrappers import ( |
| 21 | + ClipRewardEnv, |
| 22 | + EpisodicLifeEnv, |
| 23 | + FireResetEnv, |
| 24 | + MaxAndSkipEnv, |
| 25 | + NoopResetEnv, |
| 26 | +) |
| 27 | +from stable_baselines3.common.buffers import ReplayBuffer |
| 28 | +from torch.utils.tensorboard import SummaryWriter |
| 29 | + |
| 30 | + |
| 31 | +def parse_args(): |
| 32 | + # fmt: off |
| 33 | + parser = argparse.ArgumentParser() |
| 34 | + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), |
| 35 | + help="the name of this experiment") |
| 36 | + parser.add_argument("--seed", type=int, default=1, |
| 37 | + help="seed of the experiment") |
| 38 | + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| 39 | + help="if toggled, `torch.backends.cudnn.deterministic=False`") |
| 40 | + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| 41 | + help="if toggled, cuda will be enabled by default") |
| 42 | + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| 43 | + help="if toggled, this experiment will be tracked with Weights and Biases") |
| 44 | + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", |
| 45 | + help="the wandb's project name") |
| 46 | + parser.add_argument("--wandb-entity", type=str, default=None, |
| 47 | + help="the entity (team) of wandb's project") |
| 48 | + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| 49 | + help="weather to capture videos of the agent performances (check out `videos` folder)") |
| 50 | + |
| 51 | + # Algorithm specific arguments |
| 52 | + parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4", |
| 53 | + help="the id of the environment") |
| 54 | + parser.add_argument("--total-timesteps", type=int, default=10000000, |
| 55 | + help="total timesteps of the experiments") |
| 56 | + parser.add_argument("--learning-rate", type=float, default=1e-4, |
| 57 | + help="the learning rate of the optimizer") |
| 58 | + parser.add_argument("--buffer-size", type=int, default=1000000, |
| 59 | + help="the replay memory buffer size") |
| 60 | + parser.add_argument("--gamma", type=float, default=0.99, |
| 61 | + help="the discount factor gamma") |
| 62 | + parser.add_argument("--target-network-frequency", type=int, default=1000, |
| 63 | + help="the timesteps it takes to update the target network") |
| 64 | + parser.add_argument("--batch-size", type=int, default=32, |
| 65 | + help="the batch size of sample from the reply memory") |
| 66 | + parser.add_argument("--start-e", type=float, default=1, |
| 67 | + help="the starting epsilon for exploration") |
| 68 | + parser.add_argument("--end-e", type=float, default=0.01, |
| 69 | + help="the ending epsilon for exploration") |
| 70 | + parser.add_argument("--exploration-fraction", type=float, default=0.10, |
| 71 | + help="the fraction of `total-timesteps` it takes from start-e to go end-e") |
| 72 | + parser.add_argument("--learning-starts", type=int, default=80000, |
| 73 | + help="timestep to start learning") |
| 74 | + parser.add_argument("--train-frequency", type=int, default=4, |
| 75 | + help="the frequency of training") |
| 76 | + args = parser.parse_args() |
| 77 | + # fmt: on |
| 78 | + return args |
| 79 | + |
| 80 | + |
| 81 | +def make_env(env_id, seed, idx, capture_video, run_name): |
| 82 | + def thunk(): |
| 83 | + env = gym.make(env_id) |
| 84 | + env = gym.wrappers.RecordEpisodeStatistics(env) |
| 85 | + if capture_video: |
| 86 | + if idx == 0: |
| 87 | + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") |
| 88 | + env = NoopResetEnv(env, noop_max=30) |
| 89 | + env = MaxAndSkipEnv(env, skip=4) |
| 90 | + env = EpisodicLifeEnv(env) |
| 91 | + if "FIRE" in env.unwrapped.get_action_meanings(): |
| 92 | + env = FireResetEnv(env) |
| 93 | + env = ClipRewardEnv(env) |
| 94 | + env = gym.wrappers.ResizeObservation(env, (84, 84)) |
| 95 | + env = gym.wrappers.GrayScaleObservation(env) |
| 96 | + env = gym.wrappers.FrameStack(env, 4) |
| 97 | + env.seed(seed) |
| 98 | + env.action_space.seed(seed) |
| 99 | + env.observation_space.seed(seed) |
| 100 | + return env |
| 101 | + |
| 102 | + return thunk |
| 103 | + |
| 104 | + |
| 105 | +# ALGO LOGIC: initialize agent here: |
| 106 | +class QNetwork(nn.Module): |
| 107 | + action_dim: int |
| 108 | + |
| 109 | + @nn.compact |
| 110 | + def __call__(self, x): |
| 111 | + x = jnp.transpose(x, (0, 2, 3, 1)) |
| 112 | + x = x / (255.0) |
| 113 | + x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) |
| 114 | + x = nn.relu(x) |
| 115 | + x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) |
| 116 | + x = nn.relu(x) |
| 117 | + x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) |
| 118 | + x = nn.relu(x) |
| 119 | + x = x.reshape((x.shape[0], -1)) |
| 120 | + x = nn.Dense(512)(x) |
| 121 | + x = nn.relu(x) |
| 122 | + x = nn.Dense(self.action_dim)(x) |
| 123 | + return x |
| 124 | + |
| 125 | + |
| 126 | +class TrainState(TrainState): |
| 127 | + target_params: flax.core.FrozenDict |
| 128 | + |
| 129 | + |
| 130 | +def linear_schedule(start_e: float, end_e: float, duration: int, t: int): |
| 131 | + slope = (end_e - start_e) / duration |
| 132 | + return max(slope * t + start_e, end_e) |
| 133 | + |
| 134 | + |
| 135 | +if __name__ == "__main__": |
| 136 | + args = parse_args() |
| 137 | + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" |
| 138 | + if args.track: |
| 139 | + import wandb |
| 140 | + |
| 141 | + wandb.init( |
| 142 | + project=args.wandb_project_name, |
| 143 | + entity=args.wandb_entity, |
| 144 | + sync_tensorboard=True, |
| 145 | + config=vars(args), |
| 146 | + name=run_name, |
| 147 | + monitor_gym=True, |
| 148 | + save_code=True, |
| 149 | + ) |
| 150 | + writer = SummaryWriter(f"runs/{run_name}") |
| 151 | + writer.add_text( |
| 152 | + "hyperparameters", |
| 153 | + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), |
| 154 | + ) |
| 155 | + |
| 156 | + # TRY NOT TO MODIFY: seeding |
| 157 | + random.seed(args.seed) |
| 158 | + np.random.seed(args.seed) |
| 159 | + key = jax.random.PRNGKey(args.seed) |
| 160 | + key, q_key = jax.random.split(key, 2) |
| 161 | + |
| 162 | + # env setup |
| 163 | + envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) |
| 164 | + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" |
| 165 | + |
| 166 | + obs = envs.reset() |
| 167 | + |
| 168 | + q_network = QNetwork(action_dim=envs.single_action_space.n) |
| 169 | + |
| 170 | + q_state = TrainState.create( |
| 171 | + apply_fn=q_network.apply, |
| 172 | + params=q_network.init(q_key, obs), |
| 173 | + target_params=q_network.init(q_key, obs), |
| 174 | + tx=optax.adam(learning_rate=args.learning_rate), |
| 175 | + ) |
| 176 | + |
| 177 | + q_network.apply = jax.jit(q_network.apply) |
| 178 | + # This step is not necessary as init called on same observation and key will always lead to same initializations |
| 179 | + q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) |
| 180 | + |
| 181 | + rb = ReplayBuffer( |
| 182 | + args.buffer_size, |
| 183 | + envs.single_observation_space, |
| 184 | + envs.single_action_space, |
| 185 | + "cpu", |
| 186 | + optimize_memory_usage=True, |
| 187 | + handle_timeout_termination=True, |
| 188 | + ) |
| 189 | + |
| 190 | + @jax.jit |
| 191 | + def update(q_state, observations, actions, next_observations, rewards, dones): |
| 192 | + q_next_target = q_network.apply(q_state.target_params, next_observations) # (batch_size, num_actions) |
| 193 | + q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,) |
| 194 | + next_q_value = rewards + (1 - dones) * args.gamma * q_next_target |
| 195 | + |
| 196 | + def mse_loss(params): |
| 197 | + q_pred = q_network.apply(params, observations) # (batch_size, num_actions) |
| 198 | + q_pred = q_pred[np.arange(q_pred.shape[0]), actions.squeeze()] # (batch_size,) |
| 199 | + return ((q_pred - next_q_value) ** 2).mean(), q_pred |
| 200 | + |
| 201 | + (loss_value, q_pred), grads = jax.value_and_grad(mse_loss, has_aux=True)(q_state.params) |
| 202 | + q_state = q_state.apply_gradients(grads=grads) |
| 203 | + return loss_value, q_pred, q_state |
| 204 | + |
| 205 | + start_time = time.time() |
| 206 | + |
| 207 | + # TRY NOT TO MODIFY: start the game |
| 208 | + obs = envs.reset() |
| 209 | + for global_step in range(args.total_timesteps): |
| 210 | + # ALGO LOGIC: put action logic here |
| 211 | + epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) |
| 212 | + if random.random() < epsilon: |
| 213 | + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) |
| 214 | + else: |
| 215 | + # obs = jax.device_put(obs) |
| 216 | + logits = q_network.apply(q_state.params, obs) |
| 217 | + actions = logits.argmax(axis=-1) |
| 218 | + actions = jax.device_get(actions) |
| 219 | + |
| 220 | + # TRY NOT TO MODIFY: execute the game and log data. |
| 221 | + next_obs, rewards, dones, infos = envs.step(actions) |
| 222 | + |
| 223 | + # TRY NOT TO MODIFY: record rewards for plotting purposes |
| 224 | + for info in infos: |
| 225 | + if "episode" in info.keys(): |
| 226 | + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") |
| 227 | + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) |
| 228 | + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) |
| 229 | + writer.add_scalar("charts/epsilon", epsilon, global_step) |
| 230 | + break |
| 231 | + |
| 232 | + # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` |
| 233 | + real_next_obs = next_obs.copy() |
| 234 | + for idx, d in enumerate(dones): |
| 235 | + if d: |
| 236 | + real_next_obs[idx] = infos[idx]["terminal_observation"] |
| 237 | + rb.add(obs, real_next_obs, actions, rewards, dones, infos) |
| 238 | + |
| 239 | + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook |
| 240 | + obs = next_obs |
| 241 | + |
| 242 | + # ALGO LOGIC: training. |
| 243 | + if global_step > args.learning_starts and global_step % args.train_frequency == 0: |
| 244 | + data = rb.sample(args.batch_size) |
| 245 | + # perform a gradient-descent step |
| 246 | + loss, old_val, q_state = update( |
| 247 | + q_state, |
| 248 | + data.observations.numpy(), |
| 249 | + data.actions.numpy(), |
| 250 | + data.next_observations.numpy(), |
| 251 | + data.rewards.flatten().numpy(), |
| 252 | + data.dones.flatten().numpy(), |
| 253 | + ) |
| 254 | + |
| 255 | + if global_step % 100 == 0: |
| 256 | + writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step) |
| 257 | + writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step) |
| 258 | + print("SPS:", int(global_step / (time.time() - start_time))) |
| 259 | + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) |
| 260 | + |
| 261 | + # update the target network |
| 262 | + if global_step % args.target_network_frequency == 0: |
| 263 | + q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) |
| 264 | + |
| 265 | + envs.close() |
| 266 | + writer.close() |
0 commit comments