|
| 1 | +# python -m trlx.reference CarperAI/trlx:add-benchmark-tools --against CarperAI/trlx:main |
| 2 | + |
| 3 | +import argparse |
| 4 | +import os |
| 5 | +import subprocess |
| 6 | + |
| 7 | +import wandb |
| 8 | +import wandb.apis.reports as wb |
| 9 | + |
| 10 | +parser = argparse.ArgumentParser() |
| 11 | +parser.add_argument("branch", type=str, help="Git branch in the format `origin:branch`") |
| 12 | +parser.add_argument("--against", type=str, default="CarperAI/trlx:main", help="Reference git branch") |
| 13 | +parser.add_argument("--public", action="store_true", help="Use CarperAI entity to store/pull from w&b runs") |
| 14 | +args = parser.parse_args() |
| 15 | + |
| 16 | +pr_origin = ref_origin = "CarperAI/trlx" |
| 17 | +pr_branch = args.branch |
| 18 | +ref_branch = args.against |
| 19 | +if ":" in pr_branch: |
| 20 | + pr_origin, pr_branch = pr_branch.rsplit(":", 1) |
| 21 | +if ":" in ref_branch: |
| 22 | + ref_origin, ref_branch = ref_branch.rsplit(":", 1) |
| 23 | + |
| 24 | +out = os.popen(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} --only_hash") |
| 25 | +pr_hash, pr_git_hash = [x[:-1] for x in out.readlines()] |
| 26 | + |
| 27 | +out = os.popen(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} --only_hash") |
| 28 | +ref_hash, ref_git_hash = [x[:-1] for x in out.readlines()] |
| 29 | + |
| 30 | +print(f"{pr_origin}:{pr_branch=} {pr_hash=} {pr_git_hash=}") |
| 31 | +print(f"{ref_origin}:{ref_branch} {ref_hash=} {ref_git_hash=}") |
| 32 | + |
| 33 | +api = wandb.Api() |
| 34 | +project_name = "CarperAI/trlx-references" if args.public else "trlx-references" |
| 35 | +public = "--public" if args.public else "" |
| 36 | + |
| 37 | +runs = api.runs(project_name, filters={"tags": {"$in": [ref_hash]}}) |
| 38 | +if runs: |
| 39 | + print(f"On {ref_branch} @{ref_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}") |
| 40 | +else: |
| 41 | + print(f"Making runs on {ref_branch} @{ref_git_hash}") |
| 42 | + subprocess.run(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} {public}".split()) |
| 43 | + |
| 44 | +runs = api.runs(project_name, filters={"tags": {"$in": [pr_hash]}}) |
| 45 | +if runs: |
| 46 | + print(f"On {pr_branch} @{pr_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}") |
| 47 | +else: |
| 48 | + print(f"Making runs on {pr_branch} @{pr_git_hash}") |
| 49 | + subprocess.run(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} {public}".split()) |
| 50 | + |
| 51 | +report = wb.Report( |
| 52 | + project=project_name.split("/")[1] if args.public else project_name, |
| 53 | + title=f"{pr_branch} v. {ref_branch}", |
| 54 | + description=f"{pr_branch}\n@{pr_git_hash}\n\n{ref_branch}\n@{ref_git_hash}", |
| 55 | +) |
| 56 | +blocks = [] |
| 57 | + |
| 58 | +experiment_names = set(x.name.split(":")[0] for x in api.runs(project_name)) |
| 59 | +for name in experiment_names: |
| 60 | + filters = {"$and": [{"display_name": {"$regex": f"^{name}"}}, {"tags": {"$in": [pr_hash, ref_hash]}}]} |
| 61 | + |
| 62 | + runs = api.runs(project_name, filters=filters) |
| 63 | + metrics = set(sum([[metric for metric in run.history().columns if not metric.startswith("_")] for run in runs], [])) |
| 64 | + |
| 65 | + metrics_panels = [ |
| 66 | + wb.LinePlot( |
| 67 | + title=f"{metric}", |
| 68 | + x="Step", |
| 69 | + y=[metric], |
| 70 | + title_x="Step", |
| 71 | + smoothing_show_original=True, |
| 72 | + max_runs_to_show=2, |
| 73 | + plot_type="line", |
| 74 | + font_size="auto", |
| 75 | + legend_position="north", |
| 76 | + ) |
| 77 | + for metric in metrics |
| 78 | + ] |
| 79 | + |
| 80 | + # sort the most important metrics to be shown first |
| 81 | + major_metrics = set() |
| 82 | + for metric in metrics: |
| 83 | + if metric.startswith("reward") or metric.startswith("metric"): |
| 84 | + major_metrics.add(metric) |
| 85 | + metrics = metrics - major_metrics |
| 86 | + |
| 87 | + blocks.extend( |
| 88 | + [ |
| 89 | + wb.H1(text=name), |
| 90 | + wb.PanelGrid( |
| 91 | + panels=[panel for panel in metrics_panels if panel.title in major_metrics], |
| 92 | + runsets=[wb.Runset(project=project_name, filters=filters)], |
| 93 | + ), |
| 94 | + wb.PanelGrid( |
| 95 | + panels=[panel for panel in metrics_panels if panel.title in metrics], |
| 96 | + runsets=[wb.Runset(project=project_name, filters=filters)], |
| 97 | + ), |
| 98 | + ] |
| 99 | + ) |
| 100 | + |
| 101 | +report.blocks = blocks |
| 102 | +report.save() |
| 103 | +print(report.url) |
0 commit comments