File tree Expand file tree Collapse file tree 5 files changed +54
-9
lines changed
examples/experiments/grounded_program_synthesis Expand file tree Collapse file tree 5 files changed +54
-9
lines changed Original file line number Diff line number Diff line change @@ -35,6 +35,10 @@ method:
35
35
cliprange : 0.2 # clip range
36
36
cliprange_value : 0.2 # clip range
37
37
vf_coef : 0.2 # value term weight
38
+ scale_reward : False # False | "ref" | "running" estimate against which to scale rewards
39
+ ref_mean : null
40
+ ref_std : null # rescale rewards with this deviation
41
+ cliprange_reward : 10
38
42
gen_kwargs :
39
43
max_length : 48 # LM max sample gen length
40
44
min_length : 48 # LM min sample gen length
Original file line number Diff line number Diff line change @@ -10,12 +10,11 @@ train:
10
10
total_steps : 80000 # Train for max(epochs, total_steps)
11
11
batch_size : 8 # batch size
12
12
13
- lr_ramp_steps : 100 # learning rate warm up
14
- lr_decay_steps : 79000 # learning rate decay
15
- weight_decay : 1.0e-6 # weight decay param
16
- learning_rate_init : 1.412e-4 # init learning rate
17
- learning_rate_target : 1.412e-4 # target final learning rate
13
+ lr_init : 1.412e-4 # init learning rate
14
+ lr_target : 1.412e-4 # target final learning rate
18
15
opt_betas : [0.9, 0.95] # adam betas
16
+ opt_eps : 1.0e-8 # adam eps
17
+ weight_decay : 1.0e-6 # weight decay param
19
18
20
19
checkpoint_interval : 1000000 # checkpoint interval
21
20
eval_interval : 16 # eval interval
@@ -36,6 +35,10 @@ method:
36
35
cliprange : 0.2 # clip range
37
36
cliprange_value : 0.2 # clip range
38
37
vf_coef : 0.2 # value term weight
38
+ scale_reward : False # False|"ref"|"running" estimate against which to scale rewards
39
+ cliprange_reward : 10
40
+ ref_mean : null
41
+ ref_std : null
39
42
gen_kwargs :
40
43
max_length : 256 # LM max sample gen length
41
44
min_length : 48 # LM min sample gen length
Original file line number Diff line number Diff line change 1
- import random
2
1
import copy
2
+ import json
3
+ import random
4
+ from pathlib import Path
3
5
from pprint import pprint
6
+
4
7
from tqdm import tqdm
5
- import json
6
8
from transformers import AutoTokenizer
7
9
8
10
@@ -388,5 +390,6 @@ def basic_stats(dataset, tokenizer):
388
390
test_data = create_synthetic_dataset (2_000 )
389
391
print (f"Train data size: { len (train_data )} " )
390
392
print (f"Test data size: { len (test_data )} " )
393
+ Path ("dataset" ).mkdir (parents = True , exist_ok = True )
391
394
write_to_json (train_data , "dataset/train.json" )
392
395
write_to_json (test_data , "dataset/test.json" )
Original file line number Diff line number Diff line change @@ -49,7 +49,7 @@ def reward_fn(samples):
49
49
return reward_list
50
50
51
51
52
- default_config = yaml .safe_load (open ("config /trlx_ppo_config.yml" ))
52
+ default_config = yaml .safe_load (open ("configs /trlx_ppo_config.yml" ))
53
53
54
54
55
55
def main (hparams = {}):
@@ -60,7 +60,6 @@ def main(hparams={}):
60
60
train_prompts = list (dataset .load_datapoints (split = "train" ))[:1000 ]
61
61
62
62
model = trlx .train (
63
- "reshinthadith/codegen_350M_list_manip_5_len" ,
64
63
reward_fn = reward_fn ,
65
64
prompts = train_prompts ,
66
65
config = config ,
Original file line number Diff line number Diff line change
1
+ import os
2
+
3
+ from trlx .data .configs import TRLConfig
4
+ from typing import List
5
+
6
+
7
+ def _get_config_dirs (dir : str , config_dir_name : str = "configs" ) -> List [str ]:
8
+ """Returns all sub-directories of `dir` named `configs`."""
9
+ config_dirs = []
10
+ for root , dirs , _ in os .walk (dir ):
11
+ for d in dirs :
12
+ if d == config_dir_name :
13
+ config_dirs .append (os .path .join (root , d ))
14
+ return config_dirs
15
+
16
+
17
+ def _get_yaml_filepaths (dir : str ) -> List [str ]:
18
+ """Returns a list of `yml` filepaths in `dir`."""
19
+ filepaths = []
20
+ for file in os .listdir (dir ):
21
+ if file .endswith (".yml" ):
22
+ filepaths .append (os .path .join (dir , file ))
23
+ return filepaths
24
+
25
+
26
+ def test_repo_trl_configs ():
27
+ """Tests to ensure all default configs in the repository are valid."""
28
+ config_dirs = ["configs" , * _get_config_dirs ("examples" )]
29
+ config_files = sum (map (_get_yaml_filepaths , config_dirs ), []) # sum for flat-map behavior
30
+ for file in config_files :
31
+ assert os .path .isfile (file ), f"Config file { file } does not exist."
32
+ assert file .endswith (".yml" ), f"Config file { file } is not a yaml file."
33
+ try :
34
+ TRLConfig .load_yaml (file )
35
+ except Exception as e :
36
+ assert False , f"Failed to load config file `{ file } ` with error `{ e } `"
You can’t perform that action at this time.
0 commit comments