@@ -785,6 +785,7 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler:
785785 def _get_last_hidden_state (
786786 self ,
787787 model ,
788+ unwrapped_model ,
788789 input_ids ,
789790 attention_mask ,
790791 logits_to_keep ,
@@ -821,6 +822,7 @@ def _get_last_hidden_state(
821822 # This goes through gradient checkpointing and other training wrappers
822823 # Note: With logits_to_keep, the model only computes logits for the last N tokens,
823824 # so the full [B, T, V] tensor is NOT materialized - only [B, logits_to_keep, V]
825+ # This preserves training wrappers while minimizing logits computation
824826 outputs = model (** model_inputs )
825827 last_hidden_state = outputs .hidden_states [- 1 ]
826828 # Exclude the last value: it corresponds to the next token pred
@@ -1020,10 +1022,12 @@ def _move_model_to_vllm(self):
10201022 self ._sync_fsdp2_params_to_vllm (self .model )
10211023 else :
10221024 # DeepSpeed ZeRO-3 with PEFT
1023- for name , param in self .model .named_parameters ():
1025+ # When using liger kernel, unwrap the model to get actual parameters for weight syncing
1026+ model_for_sync = self .accelerator .unwrap_model (self .model ) if self .use_liger_kernel else self .model
1027+ for name , param in model_for_sync .named_parameters ():
10241028 # When using PEFT, we need to recover the original parameter name and discard some parameters
10251029 name = name .removeprefix ("base_model.model." ).replace (".base_layer" , "" )
1026- if self . model .prefix in name :
1030+ if model_for_sync .prefix in name :
10271031 continue
10281032 # When module to save, remove its prefix and discard the original module
10291033 if "original_module" in name :
@@ -1048,7 +1052,11 @@ def _move_model_to_vllm(self):
10481052 elif fsdp_version == 2 :
10491053 self ._sync_fsdp2_params_to_vllm (self .model )
10501054 else :
1051- for name , param in self .model .named_parameters ():
1055+ # For non-FSDP, non-PEFT models
1056+ # When using liger kernel, unwrap the model to get actual parameters for weight syncing
1057+ # This ensures we access the correct parameters in both server and colocate modes
1058+ model_for_sync = self .accelerator .unwrap_model (self .model ) if self .use_liger_kernel else self .model
1059+ for name , param in model_for_sync .named_parameters ():
10521060 name = self ._fix_param_name_to_vllm (name )
10531061 with gather_if_zero3 ([param ]):
10541062 if self .vllm_mode == "server" and self .accelerator .is_main_process :
@@ -1798,11 +1806,12 @@ def compute_liger_loss(self, model, unwrapped_model, inputs):
17981806 # Use unwrapped_model for lm_head access (handles FSDP properly)
17991807 lm_head_model = unwrapped_model .base_model .model if is_peft_model (unwrapped_model ) else unwrapped_model
18001808
1801- # Get the last hidden state using the wrapped model's inner transformer
1809+ # Get the last hidden state using the wrapped model to preserve training wrappers
18021810 # This calls model.model (e.g., Qwen2Model) to avoid materializing the full logits tensor
18031811 # Using the wrapped model preserves training wrappers (gradient checkpointing, etc.)
18041812 last_hidden_state = self ._get_last_hidden_state (
18051813 model , # Use the wrapped model to preserve training wrappers
1814+ unwrapped_model ,
18061815 input_ids ,
18071816 attention_mask ,
18081817 logits_to_keep ,
0 commit comments