Skip to content

Commit 9ece8d1

Browse files
authored
feat: Add training entropy metric
1 parent d96e539 commit 9ece8d1

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

src/art/unsloth/train.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def compute_loss(
8383
seq_len % chunk_size == 0
8484
), f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
8585
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
86-
new_logprobs = calculate_logprobs(
86+
new_logprobs, entropies = calculate_logprobs(
8787
autocast_dtype,
8888
trainer,
8989
inputs["tokens"],
@@ -94,7 +94,7 @@ def compute_loss(
9494
reference_logprobs=False,
9595
)
9696
if config.beta > 0.0:
97-
ref_logprobs = calculate_logprobs(
97+
ref_logprobs, _ = calculate_logprobs(
9898
autocast_dtype,
9999
trainer,
100100
inputs["tokens"],
@@ -143,9 +143,14 @@ def compute_loss(
143143
kl_div = kl_div * weights * assistant_mask
144144
mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6)
145145
mean_kl = kl_div.sum() / (assistant_mask.sum() + 1e-6)
146+
147+
# Compute mean entropy for the current step
148+
shifted_entropies = shift_tensor(entropies, 0.0)
149+
mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
146150

147151
trainer._metrics["learning_rate"].append(config.learning_rate)
148152
trainer._metrics["policy_loss"].append(mean_policy_loss.item())
153+
trainer._metrics["entropy"].append(mean_entropy.item())
149154
if config.beta > 0.0:
150155
trainer._metrics["kl_div"].append(mean_kl.item())
151156
return mean_policy_loss + config.beta * mean_kl
@@ -235,7 +240,7 @@ def calculate_logprobs(
235240
lm_head_t: torch.Tensor,
236241
chunk_size: int,
237242
reference_logprobs: bool,
238-
) -> torch.Tensor: # Returns shape [B, S]
243+
) -> tuple[torch.Tensor, torch.Tensor]: # Returns (log_probs, entropy) both shape [B, S]
239244
with (
240245
torch.amp.autocast_mode.autocast(device_type="cuda", dtype=autocast_dtype),
241246
torch.inference_mode() if reference_logprobs else nullcontext(),
@@ -258,14 +263,19 @@ def _calculate_logprobs(
258263
hidden_states: torch.Tensor, # Shape [B, S, H]
259264
next_input_ids: torch.Tensor, # Shape [B, S]
260265
chunk_size: int,
261-
) -> torch.Tensor: # Returns shape [B, S]
266+
) -> tuple[torch.Tensor, torch.Tensor]: # Returns (log_probs, entropy) both shape [B, S]
262267
batch_size, seq_len, _ = hidden_states.shape
263268
# Output shape is [B, S]
264269
log_probs = torch.empty(
265270
(batch_size, seq_len),
266271
dtype=hidden_states.dtype,
267272
device=hidden_states.device,
268273
)
274+
entropy = torch.empty(
275+
(batch_size, seq_len),
276+
dtype=hidden_states.dtype,
277+
device=hidden_states.device,
278+
)
269279
# Ensure lm_head_t is in the same dtype as hidden_states
270280
lm_head_t = lm_head_t.to(hidden_states.dtype)
271281

@@ -281,15 +291,25 @@ def _calculate_logprobs(
281291
) # [B, chunk_size]
282292
chunk_logsumexp = torch.logsumexp(chunk_logits, dim=-1) # [B, chunk_size]
283293
log_probs[:, i : i + chunk_size] = chunk_selected_logits - chunk_logsumexp
294+
295+
# Compute entropy for the chunk
296+
log_probs_full = chunk_logits - chunk_logsumexp.unsqueeze(-1)
297+
chunk_entropy = (-torch.exp(log_probs_full) * log_probs_full).sum(
298+
dim=-1
299+
) # [B, chunk_size]
300+
entropy[:, i : i + chunk_size] = chunk_entropy
301+
284302
del (
285303
chunk_hs,
286304
chunk_input_ids,
287305
chunk_logits,
288306
chunk_selected_logits,
289307
chunk_logsumexp,
308+
log_probs_full,
309+
chunk_entropy,
290310
)
291311
del hidden_states
292-
return log_probs
312+
return log_probs, entropy
293313

294314

295315
def shift_tensor(tensor: torch.Tensor, pad: int | float | bool) -> torch.Tensor:

0 commit comments

Comments
 (0)