@@ -305,7 +305,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
305305 student_outputs = base_student (
306306 input_ids = inputs ["input_ids" ],
307307 attention_mask = inputs ["attention_mask" ],
308- output_hidden_states = True ,
309308 use_cache = False ,
310309 )
311310
@@ -321,13 +320,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
321320 teacher_outputs = base_teacher (
322321 input_ids = inputs ["input_ids" ],
323322 attention_mask = inputs ["attention_mask" ],
324- output_hidden_states = True ,
325323 use_cache = False ,
326324 )
327325
328326 # hidden states (shifted)
329- student_hidden = student_outputs .last_hidden_state [:, :- 1 ].contiguous ()
330- teacher_hidden = teacher_outputs .last_hidden_state [:, :- 1 ].contiguous ()
327+ student_hidden = student_outputs .last_hidden_state [:, :- 1 ]
328+ teacher_hidden = teacher_outputs .last_hidden_state [:, :- 1 ]
329+
330+ # Release full outputs to free memory
331+ del student_outputs , teacher_outputs
331332
332333 # labels mask and labels (shifted)
333334 labels_mask = inputs ["labels" ] != - 100
@@ -336,6 +337,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
336337 )
337338 true_labels = masked_input_ids [:, 1 :].contiguous ()
338339
340+ # Release intermediate tensors
341+ del labels_mask , masked_input_ids
342+
339343 # heads
340344 student_head = unwrapped_student .get_output_embeddings ()
341345 teacher_head = unwrapped_teacher .get_output_embeddings ()
@@ -350,6 +354,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
350354 student_bias = getattr (student_head , "bias" , None ),
351355 teacher_bias = getattr (teacher_head , "bias" , None ),
352356 )
357+
358+ # Release hidden states after loss computation
359+ del student_hidden , teacher_hidden , true_labels
353360 else :
354361 # compute student output
355362 student_outputs = model (
0 commit comments