Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# ruff: noqa: SIM117
import collections
import copy
import dataclasses
import fnmatch
import glob
import json
import math
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
Type, cast)

import gguf
import huggingface_hub
Expand Down Expand Up @@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig,
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""

@dataclasses.dataclass
class Source:
"""A source for weights."""

model_or_path: str
"""The model ID or path."""

revision: Optional[str]
"""The optional model revision."""

prefix: str = ""
"""A prefix to prepend to all weights."""

fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
Expand Down Expand Up @@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str,
return hf_folder, hf_weights_files, use_safetensors

def _get_weights_iterator(
self, model_name_or_path: str, revision: Optional[str],
fall_back_to_pt: bool
self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision, fall_back_to_pt)
source.model_or_path, source.revision, source.fall_back_to_pt)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
model_name_or_path, self.load_config.download_dir, hf_folder,
source.model_or_path, self.load_config.download_dir, hf_folder,
hf_weights_files)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
Expand All @@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator):
xm.mark_step()

weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator

# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)

def _get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:

primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True))
yield from self._get_weights_iterator(primary_weights)

secondary_weights = cast(Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()))
for source in secondary_weights:
yield from self._get_weights_iterator(source)

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
Expand All @@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig,
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )

model.load_weights(self._get_all_weights(model_config, model))

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
Expand Down
30 changes: 26 additions & 4 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (flatten_bn,
Expand Down Expand Up @@ -334,14 +335,23 @@ def __init__(self,
self.multi_modal_config = multimodal_config
assert self.multi_modal_config

self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None:
self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id)
else:
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
self.secondary_weights.append(
DefaultModelLoader.Source(
model_or_path=config.audio_model_id,
revision=None,
prefix="audio_tower.",
))
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
if config.text_model_id is not None:
self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id,
revision=None,
prefix="language_model."))

def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -466,6 +476,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)

# load audio tower weights
audio_tower_weights = weights_group["audio_tower"]
audio_tower_params_dict = dict(
self.audio_tower.named_parameters(
prefix=self.audio_tower.base_model_prefix))
for name, loaded_weight in audio_tower_weights:
if name in audio_tower_params_dict:
param = audio_tower_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load projector weights
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
Expand Down
Loading