Skip to content

Commit d759c9c

Browse files
committed
Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution
1 parent 10d26ef commit d759c9c

File tree

2 files changed

+101
-5
lines changed

2 files changed

+101
-5
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,15 @@ def data_collator(features): # No data collation is needed in GRPO
453453
)
454454

455455
if self.accelerator.is_main_process:
456-
self.vllm_client = VLLMClient(args.vllm_server_host, args.vllm_server_port, args.vllm_server_timeout)
456+
self.vllm_client = VLLMClient(host=args.vllm_server_host, server_port=args.vllm_server_port, connection_timeout=args.vllm_server_timeout)
457457

458458
# vLLM specific sampling arguments
459459
self.guided_decoding_regex = args.vllm_guided_decoding_regex
460460

461461
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
462-
462+
# when using lora, the initial model may be slightly different from the one in vllm
463+
# so we force loading the model once at the beginning of training
464+
self.vllm_never_synced_before = True
463465
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
464466
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
465467
# synchronize all processes after vLLM has been fully initialized.
@@ -602,6 +604,11 @@ def _move_model_to_vllm(self):
602604
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
603605
gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
604606

607+
# Zero3 + peft + very_large_model special combination
608+
if zero_stage_3 and is_peft_model(self.model):
609+
self._move_model_to_vllm_especially_for_zero3_plus_peft()
610+
return # early return
611+
605612
if is_peft_model(self.model):
606613
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
607614
# adapters in a sharded manner is not supported.
@@ -635,6 +642,80 @@ def _move_model_to_vllm(self):
635642
# Reset cache on main process
636643
if self.accelerator.is_main_process:
637644
self.vllm_client.reset_prefix_cache()
645+
return
646+
647+
def _move_model_to_vllm_especially_for_zero3_plus_peft(self):
648+
"""
649+
Why this special method is needed for the special (zero3 + peft + very_large_model) combination? Why this combination is very nasty?
650+
651+
1. `model.merge_adapter()` must be executed after `unwrap_model_for_generation`
652+
2. `unwrap_model_for_generation` can cause GPU OOM if model is very large
653+
3. Usually, GPU OOM can be resolve by setting `gather_deepspeed3_params=False`
654+
4. But guess what? `gather_deepspeed3_params=False` cause error for `model.merge_adapter()`
655+
656+
So, we rewrite the `merge_adapter` code, to avoid GPU OOM, we have to merge the adapter module by module. The basic idea is:
657+
658+
1. We first deal with lora weights only
659+
2. Then we deal with whatever params that are left behind
660+
"""
661+
from peft.tuners.tuners_utils import BaseTunerLayer, onload_layer
662+
663+
# Update the weights in vLLM. When using DeepSpeed ZeRO Stage 3, we need to gather the parameters before updating the weights.
664+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
665+
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
666+
warning = "This special `_move_model_to_vllm` is designed only for zero3 + peft + very_large_model. Only a nasty problem needs a nasty solution like this."
667+
if not zero_stage_3 or not is_peft_model(self.model):
668+
raise RuntimeError(warning)
669+
670+
parameter_to_transfer_map_id_name = {id(param): name for name, param in self.model.named_parameters()}
671+
# 1. We first deal with lora weights only, it is very very nasty 😂
672+
for module in self.model.modules():
673+
# This return not only leaf modules, but also the parent module
674+
if isinstance(module, BaseTunerLayer):
675+
# do not know what this `onload_layer` thing does, but it seems important
676+
with onload_layer(module):
677+
# get all the parameters of this small module
678+
param_list_of_this_small_module = [param for relative_name, param in module.named_parameters()]
679+
with deepspeed.zero.GatheredParameters(param_list_of_this_small_module) if zero_stage_3 else nullcontext():
680+
# we must `GatheredParameters` before module.merge
681+
module.merge(adapter_names=None)
682+
for relative_name, param in module.named_parameters():
683+
param_python_id = id(param)
684+
# get the absolute name of the parameter
685+
# absolute_name = f"{module.prefix}.{relative_name}"
686+
absolute_name = parameter_to_transfer_map_id_name[param_python_id] # f"{module.prefix}.{relative_name}"
687+
# one less weight to worry about
688+
parameter_to_transfer_map_id_name.pop(param_python_id)
689+
# only the main process is responsible for transferring weights
690+
if self.accelerator.is_main_process:
691+
# When using PEFT, we need to recover the original parameter name and discard some parameters
692+
absolute_name = absolute_name.removeprefix("base_model.model.").replace(".base_layer", "")
693+
if self.model.prefix in absolute_name:
694+
continue
695+
# When module to save, remove its prefix and discard the original module
696+
if "original_module" in absolute_name:
697+
continue
698+
absolute_name = absolute_name.replace("modules_to_save.default.", "")
699+
# Finally it is time to be transferred. 🌟
700+
# print(f"Transferring: {absolute_name}")
701+
self.vllm_client.update_named_param(absolute_name, param.data)
702+
# and of course, we must unmerge before exit `GatheredParameters`
703+
module.unmerge()
704+
# 2. Then we deal with whatever params that are left behind
705+
remaining_param_list = [(name, param) for name, param in self.model.named_parameters() if name in parameter_to_transfer_map_id_name.values()]
706+
for name, param in remaining_param_list:
707+
# print(f"Transferring Part2: {name}")
708+
with deepspeed.zero.GatheredParameters([param]) if zero_stage_3 else nullcontext():
709+
if self.accelerator.is_main_process:
710+
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
711+
if self.model.prefix in name: raise RuntimeError("Something must be wrong because we assume lora-related weights are already transferred.")
712+
if "original_module" in name: raise RuntimeError("Something must be wrong because we assume lora-related weights are already transferred.")
713+
if ("modules_to_save.default." in name): raise RuntimeError("Something must be wrong because we assume lora-related weights are already transferred.")
714+
self.vllm_client.update_named_param(name, param.data)
715+
716+
# Reset the prefix cache after updating weights
717+
if self.accelerator.is_main_process:
718+
self.vllm_client.reset_prefix_cache()
638719

639720
@profiling_decorator
640721
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
@@ -670,7 +751,8 @@ def _generate_and_score_completions(
670751
# Generate completions using either vLLM or regular generation
671752
if self.args.use_vllm:
672753
# First, have main process load weights if needed
673-
if self.state.global_step != self._last_loaded_step:
754+
if self.state.global_step != self._last_loaded_step or self.vllm_never_synced_before:
755+
self.vllm_never_synced_before = False
674756
self._move_model_to_vllm()
675757
self._last_loaded_step = self.state.global_step
676758

trl/trainer/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,7 +1691,7 @@ def selective_log_softmax(logits, index):
16911691
return per_token_logps
16921692

16931693

1694-
def print_prompt_completions_sample(prompts: list[str], completions: list[str], rewards: list[int], step: int) -> None:
1694+
def print_prompt_completions_sample(prompts: list[str], completions: list[str], rewards: list[int], step: int, max_display_rows=2) -> None:
16951695
"""
16961696
Print out a sample of model completions to the console.
16971697
@@ -1737,7 +1737,21 @@ def print_prompt_completions_sample(prompts: list[str], completions: list[str],
17371737
table.add_column("Completion", style="bright_green")
17381738
table.add_column("Reward", style="bold cyan", justify="right")
17391739

1740-
for prompt, completion, reward in zip(prompts, completions, rewards):
1740+
n_rows = len(prompts)
1741+
1742+
for i, prompt, completion, reward in zip(range(n_rows), prompts, completions, rewards):
1743+
if i >= (max_display_rows - 1):
1744+
if i == n_rows-1:
1745+
# last row, always print normally
1746+
...
1747+
elif i == max_display_rows - 1:
1748+
# not last row, print ellipsis and average reward
1749+
prompt = "..."
1750+
completion = "..."
1751+
reward = sum(rewards) / n_rows
1752+
else:
1753+
# skip all other rows because of the max_display_rows limitation
1754+
continue
17411755
table.add_row(Text(prompt), Text(completion), f"{reward:.2f}") # Formatting reward to 2 decimal places
17421756
table.add_section() # Adds a separator between rows
17431757

0 commit comments

Comments
 (0)