1414
1515import os
1616import textwrap
17- import warnings
1817from collections import defaultdict
19- from contextlib import nullcontext
2018from typing import Any , Callable , Optional , Sized , Union
2119
20+ import deepspeed
2221import torch
2322import torch .utils .data
2423import transformers
4443from ..data_utils import apply_chat_template , is_conversational , maybe_apply_chat_template
4544from ..extras .profiling import profiling_context , profiling_decorator
4645from ..extras .vllm_client import VLLMClient
47- from ..import_utils import is_deepspeed_available , is_rich_available , is_vllm_available
46+ from ..import_utils import is_rich_available , is_vllm_available
4847from ..models import create_reference_model , prepare_deepspeed , unwrap_model_for_generation
4948from .callbacks import SyncRefModelCallback
5049from .grpo_config import GRPOConfig
5756)
5857
5958
60- if is_deepspeed_available ():
61- import deepspeed
62-
6359if is_peft_available ():
64- from peft import PeftConfig , get_peft_model
60+ from peft import PeftConfig , get_peft_model , get_peft_config
6561
6662
6763if is_wandb_available ():
@@ -273,6 +269,11 @@ def __init__(
273269 model_name = model_name .split ("/" )[- 1 ]
274270 args = GRPOConfig (f"{ model_name } -GRPO" )
275271
272+ if peft_config :
273+ self .lora_config = peft_config
274+ elif is_peft_model (model ):
275+ self .lora_config = model .peft_config ["default" ]
276+
276277 # Models
277278 # Trained model
278279 model_init_kwargs = args .model_init_kwargs or {}
@@ -389,8 +390,7 @@ def data_collator(features): # No data collation is needed in GRPO
389390
390391 # Multi-step
391392 self .num_iterations = args .num_iterations # = 𝜇 in the GRPO paper
392- self .epsilon_low = args .epsilon
393- self .epsilon_high = args .epsilon_high if args .epsilon_high is not None else args .epsilon
393+ self .epsilon = args .epsilon
394394 # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
395395 self ._step = 0
396396 # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
@@ -407,7 +407,6 @@ def data_collator(features): # No data collation is needed in GRPO
407407
408408 # Initialize the metrics
409409 self ._metrics = {"train" : defaultdict (list ), "eval" : defaultdict (list )}
410- self ._total_train_tokens = 0
411410 self .log_completions = args .log_completions
412411
413412 super ().__init__ (
@@ -454,9 +453,7 @@ def data_collator(features): # No data collation is needed in GRPO
454453 )
455454
456455 if self .accelerator .is_main_process :
457- self .vllm_client = VLLMClient (
458- args .vllm_server_host , args .vllm_server_port , connection_timeout = args .vllm_server_timeout
459- )
456+ self .vllm_client = VLLMClient (args .vllm_server_host , args .vllm_server_port )
460457
461458 # vLLM specific sampling arguments
462459 self .guided_decoding_regex = args .vllm_guided_decoding_regex
@@ -599,45 +596,20 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
599596 return selective_log_softmax (logits , input_ids ) # compute logprobs for the input tokens
600597
601598 @profiling_decorator
602- def _move_model_to_vllm (self ):
603- # For DeepSpeed ZeRO-3, we need to gather all parameters before operations
599+ def _move_lora_to_vllm (self ):
604600 deepspeed_plugin = self .accelerator .state .deepspeed_plugin
605601 zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin .zero_stage == 3
606- gather_if_zero3 = deepspeed .zero .GatheredParameters if zero_stage_3 else nullcontext
607-
608- if is_peft_model (self .model ):
609- # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
610- # adapters in a sharded manner is not supported.
611- with gather_if_zero3 (list (self .model .parameters ())):
612- self .model .merge_adapter ()
613-
614- # Update vLLM weights while parameters are gathered
615- for name , param in self .model .named_parameters ():
616- # When using PEFT, we need to recover the original parameter name and discard some parameters
617- name = name .removeprefix ("base_model.model." ).replace (".base_layer" , "" )
618- if self .model .prefix in name :
619- continue
620- # When module to save, remove its prefix and discard the original module
621- if "original_module" in name :
622- continue
623- name = name .replace ("modules_to_save.default." , "" )
624-
625- if self .accelerator .is_main_process :
626- self .vllm_client .update_named_param (name , param .data )
627-
628- # Unmerge adapters while parameters are still gathered
629- self .model .unmerge_adapter ()
630- # Parameters will automatically be repartitioned when exiting the context
631- else :
632- # For non-PEFT models, simply gather and update each parameter individually.
633- for name , param in self .model .named_parameters ():
634- with gather_if_zero3 ([param ]):
635- if self .accelerator .is_main_process :
636- self .vllm_client .update_named_param (name , param .data )
637-
638- # Reset cache on main process
639602 if self .accelerator .is_main_process :
640- self .vllm_client .reset_prefix_cache ()
603+ self .vllm_client .update_lora_params (self .model ,self .lora_config )
604+
605+ @profiling_decorator
606+ def _move_model_to_vllm (self ):
607+ deepspeed_plugin = self .accelerator .state .deepspeed_plugin
608+ zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin .zero_stage == 3
609+ for name , param in self .model .named_parameters ():
610+ with deepspeed .zero .GatheredParameters ([param ], enabled = zero_stage_3 ):
611+ if self .accelerator .is_main_process :
612+ self .vllm_client .update_named_param (name , param .data )
641613
642614 @profiling_decorator
643615 def _prepare_inputs (self , inputs : dict [str , Union [torch .Tensor , Any ]]) -> dict [str , Union [torch .Tensor , Any ]]:
@@ -674,15 +646,15 @@ def _generate_and_score_completions(
674646 if self .args .use_vllm :
675647 # First, have main process load weights if needed
676648 if self .state .global_step != self ._last_loaded_step :
677- self ._move_model_to_vllm ()
649+ if is_peft_model (self .model ):
650+ self ._move_lora_to_vllm ()
651+ else :
652+ self ._move_model_to_vllm ()
678653 self ._last_loaded_step = self .state .global_step
679654
680655 # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
681656 all_prompts_text = gather_object (prompts_text )
682657 if self .accelerator .is_main_process :
683- # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
684- # num_generations outputs for each one. This is faster than generating outputs for each duplicate
685- # prompt individually.
686658 ordered_set_of_prompts = all_prompts_text [:: self .num_generations ]
687659 with profiling_context (self , "vLLM.generate" ):
688660 completion_ids = self .vllm_client .generate (
@@ -713,9 +685,7 @@ def _generate_and_score_completions(
713685 prompt_completion_ids = torch .cat ([prompt_ids , completion_ids ], dim = 1 )
714686 else :
715687 # Regular generation path
716- with unwrap_model_for_generation (
717- self .model_wrapped , self .accelerator , gather_deepspeed3_params = self .args .ds3_gather_for_generation
718- ) as unwrapped_model :
688+ with unwrap_model_for_generation (self .model_wrapped , self .accelerator ) as unwrapped_model :
719689 prompt_completion_ids = unwrapped_model .generate (
720690 prompt_ids , attention_mask = prompt_mask , generation_config = self .generation_config
721691 )
@@ -839,10 +809,6 @@ def _generate_and_score_completions(
839809 # Log the metrics
840810 mode = "eval" if self .control .should_evaluate else "train"
841811
842- if mode == "train" :
843- self ._total_train_tokens += self .accelerator .gather_for_metrics (attention_mask .sum ()).sum ().item ()
844- self ._metrics [mode ]["num_tokens" ] = [self ._total_train_tokens ]
845-
846812 completion_length = self .accelerator .gather_for_metrics (completion_mask .sum (1 )).float ().mean ().item ()
847813 self ._metrics [mode ]["completion_length" ].append (completion_length )
848814
@@ -921,7 +887,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
921887 # _generate_and_score_completions) and use per_token_logps.detach() instead.
922888 old_per_token_logps = inputs ["old_per_token_logps" ] if self .num_iterations > 1 else per_token_logps .detach ()
923889 coef_1 = torch .exp (per_token_logps - old_per_token_logps )
924- coef_2 = torch .clamp (coef_1 , 1 - self .epsilon_low , 1 + self .epsilon_high )
890+ coef_2 = torch .clamp (coef_1 , 1 - self .epsilon , 1 + self .epsilon )
925891 per_token_loss1 = coef_1 * advantages .unsqueeze (1 )
926892 per_token_loss2 = coef_2 * advantages .unsqueeze (1 )
927893 per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
0 commit comments