Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/fastsafetensors_loader/test_fastsafetensors_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import SamplingParams
from vllm.config import LoadFormat

test_model = "openai-community/gpt2"

Expand All @@ -17,7 +16,6 @@


def test_model_loader_download_files(vllm_runner):
with vllm_runner(test_model,
load_format=LoadFormat.FASTSAFETENSORS) as llm:
with vllm_runner(test_model, load_format="fastsafetensors") as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs
43 changes: 43 additions & 0 deletions tests/model_executor/model_loader/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
from torch import nn

from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.registry import ModelLoaderRegistry


class TestModelLoader(BaseModelLoader):

def __init__(self, load_config: LoadConfig) -> None:
super().__init__(load_config)

def download_model(self, model_config: ModelConfig) -> None:
pass

def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
pass


@pytest.mark.parametrize(
"load_format, loader_cls",
[
("test_load_format",
"tests.model_executor.model_loader.test_registry:TestModelLoader"),
("test_load_format", TestModelLoader),
# Overwrite existing loader
("auto", TestModelLoader),
])
def test_customized_model_loader(load_format, loader_cls):
ModelLoaderRegistry.register(
load_format=load_format,
loader_cls=loader_cls,
)
test_load_config = LoadConfig(load_format=load_format)
model_loader = get_model_loader(test_load_config)
assert type(model_loader).__name__ == TestModelLoader.__name__
assert load_format in ModelLoaderRegistry.get_supported_load_formats()
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import SamplingParams
from vllm.config import LoadConfig, LoadFormat
from vllm.config import LoadConfig
from vllm.model_executor.model_loader import get_model_loader

load_format = "runai_streamer"
test_model = "openai-community/gpt2"

prompts = [
Expand All @@ -18,7 +19,7 @@


def get_runai_model_loader():
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
load_config = LoadConfig(load_format=load_format)
return get_model_loader(load_config)


Expand All @@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():


def test_runai_model_loader_download_files(vllm_runner):
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
with vllm_runner(test_model, load_format=load_format) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs
32 changes: 10 additions & 22 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,29 +1730,12 @@ def verify_with_parallel_config(
logger.warning("Possibly too large swap space. %s", msg)


class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
FASTSAFETENSORS = "fastsafetensors"


@config
@dataclass
class LoadConfig:
"""Configuration for loading the model weights."""

load_format: Union[str, LoadFormat,
"BaseModelLoader"] = LoadFormat.AUTO.value
load_format: str = "auto"
"""The format of the model weights to load:\n
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.\n
Expand All @@ -1773,7 +1756,8 @@ class LoadConfig:
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
- "mistral" will load weights from consolidated safetensors files used by
Mistral models."""
Mistral models.
- Other custom values can be supported via plugins."""
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
Expand Down Expand Up @@ -1818,9 +1802,13 @@ def compute_hash(self) -> str:
return hash_str

def __post_init__(self):
if isinstance(self.load_format, str):
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
self.load_format = self.load_format.lower()
from vllm.model_executor.model_loader.registry import (
ModelLoaderRegistry)

assert self.load_format in \
ModelLoaderRegistry.get_supported_load_formats(), \
f"Load format `{self.load_format}` is not supported"

if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
Expand Down
15 changes: 6 additions & 9 deletions vllm/engine/arg_utils.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a follow up PR with my suggestion to parse Union[str, Literal[...] into add_argument(..., metavar=",".join(get_args(literal)

Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, GuidedDecodingBackend,
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
KVTransferConfig, LoadConfig, LoRAConfig, ModelConfig,
ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
Expand Down Expand Up @@ -545,9 +545,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
title="LoadConfig",
description=LoadConfig.__doc__,
)
load_group.add_argument("--load-format",
choices=[f.value for f in LoadFormat],
**load_kwargs["load_format"])
load_group.add_argument("--load-format", **load_kwargs["load_format"])
load_group.add_argument("--download-dir",
**load_kwargs["download_dir"])
load_group.add_argument("--model-loader-extra-config",
Expand Down Expand Up @@ -886,10 +884,9 @@ def create_model_config(self) -> ModelConfig:

# NOTE: This is to allow model loading from S3 in CI
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
and self.model in MODELS_ON_S3
and self.load_format == LoadFormat.AUTO): # noqa: E501
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
self.load_format = LoadFormat.RUNAI_STREAMER
self.load_format = "runai_streamer"

return ModelConfig(
model=self.model,
Expand Down Expand Up @@ -1297,7 +1294,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
#############################################################
# Unsupported Feature Flags on V1.

if self.load_format == LoadFormat.SHARDED_STATE.value:
if self.load_format == "sharded_state":
_raise_or_fallback(
feature_name=f"--load_format {self.load_format}",
recommend_to_remove=False)
Expand Down
29 changes: 3 additions & 26 deletions vllm/model_executor/model_loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from torch import nn

from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader)
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
from vllm.model_executor.model_loader.registry import ModelLoaderRegistry
from vllm.model_executor.model_loader.runai_streamer_loader import (
RunaiModelStreamerLoader)
from vllm.model_executor.model_loader.sharded_state_loader import (
Expand All @@ -23,31 +24,7 @@

def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)

if load_config.load_format == LoadFormat.DUMMY:
return DummyModelLoader(load_config)

if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)

if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)

if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)

if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)

if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)

if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)

return DefaultModelLoader(load_config)
return ModelLoaderRegistry.get_model_loader(load_config)


def get_model(*,
Expand Down
18 changes: 9 additions & 9 deletions vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm import envs
from vllm.config import LoadConfig, LoadFormat, ModelConfig
from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
Expand Down Expand Up @@ -104,19 +104,19 @@ def _prepare_weights(
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif (load_format == LoadFormat.SAFETENSORS
or load_format == LoadFormat.FASTSAFETENSORS):
elif (load_format == "safetensors"
or load_format == "fastsafetensors"):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
elif load_format == "mistral":
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT:
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
Expand Down Expand Up @@ -178,7 +178,7 @@ def _get_weights_iterator(
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
if self.load_config.load_format == LoadFormat.NPCACHE:
if self.load_config.load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
Expand All @@ -189,7 +189,7 @@ def _get_weights_iterator(
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
Expand Down
98 changes: 98 additions & 0 deletions vllm/model_executor/model_loader/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import importlib
from collections.abc import Set
from typing import Union

from vllm.config import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader)
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
from vllm.model_executor.model_loader.runai_streamer_loader import (
RunaiModelStreamerLoader)
from vllm.model_executor.model_loader.sharded_state_loader import (
ShardedStateLoader)
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader

logger = init_logger(__name__)

DEFAULT_MODEL_LOADERS = {
"auto": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,
"gguf": GGUFModelLoader,
"mistral": DefaultModelLoader,
"npcache": DefaultModelLoader,
"pt": DefaultModelLoader,
"runai_streamer": RunaiModelStreamerLoader,
"runai_streamer_sharded": ShardedStateLoader,
"safetensors": DefaultModelLoader,
"sharded_state": ShardedStateLoader,
"tensorizer": TensorizerLoader,
}


class ModelLoaderRegistry:
_model_loaders: dict[str, type[BaseModelLoader]] = DEFAULT_MODEL_LOADERS

@classmethod
def get_supported_load_formats(cls) -> Set[str]:
return cls._model_loaders.keys()

@classmethod
def get_model_loader(cls, load_config: LoadConfig) -> BaseModelLoader:
load_format = load_config.load_format
if load_format not in cls._model_loaders:
raise ValueError(f"load_format: {load_format} is not supported")
return cls._model_loaders[load_format](load_config)

@classmethod
def register(
cls,
load_format: str,
loader_cls: Union[BaseModelLoader, str],
) -> None:
"""
Register an external model loader to be used in vLLM.

`loader_cls` can be either:

- A class derived from `BaseModelLoader`
- A string in the format `<module>:<class>` which can be used to
lazily import the model loader.
"""
if not isinstance(load_format, str):
msg = f"`load_format` should be a string, not a {type(load_format)}"
raise TypeError(msg)

if load_format in cls._model_loaders:
logger.warning(
"Load format %s is already registered, and will be "
"overwritten by the new loader class %s.", load_format,
loader_cls)

if isinstance(loader_cls, str):
split_str = loader_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
module_name, class_name = split_str
module = importlib.import_module(module_name)
model_loader_cls = getattr(module, class_name)
elif isinstance(loader_cls, type) and issubclass(
loader_cls, BaseModelLoader):
model_loader_cls = loader_cls
else:
msg = ("`model_cls` should be a string or `BaseModelLoader`, "
f"not a {type(model_loader_cls)}")

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]

Check failure on line 93 in vllm/model_executor/model_loader/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "model_loader_cls" is used before definition [used-before-def]
raise TypeError(msg)

cls._model_loaders[load_format] = model_loader_cls
logger.info("Registered `%s` with load format `%s`", model_loader_cls,
load_format)
Loading