Skip to content

Commit 11d8a09

Browse files
authored
[Misc] Optimize Qwen2-VL LoRA test (#11663)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 365801f commit 11d8a09

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

tests/lora/test_qwen2vl.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from vllm.lora.request import LoRARequest
88
from vllm.platforms import current_platform
99

10-
MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct"
10+
MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
1111

1212
PROMPT_TEMPLATE = (
1313
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
@@ -49,10 +49,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
4949
# Print the outputs.
5050
generated_texts: List[str] = []
5151
for output in outputs:
52-
prompt = output.prompt
5352
generated_text = output.outputs[0].text.strip()
5453
generated_texts.append(generated_text)
55-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
54+
print(f"Generated text: {generated_text!r}")
5655
return generated_texts
5756

5857

vllm/model_executor/models/qwen2_vl.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
GPTQMarlinConfig)
5353
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
5454
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
55+
from vllm.model_executor.models.module_mapping import MultiModelKeys
5556
from vllm.multimodal import MULTIMODAL_REGISTRY
5657
from vllm.multimodal.inputs import (ImageItem, ModalityData,
5758
MultiModalFieldConfig, MultiModalKwargs,
@@ -926,15 +927,23 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
926927
}
927928

928929
# LoRA specific attributes
929-
# TODO Support LoRA for the visual encoder in the future.
930930
supported_lora_modules = [
931931
"qkv_proj",
932932
"o_proj",
933933
"gate_up_proj",
934934
"down_proj",
935+
# vision tower
936+
"qkv",
937+
"attn.proj", # Distinguish patch_embed.proj
938+
"fc1",
939+
"fc2",
940+
# projector
941+
"mlp.0",
942+
"mlp.2"
935943
]
936944
embedding_modules = {}
937945
embedding_padding_modules = []
946+
938947
# To ensure correct weight loading and mapping.
939948
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
940949
"lm_head.": "language_model.lm_head.",
@@ -1231,3 +1240,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
12311240

12321241
loader = AutoWeightsLoader(self)
12331242
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1243+
1244+
def get_mm_mapping(self) -> MultiModelKeys:
1245+
"""
1246+
Get the module prefix in multimodal models
1247+
"""
1248+
return MultiModelKeys.from_string_field(
1249+
language_model="language_model",
1250+
connector="visual.",
1251+
tower_model="visual.merger.")

0 commit comments

Comments
 (0)