Skip to content

Commit b60f05e

Browse files
authored
Add unit tests to ensure valid example configs (#120)
1 parent a94eefc commit b60f05e

File tree

5 files changed

+54
-9
lines changed

5 files changed

+54
-9
lines changed

configs/ppo_gptj.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ method:
3535
cliprange: 0.2 # clip range
3636
cliprange_value: 0.2 # clip range
3737
vf_coef: 0.2 # value term weight
38+
scale_reward: False # False | "ref" | "running" estimate against which to scale rewards
39+
ref_mean: null
40+
ref_std: null # rescale rewards with this deviation
41+
cliprange_reward: 10
3842
gen_kwargs:
3943
max_length: 48 # LM max sample gen length
4044
min_length: 48 # LM min sample gen length

examples/experiments/grounded_program_synthesis/config/trlx_ppo_config.yml renamed to examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@ train:
1010
total_steps: 80000 # Train for max(epochs, total_steps)
1111
batch_size: 8 # batch size
1212

13-
lr_ramp_steps: 100 # learning rate warm up
14-
lr_decay_steps: 79000 # learning rate decay
15-
weight_decay: 1.0e-6 # weight decay param
16-
learning_rate_init: 1.412e-4 # init learning rate
17-
learning_rate_target: 1.412e-4 # target final learning rate
13+
lr_init: 1.412e-4 # init learning rate
14+
lr_target: 1.412e-4 # target final learning rate
1815
opt_betas: [0.9, 0.95] # adam betas
16+
opt_eps: 1.0e-8 # adam eps
17+
weight_decay: 1.0e-6 # weight decay param
1918

2019
checkpoint_interval: 1000000 # checkpoint interval
2120
eval_interval: 16 # eval interval
@@ -36,6 +35,10 @@ method:
3635
cliprange: 0.2 # clip range
3736
cliprange_value: 0.2 # clip range
3837
vf_coef: 0.2 # value term weight
38+
scale_reward: False # False|"ref"|"running" estimate against which to scale rewards
39+
cliprange_reward: 10
40+
ref_mean: null
41+
ref_std: null
3942
gen_kwargs:
4043
max_length: 256 # LM max sample gen length
4144
min_length: 48 # LM min sample gen length

examples/experiments/grounded_program_synthesis/lang.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import random
21
import copy
2+
import json
3+
import random
4+
from pathlib import Path
35
from pprint import pprint
6+
47
from tqdm import tqdm
5-
import json
68
from transformers import AutoTokenizer
79

810

@@ -388,5 +390,6 @@ def basic_stats(dataset, tokenizer):
388390
test_data = create_synthetic_dataset(2_000)
389391
print(f"Train data size: {len(train_data)}")
390392
print(f"Test data size: {len(test_data)}")
393+
Path("dataset").mkdir(parents=True, exist_ok=True)
391394
write_to_json(train_data, "dataset/train.json")
392395
write_to_json(test_data, "dataset/test.json")

examples/experiments/grounded_program_synthesis/train_trlx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def reward_fn(samples):
4949
return reward_list
5050

5151

52-
default_config = yaml.safe_load(open("config/trlx_ppo_config.yml"))
52+
default_config = yaml.safe_load(open("configs/trlx_ppo_config.yml"))
5353

5454

5555
def main(hparams={}):
@@ -60,7 +60,6 @@ def main(hparams={}):
6060
train_prompts = list(dataset.load_datapoints(split="train"))[:1000]
6161

6262
model = trlx.train(
63-
"reshinthadith/codegen_350M_list_manip_5_len",
6463
reward_fn=reward_fn,
6564
prompts=train_prompts,
6665
config=config,

tests/test_configs.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
3+
from trlx.data.configs import TRLConfig
4+
from typing import List
5+
6+
7+
def _get_config_dirs(dir: str, config_dir_name: str = "configs") -> List[str]:
8+
"""Returns all sub-directories of `dir` named `configs`."""
9+
config_dirs = []
10+
for root, dirs, _ in os.walk(dir):
11+
for d in dirs:
12+
if d == config_dir_name:
13+
config_dirs.append(os.path.join(root, d))
14+
return config_dirs
15+
16+
17+
def _get_yaml_filepaths(dir: str) -> List[str]:
18+
"""Returns a list of `yml` filepaths in `dir`."""
19+
filepaths = []
20+
for file in os.listdir(dir):
21+
if file.endswith(".yml"):
22+
filepaths.append(os.path.join(dir, file))
23+
return filepaths
24+
25+
26+
def test_repo_trl_configs():
27+
"""Tests to ensure all default configs in the repository are valid."""
28+
config_dirs = ["configs", *_get_config_dirs("examples")]
29+
config_files = sum(map(_get_yaml_filepaths, config_dirs), []) # sum for flat-map behavior
30+
for file in config_files:
31+
assert os.path.isfile(file), f"Config file {file} does not exist."
32+
assert file.endswith(".yml"), f"Config file {file} is not a yaml file."
33+
try:
34+
TRLConfig.load_yaml(file)
35+
except Exception as e:
36+
assert False, f"Failed to load config file `{file}` with error `{e}`"

0 commit comments

Comments
 (0)