@@ -83,7 +83,7 @@ def compute_loss(
83
83
seq_len % chunk_size == 0
84
84
), f"Sequence length ({ seq_len } ) must be evenly divisible by chunk size ({ chunk_size } )"
85
85
os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] = "1"
86
- new_logprobs = calculate_logprobs (
86
+ new_logprobs , entropies = calculate_logprobs (
87
87
autocast_dtype ,
88
88
trainer ,
89
89
inputs ["tokens" ],
@@ -94,7 +94,7 @@ def compute_loss(
94
94
reference_logprobs = False ,
95
95
)
96
96
if config .beta > 0.0 :
97
- ref_logprobs = calculate_logprobs (
97
+ ref_logprobs , _ = calculate_logprobs (
98
98
autocast_dtype ,
99
99
trainer ,
100
100
inputs ["tokens" ],
@@ -143,9 +143,14 @@ def compute_loss(
143
143
kl_div = kl_div * weights * assistant_mask
144
144
mean_policy_loss = policy_loss .sum () / (assistant_mask .sum () + 1e-6 )
145
145
mean_kl = kl_div .sum () / (assistant_mask .sum () + 1e-6 )
146
+
147
+ # Compute mean entropy for the current step
148
+ shifted_entropies = shift_tensor (entropies , 0.0 )
149
+ mean_entropy = (shifted_entropies * weights * assistant_mask ).sum () / (assistant_mask .sum () + 1e-6 )
146
150
147
151
trainer ._metrics ["learning_rate" ].append (config .learning_rate )
148
152
trainer ._metrics ["policy_loss" ].append (mean_policy_loss .item ())
153
+ trainer ._metrics ["entropy" ].append (mean_entropy .item ())
149
154
if config .beta > 0.0 :
150
155
trainer ._metrics ["kl_div" ].append (mean_kl .item ())
151
156
return mean_policy_loss + config .beta * mean_kl
@@ -235,7 +240,7 @@ def calculate_logprobs(
235
240
lm_head_t : torch .Tensor ,
236
241
chunk_size : int ,
237
242
reference_logprobs : bool ,
238
- ) -> torch .Tensor : # Returns shape [B, S]
243
+ ) -> tuple [ torch .Tensor , torch . Tensor ] : # Returns (log_probs, entropy) both shape [B, S]
239
244
with (
240
245
torch .amp .autocast_mode .autocast (device_type = "cuda" , dtype = autocast_dtype ),
241
246
torch .inference_mode () if reference_logprobs else nullcontext (),
@@ -258,14 +263,19 @@ def _calculate_logprobs(
258
263
hidden_states : torch .Tensor , # Shape [B, S, H]
259
264
next_input_ids : torch .Tensor , # Shape [B, S]
260
265
chunk_size : int ,
261
- ) -> torch .Tensor : # Returns shape [B, S]
266
+ ) -> tuple [ torch .Tensor , torch . Tensor ] : # Returns (log_probs, entropy) both shape [B, S]
262
267
batch_size , seq_len , _ = hidden_states .shape
263
268
# Output shape is [B, S]
264
269
log_probs = torch .empty (
265
270
(batch_size , seq_len ),
266
271
dtype = hidden_states .dtype ,
267
272
device = hidden_states .device ,
268
273
)
274
+ entropy = torch .empty (
275
+ (batch_size , seq_len ),
276
+ dtype = hidden_states .dtype ,
277
+ device = hidden_states .device ,
278
+ )
269
279
# Ensure lm_head_t is in the same dtype as hidden_states
270
280
lm_head_t = lm_head_t .to (hidden_states .dtype )
271
281
@@ -281,15 +291,25 @@ def _calculate_logprobs(
281
291
) # [B, chunk_size]
282
292
chunk_logsumexp = torch .logsumexp (chunk_logits , dim = - 1 ) # [B, chunk_size]
283
293
log_probs [:, i : i + chunk_size ] = chunk_selected_logits - chunk_logsumexp
294
+
295
+ # Compute entropy for the chunk
296
+ log_probs_full = chunk_logits - chunk_logsumexp .unsqueeze (- 1 )
297
+ chunk_entropy = (- torch .exp (log_probs_full ) * log_probs_full ).sum (
298
+ dim = - 1
299
+ ) # [B, chunk_size]
300
+ entropy [:, i : i + chunk_size ] = chunk_entropy
301
+
284
302
del (
285
303
chunk_hs ,
286
304
chunk_input_ids ,
287
305
chunk_logits ,
288
306
chunk_selected_logits ,
289
307
chunk_logsumexp ,
308
+ log_probs_full ,
309
+ chunk_entropy ,
290
310
)
291
311
del hidden_states
292
- return log_probs
312
+ return log_probs , entropy
293
313
294
314
295
315
def shift_tensor (tensor : torch .Tensor , pad : int | float | bool ) -> torch .Tensor :
0 commit comments