Skip to content

Commit 0270960

Browse files
authored
Update TrainConfig optimizer hyperparameters (#82)
* Update `TrainConfig` optimizer hyperparameters * Update `TrainConfig` docstring
1 parent 87f6127 commit 0270960

File tree

7 files changed

+35
-38
lines changed

7 files changed

+35
-38
lines changed

configs/ilql_config.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@ train:
1010
epochs: 10
1111
total_steps: 10000
1212

13-
lr_ramp_steps: 100
14-
lr_decay_steps: 3366
15-
weight_decay: 1e-6
16-
learning_rate_init: 1e-4
17-
learning_rate_target: 1e-4
13+
lr_init: 1.0e-4
14+
lr_target: 1.0e-4
1815
opt_betas: [0.9, 0.95]
16+
opt_eps: 1.0e-8
17+
weight_decay: 1.0e-6
1918

2019
checkpoint_interval: 1000
2120
eval_interval: 16

configs/ppo_config.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@ train:
1010
total_steps: 10000 # Train for max(epochs, total_steps)
1111
batch_size: 128 # batch size
1212

13-
lr_ramp_steps: 100 # learning rate warm up
14-
lr_decay_steps: 79000 # learning rate decay
15-
weight_decay: 1.0e-6 # weight decay param
16-
learning_rate_init: 1.412e-4 # init learning rate
17-
learning_rate_target: 1.412e-4 # target final learning rate
13+
lr_init: 1.412e-4 # init learning rate
14+
lr_target: 1.412e-4 # target final learning rate
1815
opt_betas: [0.9, 0.95] # adam betas
16+
opt_eps: 1.0e-8 # adam eps
17+
weight_decay: 1.0e-6 # weight decay param
1918

2019
checkpoint_interval: 10000 # checkpoint interval
2120
eval_interval: 16 # eval interval

configs/ppo_gptj.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@ train:
1010
total_steps: 80000 # Train for max(epochs, total_steps)
1111
batch_size: 8 # batch size
1212

13-
lr_ramp_steps: 100 # learning rate warm up
14-
lr_decay_steps: 79000 # learning rate decay
15-
weight_decay: 1.0e-6 # weight decay param
16-
learning_rate_init: 1.412e-4 # init learning rate
17-
learning_rate_target: 1.412e-4 # target final learning rate
13+
lr_init: 1.412e-4 # init learning rate
14+
lr_target: 1.412e-4 # target final learning rate
1815
opt_betas: [0.9, 0.95] # adam betas
16+
opt_eps: 1.0e-8 # adam eps
17+
weight_decay: 1.0e-6 # weight decay param
1918

2019
checkpoint_interval: 1000000 # checkpoint interval
2120
eval_interval: 16 # eval interval

configs/test_config.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@ train:
1010
total_steps: 1000 # Train for max(epochs, total_steps)
1111
batch_size: 16 # batch size
1212

13-
lr_ramp_steps: 100 # learning rate warm up
14-
lr_decay_steps: 79000 # learning rate decay
15-
weight_decay: 1.0e-6 # weight decay param
16-
learning_rate_init: 1.412e-4 # init learning rate
17-
learning_rate_target: 1.412e-4 # target final learning rate
13+
lr_init: 1.412e-4 # init learning rate
14+
lr_target: 1.412e-4 # target final learning rate
1815
opt_betas: [0.9, 0.95] # adam betas
16+
opt_eps: 1.0e-8 # adam eps
17+
weight_decay: 1.0e-6 # weight decay param
1918

2019
checkpoint_interval: 10000 # checkpoint interval
2120
eval_interval: 128 # eval interval

examples/randomwalks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def metric_fn(samples):
9292
config = TRLConfig.load_yaml("configs/ilql_config.yml")
9393
config.train.gen_size = 10
9494
config.train.epochs = 100
95-
config.train.learning_rate_init = 1e-3
95+
config.train.lr_init = 1e-3
9696
config.method.alpha = 0.1
9797

9898
config.model.tokenizer_path = ""

trlx/data/configs.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,20 @@ class TrainConfig:
4848
:param batch_size: Batch size for training
4949
:type batch_size: int
5050
51-
:param lr_ramp_steps: Number of steps before learning rate reaches learning_rate_init
52-
:type lr_ramp_steps: int
51+
:param lr_init: Initial learning rate value
52+
:type lr_init: float
5353
54-
:param lr_decay_steps: Number of after ramp up steps before learning rate decays to learning_rate_target
55-
:type lr_decay_steps: int
54+
:param lr_target: Target learning rate after decay
55+
:type lr_target: float
5656
57-
:param weight_decay: Weight decay for optimizer
58-
:type weight_decay: float
57+
:param opt_betas: Beta parameters for Adam optimizer
58+
:type opt_betas: Tuple[float]
5959
60-
:param learning_rate_init: Initial learning rate after ramp up
61-
:type learning_rate_init: float
60+
:param opt_eps: Epsilon for optimizer
61+
:type opt_eps: float
6262
63-
:param learning_rate_target: Target learning rate after decay
64-
:type learning_rate_target: float
63+
:param weight_decay: Weight decay for optimizer
64+
:type weight_decay: float
6565
6666
:param checkpoint_interval: Save model every checkpoint_interval steps
6767
:type checkpoint_interval: int
@@ -90,12 +90,11 @@ class TrainConfig:
9090
epochs: int
9191
batch_size: int
9292

93-
lr_ramp_steps: int
94-
lr_decay_steps: int
95-
weight_decay: float
96-
learning_rate_init: float
97-
learning_rate_target: float
93+
lr_init: float
94+
lr_target: float
9895
opt_betas: Tuple[float]
96+
opt_eps: float
97+
weight_decay: float
9998

10099
checkpoint_interval: int
101100
eval_interval: int

trlx/model/accelerate_base_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,16 @@ def __init__(self, config, train_mode=True):
8080

8181
self.opt = torch.optim.AdamW(
8282
self.model.parameters(),
83-
lr=float(self.config.train.learning_rate_init),
83+
lr=self.config.train.lr_init,
8484
betas=self.config.train.opt_betas,
85+
eps=self.config.train.opt_eps,
86+
weight_decay=self.config.train.weight_decay,
8587
)
8688

8789
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
8890
self.opt,
8991
self.config.train.total_steps,
90-
eta_min=float(self.config.train.learning_rate_target),
92+
eta_min=self.config.train.lr_target,
9193
)
9294

9395
def tokenize(self, text: Iterable[str]):

0 commit comments

Comments
 (0)