Skip to content

Commit 9c7698a

Browse files
Isotr0pyDamonFool
authored andcommitted
[VLM] Add TP support for Phi-4-MM (vllm-project#14453)
Signed-off-by: Isotr0py <[email protected]>
1 parent 8b8e7ad commit 9c7698a

File tree

4 files changed

+50
-295
lines changed

4 files changed

+50
-295
lines changed

examples/offline_inference/audio_language.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def run_phi4mm(questions: str, audio_count: int):
7777
enable_lora=True,
7878
max_lora_rank=320,
7979
lora_extra_vocab_size=0,
80+
limit_mm_per_prompt={"audio": audio_count},
8081
)
8182
lora_request = LoRARequest("speech", 1, speech_lora_path)
8283
# To maintain code compatibility in this script, we add LoRA here.

vllm/model_executor/models/phi4mm.py

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from transformers.utils import logging
1616

1717
from vllm.config import VllmConfig
18-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
18+
from vllm.distributed import get_pp_group
1919
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
2020
InputContext)
2121
from vllm.inputs.data import TokenInputs, token_inputs
@@ -34,7 +34,7 @@
3434

3535
from .interfaces import SupportsLoRA, SupportsMultiModal
3636
from .phi4mm_audio import AudioEmbedding
37-
from .utils import maybe_prefix
37+
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
3838
from .vision_siglip_navit import get_siglip_vision_model
3939

4040
# <|endoftext10|> (see vocab.json in hf model)
@@ -352,12 +352,6 @@ def __init__(self,
352352
# n_embed or hidden_size
353353
hidden_size = config.n_embd if hasattr(
354354
config, 'n_embd') else config.hidden_size
355-
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
356-
embd_drop = config.embd_pdrop if hasattr(
357-
config, 'embd_pdrop') else config.embed_pdrop
358-
self.drop = nn.Dropout(embd_drop)
359-
else:
360-
self.drop = None
361355

362356
# layer_idx to output the img features
363357
if isinstance(config.img_processor, dict):
@@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
14311425
],
14321426
}
14331427

1428+
hf_to_vllm_mapper = WeightsMapper(
1429+
orig_to_new_substr={
1430+
"base_layer.": "",
1431+
},
1432+
orig_to_new_prefix={
1433+
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
1434+
"embed_tokens_extend.audio_projection_for_vision.",
1435+
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
1436+
"embed_tokens_extend.audio_projection.",
1437+
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
1438+
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
1439+
},
1440+
)
1441+
14341442
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
14351443
super().__init__()
14361444
config = vllm_config.model_config.hf_config
@@ -1445,8 +1453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
14451453
self.lora_config = lora_config
14461454

14471455
# Tensor/Pipeline parallel not supported for now.
1448-
assert get_tensor_model_parallel_world_size(
1449-
) == 1, "tensor parallel is not supported"
14501456
assert get_pp_group(
14511457
).world_size == 1, "pipeline parallel is not supported"
14521458

@@ -1686,44 +1692,6 @@ def merge_image_features_to_inputs_embeds(
16861692
)
16871693
return merged_embeds
16881694

1689-
def load_weights(self, weights: Iterable[Tuple[str,
1690-
torch.Tensor]]) -> None:
1691-
weights = {name: weight for name, weight in weights}
1692-
adjusted_weights = {}
1693-
1694-
for name, weight in weights.items():
1695-
# NOTE vision-speech tasks use a separate projection layer
1696-
audio_proj_4v = \
1697-
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
1698-
if name.startswith(audio_proj_4v):
1699-
name = name.replace(
1700-
audio_proj_4v,
1701-
"embed_tokens_extend.audio_projection_for_vision")
1702-
1703-
name = (name.replace(
1704-
"model.embed_tokens_extend.audio_embed."\
1705-
"audio_projection.speech.",
1706-
"embed_tokens_extend.audio_projection.",
1707-
).replace(
1708-
"model.embed_tokens_extend.audio_embed.",
1709-
"embed_tokens_extend.",
1710-
).replace("model.embed_tokens_extend.image_embed.",
1711-
"vision_encoder."))
1712-
# NOTE: this is deal with LoRA injection, where `base_layer`
1713-
# remains as the original layer in the model
1714-
if name.endswith(".base_layer.weight"):
1715-
name = name.replace(".base_layer.weight", ".weight")
1716-
adjusted_weights[name] = weight
1717-
1718-
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
1719-
strict=False)
1720-
logger.debug("*** missing keys:")
1721-
for key in missing_keys:
1722-
logger.debug(key)
1723-
logger.debug("**** unexpected keys:")
1724-
for key in unexpected_keys:
1725-
logger.debug(key)
1726-
17271695
def forward(
17281696
self,
17291697
input_ids: torch.Tensor,
@@ -1796,6 +1764,13 @@ def sample(
17961764
next_tokens = self.sampler(logits, sampling_metadata)
17971765
return next_tokens
17981766

1767+
def load_weights(self, weights: Iterable[Tuple[str,
1768+
torch.Tensor]]) -> None:
1769+
weights = ((name, data) for name, data in weights
1770+
if "lora" not in name)
1771+
loader = AutoWeightsLoader(self)
1772+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1773+
17991774
def get_mm_mapping(self) -> MultiModelKeys:
18001775
"""
18011776
Get the module prefix in multimodal models
@@ -1804,4 +1779,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
18041779
language_model="model.",
18051780
connector=["audio_projection_for_vision", "audio_projection"],
18061781
tower_model=["vision_encoder", "embed_tokens_extend"],
1807-
)
1782+
)

0 commit comments

Comments
 (0)