Skip to content

Commit 7a387a9

Browse files
22quinnPradyun Ramadorai
authored andcommitted
[Core] Support model loader plugins (vllm-project#21067)
Signed-off-by: 22quinn <[email protected]>
1 parent 8f537ea commit 7a387a9

File tree

9 files changed

+159
-86
lines changed

9 files changed

+159
-86
lines changed

tests/fastsafetensors_loader/test_fastsafetensors_loader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from vllm import SamplingParams
5-
from vllm.config import LoadFormat
65

76
test_model = "openai-community/gpt2"
87

@@ -17,7 +16,6 @@
1716

1817

1918
def test_model_loader_download_files(vllm_runner):
20-
with vllm_runner(test_model,
21-
load_format=LoadFormat.FASTSAFETENSORS) as llm:
19+
with vllm_runner(test_model, load_format="fastsafetensors") as llm:
2220
deserialized_outputs = llm.generate(prompts, sampling_params)
2321
assert deserialized_outputs

tests/model_executor/model_loader/__init__.py

Whitespace-only changes.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
from torch import nn
6+
7+
from vllm.config import LoadConfig, ModelConfig
8+
from vllm.model_executor.model_loader import (get_model_loader,
9+
register_model_loader)
10+
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
11+
12+
13+
@register_model_loader("custom_load_format")
14+
class CustomModelLoader(BaseModelLoader):
15+
16+
def __init__(self, load_config: LoadConfig) -> None:
17+
super().__init__(load_config)
18+
19+
def download_model(self, model_config: ModelConfig) -> None:
20+
pass
21+
22+
def load_weights(self, model: nn.Module,
23+
model_config: ModelConfig) -> None:
24+
pass
25+
26+
27+
def test_register_model_loader():
28+
load_config = LoadConfig(load_format="custom_load_format")
29+
assert isinstance(get_model_loader(load_config), CustomModelLoader)
30+
31+
32+
def test_invalid_model_loader():
33+
with pytest.raises(ValueError):
34+
35+
@register_model_loader("invalid_load_format")
36+
class InValidModelLoader:
37+
pass

tests/runai_model_streamer_test/test_runai_model_streamer_loader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from vllm import SamplingParams
5-
from vllm.config import LoadConfig, LoadFormat
5+
from vllm.config import LoadConfig
66
from vllm.model_executor.model_loader import get_model_loader
77

8+
load_format = "runai_streamer"
89
test_model = "openai-community/gpt2"
910

1011
prompts = [
@@ -18,7 +19,7 @@
1819

1920

2021
def get_runai_model_loader():
21-
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
22+
load_config = LoadConfig(load_format=load_format)
2223
return get_model_loader(load_config)
2324

2425

@@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():
2829

2930

3031
def test_runai_model_loader_download_files(vllm_runner):
31-
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
32+
with vllm_runner(test_model, load_format=load_format) as llm:
3233
deserialized_outputs = llm.generate(prompts, sampling_params)
3334
assert deserialized_outputs

vllm/config.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from vllm.model_executor.layers.quantization import QuantizationMethods
6666
from vllm.model_executor.layers.quantization.base_config import (
6767
QuantizationConfig)
68-
from vllm.model_executor.model_loader import BaseModelLoader
68+
from vllm.model_executor.model_loader import LoadFormats
6969
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
7070

7171
ConfigType = type[DataclassInstance]
@@ -78,6 +78,7 @@
7878
QuantizationConfig = Any
7979
QuantizationMethods = Any
8080
BaseModelLoader = Any
81+
LoadFormats = Any
8182
TensorizerConfig = Any
8283
ConfigType = type
8384
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
@@ -1773,29 +1774,12 @@ def verify_with_parallel_config(
17731774
logger.warning("Possibly too large swap space. %s", msg)
17741775

17751776

1776-
class LoadFormat(str, enum.Enum):
1777-
AUTO = "auto"
1778-
PT = "pt"
1779-
SAFETENSORS = "safetensors"
1780-
NPCACHE = "npcache"
1781-
DUMMY = "dummy"
1782-
TENSORIZER = "tensorizer"
1783-
SHARDED_STATE = "sharded_state"
1784-
GGUF = "gguf"
1785-
BITSANDBYTES = "bitsandbytes"
1786-
MISTRAL = "mistral"
1787-
RUNAI_STREAMER = "runai_streamer"
1788-
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
1789-
FASTSAFETENSORS = "fastsafetensors"
1790-
1791-
17921777
@config
17931778
@dataclass
17941779
class LoadConfig:
17951780
"""Configuration for loading the model weights."""
17961781

1797-
load_format: Union[str, LoadFormat,
1798-
"BaseModelLoader"] = LoadFormat.AUTO.value
1782+
load_format: Union[str, LoadFormats] = "auto"
17991783
"""The format of the model weights to load:\n
18001784
- "auto" will try to load the weights in the safetensors format and fall
18011785
back to the pytorch bin format if safetensors format is not available.\n
@@ -1816,7 +1800,8 @@ class LoadConfig:
18161800
- "gguf" will load weights from GGUF format files (details specified in
18171801
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
18181802
- "mistral" will load weights from consolidated safetensors files used by
1819-
Mistral models."""
1803+
Mistral models.
1804+
- Other custom values can be supported via plugins."""
18201805
download_dir: Optional[str] = None
18211806
"""Directory to download and load the weights, default to the default
18221807
cache directory of Hugging Face."""
@@ -1864,10 +1849,7 @@ def compute_hash(self) -> str:
18641849
return hash_str
18651850

18661851
def __post_init__(self):
1867-
if isinstance(self.load_format, str):
1868-
load_format = self.load_format.lower()
1869-
self.load_format = LoadFormat(load_format)
1870-
1852+
self.load_format = self.load_format.lower()
18711853
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
18721854
logger.info(
18731855
"Ignoring the following patterns when downloading weights: %s",

vllm/engine/arg_utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,12 @@
2626
DetailedTraceModules, Device, DeviceConfig,
2727
DistributedExecutorBackend, GuidedDecodingBackend,
2828
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
29-
KVTransferConfig, LoadConfig, LoadFormat,
30-
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
31-
ModelImpl, MultiModalConfig, ObservabilityConfig,
32-
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
33-
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
34-
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
35-
get_field)
29+
KVTransferConfig, LoadConfig, LogprobsMode,
30+
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
31+
MultiModalConfig, ObservabilityConfig, ParallelConfig,
32+
PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig,
33+
SchedulerPolicy, SpeculativeConfig, TaskOption,
34+
TokenizerMode, VllmConfig, get_attr_docs, get_field)
3635
from vllm.logger import init_logger
3736
from vllm.platforms import CpuArchEnum, current_platform
3837
from vllm.plugins import load_general_plugins
@@ -47,10 +46,12 @@
4746
if TYPE_CHECKING:
4847
from vllm.executor.executor_base import ExecutorBase
4948
from vllm.model_executor.layers.quantization import QuantizationMethods
49+
from vllm.model_executor.model_loader import LoadFormats
5050
from vllm.usage.usage_lib import UsageContext
5151
else:
5252
ExecutorBase = Any
5353
QuantizationMethods = Any
54+
LoadFormats = Any
5455
UsageContext = Any
5556

5657
logger = init_logger(__name__)
@@ -276,7 +277,7 @@ class EngineArgs:
276277
trust_remote_code: bool = ModelConfig.trust_remote_code
277278
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
278279
download_dir: Optional[str] = LoadConfig.download_dir
279-
load_format: str = LoadConfig.load_format
280+
load_format: Union[str, LoadFormats] = LoadConfig.load_format
280281
config_format: str = ModelConfig.config_format
281282
dtype: ModelDType = ModelConfig.dtype
282283
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
@@ -547,9 +548,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
547548
title="LoadConfig",
548549
description=LoadConfig.__doc__,
549550
)
550-
load_group.add_argument("--load-format",
551-
choices=[f.value for f in LoadFormat],
552-
**load_kwargs["load_format"])
551+
load_group.add_argument("--load-format", **load_kwargs["load_format"])
553552
load_group.add_argument("--download-dir",
554553
**load_kwargs["download_dir"])
555554
load_group.add_argument("--model-loader-extra-config",
@@ -864,10 +863,9 @@ def create_model_config(self) -> ModelConfig:
864863

865864
# NOTE: This is to allow model loading from S3 in CI
866865
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
867-
and self.model in MODELS_ON_S3
868-
and self.load_format == LoadFormat.AUTO): # noqa: E501
866+
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
869867
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
870-
self.load_format = LoadFormat.RUNAI_STREAMER
868+
self.load_format = "runai_streamer"
871869

872870
return ModelConfig(
873871
model=self.model,
@@ -1299,7 +1297,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
12991297
#############################################################
13001298
# Unsupported Feature Flags on V1.
13011299

1302-
if self.load_format == LoadFormat.SHARDED_STATE.value:
1300+
if self.load_format == "sharded_state":
13031301
_raise_or_fallback(
13041302
feature_name=f"--load_format {self.load_format}",
13051303
recommend_to_remove=False)

vllm/model_executor/model_loader/__init__.py

Lines changed: 87 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Optional
4+
from typing import Literal, Optional
55

66
from torch import nn
77

8-
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
8+
from vllm.config import LoadConfig, ModelConfig, VllmConfig
9+
from vllm.logger import init_logger
910
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
1011
from vllm.model_executor.model_loader.bitsandbytes_loader import (
1112
BitsAndBytesModelLoader)
@@ -20,34 +21,92 @@
2021
from vllm.model_executor.model_loader.utils import (
2122
get_architecture_class_name, get_model_architecture, get_model_cls)
2223

24+
logger = init_logger(__name__)
25+
26+
# Reminder: Please update docstring in `LoadConfig`
27+
# if a new load format is added here
28+
LoadFormats = Literal[
29+
"auto",
30+
"bitsandbytes",
31+
"dummy",
32+
"fastsafetensors",
33+
"gguf",
34+
"mistral",
35+
"npcache",
36+
"pt",
37+
"runai_streamer",
38+
"runai_streamer_sharded",
39+
"safetensors",
40+
"sharded_state",
41+
"tensorizer",
42+
]
43+
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
44+
"auto": DefaultModelLoader,
45+
"bitsandbytes": BitsAndBytesModelLoader,
46+
"dummy": DummyModelLoader,
47+
"fastsafetensors": DefaultModelLoader,
48+
"gguf": GGUFModelLoader,
49+
"mistral": DefaultModelLoader,
50+
"npcache": DefaultModelLoader,
51+
"pt": DefaultModelLoader,
52+
"runai_streamer": RunaiModelStreamerLoader,
53+
"runai_streamer_sharded": ShardedStateLoader,
54+
"safetensors": DefaultModelLoader,
55+
"sharded_state": ShardedStateLoader,
56+
"tensorizer": TensorizerLoader,
57+
}
58+
59+
60+
def register_model_loader(load_format: str):
61+
"""Register a customized vllm model loader.
62+
63+
When a load format is not supported by vllm, you can register a customized
64+
model loader to support it.
65+
66+
Args:
67+
load_format (str): The model loader format name.
68+
69+
Examples:
70+
>>> from vllm.config import LoadConfig
71+
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
72+
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
73+
>>>
74+
>>> @register_model_loader("my_loader")
75+
... class MyModelLoader(BaseModelLoader):
76+
... def download_model(self):
77+
... pass
78+
...
79+
... def load_weights(self):
80+
... pass
81+
>>>
82+
>>> load_config = LoadConfig(load_format="my_loader")
83+
>>> type(get_model_loader(load_config))
84+
<class 'MyModelLoader'>
85+
""" # noqa: E501
86+
87+
def _wrapper(model_loader_cls):
88+
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
89+
logger.warning(
90+
"Load format `%s` is already registered, and will be "
91+
"overwritten by the new loader class `%s`.", load_format,
92+
model_loader_cls)
93+
if not issubclass(model_loader_cls, BaseModelLoader):
94+
raise ValueError("The model loader must be a subclass of "
95+
"`BaseModelLoader`.")
96+
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
97+
logger.info("Registered model loader `%s` with load format `%s`",
98+
model_loader_cls, load_format)
99+
return model_loader_cls
100+
101+
return _wrapper
102+
23103

24104
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
25105
"""Get a model loader based on the load format."""
26-
if isinstance(load_config.load_format, type):
27-
return load_config.load_format(load_config)
28-
29-
if load_config.load_format == LoadFormat.DUMMY:
30-
return DummyModelLoader(load_config)
31-
32-
if load_config.load_format == LoadFormat.TENSORIZER:
33-
return TensorizerLoader(load_config)
34-
35-
if load_config.load_format == LoadFormat.SHARDED_STATE:
36-
return ShardedStateLoader(load_config)
37-
38-
if load_config.load_format == LoadFormat.BITSANDBYTES:
39-
return BitsAndBytesModelLoader(load_config)
40-
41-
if load_config.load_format == LoadFormat.GGUF:
42-
return GGUFModelLoader(load_config)
43-
44-
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
45-
return RunaiModelStreamerLoader(load_config)
46-
47-
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
48-
return ShardedStateLoader(load_config, runai_model_streamer=True)
49-
50-
return DefaultModelLoader(load_config)
106+
load_format = load_config.load_format
107+
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
108+
raise ValueError(f"Load format `{load_format}` is not supported")
109+
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
51110

52111

53112
def get_model(*,
@@ -66,6 +125,7 @@ def get_model(*,
66125
"get_architecture_class_name",
67126
"get_model_architecture",
68127
"get_model_cls",
128+
"register_model_loader",
69129
"BaseModelLoader",
70130
"BitsAndBytesModelLoader",
71131
"GGUFModelLoader",

0 commit comments

Comments
 (0)