Skip to content

Commit ded2e5e

Browse files
fix(ppo_trainer): update AdaptiveKLController with correct KL (#361)
1 parent adbf8fc commit ded2e5e

File tree

1 file changed

+21
-35
lines changed

1 file changed

+21
-35
lines changed

trlx/trainer/accelerate_ppo_trainer.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def loss(self, batch: PPORLBatch):
204204
returns=returns,
205205
mask=mask,
206206
)
207-
self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats
207+
208208
return loss, stats
209209

210210
def setup_rollout_logging(self, config):
@@ -232,7 +232,7 @@ def post_epoch_callback(self):
232232
self.make_experience(self.config.method.num_rollouts, self.iter_count)
233233

234234
def post_backward_callback(self):
235-
self.kl_ctl.update(self.approx_kl, n_steps=self.config.train.batch_size)
235+
self.kl_ctl.update(self.mean_kl.item(), n_steps=self.config.train.batch_size)
236236

237237
def prepare_learning(self):
238238
eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size)
@@ -438,52 +438,34 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
438438
ref_logprobs = ref_logprobs.cpu()
439439
prompt_tensors = prompt_tensors.cpu()
440440
sample_outputs = sample_outputs.cpu()
441+
values = values.cpu()[:, :-1]
441442

442443
# Estimate the KL divergence between the model and reference model
443444
if self.config.model.model_arch_type == "seq2seq":
444-
values = values.cpu()[:, :-1]
445+
attention_mask = sample_outputs != self.tokenizer.pad_token_id
445446
start = 0
446-
447-
# Get the number of non-padding tokens for each sample
448-
# This assumes all padding is on the right side
449-
padding_token: int = 0
450-
ends = (sample_outputs[:, start:] != padding_token).sum(1)
451-
452-
# Get the logprobs and values, for tokens that are not padding
453-
# or beginning of sequences tokens. These are from the model
454-
# (not the reference model)
455-
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]
456-
all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
457-
458-
kl_divergence_estimate: List[torch.Tensor] = [
459-
-self.kl_ctl.value
460-
* (
461-
logprobs[sample_idx, start : ends[sample_idx]]
462-
- ref_logprobs[sample_idx, start : ends[sample_idx]]
463-
)
464-
for sample_idx in range(n_samples)
465-
]
466-
467-
# Else if not seq2seq (i.e. causal)
468447
else:
469-
values = values.cpu()[:, :-1]
470448
start = prompt_tensors.shape[1] - 1
471-
ends = start + attention_mask[:, start:].sum(1)
472-
all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
473-
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]
474449

475-
kl_divergence_estimate = -self.kl_ctl.value * (logprobs - ref_logprobs)
476-
kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)]
450+
ends = start + attention_mask[:, start:].sum(1)
451+
452+
# Get the logprobs and values, for tokens that are not padding
453+
# or beginning of sequences tokens. These are from the model (not the reference model)
454+
all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
455+
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]
456+
457+
log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1].cpu()
458+
self.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device)
459+
kl_penalty = self.kl_ctl.value * -log_ratio
460+
kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)]
477461

478462
rollout_count = 0
479463

480464
for sample_idx in range(n_samples):
481-
sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx]
482-
483-
if len(sample_kl_divergence_estimate) == 0 or len(all_logprobs[sample_idx]) == 0:
465+
if len(kl_penalty[sample_idx]) == 0 or len(all_logprobs[sample_idx]) == 0:
484466
continue
485467

486-
rewards = sample_kl_divergence_estimate
468+
rewards = kl_penalty[sample_idx]
487469
rewards[-1] += scores[sample_idx].cpu()
488470

489471
ppo_rl_elements.append(
@@ -502,6 +484,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
502484
tbar.update(min(rollout_count, num_rollouts))
503485
tbar.close()
504486

487+
if torch.distributed.is_initialized():
488+
torch.distributed.all_reduce(self.mean_kl, torch.distributed.ReduceOp.AVG)
489+
490+
stats["policy/sqrt_kl"] = torch.sqrt(self.mean_kl)
505491
stats["kl_ctl_value"] = self.kl_ctl.value
506492
stats["time/exp"] = exp_time
507493

0 commit comments

Comments
 (0)