Skip to content

Commit 114998b

Browse files
[feat] Add benchmark tools (#357)
* feat(configs): add `tags` config option * feat(scripts): add benchmark tools * refactor(reference): clean up debug prints * style(reference): satisfy isort * style(reference): satisfy CI's isort * feat(scripts/benchmark): add `ppo_sentiments_t5` * fix(benchmark): `ddp` -> `zero2-bf16` even with 1 process * fix(benchmark): rename `wandb` project name * feat(reference): separate metrics per experiment * chore(benchmark): add `ppo_hh` to runs, but keep it under 2 hours * feat(reference): add git hashes to descriptions * fix(ppo_sentiments_t5): use `hparams` from sys.argv * chore(benchmark): limit `ppo_hh`'s `total_steps` across branches * fix(reference): set `max_runs_to_show` to 2 * feat(benchmark): add hh 6b to set of runs * style: satisfy black * feat(reference): add a few simple prints * feat(benchmark): pin dependencies * fix(benchmark): ignore git apply patch failed error * chore(README): add a link to reference runs * refactor(reference): move script under `trlx` (same as sweeps) * chore(benchmark): remove patch for other branches * revert(ppo_hh): restore default `total_steps` * style(reference): satisfy black * style(reference): satisfy isort * feat(README): add benchmarking instruction
1 parent 9d85215 commit 114998b

File tree

7 files changed

+184
-3
lines changed

7 files changed

+184
-3
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ For more usage see [examples](./examples). You can also try the colab notebooks
3535
| Simulacra (GPT2, ILQL) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_simulacra.ipynb)|
3636
| Sentiment (GPT2, ILQL) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_sentiments.ipynb)|
3737

38+
Latest runs of the examples are on our [Weights & Biases](https://wandb.ai/sorry/trlx-references/reportlist)
39+
3840
## How to Train
3941

4042
You can train a model using a reward function or a reward-labeled dataset.
@@ -99,6 +101,11 @@ For more usage see the [NeMo README](./trlx/models)
99101
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py
100102
```
101103

104+
#### Benchmark your trlX fork against trlX's `main` branch
105+
```bash
106+
python -m trlx.reference octocat/trlx-fork:fix-branch
107+
```
108+
102109
## Logging
103110

104111
trlX uses the standard Python `logging` library to log training information to the console. The default logger is set to the `INFO` level, which means that `INFO`, `WARNING`, `ERROR`, and `CRITICAL` level messages will be printed to standard output.

examples/hh/ppo_hh.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
default_config.method.chunk_size = 16
8686
elif config_name == "6B":
8787
default_config.train.batch_size = 4
88+
default_config.train.seq_length = 512
8889
default_config.train.total_steps = 6000
8990
default_config.train.checkpoint_dir = "checkpoints/ppo_hh_6B"
9091
default_config.model.model_path = "Dahoas/pythia-6B-static-sft"

examples/ppo_sentiments_t5.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
import os
3+
import sys
24
from typing import Dict, List
35

46
import numpy as np
@@ -166,4 +168,5 @@ def tokenize(sample):
166168

167169

168170
if __name__ == "__main__":
169-
main()
171+
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
172+
main(hparams)

scripts/benchmark.sh

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/bin/bash
2+
set -e
3+
4+
origin=CarperAI/trlx
5+
branch=main
6+
entity=null
7+
only_hash=false
8+
only_tiny=false
9+
10+
while [[ "$#" -gt 0 ]]; do
11+
case $1 in
12+
--origin) origin="$2"; shift ;;
13+
--branch) branch="$2"; shift ;;
14+
--public) entity='"CarperAI"' ;;
15+
--only_hash) only_hash=true ;;
16+
--only_tiny) only_tiny=true ;;
17+
*) echo "Unknown parameter passed: $1"; exit 1 ;;
18+
esac
19+
shift
20+
done
21+
22+
dir=`mktemp -d -p .`
23+
if [ ! -d "$dir" ]; then
24+
echo "Couldn't create a temporary directory, aborting"
25+
exit 1
26+
fi
27+
28+
cd $dir
29+
trap "rm -rf ../$dir" EXIT
30+
31+
git clone --depth 1 --single-branch -b $branch https://github.com/$origin .
32+
33+
hash=`find . -not \( -path ./.git -prune \) -not -name "*.md" -type f -print0 | sort -z | xargs -0 sha1sum | sha1sum | cut -f1 -d" "`
34+
git_hash=`git log --format=%h/%s/%as -n1`
35+
36+
if [ "$only_hash" = true ]; then
37+
echo "$hash"
38+
echo "$git_hash"
39+
exit 0
40+
fi
41+
42+
python -m venv venv
43+
. venv/bin/activate
44+
python -m pip install pip --upgrade
45+
pip install -r requirements.txt
46+
pip install -e .
47+
48+
args='{"train": {"project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}'
49+
python examples/randomwalks/ilql_randomwalks.py "$args"
50+
python examples/randomwalks/ppo_randomwalks.py "$args"
51+
52+
if [ "$only_tiny" = true ]; then
53+
exit 0
54+
fi
55+
56+
rm -rf ../benchmark_logs && mkdir ../benchmark_logs
57+
58+
CUDA_VISIBLE_DEVICES=0 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8880 examples/ppo_sentiments.py "$args" > ../benchmark_logs/ppo_sentiments.log 2>&1 &
59+
CUDA_VISIBLE_DEVICES=1 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8881 examples/sft_sentiments.py "$args" > ../benchmark_logs/sft_sentiments.log 2>&1 &
60+
CUDA_VISIBLE_DEVICES=2 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8882 examples/ilql_sentiments.py "$args" > ../benchmark_logs/ilql_sentiments.log 2>&1 &
61+
CUDA_VISIBLE_DEVICES=3 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8883 examples/ppo_sentiments_t5.py "$args" > ../benchmark_logs/ppo_sentiments_t5.log 2>&1 &
62+
63+
wait
64+
65+
args='{"train": {"total_steps": 1500, "seq_length": 512, "project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}'
66+
CONFIG_NAME=6B accelerate launch --num_processes 7 --config_file configs/accelerate/zero2-bf16.yaml examples/hh/ppo_hh.py "$args"

trlx/data/configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from copy import deepcopy
22
from dataclasses import dataclass, field
3-
from typing import Any, Dict, Optional, Set
3+
from typing import Any, Dict, List, Optional, Set
44

55
import yaml
66

@@ -220,6 +220,7 @@ class TrainConfig:
220220

221221
tracker: Optional[str] = "wandb"
222222
logging_dir: Optional[str] = None
223+
tags: Optional[List[str]] = field(default_factory=list)
223224

224225
seed: int = 1000
225226

trlx/reference.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# python -m trlx.reference CarperAI/trlx:add-benchmark-tools --against CarperAI/trlx:main
2+
3+
import argparse
4+
import os
5+
import subprocess
6+
7+
import wandb
8+
import wandb.apis.reports as wb
9+
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("branch", type=str, help="Git branch in the format `origin:branch`")
12+
parser.add_argument("--against", type=str, default="CarperAI/trlx:main", help="Reference git branch")
13+
parser.add_argument("--public", action="store_true", help="Use CarperAI entity to store/pull from w&b runs")
14+
args = parser.parse_args()
15+
16+
pr_origin = ref_origin = "CarperAI/trlx"
17+
pr_branch = args.branch
18+
ref_branch = args.against
19+
if ":" in pr_branch:
20+
pr_origin, pr_branch = pr_branch.rsplit(":", 1)
21+
if ":" in ref_branch:
22+
ref_origin, ref_branch = ref_branch.rsplit(":", 1)
23+
24+
out = os.popen(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} --only_hash")
25+
pr_hash, pr_git_hash = [x[:-1] for x in out.readlines()]
26+
27+
out = os.popen(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} --only_hash")
28+
ref_hash, ref_git_hash = [x[:-1] for x in out.readlines()]
29+
30+
print(f"{pr_origin}:{pr_branch=} {pr_hash=} {pr_git_hash=}")
31+
print(f"{ref_origin}:{ref_branch} {ref_hash=} {ref_git_hash=}")
32+
33+
api = wandb.Api()
34+
project_name = "CarperAI/trlx-references" if args.public else "trlx-references"
35+
public = "--public" if args.public else ""
36+
37+
runs = api.runs(project_name, filters={"tags": {"$in": [ref_hash]}})
38+
if runs:
39+
print(f"On {ref_branch} @{ref_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}")
40+
else:
41+
print(f"Making runs on {ref_branch} @{ref_git_hash}")
42+
subprocess.run(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} {public}".split())
43+
44+
runs = api.runs(project_name, filters={"tags": {"$in": [pr_hash]}})
45+
if runs:
46+
print(f"On {pr_branch} @{pr_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}")
47+
else:
48+
print(f"Making runs on {pr_branch} @{pr_git_hash}")
49+
subprocess.run(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} {public}".split())
50+
51+
report = wb.Report(
52+
project=project_name.split("/")[1] if args.public else project_name,
53+
title=f"{pr_branch} v. {ref_branch}",
54+
description=f"{pr_branch}\n@{pr_git_hash}\n\n{ref_branch}\n@{ref_git_hash}",
55+
)
56+
blocks = []
57+
58+
experiment_names = set(x.name.split(":")[0] for x in api.runs(project_name))
59+
for name in experiment_names:
60+
filters = {"$and": [{"display_name": {"$regex": f"^{name}"}}, {"tags": {"$in": [pr_hash, ref_hash]}}]}
61+
62+
runs = api.runs(project_name, filters=filters)
63+
metrics = set(sum([[metric for metric in run.history().columns if not metric.startswith("_")] for run in runs], []))
64+
65+
metrics_panels = [
66+
wb.LinePlot(
67+
title=f"{metric}",
68+
x="Step",
69+
y=[metric],
70+
title_x="Step",
71+
smoothing_show_original=True,
72+
max_runs_to_show=2,
73+
plot_type="line",
74+
font_size="auto",
75+
legend_position="north",
76+
)
77+
for metric in metrics
78+
]
79+
80+
# sort the most important metrics to be shown first
81+
major_metrics = set()
82+
for metric in metrics:
83+
if metric.startswith("reward") or metric.startswith("metric"):
84+
major_metrics.add(metric)
85+
metrics = metrics - major_metrics
86+
87+
blocks.extend(
88+
[
89+
wb.H1(text=name),
90+
wb.PanelGrid(
91+
panels=[panel for panel in metrics_panels if panel.title in major_metrics],
92+
runsets=[wb.Runset(project=project_name, filters=filters)],
93+
),
94+
wb.PanelGrid(
95+
panels=[panel for panel in metrics_panels if panel.title in metrics],
96+
runsets=[wb.Runset(project=project_name, filters=filters)],
97+
),
98+
]
99+
)
100+
101+
report.blocks = blocks
102+
report.save()
103+
print(report.url)

trlx/trainer/accelerate_base_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(self, config, **kwargs): # noqa: C901
9292
"name": run_name,
9393
"entity": self.config.train.entity_name,
9494
"group": self.config.train.group_name,
95-
"tags": ["/".join(get_git_tag())],
95+
"tags": self.config.train.tags + ["/".join(get_git_tag())],
9696
"mode": "disabled" if os.environ.get("debug", False) else "online",
9797
}
9898

0 commit comments

Comments
 (0)