-
Notifications
You must be signed in to change notification settings - Fork 482
[feat] Add benchmark tools #357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
7e71448
3c710cf
5ab91d2
5264961
658ca6d
4d04cc9
445cdc2
5d36f3f
8fa0b74
f54f14b
f8d4c6e
bcf4a4f
b399b2b
617b75d
6d476db
b5778c1
55d2a7a
4c0b11f
0ee4f8a
711fb4d
fb8cdb7
4e60ee2
d42e306
cef0b5e
f595b87
d5d2367
5db8e6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
From 7e7144868f71f437e10a15868b99d2cfcc571f3e Mon Sep 17 00:00:00 2001 | ||
From: reciprocated <[email protected]> | ||
Date: Mon, 6 Mar 2023 10:52:20 +0200 | ||
Subject: [PATCH] feat(configs): add `tags` config option | ||
|
||
--- | ||
trlx/data/configs.py | 3 ++- | ||
trlx/trainer/accelerate_base_trainer.py | 2 +- | ||
2 files changed, 3 insertions(+), 2 deletions(-) | ||
|
||
diff --git a/trlx/data/configs.py b/trlx/data/configs.py | ||
index 2029700..725b570 100644 | ||
--- a/trlx/data/configs.py | ||
+++ b/trlx/data/configs.py | ||
@@ -1,6 +1,6 @@ | ||
from copy import deepcopy | ||
from dataclasses import dataclass, field | ||
-from typing import Any, Dict, Optional, Set | ||
+from typing import Any, Dict, Optional, Set, List | ||
|
||
import yaml | ||
|
||
@@ -218,6 +218,7 @@ class TrainConfig: | ||
|
||
tracker: Optional[str] = "wandb" | ||
logging_dir: Optional[str] = None | ||
+ tags: Optional[List[str]] = field(default_factory=list) | ||
|
||
seed: int = 1000 | ||
|
||
diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py | ||
index 0e92efc..0a27804 100644 | ||
--- a/trlx/trainer/accelerate_base_trainer.py | ||
+++ b/trlx/trainer/accelerate_base_trainer.py | ||
@@ -91,7 +91,7 @@ class AccelerateRLTrainer(BaseRLTrainer): | ||
"name": run_name, | ||
"entity": self.config.train.entity_name, | ||
"group": self.config.train.group_name, | ||
- "tags": ["/".join(get_git_tag())], | ||
+ "tags": self.config.train.tags + ["/".join(get_git_tag())], | ||
"mode": "disabled" if os.environ.get("debug", False) else "online", | ||
} | ||
|
||
-- | ||
2.30.1 (Apple Git-130) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add |
||
origin=CarperAI/trlx | ||
branch=main | ||
entity=null | ||
only_hash=false | ||
only_tiny=false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's no way to set this via |
||
|
||
while [[ "$#" -gt 0 ]]; do | ||
case $1 in | ||
--origin) origin="$2"; shift ;; | ||
--branch) branch="$2"; shift ;; | ||
--public) entity='"CarperAI"' ;; | ||
--only_hash) only_hash=true ;; | ||
--only_tiny) only_tiny=true ;; | ||
*) echo "Unknown parameter passed: $1"; exit 1 ;; | ||
esac | ||
shift | ||
done | ||
|
||
dir=`mktemp -d -p .` | ||
if [ ! -d "$dir" ]; then | ||
echo "Couldn't create a temporary directory, aborting" | ||
exit 1 | ||
fi | ||
|
||
cd $dir | ||
trap "rm -rf ../$dir" EXIT | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there any case where $dir could be null/empty string? might be good to add as otherwise you could accidentally rm -rf the containing directory There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think there is, but an additional check would hurt either |
||
|
||
git clone --depth 1 --single-branch -b $branch https://github.com/$origin . | ||
|
||
# temporary, adding `tags` config option to old branches | ||
git apply ../0001-feat-configs-add-tags-config-option.patch ||: | ||
|
||
hash=`find . -not \( -path ./.git -prune \) -not -name "*.md" -type f -print0 | sort -z | xargs -0 sha1sum | sha1sum | cut -f1 -d" "` | ||
git_hash=`git log --format=%h/%s/%as -n1` | ||
|
||
if [ "$only_hash" = true ]; then | ||
echo "$hash" | ||
echo "$git_hash" | ||
exit 0 | ||
fi | ||
|
||
python -m venv venv | ||
. venv/bin/activate | ||
python -m pip install pip --upgrade | ||
pip install -r requirements.txt | ||
pip install -e . | ||
|
||
args='{"train": {"project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe worth adding the GPU name to the tags (interconnect would be cool too if there's some easy way to get that) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's already logged under "System Hardware" in w&b as |
||
python examples/randomwalks/ilql_randomwalks.py "$args" | ||
python examples/randomwalks/ppo_randomwalks.py "$args" | ||
|
||
if [ "$only_tiny" = true ]; then | ||
exit 0 | ||
fi | ||
|
||
rm -rf ../benchmark_logs && mkdir ../benchmark_logs | ||
|
||
CUDA_VISIBLE_DEVICES=0 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8880 examples/ppo_sentiments.py "$args" > ../benchmark_logs/ppo_sentiments.log 2>&1 & | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it happen to change the timings at all if you run all of them together on the same machine vs one by one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nothing noticable https://wandb.ai/sorry/trlx/reports/Timing-difference--VmlldzozODA4MTA4 |
||
CUDA_VISIBLE_DEVICES=1 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8881 examples/sft_sentiments.py "$args" > ../benchmark_logs/sft_sentiments.log 2>&1 & | ||
CUDA_VISIBLE_DEVICES=2 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8882 examples/ilql_sentiments.py "$args" > ../benchmark_logs/ilql_sentiments.log 2>&1 & | ||
CUDA_VISIBLE_DEVICES=3 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8883 examples/ppo_sentiments_t5.py "$args" > ../benchmark_logs/ppo_sentiments_t5.log 2>&1 & | ||
|
||
wait | ||
|
||
args='{"train": {"total_steps": 1500, "seq_length": 512, "project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}' | ||
CONFIG_NAME=6B accelerate launch --num_processes 7 --config_file configs/accelerate/zero2-bf16.yaml examples/hh/ppo_hh.py "$args" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# python scripts/reference.py CarperAI/trlx:convert-examples-configs --against CarperAI/trlx:main | ||
|
||
import argparse | ||
import os | ||
import subprocess | ||
|
||
import wandb | ||
import wandb.apis.reports as wb | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("branch", type=str, help="Git branch in the format `origin:branch`") | ||
parser.add_argument("--against", type=str, default="CarperAI/trlx:main", help="Reference git branch") | ||
parser.add_argument("--public", action="store_true", help="Use CarperAI entity to store/pull from w&b runs") | ||
args = parser.parse_args() | ||
|
||
pr_origin = ref_origin = "CarperAI/trlx" | ||
pr_branch = args.branch | ||
ref_branch = args.against | ||
if ':' in pr_branch: | ||
pr_origin, pr_branch = pr_branch.rsplit(':', 1) | ||
if ':' in ref_branch: | ||
ref_origin, ref_branch = ref_branch.rsplit(':', 1) | ||
|
||
out = os.popen(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} --only_hash") | ||
pr_hash, pr_git_hash = [x[:-1] for x in out.readlines()] | ||
|
||
out = os.popen(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} --only_hash") | ||
ref_hash, ref_git_hash = [x[:-1] for x in out.readlines()] | ||
|
||
print(f"{pr_origin}:{pr_branch=} {pr_hash=} {pr_git_hash=}") | ||
print(f"{ref_origin}:{ref_branch} {ref_hash=} {ref_git_hash=}") | ||
|
||
api = wandb.Api() | ||
project_name = "CarperAI/trlx-references" if args.public else "trlx-references" | ||
public = "--public" if args.public else "" | ||
|
||
runs = api.runs(project_name, filters={"tags": {"$in": [ref_hash]}}) | ||
if runs: | ||
print(f"On {ref_branch} @{ref_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}") | ||
else: | ||
print(f"Making runs on {ref_branch} @{ref_git_hash}") | ||
subprocess.run(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} {public}".split()) | ||
|
||
runs = api.runs(project_name, filters={"tags": {"$in": [pr_hash]}}) | ||
if runs: | ||
print(f"On {pr_branch} @{pr_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}") | ||
else: | ||
print(f"Making runs on {pr_branch} @{pr_git_hash}") | ||
subprocess.run(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} {public}".split()) | ||
|
||
report = wb.Report( | ||
project=project_name.split('/')[1] if args.public else project_name, | ||
title=f"{pr_branch} v. {ref_branch}", | ||
description=f"{pr_branch}\n@{pr_git_hash}\n\n{ref_branch}\n@{ref_git_hash}", | ||
) | ||
blocks = [] | ||
|
||
experiment_names = set(x.name.split(':')[0] for x in api.runs(project_name)) | ||
for name in experiment_names: | ||
filters = {"$and": [ | ||
{"display_name": {"$regex": f"^{name}"}}, | ||
{"tags": {"$in": [pr_hash, ref_hash]}} | ||
]} | ||
|
||
runs = api.runs(project_name, filters=filters) | ||
metrics = set(sum([[metric for metric in run.history().columns if not metric.startswith("_")] for run in runs], [])) | ||
|
||
metrics_panels = [ | ||
wb.LinePlot( | ||
title=f"{metric}", | ||
x="Step", | ||
y=[metric], | ||
title_x="Step", | ||
smoothing_show_original=True, | ||
max_runs_to_show=2, | ||
plot_type="line", | ||
font_size="auto", | ||
legend_position="north", | ||
) for metric in metrics | ||
] | ||
|
||
# sort the most important metrics to be shown first | ||
major_metrics = set() | ||
for metric in metrics: | ||
if metric.startswith("reward") or metric.startswith("metric"): | ||
major_metrics.add(metric) | ||
metrics = metrics - major_metrics | ||
|
||
blocks.extend([ | ||
wb.H1(text=name), | ||
wb.PanelGrid( | ||
panels=[panel for panel in metrics_panels if panel.title in major_metrics], | ||
runsets=[wb.Runset( | ||
project=project_name, | ||
filters=filters | ||
)], | ||
), | ||
wb.PanelGrid( | ||
panels=[panel for panel in metrics_panels if panel.title in metrics], | ||
runsets=[wb.Runset( | ||
project=project_name, | ||
filters=filters | ||
)], | ||
), | ||
]) | ||
|
||
report.blocks = blocks | ||
report.save() | ||
print(report.url) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should hide this file away in a subfolder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I plan to delete this altogether right before this pr is merged, since it just patches other branches with a subset of changes this pr introduces
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, that makes sense