Skip to content

Commit 095544e

Browse files
qgallouedeckashif
andauthored
Fix GKD Liger memory spike (huggingface#4140)
Co-authored-by: Kashif Rasul <[email protected]>
1 parent 06c059b commit 095544e

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

tests/test_gkd_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import os
1616

17-
import pytest
1817
import torch
1918
import torch.nn.functional as F
2019
from datasets import load_dataset
@@ -239,7 +238,6 @@ def test_gkd_trainer(self):
239238
assert "model.safetensors" in os.listdir(self.tmp_dir + "/checkpoint-2")
240239

241240
@require_liger_kernel
242-
@pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.")
243241
def test_gkd_trainer_with_liger(self):
244242
training_args = GKDConfig(
245243
output_dir=self.tmp_dir,

trl/trainer/gkd_trainer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
305305
student_outputs = base_student(
306306
input_ids=inputs["input_ids"],
307307
attention_mask=inputs["attention_mask"],
308-
output_hidden_states=True,
309308
use_cache=False,
310309
)
311310

@@ -321,13 +320,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
321320
teacher_outputs = base_teacher(
322321
input_ids=inputs["input_ids"],
323322
attention_mask=inputs["attention_mask"],
324-
output_hidden_states=True,
325323
use_cache=False,
326324
)
327325

328326
# hidden states (shifted)
329-
student_hidden = student_outputs.last_hidden_state[:, :-1].contiguous()
330-
teacher_hidden = teacher_outputs.last_hidden_state[:, :-1].contiguous()
327+
student_hidden = student_outputs.last_hidden_state[:, :-1]
328+
teacher_hidden = teacher_outputs.last_hidden_state[:, :-1]
329+
330+
# Release full outputs to free memory
331+
del student_outputs, teacher_outputs
331332

332333
# labels mask and labels (shifted)
333334
labels_mask = inputs["labels"] != -100
@@ -336,6 +337,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
336337
)
337338
true_labels = masked_input_ids[:, 1:].contiguous()
338339

340+
# Release intermediate tensors
341+
del labels_mask, masked_input_ids
342+
339343
# heads
340344
student_head = unwrapped_student.get_output_embeddings()
341345
teacher_head = unwrapped_teacher.get_output_embeddings()
@@ -350,6 +354,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
350354
student_bias=getattr(student_head, "bias", None),
351355
teacher_bias=getattr(teacher_head, "bias", None),
352356
)
357+
358+
# Release hidden states after loss computation
359+
del student_hidden, teacher_hidden, true_labels
353360
else:
354361
# compute student output
355362
student_outputs = model(

0 commit comments

Comments
 (0)