|
13 | 13 | from vllm.model_executor.layers.activation import get_act_fn
|
14 | 14 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
15 | 15 | from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
16 |
| -from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
17 | 16 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
18 | 17 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
19 | 18 | from vllm.sequence import IntermediateTensors, SequenceData
|
20 | 19 |
|
21 | 20 | from .blip import (BlipVisionModel, dummy_image_for_blip,
|
22 | 21 | get_max_blip_image_tokens)
|
23 | 22 | from .interfaces import SupportsMultiModal, SupportsPP
|
24 |
| -from .utils import (group_weights_with_prefix, init_vllm_registered_model, |
| 23 | +from .utils import (AutoWeightsLoader, init_vllm_registered_model, |
25 | 24 | merge_multimodal_embeddings)
|
26 | 25 |
|
27 | 26 | # We use this internally as placeholders since there is no image token
|
@@ -687,35 +686,5 @@ def sample(
|
687 | 686 | return self.language_model.sample(logits, sampling_metadata)
|
688 | 687 |
|
689 | 688 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
690 |
| - # prepare weight iterators for components |
691 |
| - weights_group = group_weights_with_prefix(weights) |
692 |
| - |
693 |
| - # load vision encoder |
694 |
| - self.vision_model.load_weights(weights_group["vision_model"]) |
695 |
| - |
696 |
| - # load query tokens |
697 |
| - for name, loaded_weight in weights_group["query_tokens"]: |
698 |
| - assert name == "" |
699 |
| - param = self.query_tokens |
700 |
| - weight_loader = getattr(param, "weight_loader", |
701 |
| - default_weight_loader) |
702 |
| - weight_loader(param, loaded_weight) |
703 |
| - |
704 |
| - # load qformer |
705 |
| - qformer_params_dict = dict(self.qformer.named_parameters()) |
706 |
| - for name, loaded_weight in weights_group["qformer"]: |
707 |
| - param = qformer_params_dict[name] |
708 |
| - weight_loader = getattr(param, "weight_loader", |
709 |
| - default_weight_loader) |
710 |
| - weight_loader(param, loaded_weight) |
711 |
| - |
712 |
| - # load mlp projector |
713 |
| - mlp_params_dict = dict(self.language_projection.named_parameters()) |
714 |
| - for name, loaded_weight in weights_group["language_projection"]: |
715 |
| - param = mlp_params_dict[name] |
716 |
| - weight_loader = getattr(param, "weight_loader", |
717 |
| - default_weight_loader) |
718 |
| - weight_loader(param, loaded_weight) |
719 |
| - |
720 |
| - # load llm backbone |
721 |
| - self.language_model.load_weights(weights_group["language_model"]) |
| 689 | + loader = AutoWeightsLoader(self) |
| 690 | + loader.load_weights(weights) |
0 commit comments