Skip to content

Commit 8309ce8

Browse files
authored
Update grpo_trainer.py
1 parent b5f92ab commit 8309ce8

File tree

1 file changed

+27
-61
lines changed

1 file changed

+27
-61
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 27 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
import os
1616
import textwrap
17-
import warnings
1817
from collections import defaultdict
19-
from contextlib import nullcontext
2018
from typing import Any, Callable, Optional, Sized, Union
2119

20+
import deepspeed
2221
import torch
2322
import torch.utils.data
2423
import transformers
@@ -44,7 +43,7 @@
4443
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
4544
from ..extras.profiling import profiling_context, profiling_decorator
4645
from ..extras.vllm_client import VLLMClient
47-
from ..import_utils import is_deepspeed_available, is_rich_available, is_vllm_available
46+
from ..import_utils import is_rich_available, is_vllm_available
4847
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
4948
from .callbacks import SyncRefModelCallback
5049
from .grpo_config import GRPOConfig
@@ -57,11 +56,8 @@
5756
)
5857

5958

60-
if is_deepspeed_available():
61-
import deepspeed
62-
6359
if is_peft_available():
64-
from peft import PeftConfig, get_peft_model
60+
from peft import PeftConfig, get_peft_model,get_peft_config
6561

6662

6763
if is_wandb_available():
@@ -273,6 +269,11 @@ def __init__(
273269
model_name = model_name.split("/")[-1]
274270
args = GRPOConfig(f"{model_name}-GRPO")
275271

272+
if peft_config:
273+
self.lora_config=peft_config
274+
elif is_peft_model(model):
275+
self.lora_config=model.peft_config["default"]
276+
276277
# Models
277278
# Trained model
278279
model_init_kwargs = args.model_init_kwargs or {}
@@ -389,8 +390,7 @@ def data_collator(features): # No data collation is needed in GRPO
389390

390391
# Multi-step
391392
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
392-
self.epsilon_low = args.epsilon
393-
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
393+
self.epsilon = args.epsilon
394394
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
395395
self._step = 0
396396
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
@@ -407,7 +407,6 @@ def data_collator(features): # No data collation is needed in GRPO
407407

408408
# Initialize the metrics
409409
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
410-
self._total_train_tokens = 0
411410
self.log_completions = args.log_completions
412411

413412
super().__init__(
@@ -454,9 +453,7 @@ def data_collator(features): # No data collation is needed in GRPO
454453
)
455454

456455
if self.accelerator.is_main_process:
457-
self.vllm_client = VLLMClient(
458-
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout
459-
)
456+
self.vllm_client = VLLMClient(args.vllm_server_host, args.vllm_server_port)
460457

461458
# vLLM specific sampling arguments
462459
self.guided_decoding_regex = args.vllm_guided_decoding_regex
@@ -599,45 +596,20 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
599596
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
600597

601598
@profiling_decorator
602-
def _move_model_to_vllm(self):
603-
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
599+
def _move_lora_to_vllm(self):
604600
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
605601
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
606-
gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
607-
608-
if is_peft_model(self.model):
609-
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
610-
# adapters in a sharded manner is not supported.
611-
with gather_if_zero3(list(self.model.parameters())):
612-
self.model.merge_adapter()
613-
614-
# Update vLLM weights while parameters are gathered
615-
for name, param in self.model.named_parameters():
616-
# When using PEFT, we need to recover the original parameter name and discard some parameters
617-
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
618-
if self.model.prefix in name:
619-
continue
620-
# When module to save, remove its prefix and discard the original module
621-
if "original_module" in name:
622-
continue
623-
name = name.replace("modules_to_save.default.", "")
624-
625-
if self.accelerator.is_main_process:
626-
self.vllm_client.update_named_param(name, param.data)
627-
628-
# Unmerge adapters while parameters are still gathered
629-
self.model.unmerge_adapter()
630-
# Parameters will automatically be repartitioned when exiting the context
631-
else:
632-
# For non-PEFT models, simply gather and update each parameter individually.
633-
for name, param in self.model.named_parameters():
634-
with gather_if_zero3([param]):
635-
if self.accelerator.is_main_process:
636-
self.vllm_client.update_named_param(name, param.data)
637-
638-
# Reset cache on main process
639602
if self.accelerator.is_main_process:
640-
self.vllm_client.reset_prefix_cache()
603+
self.vllm_client.update_lora_params(self.model,self.lora_config)
604+
605+
@profiling_decorator
606+
def _move_model_to_vllm(self):
607+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
608+
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
609+
for name, param in self.model.named_parameters():
610+
with deepspeed.zero.GatheredParameters([param], enabled=zero_stage_3):
611+
if self.accelerator.is_main_process:
612+
self.vllm_client.update_named_param(name, param.data)
641613

642614
@profiling_decorator
643615
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
@@ -674,15 +646,15 @@ def _generate_and_score_completions(
674646
if self.args.use_vllm:
675647
# First, have main process load weights if needed
676648
if self.state.global_step != self._last_loaded_step:
677-
self._move_model_to_vllm()
649+
if is_peft_model(self.model):
650+
self._move_lora_to_vllm()
651+
else:
652+
self._move_model_to_vllm()
678653
self._last_loaded_step = self.state.global_step
679654

680655
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
681656
all_prompts_text = gather_object(prompts_text)
682657
if self.accelerator.is_main_process:
683-
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
684-
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
685-
# prompt individually.
686658
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
687659
with profiling_context(self, "vLLM.generate"):
688660
completion_ids = self.vllm_client.generate(
@@ -713,9 +685,7 @@ def _generate_and_score_completions(
713685
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
714686
else:
715687
# Regular generation path
716-
with unwrap_model_for_generation(
717-
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
718-
) as unwrapped_model:
688+
with unwrap_model_for_generation(self.model_wrapped, self.accelerator) as unwrapped_model:
719689
prompt_completion_ids = unwrapped_model.generate(
720690
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
721691
)
@@ -839,10 +809,6 @@ def _generate_and_score_completions(
839809
# Log the metrics
840810
mode = "eval" if self.control.should_evaluate else "train"
841811

842-
if mode == "train":
843-
self._total_train_tokens += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
844-
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
845-
846812
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
847813
self._metrics[mode]["completion_length"].append(completion_length)
848814

@@ -921,7 +887,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
921887
# _generate_and_score_completions) and use per_token_logps.detach() instead.
922888
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
923889
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
924-
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
890+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
925891
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
926892
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
927893
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

0 commit comments

Comments
 (0)