Skip to content

Commit ff0d077

Browse files
Restructure sweeps for reuse (#102)
* chore(readme): update instructions * refactor(sweep): reuse existing examples and configs * fix(sweep): enable checkpointing for hyperband * feat(sweep): add accelerate support * fix(sweep): report with new params space * feat(sweep): replace generic names * chore(ppo_config): update better values * chore(sweep): set max_concurrent_trials to default * chore(examples): update the rest of examples to a new main signature * chore(readme): update sweep instruction * chore(sweep): add warning/confirmation check before importing * chore(sweep): update sweep instruction * update(config): to more stable values
1 parent 3db86ca commit ff0d077

File tree

17 files changed

+292
-225
lines changed

17 files changed

+292
-225
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ accelerate launch examples/simulacra.py
4141

4242
#### Use Ray Tune to launch hyperparameter sweep
4343
```bash
44-
python train_sweep.py --config configs/ray_tune_configs/ppo_config.yml --example-name ppo_sentiments
44+
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py
4545
```
4646

4747
For more usage see [examples](./examples)

configs/ilql_config.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ train:
88
seq_length: 64
99
batch_size: 128
1010
epochs: 100
11-
total_steps: 10000
11+
total_steps: 1000
1212

13-
lr_init: 1.0e-4
14-
lr_target: 1.0e-4
13+
lr_init: 5.0e-5
14+
lr_target: 5.0e-5
1515
opt_betas: [0.9, 0.95]
1616
opt_eps: 1.0e-8
1717
weight_decay: 1.0e-6
1818

1919
checkpoint_interval: 1000
20-
eval_interval: 128
20+
eval_interval: 100
2121

2222
pipeline: "PromptPipeline"
2323
orchestrator: "OfflineOrchestrator"
@@ -29,7 +29,7 @@ method:
2929
gamma: 0.99
3030
cql_scale: 0.1
3131
awac_scale: 1
32-
alpha: 0.005
33-
steps_for_target_q_sync: 1
34-
betas: [16]
32+
alpha: 0.001
33+
steps_for_target_q_sync: 5
34+
betas: [4]
3535
two_qs: true

configs/ppo_config.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ train:
1010
total_steps: 10000 # Train for max(epochs, total_steps)
1111
batch_size: 128 # batch size
1212

13-
lr_init: 1.412e-4 # init learning rate
14-
lr_target: 1.412e-4 # target final learning rate
13+
lr_init: 1.0e-4 # init learning rate
14+
lr_target: 1.0e-4 # target final learning rate
1515
opt_betas: [0.9, 0.95] # adam betas
1616
opt_eps: 1.0e-8 # adam eps
1717
weight_decay: 1.0e-6 # weight decay param
1818

1919
checkpoint_interval: 10000 # checkpoint interval
20-
eval_interval: 16 # eval interval
20+
eval_interval: 100 # eval interval
2121

2222
pipeline: "PromptPipeline" # prompt pipeline to load
2323
orchestrator: "PPOOrchestrator" # orchestrator to load
@@ -28,15 +28,15 @@ method:
2828
num_rollouts: 128 # Number of rollouts to collect per epoch
2929
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
3030
ppo_epochs: 4 # Number of ppo epochs
31-
init_kl_coef: 0.2 # init kl coefficient
31+
init_kl_coef: 0.05 # init kl coefficient
3232
target: 6 # target kl coefficient, set None for fixed kl coef
3333
horizon: 10000 # PPO horizon
3434
gamma: 1 # PPO discount
3535
lam: 0.95 # PPO lambda
3636
cliprange: 0.2 # clip range
3737
cliprange_value: 0.2 # clip range
38-
vf_coef: 2.3 # value term weight
39-
scale_reward: "running" # False | "ref" | "running" estimate against which to scale rewards
38+
vf_coef: 1 # value term weight
39+
scale_reward: False # False | "ref" | "running" estimate against which to scale rewards
4040
ref_mean: null
4141
ref_std: null # rescale rewards with this deviation
4242
cliprange_reward: 10

configs/ray_tune_configs/ppo_config.yml

Lines changed: 0 additions & 68 deletions
This file was deleted.

configs/sweeps/ilql_sweep.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
tune_config:
2+
mode: "max"
3+
metric: "metrics/sentiments"
4+
search_alg: "random"
5+
scheduler: "fifo"
6+
num_samples: 32
7+
8+
lr_init:
9+
strategy: "loguniform"
10+
values: [0.00001, 0.01]
11+
tau:
12+
strategy: "uniform"
13+
values: [0.6, 0.9]
14+
steps_for_target_q_sync:
15+
strategy: "choice"
16+
values: [1, 5, 10]
17+
alpha:
18+
strategy: "loguniform"
19+
values: [0.001, 1.0]

configs/sweeps/ppo_sweep.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
tune_config:
2+
mode: "max"
3+
metric: "mean_reward"
4+
search_alg: "random"
5+
scheduler: "fifo"
6+
num_samples: 32
7+
8+
# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs
9+
lr_init:
10+
strategy: "loguniform"
11+
values: [0.00001, 0.01]
12+
init_kl_coef:
13+
strategy: "uniform"
14+
values: [0, 0.2]
15+
vf_coef:
16+
strategy: "uniform"
17+
values: [0.5, 2]

examples/architext.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,33 @@ def reward_fn(samples):
99
return [-sample.count(":") for sample in samples]
1010

1111

12+
prompts = [
13+
"[prompt] the bedroom is adjacent to the living room [layout]",
14+
"[prompt] a bedroom is adjacent to the living room [layout]",
15+
"[prompt] the bedroom is adjacent to the kitchen [layout]",
16+
"[prompt] a bedroom is adjacent to the kitchen [layout]",
17+
"[prompt] the bedroom is adjacent to the kitchen [layout]",
18+
"[prompt] the kitchen is adjacent to the bathroom [layout]",
19+
"[prompt] a bathroom is adjacent to the living room [layout]",
20+
"[prompt] the bathroom is adjacent to the living room [layout]",
21+
"[prompt] the bedroom is not adjacent to the living room [layout]",
22+
"[prompt] a bedroom is not adjacent to the living room [layout]",
23+
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
24+
"[prompt] a bedroom is not adjacent to the kitchen [layout]",
25+
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
26+
"[prompt] the kitchen is not adjacent to the bathroom [layout]",
27+
]
28+
29+
default_config = yaml.safe_load(open("configs/ppo_config.yml"))
30+
31+
32+
def main(hparams={}):
33+
config = TRLConfig.update(default_config, hparams)
34+
35+
model = trlx.train(
36+
"architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config
37+
)
38+
39+
1240
if __name__ == "__main__":
13-
prompts = [
14-
"[prompt] the bedroom is adjacent to the living room [layout]",
15-
"[prompt] a bedroom is adjacent to the living room [layout]",
16-
"[prompt] the bedroom is adjacent to the kitchen [layout]",
17-
"[prompt] a bedroom is adjacent to the kitchen [layout]",
18-
"[prompt] the bedroom is adjacent to the kitchen [layout]",
19-
"[prompt] the kitchen is adjacent to the bathroom [layout]",
20-
"[prompt] a bathroom is adjacent to the living room [layout]",
21-
"[prompt] the bathroom is adjacent to the living room [layout]",
22-
"[prompt] the bedroom is not adjacent to the living room [layout]",
23-
"[prompt] a bedroom is not adjacent to the living room [layout]",
24-
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
25-
"[prompt] a bedroom is not adjacent to the kitchen [layout]",
26-
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
27-
"[prompt] the kitchen is not adjacent to the bathroom [layout]",
28-
]
29-
30-
model = trlx.train("architext/gptj-162M", reward_fn=reward_fn, prompts=prompts)
41+
main()

examples/experiments/grounded_program_synthesis/train_trlx.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
# Toy example of optimizing textual interior designs to output the least number of rooms
2-
# Also see https://architext.design/
31
import trlx
42
from trlx.data.configs import TRLConfig
53
from lang import Interpreter
64
import json
75
import logging
6+
import yaml
87

98

109
logger = logging.getLogger(__name__)
@@ -50,6 +49,25 @@ def reward_fn(samples):
5049
return reward_list
5150

5251

52+
default_config = yaml.safe_load(open("config/trlx_ppo_config.yml"))
53+
54+
55+
def main(hparams={}):
56+
config = TRLConfig.update(default_config, hparams)
57+
58+
# Dataset
59+
dataset = DSLDataset()
60+
train_prompts = list(dataset.load_datapoints(split="train"))[:1000]
61+
62+
model = trlx.train(
63+
"reshinthadith/codegen_350M_list_manip_5_len",
64+
reward_fn=reward_fn,
65+
prompts=train_prompts,
66+
config=config,
67+
)
68+
model.save_pretrained("dataset/trained_model")
69+
70+
5371
if __name__ == "__main__":
5472
# TEST REWARD FUNTION
5573
assert (
@@ -67,15 +85,5 @@ def reward_fn(samples):
6785
["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"]
6886
)
6987
) == [-0.5]
70-
# Datset
71-
dataset = DSLDataset()
72-
train_prompts = list(dataset.load_datapoints(split="train"))[:1000]
73-
trl_config = TRLConfig.load_yaml("config/trlx_ppo_config.yml")
7488

75-
model = trlx.train(
76-
"reshinthadith/codegen_350M_list_manip_5_len",
77-
reward_fn=reward_fn,
78-
prompts=train_prompts,
79-
config=trl_config,
80-
)
81-
model.save_pretrained("dataset/trained_model")
89+
main()

examples/ilql_sentiments.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
from transformers import pipeline
33

44
import trlx
5+
import yaml
56
from typing import List, Dict
67
import os
8+
from trlx.data.configs import TRLConfig
79

810

911
def get_positive_score(scores):
1012
"Extract value associated with a positive sentiment from pipeline's output"
1113
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]
1214

1315

14-
def main():
16+
default_config = yaml.safe_load(open("configs/ilql_config.yml"))
17+
18+
19+
def main(hparams={}):
20+
config = TRLConfig.update(default_config, hparams)
21+
1522
sentiment_fn = pipeline(
1623
"sentiment-analysis",
1724
"lvwerra/distilbert-imdb",
@@ -32,6 +39,7 @@ def metric_fn(samples: List[str]) -> Dict[str, List[float]]:
3239
dataset=(imdb["text"], imdb["label"]),
3340
eval_prompts=["I don't know much about Hungarian underground"] * 64,
3441
metric_fn=metric_fn,
42+
config=config,
3543
)
3644

3745

examples/ppo_sentiments.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
from datasets import load_dataset
55
from transformers import pipeline
66
import os
7+
import yaml
78

89
import trlx
910
import torch
1011
from typing import List
12+
from trlx.data.configs import TRLConfig
1113

1214

1315
def get_positive_score(scores):
1416
"Extract value associated with a positive sentiment from pipeline's output"
1517
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]
1618

1719

18-
def main():
20+
default_config = yaml.safe_load(open("configs/ppo_config.yml"))
21+
22+
23+
def main(hparams={}):
24+
config = TRLConfig.update(default_config, hparams)
25+
1926
if torch.cuda.is_available():
2027
device = int(os.environ.get("LOCAL_RANK", 0))
2128
else:
@@ -43,6 +50,7 @@ def reward_fn(samples: List[str]) -> List[float]:
4350
reward_fn=reward_fn,
4451
prompts=prompts,
4552
eval_prompts=["I don't know much about Hungarian underground"] * 64,
53+
config=config,
4654
)
4755

4856

0 commit comments

Comments
 (0)