Skip to content

Conversation

@maoulee
Copy link

@maoulee maoulee commented Mar 22, 2025

What does this PR do?

This PR modifies vLLM, through patching, to directly load the weights of a PEFT model and apply them as LoRA adapters during inference. This avoids the need to merge the entire model and transfer it to the generation server during online reinforcement learning algorithms like GRPO and DPO.

This provides the following benefits:

  1. Supports Online RL for Quantized Models: Enables GRPO and Online DPO training on quantized models (e.g., QLoRA models), which are typically not mergeable with the base model.
  2. Avoids OOM Issues: Sending both the base model parameters and PEFT parameters to the generation server for parameter updates requires twice the VRAM of the base model. Transferring only the LoRA parameters requires significantly less VRAM, making it easier to deploy models using PEFT.
  3. Avoids PEFT Load/Unload Issues: Eliminates potential errors associated with merging and unloading PEFT models.

Limitations:

  1. Testing with ZeRO-3 has not yet been performed.

@maoulee maoulee changed the title Enable direct loading of LoRA adapters in vLLM to streamline GRPO/Online Dpo training Add GRPO/ Online DPO support for quantitative models when use vllm as infer backbone. Mar 23, 2025
@maoulee
Copy link
Author

maoulee commented Mar 24, 2025

@qgallouedec Could you be free test this PR?

@qgallouedec
Copy link
Member

Thanks @maoulee

This avoids the need to merge the entire model and transfer it to the generation server during online reinforcement learning algorithms like GRPO and DPO.

I understand the motivation, but does it really matter here? I don't think the merge is the bottleneck. Have you done any benchmarking?

  1. Avoids OOM Issues: Sending both the base model parameters and PEFT parameters to the generation server for parameter updates requires twice the VRAM of the base model. Transferring only the LoRA parameters requires significantly less VRAM, making it easier to deploy models using PEFT.

Not sure to get this one, why would you send both the adapters and the model? And even if you do, the adapter is significantly smaller than the model, how do you end up with "twice the VRAM"?
Speaking of the VRAM, the models are already loaded, and nothing is duplicated, so I don't understand where do you save any VRAM.

  1. Avoids PEFT Load/Unload Issues: Eliminates potential errors associated with merging and unloading PEFT models.

I'm not aware of such error. Do you mean numerical errors? Do you have any reports? pointers?

I haven't looked in detail at the changes you're proposing, because it seems to me that your PR still needs to be cleaned up, the majority of the changes don't seem to be related to what you're proposing, so please limit the number of line changed to make the review possible.🙏

@binary-husky
Copy link
Contributor

@qgallouedec I think there are two cases: non-zero3 / zero3 cases

In non-zero3 cases, PEFT may not consume too much memory.

However, in zero3 cases, the zero3 param gather may cause GPU OOM:

  • Experiment: 72B model + machine 1 [8 * A100 running training] + machine2 [4 * A100] running VLLM = GPU memory OOM when self._move_model_to_vllm

image

image

  • In theory, this PR proposes a better solution because it not only solve the param gather OOM, but also improve communication speed, it is worth considering but also need a lot of tests

@maoulee
Copy link
Author

maoulee commented Mar 25, 2025

@qgallouedec I think there are two cases: non-zero3 / zero3 cases

In non-zero3 cases, PEFT may not consume too much memory.

However, in zero3 cases, the zero3 param gather may cause GPU OOM:

  • Experiment: 72B model + machine 1 [8 * A100 running training] + machine2 [4 * A100] running VLLM = GPU memory OOM when self._move_model_to_vllm

image

image

  • In theory, this PR proposes a better solution because it not only solve the param gather OOM, but also improve communication speed, it is worth considering but also need a lot of tests

I've tested this PR and it works well on r1-32b-int4 grpo! Unfortunately, my current school workload and limited GPU access (just 2*A100 40GB) have prevented me from testing other setups thoroughly.
If you're able to give it a try, that would be great! The core functionality is the same as the standard vllm_serve and vllm_client. This PR adds apply_lora functionality and includes vllm_patch.py for monkey-patching the vLLM LoRA management module.

@Pullerz
Copy link

Pullerz commented Apr 2, 2025

Hey @maoulee, I've been giving your implementation a try as I've had issues around merging adapters particularly when using different quantization techniques – what did you modify in the GRPO trainer to call your code correctly, particularly in the _move_model_to_vllm method, as I'm currently getting VLLM segfaults when I am calling the update_lora_params method.

Any help on this or code snippets you could share would be hugely appreciated if you have the time!

Edit: The issue seems to be related to using tensor_parallel_size > 1 for me

@maoulee
Copy link
Author

maoulee commented Apr 20, 2025

Hey @maoulee, I've been giving your implementation a try as I've had issues around merging adapters particularly when using different quantization techniques – what did you modify in the GRPO trainer to call your code correctly, particularly in the _move_model_to_vllm method, as I'm currently getting VLLM segfaults when I am calling the update_lora_params method.

Any help on this or code snippets you could share would be hugely appreciated if you have the time!

Edit: The issue seems to be related to using tensor_parallel_size > 1 for me

Sorry, my email just reminded me now. I have updated the code again, and you can use move_lora_to_vllm to update the adapter parameters.
in grpo_trainer.py
@profiling_decorator
def _move_lora_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext

    if is_peft_model(self.model):
        # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
        # adapters in a sharded manner is not supported.
        with gather_if_zero3(list(self.model.parameters())):
        
            # Update vLLM weights while parameters are gathered
            for name, param in self.model.named_parameters():
                # When using PEFT, we need to recover the original parameter name and discard some parameters
                if ".lora_A." in name or ".lora_B." in name:
                    name=name.replace(".default", "")
                    if self.accelerator.is_main_process:
                        self.vllm_client.update_lora_param(name, param.data)

            if self.accelerator.is_main_process:
                self.vllm_client.apply_lora(self.lora_config)
    else:
        # For non-PEFT models, simply gather and update each parameter individually.
        for name, param in self.model.named_parameters():
            with gather_if_zero3([param]):
                if self.accelerator.is_main_process:
                    self.vllm_client.update_named_param(name, param.data)

    # Reset cache on main process
    if self.accelerator.is_main_process:
        self.vllm_client.reset_prefix_cache()

in vllm_client.py:
def update_lora_param(self, name: str, weights: torch.Tensor):
"""
Updates a specific named parameter in the model and broadcasts it to other processes.

    Args:
        name (`str`):
            Name of the layer whose weights are being updated.
        weights (`torch.Tensor`):
            Tensor containing the updated weights.
    """
    dtype, shape = str(weights.dtype), tuple(weights.shape)
    url = f"http://{self.host}:{self.server_port}/update_lora_param/"
    response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape})
    if response.status_code != 200:
        raise Exception(f"Request failed: {response.status_code}, {response.text}")

    # Broadcast the weights to the other processes
    self.pynccl_comm.broadcast(weights, src=self.rank, stream=torch.cuda.current_stream())
    self.pynccl_comm.group.barrier()

def apply_lora(self,config):
    url = f"http://{self.host}:{self.server_port}/apply_lora/"
    config_dict = config.to_dict()
    for key, value in config_dict.items():
        if isinstance(value, set):
            config_dict[key] = list(value)
    response = self.session.post(url, json={"lora_config": config_dict})

in vllm_serve.py:
def update_lora_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]) -> None:
"""
Receives updated weights from the client process and updates the named parameter in the model.

    Args:
        name (`str`):
            Name of the weight tensor being updated.
        dtype (`torch.dtype`):
            Data type of the weight tensor (e.g., `torch.float32`).
        shape (`Sequence[int]`):
            Shape of the weight tensor.
    """
    if self.pynccl_comm is None:
        raise RuntimeError("Communicator not initialized. Call `init_communicator` first.")

    # Allocate memory for the incoming weight tensor on the correct device.
    weight = torch.empty(shape, dtype=dtype, device=self.device)

    # Use NCCL to broadcast the updated weights from the client (src) to all workers.
    self.pynccl_comm.broadcast(weight, src=self.client_rank, stream=torch.cuda.current_stream())
    self.pynccl_comm.group.barrier()

    # Load the received weights into the model.
    self.lora_weight[name]=weight

@app.post("/update_lora_param/")
async def update_lora_param(request: UpdateWeightsRequest, background_tasks: BackgroundTasks):
"""
Updates the model weights with the provided tensor.

    Once this endpoint is called, the client process should broadcast the updated weights to all server workers.

    Args:
        request (`UpdateWeightsRequest`):
            - `name` (`str`): Name of the weight tensor being updated.
            - `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`).
            - `shape` (list of `int`): Shape of the weight

    """
    # The function is called this way: update_named_param(name="name", dtype=torch.float32, shape=(10, 10))
    # So with collect_rpc we need to call it this way:
    # llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10)))
    # And with background_tasks.add_task we need to call it this way:
    # background_tasks.add_task(llm.collective_rpc, "update_named_param", args=("name", torch.float32, (10, 10)))
    dtype = torch.__getattribute__(request.dtype.split(".")[-1])
    background_tasks.add_task(llm.collective_rpc, "update_lora_param", args=(request.name, dtype, request.shape))
    return {"message": "Request received, updating lora parameter"}

class ApplyLoraRequest(BaseModel):
    lora_config: dict

@app.post("/apply_lora/")
def apply_lora(request: ApplyLoraRequest, background_tasks: BackgroundTasks):
        lora_worker = llm.llm_engine.model_executor.driver_worker
        lora_weights = lora_worker.lora_weight
        lora_config = request.lora_config
        from vllm.lora.request import LoRARequest
        lora_request = LoRARequest(
            lora_name=str(lora_worker.lora_id),
            lora_int_id=lora_worker.lora_id,
            lora_tensors=lora_weights,
            lora_config=lora_config,
        )
        lora_worker.lora_id=lora_worker.lora_id+1
        lora_worker.lora_requests=lora_request
        return {"message": f"LoRA applied with ID: {lora_worker.lora_id}", "lora_id": lora_worker.lora_id}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants