@@ -160,6 +160,25 @@ def __len__(self) -> int:
160
160
return self .num_samples * self .mini_repeat_count * self .repeat_count
161
161
162
162
163
+ # torch.nanstd doesn't exist, so we define it here
164
+ def nanstd (tensor : torch .Tensor ) -> torch .Tensor :
165
+ """
166
+ Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.
167
+
168
+ Args:
169
+ tensor (`torch.Tensor`):
170
+ Input tensor of shape `(N,)`.
171
+
172
+ Returns:
173
+ `torch.Tensor`:
174
+ Standard deviation of the tensor, ignoring NaNs.
175
+ """
176
+ variance = torch .nanmean ((tensor - torch .nanmean (tensor , keepdim = True )) ** 2 ) # Compute variance ignoring NaNs
177
+ count = torch .sum (~ torch .isnan (tensor )) # Count of non-NaN values
178
+ variance *= count / (count - 1 ) # Bessel's correction
179
+ return torch .sqrt (variance )
180
+
181
+
163
182
class GRPOTrainer (Trainer ):
164
183
"""
165
184
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
@@ -856,8 +875,10 @@ def _generate_and_score_completions(
856
875
reward_func_name = reward_func .__name__
857
876
# Only calculate mean for samples where this reward function was applied (non-NaN values)
858
877
mean_rewards = torch .nanmean (rewards_per_func [:, i ]).item ()
859
- self ._metrics [mode ][f"rewards/{ reward_func_name } " ].append (mean_rewards )
860
- self ._metrics [mode ]["reward" ].append (rewards .mean ().item ())
878
+ self ._metrics [mode ][f"rewards/{ reward_func_name } /mean" ].append (mean_rewards )
879
+ std_rewards = nanstd (rewards_per_func [:, i ]).item ()
880
+ self ._metrics [mode ][f"rewards/{ reward_func_name } /std" ].append (std_rewards )
881
+ self ._metrics [mode ]["reward" ].append (mean_grouped_rewards .mean ().item ())
861
882
self ._metrics [mode ]["reward_std" ].append (std_grouped_rewards .mean ().item ())
862
883
863
884
if self .log_completions and self .state .global_step % self .args .logging_steps == 0 :
@@ -938,7 +959,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
938
959
mean_kl = (per_token_kl * completion_mask ).sum () / completion_mask .sum ()
939
960
self ._metrics [mode ]["kl" ].append (self .accelerator .gather_for_metrics (mean_kl ).mean ().item ())
940
961
941
- is_clipped = (per_token_loss1 < per_token_loss2 ). float ( )
962
+ is_clipped = (coef_1 < ( 1 - self . epsilon_low )) | ( coef_1 > ( 1 + self . epsilon_high ) )
942
963
clip_ratio = (is_clipped * completion_mask ).sum () / completion_mask .sum ()
943
964
self ._metrics [mode ]["clip_ratio" ].append (self .accelerator .gather_for_metrics (clip_ratio ).mean ().item ())
944
965
return loss
0 commit comments