Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions configs/ilql_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ train:
epochs: 10
total_steps: 10000

lr_ramp_steps: 100
lr_decay_steps: 3366
weight_decay: 1e-6
learning_rate_init: 1e-4
learning_rate_target: 1e-4
lr_init: 1.0e-4
lr_target: 1.0e-4
opt_betas: [0.9, 0.95]
opt_eps: 1.0e-8
weight_decay: 1.0e-6

checkpoint_interval: 1000
eval_interval: 16
Expand Down
9 changes: 4 additions & 5 deletions configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ train:
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

lr_ramp_steps: 100 # learning rate warm up
lr_decay_steps: 79000 # learning rate decay
weight_decay: 1.0e-6 # weight decay param
learning_rate_init: 1.412e-4 # init learning rate
learning_rate_target: 1.412e-4 # target final learning rate
lr_init: 1.412e-4 # init learning rate
lr_target: 1.412e-4 # target final learning rate
opt_betas: [0.9, 0.95] # adam betas
opt_eps: 1.0e-8 # adam eps
weight_decay: 1.0e-6 # weight decay param

checkpoint_interval: 10000 # checkpoint interval
eval_interval: 16 # eval interval
Expand Down
9 changes: 4 additions & 5 deletions configs/ppo_gptj.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ train:
total_steps: 80000 # Train for max(epochs, total_steps)
batch_size: 8 # batch size

lr_ramp_steps: 100 # learning rate warm up
lr_decay_steps: 79000 # learning rate decay
weight_decay: 1.0e-6 # weight decay param
learning_rate_init: 1.412e-4 # init learning rate
learning_rate_target: 1.412e-4 # target final learning rate
lr_init: 1.412e-4 # init learning rate
lr_target: 1.412e-4 # target final learning rate
opt_betas: [0.9, 0.95] # adam betas
opt_eps: 1.0e-8 # adam eps
weight_decay: 1.0e-6 # weight decay param

checkpoint_interval: 1000000 # checkpoint interval
eval_interval: 16 # eval interval
Expand Down
9 changes: 4 additions & 5 deletions configs/test_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ train:
total_steps: 1000 # Train for max(epochs, total_steps)
batch_size: 16 # batch size

lr_ramp_steps: 100 # learning rate warm up
lr_decay_steps: 79000 # learning rate decay
weight_decay: 1.0e-6 # weight decay param
learning_rate_init: 1.412e-4 # init learning rate
learning_rate_target: 1.412e-4 # target final learning rate
lr_init: 1.412e-4 # init learning rate
lr_target: 1.412e-4 # target final learning rate
opt_betas: [0.9, 0.95] # adam betas
opt_eps: 1.0e-8 # adam eps
weight_decay: 1.0e-6 # weight decay param

checkpoint_interval: 10000 # checkpoint interval
eval_interval: 128 # eval interval
Expand Down
2 changes: 1 addition & 1 deletion examples/randomwalks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def metric_fn(samples):
config = TRLConfig.load_yaml("configs/ilql_config.yml")
config.train.gen_size = 10
config.train.epochs = 100
config.train.learning_rate_init = 1e-3
config.train.lr_init = 1e-3
config.method.alpha = 0.1

config.model.tokenizer_path = ""
Expand Down
29 changes: 14 additions & 15 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,20 @@ class TrainConfig:
:param batch_size: Batch size for training
:type batch_size: int

:param lr_ramp_steps: Number of steps before learning rate reaches learning_rate_init
:type lr_ramp_steps: int
:param lr_init: Initial learning rate value
:type lr_init: float

:param lr_decay_steps: Number of after ramp up steps before learning rate decays to learning_rate_target
:type lr_decay_steps: int
:param lr_target: Target learning rate after decay
:type lr_target: float

:param weight_decay: Weight decay for optimizer
:type weight_decay: float
:param opt_betas: Beta parameters for Adam optimizer
:type opt_betas: Tuple[float]

:param learning_rate_init: Initial learning rate after ramp up
:type learning_rate_init: float
:param opt_eps: Epsilon for optimizer
:type opt_eps: float

:param learning_rate_target: Target learning rate after decay
:type learning_rate_target: float
:param weight_decay: Weight decay for optimizer
:type weight_decay: float

:param checkpoint_interval: Save model every checkpoint_interval steps
:type checkpoint_interval: int
Expand Down Expand Up @@ -90,12 +90,11 @@ class TrainConfig:
epochs: int
batch_size: int

lr_ramp_steps: int
lr_decay_steps: int
weight_decay: float
learning_rate_init: float
learning_rate_target: float
lr_init: float
lr_target: float
opt_betas: Tuple[float]
opt_eps: float
weight_decay: float

checkpoint_interval: int
eval_interval: int
Expand Down
6 changes: 4 additions & 2 deletions trlx/model/accelerate_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,16 @@ def __init__(self, config, train_mode=True):

self.opt = torch.optim.AdamW(
self.model.parameters(),
lr=float(self.config.train.learning_rate_init),
lr=self.config.train.lr_init,
betas=self.config.train.opt_betas,
eps=self.config.train.opt_eps,
weight_decay=self.config.train.weight_decay,
)

self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.opt,
self.config.train.total_steps,
eta_min=float(self.config.train.learning_rate_target),
eta_min=self.config.train.lr_target,
)

def tokenize(self, text: Iterable[str]):
Expand Down