-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
[Core] Support model loader plugins #21067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
5f0ac47
7d05267
7a74a1a
1379e3e
7660c9d
ea1cd1c
35f7354
97bd40d
479d912
a81d13d
37cd547
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import pytest | ||
from torch import nn | ||
|
||
from vllm.config import LoadConfig, ModelConfig | ||
from vllm.model_executor.model_loader import get_model_loader | ||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader | ||
from vllm.model_executor.model_loader.registry import ModelLoaderRegistry | ||
|
||
|
||
class TestModelLoader(BaseModelLoader): | ||
|
||
def __init__(self, load_config: LoadConfig) -> None: | ||
super().__init__(load_config) | ||
|
||
def download_model(self, model_config: ModelConfig) -> None: | ||
pass | ||
|
||
def load_weights(self, model: nn.Module, | ||
model_config: ModelConfig) -> None: | ||
pass | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"load_format, loader_cls", | ||
[ | ||
("test_load_format", | ||
"tests.model_executor.model_loader.test_registry:TestModelLoader"), | ||
("test_load_format", TestModelLoader), | ||
# Overwrite existing loader | ||
("auto", TestModelLoader), | ||
]) | ||
def test_customized_model_loader(load_format, loader_cls): | ||
ModelLoaderRegistry.register( | ||
load_format=load_format, | ||
loader_cls=loader_cls, | ||
) | ||
test_load_config = LoadConfig(load_format=load_format) | ||
model_loader = get_model_loader(test_load_config) | ||
assert type(model_loader).__name__ == TestModelLoader.__name__ | ||
assert load_format in ModelLoaderRegistry.get_supported_load_formats() |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make a follow up PR with my suggestion to parse |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import importlib | ||
from collections.abc import Set | ||
from typing import Union | ||
|
||
from vllm.config import LoadConfig | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader | ||
from vllm.model_executor.model_loader.bitsandbytes_loader import ( | ||
BitsAndBytesModelLoader) | ||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader | ||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader | ||
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader | ||
from vllm.model_executor.model_loader.runai_streamer_loader import ( | ||
RunaiModelStreamerLoader) | ||
from vllm.model_executor.model_loader.sharded_state_loader import ( | ||
ShardedStateLoader) | ||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader | ||
|
||
logger = init_logger(__name__) | ||
|
||
DEFAULT_MODEL_LOADERS = { | ||
"auto": DefaultModelLoader, | ||
"bitsandbytes": BitsAndBytesModelLoader, | ||
"dummy": DummyModelLoader, | ||
"fastsafetensors": DefaultModelLoader, | ||
"gguf": GGUFModelLoader, | ||
"mistral": DefaultModelLoader, | ||
"npcache": DefaultModelLoader, | ||
"pt": DefaultModelLoader, | ||
"runai_streamer": RunaiModelStreamerLoader, | ||
"runai_streamer_sharded": ShardedStateLoader, | ||
"safetensors": DefaultModelLoader, | ||
"sharded_state": ShardedStateLoader, | ||
"tensorizer": TensorizerLoader, | ||
} | ||
|
||
|
||
class ModelLoaderRegistry: | ||
_model_loaders: dict[str, type[BaseModelLoader]] = DEFAULT_MODEL_LOADERS | ||
|
||
@classmethod | ||
def get_supported_load_formats(cls) -> Set[str]: | ||
return cls._model_loaders.keys() | ||
|
||
@classmethod | ||
def get_model_loader(cls, load_config: LoadConfig) -> BaseModelLoader: | ||
load_format = load_config.load_format | ||
if load_format not in cls._model_loaders: | ||
raise ValueError(f"load_format: {load_format} is not supported") | ||
return cls._model_loaders[load_format](load_config) | ||
|
||
@classmethod | ||
def register( | ||
cls, | ||
load_format: str, | ||
loader_cls: Union[BaseModelLoader, str], | ||
) -> None: | ||
""" | ||
Register an external model loader to be used in vLLM. | ||
|
||
`loader_cls` can be either: | ||
|
||
- A class derived from `BaseModelLoader` | ||
- A string in the format `<module>:<class>` which can be used to | ||
lazily import the model loader. | ||
""" | ||
if not isinstance(load_format, str): | ||
msg = f"`load_format` should be a string, not a {type(load_format)}" | ||
raise TypeError(msg) | ||
|
||
if load_format in cls._model_loaders: | ||
logger.warning( | ||
"Load format %s is already registered, and will be " | ||
"overwritten by the new loader class %s.", load_format, | ||
loader_cls) | ||
|
||
if isinstance(loader_cls, str): | ||
split_str = loader_cls.split(":") | ||
if len(split_str) != 2: | ||
msg = "Expected a string in the format `<module>:<class>`" | ||
raise ValueError(msg) | ||
module_name, class_name = split_str | ||
module = importlib.import_module(module_name) | ||
model_loader_cls = getattr(module, class_name) | ||
22quinn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
elif isinstance(loader_cls, type) and issubclass( | ||
loader_cls, BaseModelLoader): | ||
model_loader_cls = loader_cls | ||
else: | ||
msg = ("`model_cls` should be a string or `BaseModelLoader`, " | ||
f"not a {type(model_loader_cls)}") | ||
Check failure on line 93 in vllm/model_executor/model_loader/registry.py
|
||
22quinn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
raise TypeError(msg) | ||
|
||
cls._model_loaders[load_format] = model_loader_cls | ||
logger.info("Registered `%s` with load format `%s`", model_loader_cls, | ||
load_format) |
Uh oh!
There was an error while loading. Please reload this page.