Skip to content

Commit 757ac70

Browse files
authored
[Model] Rename MiniCPMVQwen2 to MiniCPMV2.6 (#7273)
1 parent 6dffa4b commit 757ac70

File tree

3 files changed

+51
-31
lines changed

3 files changed

+51
-31
lines changed

docs/source/models/supported_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ Vision Language Models
222222
-
223223
* - :code:`MiniCPMV`
224224
- MiniCPM-V
225-
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
225+
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
226226
-
227227

228228
.. note::

examples/offline_inference_vision_language.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,26 @@ def run_llava(question):
2222
prompt = f"USER: <image>\n{question}\nASSISTANT:"
2323

2424
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
25-
26-
return llm, prompt
25+
stop_token_ids = None
26+
return llm, prompt, stop_token_ids
2727

2828

2929
# LLaVA-1.6/LLaVA-NeXT
3030
def run_llava_next(question):
3131

3232
prompt = f"[INST] <image>\n{question} [/INST]"
3333
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
34-
35-
return llm, prompt
34+
stop_token_ids = None
35+
return llm, prompt, stop_token_ids
3636

3737

3838
# Fuyu
3939
def run_fuyu(question):
4040

4141
prompt = f"{question}\n"
4242
llm = LLM(model="adept/fuyu-8b")
43-
44-
return llm, prompt
43+
stop_token_ids = None
44+
return llm, prompt, stop_token_ids
4545

4646

4747
# Phi-3-Vision
@@ -59,7 +59,8 @@ def run_phi3v(question):
5959
trust_remote_code=True,
6060
max_num_seqs=5,
6161
)
62-
return llm, prompt
62+
stop_token_ids = None
63+
return llm, prompt, stop_token_ids
6364

6465

6566
# PaliGemma
@@ -68,16 +69,17 @@ def run_paligemma(question):
6869
# PaliGemma has special prompt format for VQA
6970
prompt = "caption en"
7071
llm = LLM(model="google/paligemma-3b-mix-224")
71-
72-
return llm, prompt
72+
stop_token_ids = None
73+
return llm, prompt, stop_token_ids
7374

7475

7576
# Chameleon
7677
def run_chameleon(question):
7778

7879
prompt = f"{question}<image>"
7980
llm = LLM(model="facebook/chameleon-7b")
80-
return llm, prompt
81+
stop_token_ids = None
82+
return llm, prompt, stop_token_ids
8183

8284

8385
# MiniCPM-V
@@ -89,13 +91,26 @@ def run_minicpmv(question):
8991
# model_name = "HwwwH/MiniCPM-V-2"
9092

9193
# 2.5
92-
model_name = "openbmb/MiniCPM-Llama3-V-2_5"
94+
# model_name = "openbmb/MiniCPM-Llama3-V-2_5"
95+
96+
#2.6
97+
model_name = "openbmb/MiniCPM-V-2_6"
9398
tokenizer = AutoTokenizer.from_pretrained(model_name,
9499
trust_remote_code=True)
95100
llm = LLM(
96101
model=model_name,
97102
trust_remote_code=True,
98103
)
104+
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
105+
# 2.0
106+
# stop_token_ids = [tokenizer.eos_id]
107+
108+
# 2.5
109+
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
110+
111+
# 2.6
112+
stop_tokens = ['<|im_end|>', '<|endoftext|>']
113+
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
99114

100115
messages = [{
101116
'role': 'user',
@@ -104,7 +119,7 @@ def run_minicpmv(question):
104119
prompt = tokenizer.apply_chat_template(messages,
105120
tokenize=False,
106121
add_generation_prompt=True)
107-
return llm, prompt
122+
return llm, prompt, stop_token_ids
108123

109124

110125
# InternVL
@@ -118,7 +133,8 @@ def run_internvl(question):
118133
trust_remote_code=True,
119134
max_num_seqs=5,
120135
)
121-
return llm, prompt
136+
stop_token_ids = None
137+
return llm, prompt, stop_token_ids
122138

123139

124140
# BLIP-2
@@ -128,7 +144,8 @@ def run_blip2(question):
128144
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
129145
prompt = f"Question: {question} Answer:"
130146
llm = LLM(model="Salesforce/blip2-opt-2.7b")
131-
return llm, prompt
147+
stop_token_ids = None
148+
return llm, prompt, stop_token_ids
132149

133150

134151
model_example_map = {
@@ -149,11 +166,13 @@ def main(args):
149166
if model not in model_example_map:
150167
raise ValueError(f"Model type {model} is not supported.")
151168

152-
llm, prompt = model_example_map[model](question)
169+
llm, prompt, stop_token_ids = model_example_map[model](question)
153170

154171
# We set temperature to 0.2 so that outputs can be different
155172
# even when all prompts are identical when running batch inference.
156-
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
173+
sampling_params = SamplingParams(temperature=0.2,
174+
max_tokens=64,
175+
stop_token_ids=stop_token_ids)
157176

158177
assert args.num_prompts > 0
159178
if args.num_prompts == 1:

vllm/model_executor/models/minicpmv.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def __init__(
216216

217217
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
218218
trunc_normal_(self.query, std=0.02)
219-
220219
if kv_dim is not None and kv_dim != embed_dim:
221220
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
222221
else:
@@ -225,7 +224,6 @@ def __init__(
225224
nn.Identity()(*args, **kwargs),
226225
None,
227226
)
228-
229227
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
230228
self.ln_q = norm_layer(embed_dim)
231229
self.ln_kv = norm_layer(embed_dim)
@@ -261,7 +259,6 @@ def __init__(
261259
norm_layer)
262260

263261
self.adaptive = adaptive
264-
265262
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
266263
grid_size,
267264
version=(2, 0))
@@ -717,7 +714,7 @@ def is_default_weight_loading(self, name: str) -> bool:
717714
raise NotImplementedError
718715

719716

720-
class MiniCPMV2(MiniCPMVBaseModel):
717+
class MiniCPMV2_0(MiniCPMVBaseModel):
721718

722719
def __init__(
723720
self,
@@ -890,10 +887,7 @@ def is_default_weight_loading(self, name: str) -> bool:
890887
return "resampler" in name
891888

892889

893-
# NOTE: Currently, information about this model is unavailable. We are
894-
# temporarily using `MiniCPMVQwen2` as it's name. The name may need
895-
# to be modified in the future.
896-
class MiniCPMVQwen2(MiniCPMVBaseModel):
890+
class MiniCPMV2_6(MiniCPMVBaseModel):
897891

898892
def __init__(
899893
self,
@@ -903,6 +897,7 @@ def __init__(
903897
quant_config: Optional[QuantizationConfig] = None,
904898
):
905899
super().__init__(config, multimodal_config, cache_config, quant_config)
900+
assert self.version == (2, 6)
906901

907902
def init_llm(
908903
self,
@@ -930,6 +925,7 @@ def init_vision_module(self) -> nn.Module:
930925

931926
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
932927
with set_default_torch_dtype(torch.float16):
928+
# The resampler in 2.6 remains consistent with the one in 2.5.
933929
resampler = Resampler2_5(
934930
num_queries=self.config.query_num,
935931
embed_dim=embed_dim,
@@ -989,6 +985,13 @@ def is_default_weight_loading(self, name: str) -> bool:
989985
return "resampler" in name or "vpm" in name
990986

991987

988+
_SUPPORT_VERSION = {
989+
(2, 0): MiniCPMV2_0,
990+
(2, 5): MiniCPMV2_5,
991+
(2, 6): MiniCPMV2_6
992+
}
993+
994+
992995
@MULTIMODAL_REGISTRY.register_image_input_mapper()
993996
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
994997
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@@ -1016,11 +1019,9 @@ def __new__(
10161019
version = str(config.version).split(".")
10171020
version = tuple([int(x) for x in version])
10181021
# Dispatch class based on version
1019-
if version == (2, 0):
1020-
instance_class = MiniCPMV2
1021-
elif version == (2, 5):
1022-
instance_class = MiniCPMV2_5
1023-
else:
1024-
instance_class = MiniCPMVQwen2
1022+
instance_class = _SUPPORT_VERSION.get(version, None)
1023+
if instance_class is None:
1024+
raise ValueError(
1025+
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
10251026
return instance_class(config, multimodal_config, cache_config,
10261027
quant_config)

0 commit comments

Comments
 (0)