@@ -204,7 +204,7 @@ def loss(self, batch: PPORLBatch):
204
204
returns = returns ,
205
205
mask = mask ,
206
206
)
207
- self . approx_kl = stats [ "policy/approx_kl" ] # Update kl controller stats
207
+
208
208
return loss , stats
209
209
210
210
def setup_rollout_logging (self , config ):
@@ -232,7 +232,7 @@ def post_epoch_callback(self):
232
232
self .make_experience (self .config .method .num_rollouts , self .iter_count )
233
233
234
234
def post_backward_callback (self ):
235
- self .kl_ctl .update (self .approx_kl , n_steps = self .config .train .batch_size )
235
+ self .kl_ctl .update (self .mean_kl . item () , n_steps = self .config .train .batch_size )
236
236
237
237
def prepare_learning (self ):
238
238
eval_dataloader = self .eval_pipeline .create_loader (self .config .train .batch_size )
@@ -438,52 +438,34 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
438
438
ref_logprobs = ref_logprobs .cpu ()
439
439
prompt_tensors = prompt_tensors .cpu ()
440
440
sample_outputs = sample_outputs .cpu ()
441
+ values = values .cpu ()[:, :- 1 ]
441
442
442
443
# Estimate the KL divergence between the model and reference model
443
444
if self .config .model .model_arch_type == "seq2seq" :
444
- values = values . cpu ()[:, : - 1 ]
445
+ attention_mask = sample_outputs != self . tokenizer . pad_token_id
445
446
start = 0
446
-
447
- # Get the number of non-padding tokens for each sample
448
- # This assumes all padding is on the right side
449
- padding_token : int = 0
450
- ends = (sample_outputs [:, start :] != padding_token ).sum (1 )
451
-
452
- # Get the logprobs and values, for tokens that are not padding
453
- # or beginning of sequences tokens. These are from the model
454
- # (not the reference model)
455
- all_logprobs = [logprobs [ix , start : ends [ix ]] for ix in range (n_samples )]
456
- all_values = [values [ix , start : ends [ix ]] for ix in range (n_samples )]
457
-
458
- kl_divergence_estimate : List [torch .Tensor ] = [
459
- - self .kl_ctl .value
460
- * (
461
- logprobs [sample_idx , start : ends [sample_idx ]]
462
- - ref_logprobs [sample_idx , start : ends [sample_idx ]]
463
- )
464
- for sample_idx in range (n_samples )
465
- ]
466
-
467
- # Else if not seq2seq (i.e. causal)
468
447
else :
469
- values = values .cpu ()[:, :- 1 ]
470
448
start = prompt_tensors .shape [1 ] - 1
471
- ends = start + attention_mask [:, start :].sum (1 )
472
- all_values = [values [ix , start : ends [ix ]] for ix in range (n_samples )]
473
- all_logprobs = [logprobs [ix , start : ends [ix ]] for ix in range (n_samples )]
474
449
475
- kl_divergence_estimate = - self .kl_ctl .value * (logprobs - ref_logprobs )
476
- kl_divergence_estimate = [rs [start : ends [ix ]] for ix , rs in enumerate (kl_divergence_estimate )]
450
+ ends = start + attention_mask [:, start :].sum (1 )
451
+
452
+ # Get the logprobs and values, for tokens that are not padding
453
+ # or beginning of sequences tokens. These are from the model (not the reference model)
454
+ all_values = [values [ix , start : ends [ix ]] for ix in range (n_samples )]
455
+ all_logprobs = [logprobs [ix , start : ends [ix ]] for ix in range (n_samples )]
456
+
457
+ log_ratio = (logprobs - ref_logprobs ) * attention_mask [:, :- 1 ].cpu ()
458
+ self .mean_kl = (log_ratio .exp () - 1 - log_ratio ).mean ().to (device )
459
+ kl_penalty = self .kl_ctl .value * - log_ratio
460
+ kl_penalty = [xs [start : ends [ix ]] for ix , xs in enumerate (kl_penalty )]
477
461
478
462
rollout_count = 0
479
463
480
464
for sample_idx in range (n_samples ):
481
- sample_kl_divergence_estimate = kl_divergence_estimate [sample_idx ]
482
-
483
- if len (sample_kl_divergence_estimate ) == 0 or len (all_logprobs [sample_idx ]) == 0 :
465
+ if len (kl_penalty [sample_idx ]) == 0 or len (all_logprobs [sample_idx ]) == 0 :
484
466
continue
485
467
486
- rewards = sample_kl_divergence_estimate
468
+ rewards = kl_penalty [ sample_idx ]
487
469
rewards [- 1 ] += scores [sample_idx ].cpu ()
488
470
489
471
ppo_rl_elements .append (
@@ -502,6 +484,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
502
484
tbar .update (min (rollout_count , num_rollouts ))
503
485
tbar .close ()
504
486
487
+ if torch .distributed .is_initialized ():
488
+ torch .distributed .all_reduce (self .mean_kl , torch .distributed .ReduceOp .AVG )
489
+
490
+ stats ["policy/sqrt_kl" ] = torch .sqrt (self .mean_kl )
505
491
stats ["kl_ctl_value" ] = self .kl_ctl .value
506
492
stats ["time/exp" ] = exp_time
507
493
0 commit comments