Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def test_training_vllm_guided_decoding(self):
use_vllm=True,
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
vllm_gpu_memory_utilization=0.4, # reduce the memory utilization rate to reduce memory usage
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class GRPOConfig(TrainingArguments):
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
ensure that `vllm_worker_num + num_processes <= world_size`.
vllm_worker_num (`int`, *optional*, defaults to `1`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept
the number of process used for distributed training (i.e., `--num_processes`). In addition, please
vllm_device (`str`, *optional*, defaults to `"auto"`):
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
automatically select the next available GPU after the last one used for training. This assumes that
Expand Down Expand Up @@ -176,6 +180,14 @@ class GRPOConfig(TrainingArguments):
"(`pip install vllm`)."
},
)
vllm_worker_num: Optional[int] = field(
default=1,
metadata={
"help": "The number of vllm works used for inference. Notably, this number should be less or equal to "
"the number of process used for distributed training (i.e., `--num_processes`). In addition, please "
"ensure that `vllm_worker_num + num_processes <= world_size`."
},
)
vllm_device: Optional[str] = field(
default="auto",
metadata={
Expand Down
62 changes: 43 additions & 19 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from accelerate.utils import gather, gather_object, is_peft_model, set_seed
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
from packaging import version
Expand Down Expand Up @@ -435,14 +435,19 @@ def data_collator(features): # No data collation is needed in GRPO
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
"`pip install vllm` to use it."
)

if self.accelerator.is_main_process:
if self.args.vllm_worker_num > self.accelerator.num_processes:
raise ValueError(
f"The requested number of workers for vllm (i.e., {self.args.vllm_worker_num}) should be no larger "
f"than the number of your training processes (i.e., {self.accelerator.num_processes}) "
)
rank = self.accelerator.process_index
if rank < self.args.vllm_worker_num:
vllm_device = self.args.vllm_device
if vllm_device == "auto":
if torch.cuda.device_count() == 1:
vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
else:
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
vllm_device = f"cuda:{self.accelerator.num_processes+rank}"
# Check that the requested device is available
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
raise ValueError(
Expand All @@ -459,14 +464,17 @@ def data_collator(features): # No data collation is needed in GRPO
"If this is intentional, you may ignore this warning but should adjust "
"`vllm_gpu_memory_utilization` accordingly."
)
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
# setting (profiling_patch).
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) initilizing vLLM
# model on the desired device (world_size_patch, get_rank_patch, new_group_patch, get_backend_patch) without
# conflicts and (2) avoid a test that is not designed for our setting (profiling_patch).
world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
profiling_patch = patch(
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
)
with world_size_patch, profiling_patch:
get_rank_patch = patch("torch.distributed.get_rank", return_value=0)
new_group_patch = patch("torch.distributed.new_group", return_value=type("DummyGroup", (), {})())
get_backend_patch = patch("torch.distributed.get_backend", lambda _: "gloo")
with world_size_patch, get_rank_patch, new_group_patch, get_backend_patch, profiling_patch:
self.llm = LLM(
model=model.name_or_path,
device=vllm_device,
Expand All @@ -478,6 +486,7 @@ def data_collator(features): # No data collation is needed in GRPO
enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
)
self.llm_device = vllm_device

# Guided decoding, if enabled
if args.vllm_guided_decoding_regex is not None:
Expand Down Expand Up @@ -644,14 +653,26 @@ def _move_model_to_vllm(self):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
if self.accelerator.process_index < self.args.vllm_worker_num:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
# Unmerge the adapter to restore the model to its original state.
# This must be done after loading weights to ensure they correspond to the merged state.
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

def _split_vllm_inputs(self, ordered_set_of_prompts, n_splits):
size = (len(ordered_set_of_prompts) + n_splits - 1) // n_splits
data = [ordered_set_of_prompts[size * i : size * (i + 1)] for i in range(n_splits)]
return data

def _vllm_generation(self, ordered_set_of_prompts):
with torch.cuda.device(self.llm_device):
worker_outputs = self.llm.generate(
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
)
return worker_outputs

@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
mode = "eval" if self.control.should_evaluate else "train"
Expand Down Expand Up @@ -685,30 +706,33 @@ def _generate_and_score_completions(

# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
vllm_index = self.accelerator.process_index
n_vllms = self.args.vllm_worker_num
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
if vllm_index < n_vllms:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
all_outputs = self.llm.generate(
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
)
completion_ids = []
for outputs in all_outputs:
# spliting the prompts into N subsets according to the number of vllm processes
split_ordered_prompts = self._split_vllm_inputs(ordered_set_of_prompts, n_vllms)
subset_results = self._vllm_generation(split_ordered_prompts[vllm_index])
completion_ids_subset = []
for outputs in subset_results:
for output in outputs.outputs:
completion_ids.append(output.token_ids)
completion_ids_subset.append(output.token_ids)
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
completion_ids_subset = [None]
# gather the completions from vllm processes to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
completion_ids = gather_object(completion_ids_subset)
completion_ids = [seq for seq in completion_ids if seq is not None]
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
Expand Down