Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions benchmark/c51/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Categorical DQN (C51) Benchmark

This repository contains instructions to reproduce our DQN experiments.

Prerequisites:
* Python 3.8+
* [Poetry](https://python-poetry.org)
* [GitHub CLI](https://cli.github.com/)


## Reproducing CleanRL's Categorical DQN (C51) Benchmark

### Classic Control

```bash
git clone https://github.com/vwxyzjn/cleanrl.git && cd cleanrl
gh pr checkout 157
poetry install
bash benchmark/c51/classic_control.sh
```

Note that you may need to overwrite the `--wandb-entity cleanrl` to your own W&B entity, in case you have not obtained access to the `cleanrl/openbenchmark` project.


### Atari games

```bash
git clone https://github.com/vwxyzjn/cleanrl.git && cd cleanrl
gh pr checkout 159
poetry install
poetry install -E atari
bash benchmark/c51/atari.sh
```

Note that you may need to overwrite the `--wandb-entity cleanrl` to your own W&B entity, in case you have not obtained access to the `cleanrl/openbenchmark` project.
14 changes: 14 additions & 0 deletions benchmark/c51/atari.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# PongNoFrameskip-v4
poetry run python cleanrl/c51_atari.py --env-id PongNoFrameskip-v4 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51_atari.py --env-id PongNoFrameskip-v4 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51_atari.py --env-id PongNoFrameskip-v4 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark

# BeamRiderNoFrameskip-v4
poetry run python cleanrl/c51_atari.py --env-id BeamRiderNoFrameskip-v4 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51_atari.py --env-id BeamRiderNoFrameskip-v4 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51_atari.py --env-id BeamRiderNoFrameskip-v4 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark

# BreakoutNoFrameskip-v4
poetry run python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
14 changes: 14 additions & 0 deletions benchmark/c51/classic_control.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# CartPole-v1
poetry run python cleanrl/c51.py --env-id CartPole-v1 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51.py --env-id CartPole-v1 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51.py --env-id CartPole-v1 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark

# Acrobot-v1
poetry run python cleanrl/c51.py --env-id Acrobot-v1 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51.py --env-id Acrobot-v1 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51.py --env-id Acrobot-v1 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark

# MountainCar-v0
poetry run python cleanrl/c51.py --env-id MountainCar-v0 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51.py --env-id MountainCar-v0 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/c51.py --env-id MountainCar-v0 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
18 changes: 11 additions & 7 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_args():
# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=50000,
parser.add_argument("--total-timesteps", type=int, default=500000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
Expand All @@ -50,21 +50,21 @@ def parse_args():
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--target-network-frequency", type=int, default=1000,
parser.add_argument("--target-network-frequency", type=int, default=500,
help="the timesteps it takes to update the target network")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="the maximum norm for the gradient clipping")
parser.add_argument("--batch-size", type=int, default=32,
parser.add_argument("--batch-size", type=int, default=128,
help="the batch size of sample from the reply memory")
parser.add_argument("--start-e", type=float, default=1,
help="the starting epsilon for exploration")
parser.add_argument("--end-e", type=float, default=0.02,
parser.add_argument("--end-e", type=float, default=0.05,
help="the ending epsilon for exploration")
parser.add_argument("--exploration-fraction", type=float, default=0.5,
help="the fraction of `total-timesteps` it takes from start-e to go end-e")
parser.add_argument("--learning-starts", type=int, default=10000,
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=1,
parser.add_argument("--train-frequency", type=int, default=10,
help="the frequency of training")
args = parser.parse_args()
# fmt: on
Expand Down Expand Up @@ -156,7 +156,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device, optimize_memory_usage=True
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
)
start_time = time.time()

Expand Down Expand Up @@ -195,7 +199,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
# ALGO LOGIC: training.
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)

with torch.no_grad():
_, next_pmfs = target_network.get_action(data.next_observations)
next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones)
Expand All @@ -219,6 +222,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()

if global_step % 100 == 0:
writer.add_scalar("losses/loss", loss.item(), global_step)
old_val = (old_pmfs * q_network.atoms).sum(1)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
Expand Down
8 changes: 7 additions & 1 deletion cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device, optimize_memory_usage=True
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
optimize_memory_usage=True,
handle_timeout_termination=True,
)
start_time = time.time()

Expand Down Expand Up @@ -240,6 +245,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()

if global_step % 100 == 0:
writer.add_scalar("losses/loss", loss.item(), global_step)
old_val = (old_pmfs * q_network.atoms).sum(1)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
Expand Down
181 changes: 181 additions & 0 deletions docs/rl-algorithms/c51.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Categorical DQN (C51)

## Overview

C51 introduces a distributional perspective for DQN: instead of learning a single value for an action, C51 learns to predict a distribution of values for the action. Empirically, C51 demonstrates impressive performance in ALE.


Original papers:

* [A Distributional Perspective on Reinforcement Learning](https://arxiv.org/abs/1707.06887)

## Implemented Variants


| Variants Implemented | Description |
| ----------- | ----------- |
| :material-github: [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py), :material-file-document: [docs](/rl-algorithms/c51/#c51_ataripy) | For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques. |
| :material-github: [`c51.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py), :material-file-document: [docs](/rl-algorithms/c51/#c51py) | For classic control tasks like `CartPole-v1`. |


Below are our single-file implementations of C51:


## `c51_atari.py`

The [c51_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py) has the following features:

* For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
* Works with the `Discrete` action space

### Usage

```bash
poetry install -E atari
python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/c51_atari.py --env-id PongNoFrameskip-v4
```


### Explanation of the logged metrics

Running `python cleanrl/c51_atari.py` will automatically record various metrics such as actor or value losses in Tensorboard. Below is the documentation for these metrics:

* `charts/episodic_return`: episodic return of the game
* `charts/SPS`: number of steps per second
* `losses/loss`: the cross entropy loss between the $t$ step state value distribution and the projected $t+1$ step state value distribution
* `losses/q_values`: implemented as `(old_pmfs * q_network.atoms).sum(1)`, which is the sum of the probability of getting returns $x$ (`old_pmfs`) multiplied by $x$ (`q_network.atoms`), averaged over the sample obtained from the replay buffer; useful when gauging if under or over estimation happens


### Implementation details

[c51_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py) is based on (Bellemare et al., 2017)[^1] but presents a few implementation differences:

1. (Bellemare et al., 2017)[^1] injects stochaticity by doing "on each frame the environment rejects the agent’s selected action with probability $p = 0.25$", but `c51_atari.py` does not do this
1. `c51_atari.py` use a self-contained evaluation scheme: `c51_atari.py` reports the episodic returns obtained throughout training, whereas (Bellemare et al., 2017)[^1] is trained with `--end-e=0.01` but reported episodic returns using a separate evaluation process with `--end-e=0.001` (See "5.2. State-of-the-Art Results" on page 7).
1. `c51_atari.py` rescales the gradient so that the norm of the parameters does not exceed `0.5` like done in PPO (:material-github: [ppo2/model.py#L102-L108](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L102-L108)).


### Experiment results

PR :material-github: [vwxyzjn/cleanrl#159](https://github.com/vwxyzjn/cleanrl/pull/159) tracks our effort to conduct experiments, and the reprodudction instructions can be found at :material-github: [vwxyzjn/cleanrl/benchmark/c51](https://github.com/vwxyzjn/cleanrl/tree/master/benchmark/c51).

Below are the average episodic returns for `c51_atari.py`.


| Environment | `c51_atari.py` 10M steps | (Bellemare et al., 2017, Figure 14)[^1] 50M steps | (Hessel et al., 2017, Figure 5)[^3]
| ----------- | ----------- | ----------- | ---- |
| BreakoutNoFrameskip-v4 | 467.00 ± 96.11 | 748 | ~500 at 10M steps, ~600 at 50M steps
| PongNoFrameskip-v4 | 19.32 ± 0.92 | 20.9 | ~20 10M steps, ~20 at 50M steps
| BeamRiderNoFrameskip-v4 | 9986.96 ± 1953.30 | 14,074 | ~12000 10M steps, ~14000 at 50M steps


Note that we save computational time by reducing timesteps from 50M to 10M, but our `c51_atari.py` scores the same or higher than (Mnih et al., 2015)[^1] in 10M steps.


Learning curves:

<div class="grid-container">
<img src="../c51/BeamRiderNoFrameskip-v4.png">

<img src="../c51/BreakoutNoFrameskip-v4.png">

<img src="../c51/PongNoFrameskip-v4.png">
</div>


Tracked experiments and game play videos:

<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Atari-CleanRL-s-C51--VmlldzoxNzI0NzQ0" style="width:100%; height:500px" title="CleanRL C51 Tracked Experiments"></iframe>



## `c51.py`

The [c51.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py) has the following features:

* Works with the `Box` observation space of low-level features
* Works with the `Discrete` action space
* Works with envs like `CartPole-v1`


### Usage

```bash
python cleanrl/c51.py --env-id CartPole-v1
```


### Explanation of the logged metrics

See [related docs](/rl-algorithms/c51/#explanation-of-the-logged-metrics) for `c51_atari.py`.

### Implementation details

The [c51.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py) shares the same implementation details as [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py) except the `c51.py` runs with different hyperparameters and neural network architecture. Specifically,

1. `c51.py` uses a simpler neural network as follows:
```python
self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, env.single_action_space.n),
)
```
2. `c51.py` runs with different hyperparameters:

```bash
python c51.py --total-timesteps 500000 \
--learning-rate 2.5e-4 \
--buffer-size 10000 \
--gamma 0.99 \
--target-network-frequency 500 \
--max-grad-norm 0.5 \
--batch-size 128 \
--start-e 1 \
--end-e 0.05 \
--exploration-fraction 0.5 \
--learning-starts 10000 \
--train-frequency 10
```


### Experiment results

PR :material-github: [vwxyzjn/cleanrl#159](https://github.com/vwxyzjn/cleanrl/pull/159) tracks our effort to conduct experiments, and the reprodudction instructions can be found at :material-github: [vwxyzjn/cleanrl/benchmark/c51](https://github.com/vwxyzjn/cleanrl/tree/master/benchmark/c51).

Below are the average episodic returns for `c51.py`.


| Environment | `c51.py` |
| ----------- | ----------- |
| CartPole-v1 | 498.51 ± 1.77 |
| Acrobot-v1 | -88.81 ± 8.86 |
| MountainCar-v0 | -167.71 ± 26.85 |


Note that the C51 has no official benchmark on classic control environments, so we did not include a comparison. That said, our `c51.py` was able to achieve near perfect scores in `CartPole-v1` and `Acrobot-v1`; further, it can obtain successful runs in the sparse environment `MountainCar-v0`.


Learning curves:

<div class="grid-container">
<img src="../c51/CartPole-v1.png">

<img src="../c51/Acrobot-v1.png">

<img src="../c51/MountainCar-v0.png">
</div>


Tracked experiments and game play videos:

<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Classic-Control-CleanRL-s-C51--VmlldzoxODIwMTE4" style="width:100%; height:500px" title="CleanRL C51 Tracked Experiments"></iframe>


[^1]:Bellemare, M.G., Dabney, W., & Munos, R. (2017). A Distributional Perspective on Reinforcement Learning. ICML.
[^2]:\[Proposal\] Formal API handling of truncation vs termination. https://github.com/openai/gym/issues/2510
[^3]: Hessel, M., Modayil, J., Hasselt, H.V., Schaul, T., Ostrovski, G., Dabney, W., Horgan, D., Piot, B., Azar, M.G., & Silver, D. (2018). Rainbow: Combining Improvements in Deep Reinforcement Learning. AAAI.
Binary file added docs/rl-algorithms/c51/Acrobot-v1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/c51/CartPole-v1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/c51/MountainCar-v0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/c51/PongNoFrameskip-v4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ nav:
- rl-algorithms/overview.md
- rl-algorithms/ppo.md
- rl-algorithms/dqn.md
- rl-algorithms/c51.md
- rl-algorithms/ddpg.md
- rl-algorithms/sac.md
- rl-algorithms/td3.md
Expand Down