Skip to content

Commit 0ec3f57

Browse files
committed
fix
1 parent 79836e0 commit 0ec3f57

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler:
785785
def _get_last_hidden_state(
786786
self,
787787
model,
788+
unwrapped_model,
788789
input_ids,
789790
attention_mask,
790791
logits_to_keep,
@@ -821,6 +822,7 @@ def _get_last_hidden_state(
821822
# This goes through gradient checkpointing and other training wrappers
822823
# Note: With logits_to_keep, the model only computes logits for the last N tokens,
823824
# so the full [B, T, V] tensor is NOT materialized - only [B, logits_to_keep, V]
825+
# This preserves training wrappers while minimizing logits computation
824826
outputs = model(**model_inputs)
825827
last_hidden_state = outputs.hidden_states[-1]
826828
# Exclude the last value: it corresponds to the next token pred
@@ -1020,10 +1022,12 @@ def _move_model_to_vllm(self):
10201022
self._sync_fsdp2_params_to_vllm(self.model)
10211023
else:
10221024
# DeepSpeed ZeRO-3 with PEFT
1023-
for name, param in self.model.named_parameters():
1025+
# When using liger kernel, unwrap the model to get actual parameters for weight syncing
1026+
model_for_sync = self.accelerator.unwrap_model(self.model) if self.use_liger_kernel else self.model
1027+
for name, param in model_for_sync.named_parameters():
10241028
# When using PEFT, we need to recover the original parameter name and discard some parameters
10251029
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
1026-
if self.model.prefix in name:
1030+
if model_for_sync.prefix in name:
10271031
continue
10281032
# When module to save, remove its prefix and discard the original module
10291033
if "original_module" in name:
@@ -1048,7 +1052,11 @@ def _move_model_to_vllm(self):
10481052
elif fsdp_version == 2:
10491053
self._sync_fsdp2_params_to_vllm(self.model)
10501054
else:
1051-
for name, param in self.model.named_parameters():
1055+
# For non-FSDP, non-PEFT models
1056+
# When using liger kernel, unwrap the model to get actual parameters for weight syncing
1057+
# This ensures we access the correct parameters in both server and colocate modes
1058+
model_for_sync = self.accelerator.unwrap_model(self.model) if self.use_liger_kernel else self.model
1059+
for name, param in model_for_sync.named_parameters():
10521060
name = self._fix_param_name_to_vllm(name)
10531061
with gather_if_zero3([param]):
10541062
if self.vllm_mode == "server" and self.accelerator.is_main_process:
@@ -1798,11 +1806,12 @@ def compute_liger_loss(self, model, unwrapped_model, inputs):
17981806
# Use unwrapped_model for lm_head access (handles FSDP properly)
17991807
lm_head_model = unwrapped_model.base_model.model if is_peft_model(unwrapped_model) else unwrapped_model
18001808

1801-
# Get the last hidden state using the wrapped model's inner transformer
1809+
# Get the last hidden state using the wrapped model to preserve training wrappers
18021810
# This calls model.model (e.g., Qwen2Model) to avoid materializing the full logits tensor
18031811
# Using the wrapped model preserves training wrappers (gradient checkpointing, etc.)
18041812
last_hidden_state = self._get_last_hidden_state(
18051813
model, # Use the wrapped model to preserve training wrappers
1814+
unwrapped_model,
18061815
input_ids,
18071816
attention_mask,
18081817
logits_to_keep,

0 commit comments

Comments
 (0)