15
15
from transformers .utils import logging
16
16
17
17
from vllm .config import VllmConfig
18
- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
18
+ from vllm .distributed import get_pp_group
19
19
from vllm .inputs import (INPUT_REGISTRY , DecoderOnlyInputs , DummyData ,
20
20
InputContext )
21
21
from vllm .inputs .data import TokenInputs , token_inputs
34
34
35
35
from .interfaces import SupportsLoRA , SupportsMultiModal
36
36
from .phi4mm_audio import AudioEmbedding
37
- from .utils import maybe_prefix
37
+ from .utils import AutoWeightsLoader , WeightsMapper , maybe_prefix
38
38
from .vision_siglip_navit import get_siglip_vision_model
39
39
40
40
# <|endoftext10|> (see vocab.json in hf model)
@@ -352,12 +352,6 @@ def __init__(self,
352
352
# n_embed or hidden_size
353
353
hidden_size = config .n_embd if hasattr (
354
354
config , 'n_embd' ) else config .hidden_size
355
- if hasattr (config , 'embd_pdrop' ) or hasattr (config , 'embed_pdrop' ):
356
- embd_drop = config .embd_pdrop if hasattr (
357
- config , 'embd_pdrop' ) else config .embed_pdrop
358
- self .drop = nn .Dropout (embd_drop )
359
- else :
360
- self .drop = None
361
355
362
356
# layer_idx to output the img features
363
357
if isinstance (config .img_processor , dict ):
@@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
1431
1425
],
1432
1426
}
1433
1427
1428
+ hf_to_vllm_mapper = WeightsMapper (
1429
+ orig_to_new_substr = {
1430
+ "base_layer." : "" ,
1431
+ },
1432
+ orig_to_new_prefix = {
1433
+ "model.embed_tokens_extend.audio_embed.audio_projection.vision." :
1434
+ "embed_tokens_extend.audio_projection_for_vision." ,
1435
+ "model.embed_tokens_extend.audio_embed.audio_projection.speech." :
1436
+ "embed_tokens_extend.audio_projection." ,
1437
+ "model.embed_tokens_extend.audio_embed." : "embed_tokens_extend." ,
1438
+ "model.embed_tokens_extend.image_embed." : "vision_encoder." ,
1439
+ },
1440
+ )
1441
+
1434
1442
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
1435
1443
super ().__init__ ()
1436
1444
config = vllm_config .model_config .hf_config
@@ -1445,8 +1453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1445
1453
self .lora_config = lora_config
1446
1454
1447
1455
# Tensor/Pipeline parallel not supported for now.
1448
- assert get_tensor_model_parallel_world_size (
1449
- ) == 1 , "tensor parallel is not supported"
1450
1456
assert get_pp_group (
1451
1457
).world_size == 1 , "pipeline parallel is not supported"
1452
1458
@@ -1686,44 +1692,6 @@ def merge_image_features_to_inputs_embeds(
1686
1692
)
1687
1693
return merged_embeds
1688
1694
1689
- def load_weights (self , weights : Iterable [Tuple [str ,
1690
- torch .Tensor ]]) -> None :
1691
- weights = {name : weight for name , weight in weights }
1692
- adjusted_weights = {}
1693
-
1694
- for name , weight in weights .items ():
1695
- # NOTE vision-speech tasks use a separate projection layer
1696
- audio_proj_4v = \
1697
- "model.embed_tokens_extend.audio_embed.audio_projection.vision"
1698
- if name .startswith (audio_proj_4v ):
1699
- name = name .replace (
1700
- audio_proj_4v ,
1701
- "embed_tokens_extend.audio_projection_for_vision" )
1702
-
1703
- name = (name .replace (
1704
- "model.embed_tokens_extend.audio_embed." \
1705
- "audio_projection.speech." ,
1706
- "embed_tokens_extend.audio_projection." ,
1707
- ).replace (
1708
- "model.embed_tokens_extend.audio_embed." ,
1709
- "embed_tokens_extend." ,
1710
- ).replace ("model.embed_tokens_extend.image_embed." ,
1711
- "vision_encoder." ))
1712
- # NOTE: this is deal with LoRA injection, where `base_layer`
1713
- # remains as the original layer in the model
1714
- if name .endswith (".base_layer.weight" ):
1715
- name = name .replace (".base_layer.weight" , ".weight" )
1716
- adjusted_weights [name ] = weight
1717
-
1718
- missing_keys , unexpected_keys = self .load_state_dict (adjusted_weights ,
1719
- strict = False )
1720
- logger .debug ("*** missing keys:" )
1721
- for key in missing_keys :
1722
- logger .debug (key )
1723
- logger .debug ("**** unexpected keys:" )
1724
- for key in unexpected_keys :
1725
- logger .debug (key )
1726
-
1727
1695
def forward (
1728
1696
self ,
1729
1697
input_ids : torch .Tensor ,
@@ -1796,6 +1764,13 @@ def sample(
1796
1764
next_tokens = self .sampler (logits , sampling_metadata )
1797
1765
return next_tokens
1798
1766
1767
+ def load_weights (self , weights : Iterable [Tuple [str ,
1768
+ torch .Tensor ]]) -> None :
1769
+ weights = ((name , data ) for name , data in weights
1770
+ if "lora" not in name )
1771
+ loader = AutoWeightsLoader (self )
1772
+ return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
1773
+
1799
1774
def get_mm_mapping (self ) -> MultiModelKeys :
1800
1775
"""
1801
1776
Get the module prefix in multimodal models
@@ -1804,4 +1779,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
1804
1779
language_model = "model." ,
1805
1780
connector = ["audio_projection_for_vision" , "audio_projection" ],
1806
1781
tower_model = ["vision_encoder" , "embed_tokens_extend" ],
1807
- )
1782
+ )
0 commit comments