Skip to content

Commit 7902bba

Browse files
DarkLight1337LeiWang1999
authored andcommitted
[Bugfix] fix composite weight loading and EAGLE weight loading (vllm-project#9160)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent ad24380 commit 7902bba

File tree

15 files changed

+244
-364
lines changed

15 files changed

+244
-364
lines changed

vllm/model_executor/models/blip2.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
from vllm.model_executor.layers.activation import get_act_fn
1414
from vllm.model_executor.layers.quantization import QuantizationConfig
1515
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
16-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1716
from vllm.model_executor.sampling_metadata import SamplingMetadata
1817
from vllm.multimodal import MULTIMODAL_REGISTRY
1918
from vllm.sequence import IntermediateTensors, SequenceData
2019

2120
from .blip import (BlipVisionModel, dummy_image_for_blip,
2221
get_max_blip_image_tokens)
2322
from .interfaces import SupportsMultiModal, SupportsPP
24-
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
23+
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
2524
merge_multimodal_embeddings)
2625

2726
# We use this internally as placeholders since there is no image token
@@ -687,35 +686,5 @@ def sample(
687686
return self.language_model.sample(logits, sampling_metadata)
688687

689688
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
690-
# prepare weight iterators for components
691-
weights_group = group_weights_with_prefix(weights)
692-
693-
# load vision encoder
694-
self.vision_model.load_weights(weights_group["vision_model"])
695-
696-
# load query tokens
697-
for name, loaded_weight in weights_group["query_tokens"]:
698-
assert name == ""
699-
param = self.query_tokens
700-
weight_loader = getattr(param, "weight_loader",
701-
default_weight_loader)
702-
weight_loader(param, loaded_weight)
703-
704-
# load qformer
705-
qformer_params_dict = dict(self.qformer.named_parameters())
706-
for name, loaded_weight in weights_group["qformer"]:
707-
param = qformer_params_dict[name]
708-
weight_loader = getattr(param, "weight_loader",
709-
default_weight_loader)
710-
weight_loader(param, loaded_weight)
711-
712-
# load mlp projector
713-
mlp_params_dict = dict(self.language_projection.named_parameters())
714-
for name, loaded_weight in weights_group["language_projection"]:
715-
param = mlp_params_dict[name]
716-
weight_loader = getattr(param, "weight_loader",
717-
default_weight_loader)
718-
weight_loader(param, loaded_weight)
719-
720-
# load llm backbone
721-
self.language_model.load_weights(weights_group["language_model"])
689+
loader = AutoWeightsLoader(self)
690+
loader.load_weights(weights)

vllm/model_executor/models/fuyu.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from vllm.model_executor.layers.linear import ColumnParallelLinear
3232
from vllm.model_executor.layers.quantization import QuantizationConfig
3333
from vllm.model_executor.layers.sampler import SamplerOutput
34-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3534
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
3635
from vllm.model_executor.sampling_metadata import SamplingMetadata
3736
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -42,8 +41,7 @@
4241
SequenceData)
4342

4443
from .interfaces import SupportsMultiModal, SupportsPP
45-
from .utils import (flatten_bn, group_weights_with_prefix,
46-
merge_multimodal_embeddings)
44+
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
4745

4846
# Cannot find the following 2 numbers from hf config.
4947
_IMAGE_TOKEN_ID = 71011
@@ -349,16 +347,5 @@ def sample(
349347
return next_tokens
350348

351349
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
352-
# prepare weight iterators for components
353-
weights_group = group_weights_with_prefix(weights)
354-
355-
# load vision embeddings
356-
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
357-
for name, loaded_weight in weights_group["vision_embed_tokens"]:
358-
param = vision_params_dict[name]
359-
weight_loader = getattr(param, "weight_loader",
360-
default_weight_loader)
361-
weight_loader(param, loaded_weight)
362-
363-
# load llm backbone
364-
self.language_model.load_weights(weights_group["language_model"])
350+
loader = AutoWeightsLoader(self)
351+
loader.load_weights(weights)

vllm/model_executor/models/gemma2.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from vllm.sequence import IntermediateTensors
4141

4242
from .interfaces import SupportsLoRA, SupportsPP
43-
from .utils import (group_weights_with_prefix, is_pp_missing_parameter,
43+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4444
make_empty_intermediate_tensors_factory, make_layers)
4545

4646
logger = init_logger(__name__)
@@ -447,19 +447,9 @@ def sample(
447447
return next_tokens
448448

449449
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
450-
weights_group = group_weights_with_prefix(weights)
451-
452-
self.model.load_weights(weights_group["model"])
453-
454-
if not self.config.tie_word_embeddings:
455-
# NOTE: For now self.lm_head is not defined because
456-
# tie_word_embeddings is assumed to the False
457-
lm_head_dict = dict(self.lm_head.named_parameters())
458-
for name, loaded_weight in weights_group["lm_head"]:
459-
if is_pp_missing_parameter(name, self.lm_head):
460-
continue
461-
462-
param = lm_head_dict[name]
463-
weight_loader = getattr(param, "weight_loader",
464-
default_weight_loader)
465-
weight_loader(param, loaded_weight)
450+
loader = AutoWeightsLoader(
451+
self,
452+
skip_prefixes=(["lm_head."]
453+
if self.config.tie_word_embeddings else None),
454+
)
455+
loader.load_weights(weights)

vllm/model_executor/models/internvl.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
2121
from vllm.model_executor.layers.quantization import QuantizationConfig
2222
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
23-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2423
from vllm.model_executor.models.intern_vit import InternVisionModel
2524
from vllm.model_executor.sampling_metadata import SamplingMetadata
2625
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -32,8 +31,8 @@
3231
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
3332
get_clip_num_patches)
3433
from .interfaces import SupportsMultiModal, SupportsPP
35-
from .utils import (flatten_bn, group_weights_with_prefix,
36-
init_vllm_registered_model, merge_multimodal_embeddings)
34+
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
35+
merge_multimodal_embeddings)
3736

3837
IMG_START = '<img>'
3938
IMG_END = '</img>'
@@ -609,19 +608,5 @@ def sample(
609608
return self.language_model.sample(logits, sampling_metadata)
610609

611610
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
612-
# prepare weight iterators for components
613-
weights_group = group_weights_with_prefix(weights)
614-
615-
# load vision encoder
616-
self.vision_model.load_weights(weights_group["vision_model"])
617-
618-
# load mlp projector
619-
mlp_params_dict = dict(self.mlp1.named_parameters())
620-
for name, loaded_weight in weights_group["mlp1"]:
621-
param = mlp_params_dict[name]
622-
weight_loader = getattr(param, "weight_loader",
623-
default_weight_loader)
624-
weight_loader(param, loaded_weight)
625-
626-
# load llm backbone
627-
self.language_model.load_weights(weights_group["language_model"])
611+
loader = AutoWeightsLoader(self)
612+
loader.load_weights(weights)

vllm/model_executor/models/llama.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@
5151
from vllm.utils import is_hip
5252

5353
from .interfaces import SupportsLoRA, SupportsPP
54-
from .utils import (PPMissingLayer, group_weights_with_prefix,
55-
is_pp_missing_parameter,
54+
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
5655
make_empty_intermediate_tensors_factory, make_layers)
5756

5857

@@ -564,25 +563,14 @@ def sample(self, logits: torch.Tensor,
564563
return next_tokens
565564

566565
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
567-
weights = [
566+
loader = AutoWeightsLoader(
567+
self,
568+
skip_prefixes=(["lm_head."]
569+
if self.config.tie_word_embeddings else None),
570+
)
571+
loader.load_weights(
568572
self.maybe_remap_mistral(name, loaded_weight)
569-
for name, loaded_weight in weights
570-
]
571-
572-
weights_group = group_weights_with_prefix(weights)
573-
574-
self.model.load_weights(weights_group["model"])
575-
576-
if not self.config.tie_word_embeddings:
577-
lm_head_dict = dict(self.lm_head.named_parameters())
578-
for name, loaded_weight in weights_group["lm_head"]:
579-
if is_pp_missing_parameter(name, self.lm_head):
580-
continue
581-
582-
param = lm_head_dict[name]
583-
weight_loader = getattr(param, "weight_loader",
584-
default_weight_loader)
585-
weight_loader(param, loaded_weight)
573+
for name, loaded_weight in weights)
586574

587575
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
588576
self.model.load_kv_cache_scales(quantization_param_path)

vllm/model_executor/models/llava.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from vllm.model_executor.layers.activation import get_act_fn
1414
from vllm.model_executor.layers.quantization import QuantizationConfig
1515
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
16-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1716
from vllm.model_executor.sampling_metadata import SamplingMetadata
1817
from vllm.multimodal import MULTIMODAL_REGISTRY
1918
from vllm.sequence import IntermediateTensors
@@ -26,8 +25,8 @@
2625
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
2726
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
2827
input_processor_for_siglip)
29-
from .utils import (flatten_bn, group_weights_with_prefix,
30-
init_vllm_registered_model, merge_multimodal_embeddings)
28+
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
29+
merge_multimodal_embeddings)
3130

3231

3332
class LlavaImagePixelInputs(TypedDict):
@@ -406,19 +405,5 @@ def sample(
406405
return self.language_model.sample(logits, sampling_metadata)
407406

408407
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
409-
# prepare weight iterators for components
410-
weights_group = group_weights_with_prefix(weights)
411-
412-
# load vision encoder
413-
self.vision_tower.load_weights(weights_group["vision_tower"])
414-
415-
# load mlp projector
416-
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
417-
for name, loaded_weight in weights_group["multi_modal_projector"]:
418-
param = mlp_params_dict[name]
419-
weight_loader = getattr(param, "weight_loader",
420-
default_weight_loader)
421-
weight_loader(param, loaded_weight)
422-
423-
# load llm backbone
424-
self.language_model.load_weights(weights_group["language_model"])
408+
loader = AutoWeightsLoader(self)
409+
loader.load_weights(weights)

vllm/model_executor/models/llava_next.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
1616
from vllm.model_executor.layers.quantization import QuantizationConfig
1717
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
18-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1918
from vllm.model_executor.sampling_metadata import SamplingMetadata
2019
from vllm.multimodal import MULTIMODAL_REGISTRY
2120
from vllm.sequence import IntermediateTensors
@@ -29,8 +28,8 @@
2928
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
3029
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
3130
get_siglip_patch_grid_length, input_processor_for_siglip)
32-
from .utils import (flatten_bn, group_weights_with_prefix,
33-
init_vllm_registered_model, merge_multimodal_embeddings)
31+
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
32+
merge_multimodal_embeddings)
3433

3534
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
3635
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
@@ -642,27 +641,5 @@ def sample(
642641
return self.language_model.sample(logits, sampling_metadata)
643642

644643
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
645-
# prepare weight iterators for components
646-
weights_group = group_weights_with_prefix(weights)
647-
648-
# load vision encoder
649-
self.vision_tower.load_weights(weights_group["vision_tower"])
650-
651-
# load mlp projector
652-
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
653-
for name, loaded_weight in weights_group["multi_modal_projector"]:
654-
param = mlp_params_dict[name]
655-
weight_loader = getattr(param, "weight_loader",
656-
default_weight_loader)
657-
weight_loader(param, loaded_weight)
658-
659-
# load newline
660-
for name, loaded_weight in weights_group["image_newline"]:
661-
assert name == ""
662-
param = self.image_newline
663-
weight_loader = getattr(param, "weight_loader",
664-
default_weight_loader)
665-
weight_loader(param, loaded_weight)
666-
667-
# load llm backbone
668-
self.language_model.load_weights(weights_group["language_model"])
644+
loader = AutoWeightsLoader(self)
645+
loader.load_weights(weights)

vllm/model_executor/models/llava_next_video.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from vllm.model_executor.layers.activation import get_act_fn
1616
from vllm.model_executor.layers.quantization import QuantizationConfig
1717
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
18-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1918
from vllm.model_executor.models.clip import CLIPVisionModel
2019
from vllm.model_executor.sampling_metadata import SamplingMetadata
2120
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -28,7 +27,7 @@
2827
from .interfaces import SupportsMultiModal, SupportsPP
2928
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
3029
dummy_seq_data_for_siglip)
31-
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
30+
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
3231
merge_multimodal_embeddings)
3332

3433
# For profile run
@@ -458,19 +457,9 @@ def sample(
458457
return self.language_model.sample(logits, sampling_metadata)
459458

460459
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
461-
# prepare weight iterators for components
462-
weights_group = group_weights_with_prefix(weights)
463-
464-
# load vision encoder
465-
self.vision_tower.load_weights(weights_group["vision_tower"])
466-
467-
# load mlp projector
468-
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
469-
for name, loaded_weight in weights_group["multi_modal_projector"]:
470-
param = mlp_params_dict[name]
471-
weight_loader = getattr(param, "weight_loader",
472-
default_weight_loader)
473-
weight_loader(param, loaded_weight)
474-
475-
# load llm backbone
476-
self.language_model.load_weights(weights_group["language_model"])
460+
loader = AutoWeightsLoader(
461+
self,
462+
# This model doesn't support images for now
463+
ignore_unexpected_prefixes=["image_newline"],
464+
)
465+
loader.load_weights(weights)

vllm/model_executor/models/llava_onevision.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from vllm.model_executor.layers.activation import get_act_fn
2121
from vllm.model_executor.layers.quantization import QuantizationConfig
2222
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
23-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2423
from vllm.model_executor.sampling_metadata import SamplingMetadata
2524
from vllm.multimodal import MULTIMODAL_REGISTRY
2625
from vllm.multimodal.utils import (cached_get_tokenizer,
@@ -35,8 +34,8 @@
3534
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
3635
dummy_video_for_siglip, get_siglip_image_feature_size,
3736
get_siglip_patch_grid_length, input_processor_for_siglip)
38-
from .utils import (flatten_bn, group_weights_with_prefix,
39-
init_vllm_registered_model, merge_multimodal_embeddings)
37+
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
38+
merge_multimodal_embeddings)
4039

4140
logger = init_logger(__name__)
4241

@@ -872,19 +871,5 @@ def sample(
872871
return self.language_model.sample(logits, sampling_metadata)
873872

874873
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
875-
# prepare weight iterators for components
876-
weights_group = group_weights_with_prefix(weights)
877-
878-
# load vision encoder
879-
self.vision_tower.load_weights(weights_group["vision_tower"])
880-
881-
# load mlp projector
882-
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
883-
for name, loaded_weight in weights_group["multi_modal_projector"]:
884-
param = mlp_params_dict[name]
885-
weight_loader = getattr(param, "weight_loader",
886-
default_weight_loader)
887-
weight_loader(param, loaded_weight)
888-
889-
# load llm backbone
890-
self.language_model.load_weights(weights_group["language_model"])
874+
loader = AutoWeightsLoader(self)
875+
loader.load_weights(weights)

0 commit comments

Comments
 (0)