Skip to content

[Core] Add update_config RPC method #20095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 14, 2025
Merged
7 changes: 5 additions & 2 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,14 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner.load_model()
original_load_format = model_runner_2.load_config.load_format
model_runner_2.load_config.load_format = "dummy"
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format
model_runner_2.update_config(
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())
Expand Down
10 changes: 10 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
import dataclasses
import gc
import time
import weakref
Expand Down Expand Up @@ -1692,6 +1693,15 @@ def generate_draft_token_ids(
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids

def update_config(self, overrides: dict[str, Any]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function feels a bit scary to to be really honest. due to:
1/ not every config would be updatable even if they exist -- for example updating parallel_config probably wouldn't work :(
2/ do we guarantee that the model runner always read values from self.xxxx_config not vllm_config.xxxx_config?
3/ how do we ensure the new config is a valid config for its type?

potentially we can limit updates to limited known good configs first

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. very good point. I've updated the PR to restrict the change to load_config and model_config for now, to fulfill our purpose of model/weights update
  2. this is messy in model runner itself, we should perhaps clean up in a separate PR.
  3. pydantic config validation still runs as we do dataclasses.replace

for config_name, config_overrides in overrides.items():
try:
config = getattr(self, config_name)
except AttributeError as exc:
raise ValueError(f"Unknown config {config_name}") from exc
new_config = dataclasses.replace(config, **config_overrides)
setattr(self, config_name, new_config)

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

import torch
import torch.distributed
Expand Down Expand Up @@ -184,6 +184,9 @@ def load_model(self) -> None:
with context:
self.model_runner.load_model()

def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)

@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import bisect
import dataclasses
import gc
import time
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -968,6 +969,15 @@ def execute_model(

return model_runner_output

def update_config(self, overrides: dict[str, Any]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the update_config logic is different from gpu one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, missed it after a few updates. @Chenyaaang @yaochengji perhaps worth refactoring to inherit from a base class as TPU code shares a lot common logic

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe different hardware has its own specific logic of checking the configuration. E.g. https://github.com/vllm-project/vllm/blob/main/vllm/platforms/tpu.py#L99

We can add a TODO here

for config_name, config_overrides in overrides.items():
try:
config = getattr(self, config_name)
except AttributeError as exc:
raise ValueError(f"Unknown config {config_name}") from exc
new_config = dataclasses.replace(config, **config_overrides)
setattr(self, config_name, new_config)

def load_model(self) -> None:
self.device = self.device_config.device

Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import os
from typing import Optional
from typing import Any, Optional

import torch
import torch.distributed
Expand Down Expand Up @@ -248,6 +248,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def load_model(self) -> None:
self.model_runner.load_model()

def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)

def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
Expand Down