2222import torch
2323import torch .utils .data
2424import transformers
25- from accelerate .utils import broadcast_object_list , gather , gather_object , is_peft_model , set_seed
25+ from accelerate .utils import gather , gather_object , is_peft_model , set_seed
2626from accelerate .utils .other import is_compiled_module
2727from datasets import Dataset , IterableDataset
2828from packaging import version
@@ -435,14 +435,19 @@ def data_collator(features): # No data collation is needed in GRPO
435435 "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
436436 "`pip install vllm` to use it."
437437 )
438-
439- if self .accelerator .is_main_process :
438+ if self .args .vllm_worker_num > self .accelerator .num_processes :
439+ raise ValueError (
440+ f"The requested number of workers for vllm (i.e., { self .args .vllm_worker_num } ) should be no larger "
441+ f"than the number of your training processes (i.e., { self .accelerator .num_processes } ) "
442+ )
443+ rank = self .accelerator .process_index
444+ if rank < self .args .vllm_worker_num :
440445 vllm_device = self .args .vllm_device
441446 if vllm_device == "auto" :
442447 if torch .cuda .device_count () == 1 :
443448 vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
444449 else :
445- vllm_device = f"cuda:{ self .accelerator .num_processes } " # take the next GPU idx
450+ vllm_device = f"cuda:{ self .accelerator .num_processes + rank } "
446451 # Check that the requested device is available
447452 if vllm_device .split (":" )[0 ] == "cuda" and int (vllm_device .split (":" )[1 ]) >= torch .cuda .device_count ():
448453 raise ValueError (
@@ -459,14 +464,17 @@ def data_collator(features): # No data collation is needed in GRPO
459464 "If this is intentional, you may ignore this warning but should adjust "
460465 "`vllm_gpu_memory_utilization` accordingly."
461466 )
462- # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
463- # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
464- # setting (profiling_patch).
467+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) initilizing vLLM
468+ # model on the desired device (world_size_patch, get_rank_patch, new_group_patch, get_backend_patch) without
469+ # conflicts and (2) avoid a test that is not designed for our setting (profiling_patch).
465470 world_size_patch = patch ("torch.distributed.get_world_size" , return_value = 1 )
466471 profiling_patch = patch (
467472 "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling" , return_value = None
468473 )
469- with world_size_patch , profiling_patch :
474+ get_rank_patch = patch ("torch.distributed.get_rank" , return_value = 0 )
475+ new_group_patch = patch ("torch.distributed.new_group" , return_value = type ("DummyGroup" , (), {})())
476+ get_backend_patch = patch ("torch.distributed.get_backend" , lambda _ : "gloo" )
477+ with world_size_patch , get_rank_patch , new_group_patch , get_backend_patch , profiling_patch :
470478 self .llm = LLM (
471479 model = model .name_or_path ,
472480 device = vllm_device ,
@@ -478,6 +486,7 @@ def data_collator(features): # No data collation is needed in GRPO
478486 enable_prefix_caching = True ,
479487 max_model_len = self .args .vllm_max_model_len ,
480488 )
489+ self .llm_device = vllm_device
481490
482491 # Guided decoding, if enabled
483492 if args .vllm_guided_decoding_regex is not None :
@@ -644,14 +653,26 @@ def _move_model_to_vllm(self):
644653 }
645654 else :
646655 state_dict = unwrapped_model .state_dict ()
647- if self .accelerator .is_main_process :
656+ if self .accelerator .process_index < self . args . vllm_worker_num :
648657 llm_model = self .llm .llm_engine .model_executor .driver_worker .model_runner .model
649658 llm_model .load_weights (state_dict .items ())
650659 # Unmerge the adapter to restore the model to its original state.
651660 # This must be done after loading weights to ensure they correspond to the merged state.
652661 if is_peft_model (unwrapped_model ):
653662 unwrapped_model .unmerge_adapter ()
654663
664+ def _split_vllm_inputs (self , ordered_set_of_prompts , n_splits ):
665+ size = (len (ordered_set_of_prompts ) + 1 ) // n_splits
666+ data = [ordered_set_of_prompts [size * i : size * (i + 1 )] for i in range (n_splits )]
667+ return data
668+
669+ def _vllm_generation (self , ordered_set_of_prompts ):
670+ with torch .cuda .device (self .llm_device ):
671+ worker_outputs = self .llm .generate (
672+ ordered_set_of_prompts , sampling_params = self .sampling_params , use_tqdm = False
673+ )
674+ return worker_outputs
675+
655676 @profiling_decorator
656677 def _prepare_inputs (self , inputs : dict [str , Union [torch .Tensor , Any ]]) -> dict [str , Union [torch .Tensor , Any ]]:
657678 mode = "eval" if self .control .should_evaluate else "train"
@@ -685,30 +706,33 @@ def _generate_and_score_completions(
685706
686707 # Generate completions using either vLLM or regular generation
687708 if self .args .use_vllm :
709+ vllm_index = self .accelerator .process_index
710+ n_vllms = self .args .vllm_worker_num
688711 # First, have main process load weights if needed
689712 if self .state .global_step != self ._last_loaded_step :
690713 self ._move_model_to_vllm ()
691714 self ._last_loaded_step = self .state .global_step
692715
693716 # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
694717 all_prompts_text = gather_object (prompts_text )
695- if self . accelerator . is_main_process :
718+ if vllm_index < n_vllms :
696719 # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
697720 # num_generations outputs for each one. This is faster than generating outputs for each duplicate
698721 # prompt individually.
699722 ordered_set_of_prompts = list (dict .fromkeys (all_prompts_text ))
700- all_outputs = self . llm . generate (
701- ordered_set_of_prompts , sampling_params = self .sampling_params , use_tqdm = False
702- )
703- completion_ids = []
704- for outputs in all_outputs :
723+ # spliting the prompts into N subsets according to the number of vllm processes
724+ split_ordered_prompts = self ._split_vllm_inputs ( ordered_set_of_prompts , n_vllms )
725+ subset_results = self . _vllm_generation ( split_ordered_prompts [ vllm_index ] )
726+ completion_ids_subset = []
727+ for outputs in subset_results :
705728 for output in outputs .outputs :
706- completion_ids .append (output .token_ids )
729+ completion_ids_subset .append (output .token_ids )
707730 else :
708- completion_ids = [None ] * len ( all_prompts_text )
709- # Broadcast the completions from the main process to all processes, ensuring each process receives its
731+ completion_ids_subset = [None ]
732+ # gather the completions from vllm processes to all processes, ensuring each process receives its
710733 # corresponding slice.
711- completion_ids = broadcast_object_list (completion_ids , from_process = 0 )
734+ completion_ids = gather_object (completion_ids_subset )
735+ completion_ids = [seq for seq in completion_ids if seq is not None ]
712736 process_slice = slice (
713737 self .accelerator .process_index * len (prompts ),
714738 (self .accelerator .process_index + 1 ) * len (prompts ),
0 commit comments