Skip to content

Commit 4576678

Browse files
[rollout, vllm] fix: handle lora request when base_sync_done is false initially (volcengine#3907)
### What does this PR do? This PR fixes a remaining issue that was not fully addressed by volcengine#3821 and led to downstream errors such as volcengine#3882. The bug occurs when the configuration `rollout.load_format` is set to `"dummy"`, which causes the initial `base_sync_done` flag to be set to `False`: <https://github.com/volcengine/verl/blob/ecdaa8d9af75ab064bcbda0d797986a198d752b0/verl/workers/fsdp_workers.py#L619-L620> As a result, during the first invocation of `update_weights`, the `TensorLoRARequest` is **not** successfully added to the engine: <https://github.com/volcengine/verl/blob/ecdaa8d9af75ab064bcbda0d797986a198d752b0/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L613-L625> Later, when `generate()` is called, the code assumes that the LoRA with the internal ID `VLLM_LORA_INT_ID` is already loaded. This assumption fails in the scenario above, and the engine attempts to reload it using the placeholder `VLLM_LORA_PATH`, leading to `FileNotFoundError` exceptions. > **Note:** The correct LoRA weights should always be loaded from the custom `TensorLoRARequest`. This is made possible because `verl` manually overrides (`"hijacks"`) the `_load_adapter` method in `vllm`’s `WorkerLoraManager`. --- ### Proposed Fix We add a safety check in the `generate` method to ensure the LoRA is loaded before constructing the request: ```python # Add LoRA request lora_request = None # Ensure the LoRA is already loaded in the engine if self.model_config.lora_rank > 0: lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras() # <-- NEW if lora_loaded: lora_request = LoRARequest( lora_name=VLLM_LORA_NAME, lora_int_id=VLLM_LORA_INT_ID, lora_path=VLLM_LORA_PATH, ) generator = self.engine.generate( prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, ) ``` ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: volcengine#3821 - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. I've manually reproduced the bug mentioned in volcengine#3882, and confirmed that the above fix resolve this bug. For interested, some experiment runs on [wandb](https://wandb.ai/listar2000/solver-judge-workflow) using the code **after the proposed fix** (and with `load_format = "dummy"`) -- even though this PR has nothing to do with LoRA performances. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). **N/A** - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: **N/A: see tests above** - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent d495541 commit 4576678

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,12 @@ async def generate(
380380
# Add lora request
381381
lora_request = None
382382
if self.model_config.lora_rank > 0:
383-
lora_request = LoRARequest(lora_name=VLLM_LORA_NAME, lora_int_id=VLLM_LORA_INT_ID, lora_path=VLLM_LORA_PATH)
383+
# Make sure we also check that the lora is already loaded in the engine
384+
lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras()
385+
if lora_loaded:
386+
lora_request = LoRARequest(
387+
lora_name=VLLM_LORA_NAME, lora_int_id=VLLM_LORA_INT_ID, lora_path=VLLM_LORA_PATH
388+
)
384389

385390
generator = self.engine.generate(
386391
prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request

0 commit comments

Comments
 (0)