@@ -453,13 +453,15 @@ def data_collator(features): # No data collation is needed in GRPO
453453 )
454454
455455 if self .accelerator .is_main_process :
456- self .vllm_client = VLLMClient (args .vllm_server_host , args .vllm_server_port , args .vllm_server_timeout )
456+ self .vllm_client = VLLMClient (host = args .vllm_server_host , server_port = args .vllm_server_port , connection_timeout = args .vllm_server_timeout )
457457
458458 # vLLM specific sampling arguments
459459 self .guided_decoding_regex = args .vllm_guided_decoding_regex
460460
461461 self ._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
462-
462+ # when using lora, the initial model may be slightly different from the one in vllm
463+ # so we force loading the model once at the beginning of training
464+ self .vllm_never_synced_before = True
463465 # When using vLLM, the main process is responsible for loading the model weights. This can cause process
464466 # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
465467 # synchronize all processes after vLLM has been fully initialized.
@@ -602,6 +604,11 @@ def _move_model_to_vllm(self):
602604 zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin .zero_stage == 3
603605 gather_if_zero3 = deepspeed .zero .GatheredParameters if zero_stage_3 else nullcontext
604606
607+ # Zero3 + peft + very_large_model special combination
608+ if zero_stage_3 and is_peft_model (self .model ):
609+ self ._move_model_to_vllm_especially_for_zero3_plus_peft ()
610+ return # early return
611+
605612 if is_peft_model (self .model ):
606613 # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
607614 # adapters in a sharded manner is not supported.
@@ -635,6 +642,80 @@ def _move_model_to_vllm(self):
635642 # Reset cache on main process
636643 if self .accelerator .is_main_process :
637644 self .vllm_client .reset_prefix_cache ()
645+ return
646+
647+ def _move_model_to_vllm_especially_for_zero3_plus_peft (self ):
648+ """
649+ Why this special method is needed for the special (zero3 + peft + very_large_model) combination? Why this combination is very nasty?
650+
651+ 1. `model.merge_adapter()` must be executed after `unwrap_model_for_generation`
652+ 2. `unwrap_model_for_generation` can cause GPU OOM if model is very large
653+ 3. Usually, GPU OOM can be resolve by setting `gather_deepspeed3_params=False`
654+ 4. But guess what? `gather_deepspeed3_params=False` cause error for `model.merge_adapter()`
655+
656+ So, we rewrite the `merge_adapter` code, to avoid GPU OOM, we have to merge the adapter module by module. The basic idea is:
657+
658+ 1. We first deal with lora weights only
659+ 2. Then we deal with whatever params that are left behind
660+ """
661+ from peft .tuners .tuners_utils import BaseTunerLayer , onload_layer
662+
663+ # Update the weights in vLLM. When using DeepSpeed ZeRO Stage 3, we need to gather the parameters before updating the weights.
664+ deepspeed_plugin = self .accelerator .state .deepspeed_plugin
665+ zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin .zero_stage == 3
666+ warning = "This special `_move_model_to_vllm` is designed only for zero3 + peft + very_large_model. Only a nasty problem needs a nasty solution like this."
667+ if not zero_stage_3 or not is_peft_model (self .model ):
668+ raise RuntimeError (warning )
669+
670+ parameter_to_transfer_map_id_name = {id (param ): name for name , param in self .model .named_parameters ()}
671+ # 1. We first deal with lora weights only, it is very very nasty 😂
672+ for module in self .model .modules ():
673+ # This return not only leaf modules, but also the parent module
674+ if isinstance (module , BaseTunerLayer ):
675+ # do not know what this `onload_layer` thing does, but it seems important
676+ with onload_layer (module ):
677+ # get all the parameters of this small module
678+ param_list_of_this_small_module = [param for relative_name , param in module .named_parameters ()]
679+ with deepspeed .zero .GatheredParameters (param_list_of_this_small_module ) if zero_stage_3 else nullcontext ():
680+ # we must `GatheredParameters` before module.merge
681+ module .merge (adapter_names = None )
682+ for relative_name , param in module .named_parameters ():
683+ param_python_id = id (param )
684+ # get the absolute name of the parameter
685+ # absolute_name = f"{module.prefix}.{relative_name}"
686+ absolute_name = parameter_to_transfer_map_id_name [param_python_id ] # f"{module.prefix}.{relative_name}"
687+ # one less weight to worry about
688+ parameter_to_transfer_map_id_name .pop (param_python_id )
689+ # only the main process is responsible for transferring weights
690+ if self .accelerator .is_main_process :
691+ # When using PEFT, we need to recover the original parameter name and discard some parameters
692+ absolute_name = absolute_name .removeprefix ("base_model.model." ).replace (".base_layer" , "" )
693+ if self .model .prefix in absolute_name :
694+ continue
695+ # When module to save, remove its prefix and discard the original module
696+ if "original_module" in absolute_name :
697+ continue
698+ absolute_name = absolute_name .replace ("modules_to_save.default." , "" )
699+ # Finally it is time to be transferred. 🌟
700+ # print(f"Transferring: {absolute_name}")
701+ self .vllm_client .update_named_param (absolute_name , param .data )
702+ # and of course, we must unmerge before exit `GatheredParameters`
703+ module .unmerge ()
704+ # 2. Then we deal with whatever params that are left behind
705+ remaining_param_list = [(name , param ) for name , param in self .model .named_parameters () if name in parameter_to_transfer_map_id_name .values ()]
706+ for name , param in remaining_param_list :
707+ # print(f"Transferring Part2: {name}")
708+ with deepspeed .zero .GatheredParameters ([param ]) if zero_stage_3 else nullcontext ():
709+ if self .accelerator .is_main_process :
710+ name = name .removeprefix ("base_model.model." ).replace (".base_layer" , "" )
711+ if self .model .prefix in name : raise RuntimeError ("Something must be wrong because we assume lora-related weights are already transferred." )
712+ if "original_module" in name : raise RuntimeError ("Something must be wrong because we assume lora-related weights are already transferred." )
713+ if ("modules_to_save.default." in name ): raise RuntimeError ("Something must be wrong because we assume lora-related weights are already transferred." )
714+ self .vllm_client .update_named_param (name , param .data )
715+
716+ # Reset the prefix cache after updating weights
717+ if self .accelerator .is_main_process :
718+ self .vllm_client .reset_prefix_cache ()
638719
639720 @profiling_decorator
640721 def _prepare_inputs (self , inputs : dict [str , Union [torch .Tensor , Any ]]) -> dict [str , Union [torch .Tensor , Any ]]:
@@ -670,7 +751,8 @@ def _generate_and_score_completions(
670751 # Generate completions using either vLLM or regular generation
671752 if self .args .use_vllm :
672753 # First, have main process load weights if needed
673- if self .state .global_step != self ._last_loaded_step :
754+ if self .state .global_step != self ._last_loaded_step or self .vllm_never_synced_before :
755+ self .vllm_never_synced_before = False
674756 self ._move_model_to_vllm ()
675757 self ._last_loaded_step = self .state .global_step
676758
0 commit comments