Skip to content

Commit 79e9af0

Browse files
committed
Main contributions: 1) supporting multi-gpu inference by allocating K vllm instances on K training process; 2) add an arg vllm_worker_num to control the K; 3) avoiding OOM of the test_training_vllm_guided_decoding by introducing vllm_gpu_memory_utilization
1 parent 013d360 commit 79e9af0

File tree

3 files changed

+52
-19
lines changed

3 files changed

+52
-19
lines changed

tests/test_grpo_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,7 @@ def test_training_vllm_guided_decoding(self):
781781
use_vllm=True,
782782
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
783783
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
784+
vllm_gpu_memory_utilization=0.4, # reduce the memory utilization rate to reduce memory usage
784785
)
785786
trainer = GRPOTrainer(
786787
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM

trl/trainer/grpo_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ class GRPOConfig(TrainingArguments):
176176
"(`pip install vllm`)."
177177
},
178178
)
179+
vllm_worker_num: Optional[int] = field(
180+
default=1,
181+
metadata={
182+
"help": "The number of vllm works used for inference. Notably, this number should be less or equal to "
183+
"the number of process used for distributed training (i.e., `--num_processes`). In addition, please "
184+
"ensure that `vllm_worker_num + num_processes <= world_size`."
185+
},
186+
)
179187
vllm_device: Optional[str] = field(
180188
default="auto",
181189
metadata={

trl/trainer/grpo_trainer.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torch.utils.data
2424
import transformers
25-
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
25+
from accelerate.utils import gather, gather_object, is_peft_model, set_seed
2626
from accelerate.utils.other import is_compiled_module
2727
from datasets import Dataset, IterableDataset
2828
from packaging import version
@@ -435,14 +435,19 @@ def data_collator(features): # No data collation is needed in GRPO
435435
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
436436
"`pip install vllm` to use it."
437437
)
438-
439-
if self.accelerator.is_main_process:
438+
if self.args.vllm_worker_num > self.accelerator.num_processes:
439+
raise ValueError(
440+
f"The requested number of workers for vllm (i.e., {self.args.vllm_worker_num}) should be no larger "
441+
f"than the number of your training processes (i.e., {self.accelerator.num_processes}) "
442+
)
443+
rank = self.accelerator.process_index
444+
if rank < self.args.vllm_worker_num:
440445
vllm_device = self.args.vllm_device
441446
if vllm_device == "auto":
442447
if torch.cuda.device_count() == 1:
443448
vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
444449
else:
445-
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
450+
vllm_device = f"cuda:{self.accelerator.num_processes+rank}"
446451
# Check that the requested device is available
447452
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
448453
raise ValueError(
@@ -459,14 +464,17 @@ def data_collator(features): # No data collation is needed in GRPO
459464
"If this is intentional, you may ignore this warning but should adjust "
460465
"`vllm_gpu_memory_utilization` accordingly."
461466
)
462-
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
463-
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
464-
# setting (profiling_patch).
467+
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) initilizing vLLM
468+
# model on the desired device (world_size_patch, get_rank_patch, new_group_patch, get_backend_patch) without
469+
# conflicts and (2) avoid a test that is not designed for our setting (profiling_patch).
465470
world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
466471
profiling_patch = patch(
467472
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
468473
)
469-
with world_size_patch, profiling_patch:
474+
get_rank_patch = patch("torch.distributed.get_rank", return_value=0)
475+
new_group_patch = patch("torch.distributed.new_group", return_value=type("DummyGroup", (), {})())
476+
get_backend_patch = patch("torch.distributed.get_backend", lambda _: "gloo")
477+
with world_size_patch, get_rank_patch, new_group_patch, get_backend_patch, profiling_patch:
470478
self.llm = LLM(
471479
model=model.name_or_path,
472480
device=vllm_device,
@@ -478,6 +486,7 @@ def data_collator(features): # No data collation is needed in GRPO
478486
enable_prefix_caching=True,
479487
max_model_len=self.args.vllm_max_model_len,
480488
)
489+
self.llm_device = vllm_device
481490

482491
# Guided decoding, if enabled
483492
if args.vllm_guided_decoding_regex is not None:
@@ -644,14 +653,26 @@ def _move_model_to_vllm(self):
644653
}
645654
else:
646655
state_dict = unwrapped_model.state_dict()
647-
if self.accelerator.is_main_process:
656+
if self.accelerator.process_index < self.args.vllm_worker_num:
648657
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
649658
llm_model.load_weights(state_dict.items())
650659
# Unmerge the adapter to restore the model to its original state.
651660
# This must be done after loading weights to ensure they correspond to the merged state.
652661
if is_peft_model(unwrapped_model):
653662
unwrapped_model.unmerge_adapter()
654663

664+
def _split_vllm_inputs(self, ordered_set_of_prompts, n_splits):
665+
size = (len(ordered_set_of_prompts) + 1) // n_splits
666+
data = [ordered_set_of_prompts[size * i : size * (i + 1)] for i in range(n_splits)]
667+
return data
668+
669+
def _vllm_generation(self, ordered_set_of_prompts):
670+
with torch.cuda.device(self.llm_device):
671+
worker_outputs = self.llm.generate(
672+
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
673+
)
674+
return worker_outputs
675+
655676
@profiling_decorator
656677
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
657678
mode = "eval" if self.control.should_evaluate else "train"
@@ -685,30 +706,33 @@ def _generate_and_score_completions(
685706

686707
# Generate completions using either vLLM or regular generation
687708
if self.args.use_vllm:
709+
vllm_index = self.accelerator.process_index
710+
n_vllms = self.args.vllm_worker_num
688711
# First, have main process load weights if needed
689712
if self.state.global_step != self._last_loaded_step:
690713
self._move_model_to_vllm()
691714
self._last_loaded_step = self.state.global_step
692715

693716
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
694717
all_prompts_text = gather_object(prompts_text)
695-
if self.accelerator.is_main_process:
718+
if vllm_index < n_vllms:
696719
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
697720
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
698721
# prompt individually.
699722
ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
700-
all_outputs = self.llm.generate(
701-
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
702-
)
703-
completion_ids = []
704-
for outputs in all_outputs:
723+
# spliting the prompts into N subsets according to the number of vllm processes
724+
split_ordered_prompts = self._split_vllm_inputs(ordered_set_of_prompts, n_vllms)
725+
subset_results = self._vllm_generation(split_ordered_prompts[vllm_index])
726+
completion_ids_subset = []
727+
for outputs in subset_results:
705728
for output in outputs.outputs:
706-
completion_ids.append(output.token_ids)
729+
completion_ids_subset.append(output.token_ids)
707730
else:
708-
completion_ids = [None] * len(all_prompts_text)
709-
# Broadcast the completions from the main process to all processes, ensuring each process receives its
731+
completion_ids_subset = [None]
732+
# gather the completions from vllm processes to all processes, ensuring each process receives its
710733
# corresponding slice.
711-
completion_ids = broadcast_object_list(completion_ids, from_process=0)
734+
completion_ids = gather_object(completion_ids_subset)
735+
completion_ids = [seq for seq in completion_ids if seq is not None]
712736
process_slice = slice(
713737
self.accelerator.process_index * len(prompts),
714738
(self.accelerator.process_index + 1) * len(prompts),

0 commit comments

Comments
 (0)