Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7e71448
feat(configs): add `tags` config option
maxreciprocate Mar 6, 2023
3c710cf
feat(scripts): add benchmark tools
maxreciprocate Mar 6, 2023
5ab91d2
refactor(reference): clean up debug prints
maxreciprocate Mar 6, 2023
5264961
style(reference): satisfy isort
maxreciprocate Mar 6, 2023
658ca6d
style(reference): satisfy CI's isort
maxreciprocate Mar 6, 2023
4d04cc9
feat(scripts/benchmark): add `ppo_sentiments_t5`
maxreciprocate Mar 7, 2023
445cdc2
fix(benchmark): `ddp` -> `zero2-bf16` even with 1 process
maxreciprocate Mar 8, 2023
5d36f3f
fix(benchmark): rename `wandb` project name
maxreciprocate Mar 8, 2023
8fa0b74
feat(reference): separate metrics per experiment
maxreciprocate Mar 9, 2023
f54f14b
chore(benchmark): add `ppo_hh` to runs, but keep it under 2 hours
maxreciprocate Mar 9, 2023
f8d4c6e
feat(reference): add git hashes to descriptions
maxreciprocate Mar 9, 2023
bcf4a4f
fix(ppo_sentiments_t5): use `hparams` from sys.argv
maxreciprocate Mar 9, 2023
b399b2b
chore(benchmark): limit `ppo_hh`'s `total_steps` across branches
maxreciprocate Mar 10, 2023
617b75d
fix(reference): set `max_runs_to_show` to 2
maxreciprocate Mar 10, 2023
6d476db
feat(benchmark): add hh 6b to set of runs
maxreciprocate Mar 13, 2023
b5778c1
style: satisfy black
maxreciprocate Mar 13, 2023
55d2a7a
feat(reference): add a few simple prints
maxreciprocate Mar 13, 2023
4c0b11f
Merge branch 'main' into add-benchmark-tools
maxreciprocate Mar 22, 2023
0ee4f8a
feat(benchmark): pin dependencies
maxreciprocate Mar 22, 2023
711fb4d
fix(benchmark): ignore git apply patch failed error
maxreciprocate Mar 22, 2023
fb8cdb7
chore(README): add a link to reference runs
maxreciprocate Mar 26, 2023
4e60ee2
refactor(reference): move script under `trlx` (same as sweeps)
maxreciprocate Mar 28, 2023
d42e306
chore(benchmark): remove patch for other branches
maxreciprocate Mar 28, 2023
cef0b5e
revert(ppo_hh): restore default `total_steps`
maxreciprocate Mar 28, 2023
f595b87
style(reference): satisfy black
maxreciprocate Mar 28, 2023
d5d2367
style(reference): satisfy isort
maxreciprocate Mar 28, 2023
5db8e6b
feat(README): add benchmarking instruction
maxreciprocate Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions 0001-feat-configs-add-tags-config-option.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
From 7e7144868f71f437e10a15868b99d2cfcc571f3e Mon Sep 17 00:00:00 2001
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should hide this file away in a subfolder

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to delete this altogether right before this pr is merged, since it just patches other branches with a subset of changes this pr introduces

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, that makes sense

From: reciprocated <[email protected]>
Date: Mon, 6 Mar 2023 10:52:20 +0200
Subject: [PATCH] feat(configs): add `tags` config option

---
trlx/data/configs.py | 3 ++-
trlx/trainer/accelerate_base_trainer.py | 2 +-
2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/trlx/data/configs.py b/trlx/data/configs.py
index 2029700..725b570 100644
--- a/trlx/data/configs.py
+++ b/trlx/data/configs.py
@@ -1,6 +1,6 @@
from copy import deepcopy
from dataclasses import dataclass, field
-from typing import Any, Dict, Optional, Set
+from typing import Any, Dict, Optional, Set, List

import yaml

@@ -218,6 +218,7 @@ class TrainConfig:

tracker: Optional[str] = "wandb"
logging_dir: Optional[str] = None
+ tags: Optional[List[str]] = field(default_factory=list)

seed: int = 1000

diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py
index 0e92efc..0a27804 100644
--- a/trlx/trainer/accelerate_base_trainer.py
+++ b/trlx/trainer/accelerate_base_trainer.py
@@ -91,7 +91,7 @@ class AccelerateRLTrainer(BaseRLTrainer):
"name": run_name,
"entity": self.config.train.entity_name,
"group": self.config.train.group_name,
- "tags": ["/".join(get_git_tag())],
+ "tags": self.config.train.tags + ["/".join(get_git_tag())],
"mode": "disabled" if os.environ.get("debug", False) else "online",
}

--
2.30.1 (Apple Git-130)
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ For more usage see [examples](./examples). You can also try the colab notebooks
| Simulacra (GPT2, ILQL) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_simulacra.ipynb)|
| Sentiment (GPT2, ILQL) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_sentiments.ipynb)|

Latest runs of the examples are on our [Weights & Biases](https://wandb.ai/sorry/trlx-references/reportlist)

## How to Train

You can train a model using a reward function or a reward-labeled dataset.
Expand Down
3 changes: 2 additions & 1 deletion examples/hh/ppo_hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
train=TrainConfig(
seq_length=1024,
epochs=10000,
total_steps=10000,
total_steps=2000,
batch_size=4,
checkpoint_interval=10000,
eval_interval=500,
Expand Down Expand Up @@ -85,6 +85,7 @@
default_config.method.chunk_size = 16
elif config_name == "6B":
default_config.train.batch_size = 4
default_config.train.seq_length = 512
default_config.train.total_steps = 6000
default_config.train.checkpoint_dir = "checkpoints/ppo_hh_6B"
default_config.model.model_path = "Dahoas/pythia-6B-static-sft"
Expand Down
5 changes: 4 additions & 1 deletion examples/ppo_sentiments_t5.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import os
import sys
from typing import Dict, List

import numpy as np
Expand Down Expand Up @@ -165,4 +167,5 @@ def tokenize(sample):


if __name__ == "__main__":
main()
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
69 changes: 69 additions & 0 deletions scripts/benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/bin/bash
set -e

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add set args to make script die on error

origin=CarperAI/trlx
branch=main
entity=null
only_hash=false
only_tiny=false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no way to set this via references.py


while [[ "$#" -gt 0 ]]; do
case $1 in
--origin) origin="$2"; shift ;;
--branch) branch="$2"; shift ;;
--public) entity='"CarperAI"' ;;
--only_hash) only_hash=true ;;
--only_tiny) only_tiny=true ;;
*) echo "Unknown parameter passed: $1"; exit 1 ;;
esac
shift
done

dir=`mktemp -d -p .`
if [ ! -d "$dir" ]; then
echo "Couldn't create a temporary directory, aborting"
exit 1
fi

cd $dir
trap "rm -rf ../$dir" EXIT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any case where $dir could be null/empty string? might be good to add as otherwise you could accidentally rm -rf the containing directory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think there is, but an additional check would hurt either


git clone --depth 1 --single-branch -b $branch https://github.com/$origin .

# temporary, adding `tags` config option to old branches
git apply ../0001-feat-configs-add-tags-config-option.patch ||:

hash=`find . -not \( -path ./.git -prune \) -not -name "*.md" -type f -print0 | sort -z | xargs -0 sha1sum | sha1sum | cut -f1 -d" "`
git_hash=`git log --format=%h/%s/%as -n1`

if [ "$only_hash" = true ]; then
echo "$hash"
echo "$git_hash"
exit 0
fi

python -m venv venv
. venv/bin/activate
python -m pip install pip --upgrade
pip install -r requirements.txt
pip install -e .

args='{"train": {"project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth adding the GPU name to the tags (interconnect would be cool too if there's some easy way to get that)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already logged under "System Hardware" in w&b as GPU type: NVIDIA A100-SXM4-80GB, or do you want to see it specifically as a tag?

python examples/randomwalks/ilql_randomwalks.py "$args"
python examples/randomwalks/ppo_randomwalks.py "$args"

if [ "$only_tiny" = true ]; then
exit 0
fi

rm -rf ../benchmark_logs && mkdir ../benchmark_logs

CUDA_VISIBLE_DEVICES=0 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8880 examples/ppo_sentiments.py "$args" > ../benchmark_logs/ppo_sentiments.log 2>&1 &
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it happen to change the timings at all if you run all of them together on the same machine vs one by one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA_VISIBLE_DEVICES=1 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8881 examples/sft_sentiments.py "$args" > ../benchmark_logs/sft_sentiments.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8882 examples/ilql_sentiments.py "$args" > ../benchmark_logs/ilql_sentiments.log 2>&1 &
CUDA_VISIBLE_DEVICES=3 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8883 examples/ppo_sentiments_t5.py "$args" > ../benchmark_logs/ppo_sentiments_t5.log 2>&1 &

wait

args='{"train": {"total_steps": 1500, "seq_length": 512, "project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}'
CONFIG_NAME=6B accelerate launch --num_processes 7 --config_file configs/accelerate/zero2-bf16.yaml examples/hh/ppo_hh.py "$args"
109 changes: 109 additions & 0 deletions scripts/reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# python scripts/reference.py CarperAI/trlx:convert-examples-configs --against CarperAI/trlx:main

import argparse
import os
import subprocess

import wandb
import wandb.apis.reports as wb

parser = argparse.ArgumentParser()
parser.add_argument("branch", type=str, help="Git branch in the format `origin:branch`")
parser.add_argument("--against", type=str, default="CarperAI/trlx:main", help="Reference git branch")
parser.add_argument("--public", action="store_true", help="Use CarperAI entity to store/pull from w&b runs")
args = parser.parse_args()

pr_origin = ref_origin = "CarperAI/trlx"
pr_branch = args.branch
ref_branch = args.against
if ':' in pr_branch:
pr_origin, pr_branch = pr_branch.rsplit(':', 1)
if ':' in ref_branch:
ref_origin, ref_branch = ref_branch.rsplit(':', 1)

out = os.popen(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} --only_hash")
pr_hash, pr_git_hash = [x[:-1] for x in out.readlines()]

out = os.popen(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} --only_hash")
ref_hash, ref_git_hash = [x[:-1] for x in out.readlines()]

print(f"{pr_origin}:{pr_branch=} {pr_hash=} {pr_git_hash=}")
print(f"{ref_origin}:{ref_branch} {ref_hash=} {ref_git_hash=}")

api = wandb.Api()
project_name = "CarperAI/trlx-references" if args.public else "trlx-references"
public = "--public" if args.public else ""

runs = api.runs(project_name, filters={"tags": {"$in": [ref_hash]}})
if runs:
print(f"On {ref_branch} @{ref_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}")
else:
print(f"Making runs on {ref_branch} @{ref_git_hash}")
subprocess.run(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} {public}".split())

runs = api.runs(project_name, filters={"tags": {"$in": [pr_hash]}})
if runs:
print(f"On {pr_branch} @{pr_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}")
else:
print(f"Making runs on {pr_branch} @{pr_git_hash}")
subprocess.run(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} {public}".split())

report = wb.Report(
project=project_name.split('/')[1] if args.public else project_name,
title=f"{pr_branch} v. {ref_branch}",
description=f"{pr_branch}\n@{pr_git_hash}\n\n{ref_branch}\n@{ref_git_hash}",
)
blocks = []

experiment_names = set(x.name.split(':')[0] for x in api.runs(project_name))
for name in experiment_names:
filters = {"$and": [
{"display_name": {"$regex": f"^{name}"}},
{"tags": {"$in": [pr_hash, ref_hash]}}
]}

runs = api.runs(project_name, filters=filters)
metrics = set(sum([[metric for metric in run.history().columns if not metric.startswith("_")] for run in runs], []))

metrics_panels = [
wb.LinePlot(
title=f"{metric}",
x="Step",
y=[metric],
title_x="Step",
smoothing_show_original=True,
max_runs_to_show=2,
plot_type="line",
font_size="auto",
legend_position="north",
) for metric in metrics
]

# sort the most important metrics to be shown first
major_metrics = set()
for metric in metrics:
if metric.startswith("reward") or metric.startswith("metric"):
major_metrics.add(metric)
metrics = metrics - major_metrics

blocks.extend([
wb.H1(text=name),
wb.PanelGrid(
panels=[panel for panel in metrics_panels if panel.title in major_metrics],
runsets=[wb.Runset(
project=project_name,
filters=filters
)],
),
wb.PanelGrid(
panels=[panel for panel in metrics_panels if panel.title in metrics],
runsets=[wb.Runset(
project=project_name,
filters=filters
)],
),
])

report.blocks = blocks
report.save()
print(report.url)
3 changes: 2 additions & 1 deletion trlx/data/configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set
from typing import Any, Dict, List, Optional, Set

import yaml

Expand Down Expand Up @@ -220,6 +220,7 @@ class TrainConfig:

tracker: Optional[str] = "wandb"
logging_dir: Optional[str] = None
tags: Optional[List[str]] = field(default_factory=list)

seed: int = 1000

Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, config, **kwargs): # noqa: C901
"name": run_name,
"entity": self.config.train.entity_name,
"group": self.config.train.group_name,
"tags": ["/".join(get_git_tag())],
"tags": self.config.train.tags + ["/".join(get_git_tag())],
"mode": "disabled" if os.environ.get("debug", False) else "online",
}

Expand Down