Skip to content

Commit 840537a

Browse files
sixsixcoderAlvant
authored andcommitted
[Model] Add GLM-4v support and meet vllm==0.6.2 (vllm-project#9242)
Signed-off-by: Alvant <[email protected]>
1 parent 6ead5ae commit 840537a

File tree

7 files changed

+776
-72
lines changed

7 files changed

+776
-72
lines changed

docs/source/models/supported_models.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,12 @@ Text Generation
351351
- :code:`adept/fuyu-8b` etc.
352352
-
353353
- ✅︎
354+
* - :code:`ChatGLMModel`
355+
- GLM-4V
356+
- Image
357+
- :code:`THUDM/glm-4v-9b` etc.
358+
-
359+
- ✅︎
354360
* - :code:`InternVLChatModel`
355361
- InternVL2
356362
- Image\ :sup:`E+`

examples/offline_inference_vision_language.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,21 @@ def run_mllama(question: str, modality: str):
300300
return llm, prompt, stop_token_ids
301301

302302

303+
# GLM-4v
304+
def run_glm4v(question: str, modality: str):
305+
assert modality == "image"
306+
model_name = "THUDM/glm-4v-9b"
307+
308+
llm = LLM(model=model_name,
309+
max_model_len=2048,
310+
max_num_seqs=2,
311+
trust_remote_code=True,
312+
enforce_eager=True)
313+
prompt = question
314+
stop_token_ids = [151329, 151336, 151338]
315+
return llm, prompt, stop_token_ids
316+
317+
303318
model_example_map = {
304319
"llava": run_llava,
305320
"llava-next": run_llava_next,
@@ -316,6 +331,7 @@ def run_mllama(question: str, modality: str):
316331
"qwen_vl": run_qwen_vl,
317332
"qwen2_vl": run_qwen2_vl,
318333
"mllama": run_mllama,
334+
"glm4v": run_glm4v,
319335
}
320336

321337

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from typing import List, Optional, Tuple, Type
2+
3+
import pytest
4+
5+
from vllm.multimodal.utils import rescale_image_size
6+
from vllm.transformers_utils.tokenizer import patch_padding_side
7+
8+
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
9+
from ....utils import large_gpu_test
10+
from ...utils import check_logprobs_close
11+
12+
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
13+
"stop_sign":
14+
"What's the content of the image?",
15+
"cherry_blossom":
16+
"What is the season?",
17+
})
18+
19+
models = ["THUDM/glm-4v-9b"]
20+
target_dtype = "bfloat16"
21+
22+
23+
def run_test(
24+
hf_runner: Type[HfRunner],
25+
vllm_runner: Type[VllmRunner],
26+
inputs: List[Tuple[List[str], PromptImageInput]],
27+
model: str,
28+
*,
29+
dtype: str,
30+
max_tokens: int,
31+
num_logprobs: int,
32+
mm_limit: int,
33+
tensor_parallel_size: int,
34+
distributed_executor_backend: Optional[str] = None,
35+
):
36+
# max_model_len should be greater than image_feature_size
37+
with vllm_runner(model,
38+
max_model_len=2048,
39+
max_num_seqs=2,
40+
dtype=dtype,
41+
limit_mm_per_prompt={"image": mm_limit},
42+
tensor_parallel_size=tensor_parallel_size,
43+
distributed_executor_backend=distributed_executor_backend,
44+
enforce_eager=True) as vllm_model:
45+
stop_token_ids = [151329, 151336, 151338]
46+
vllm_outputs_per_image = [
47+
vllm_model.generate_greedy_logprobs(prompts,
48+
max_tokens,
49+
num_logprobs=num_logprobs,
50+
images=images,
51+
stop_token_ids=stop_token_ids)
52+
for prompts, images in inputs
53+
]
54+
55+
with hf_runner(model, dtype=dtype) as hf_model:
56+
hf_processor = hf_model.processor
57+
patch_padding_side(hf_processor)
58+
59+
def processor(*args, text="", images=None, **kwargs):
60+
if images is None:
61+
return hf_processor(*args, **kwargs)
62+
63+
return hf_processor.apply_chat_template(
64+
[{
65+
"role": "user",
66+
"image": images,
67+
"content": text
68+
}],
69+
add_generation_prompt=True,
70+
tokenize=True,
71+
return_dict=True,
72+
**kwargs,
73+
)
74+
75+
hf_model.processor = processor
76+
hf_model.model.get_output_embeddings = lambda: \
77+
hf_model.model.transformer.output_layer
78+
hf_outputs_per_image = [
79+
hf_model.generate_greedy_logprobs_limit(
80+
prompts,
81+
max_tokens,
82+
num_logprobs=num_logprobs,
83+
images=images,
84+
) for prompts, images in inputs
85+
]
86+
87+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
88+
vllm_outputs_per_image):
89+
check_logprobs_close(
90+
outputs_0_lst=hf_outputs,
91+
outputs_1_lst=vllm_outputs,
92+
name_0="hf",
93+
name_1="vllm",
94+
)
95+
96+
97+
@large_gpu_test(min_gb=48)
98+
@pytest.mark.parametrize("model", models)
99+
@pytest.mark.parametrize(
100+
"size_factors",
101+
[
102+
# No image
103+
[],
104+
# Single-scale
105+
[1.0],
106+
# Single-scale, batched
107+
[1.0, 1.0, 1.0],
108+
# Multi-scale
109+
[0.25, 0.5, 1.0],
110+
],
111+
)
112+
@pytest.mark.parametrize("dtype", [target_dtype])
113+
@pytest.mark.parametrize("max_tokens", [128])
114+
@pytest.mark.parametrize("num_logprobs", [5])
115+
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
116+
dtype: str, max_tokens: int, num_logprobs: int) -> None:
117+
images = [asset.pil_image for asset in image_assets]
118+
119+
inputs_per_image = [(
120+
[prompt for _ in size_factors],
121+
[rescale_image_size(image, factor) for factor in size_factors],
122+
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
123+
run_test(
124+
hf_runner,
125+
vllm_runner,
126+
inputs_per_image,
127+
model,
128+
dtype=dtype,
129+
max_tokens=max_tokens,
130+
num_logprobs=num_logprobs,
131+
mm_limit=1,
132+
tensor_parallel_size=1,
133+
)

0 commit comments

Comments
 (0)