Skip to content

Commit 22392d8

Browse files
patrickvonplatenywang96
authored andcommitted
Co-authored-by: Roger Wang <[email protected]> Signed-off-by: Alvant <[email protected]>
1 parent 86ee231 commit 22392d8

File tree

8 files changed

+807
-9
lines changed

8 files changed

+807
-9
lines changed

docs/source/models/supported_models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,11 @@ Multimodal Language Models
247247
- Image\ :sup:`E+`
248248
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
249249
-
250+
* - :code:`PixtralForConditionalGeneration`
251+
- Pixtral
252+
- Image\ :sup:`+`
253+
- :code:`mistralai/Pixtral-12B-2409`
254+
-
250255
* - :code:`QWenLMHeadModel`
251256
- Qwen-VL
252257
- Image\ :sup:`E`

examples/offline_inference_pixtral.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# ruff: noqa
2+
import argparse
3+
4+
from vllm import LLM
5+
from vllm.sampling_params import SamplingParams
6+
7+
# This script is an offline demo for running Pixtral.
8+
#
9+
# If you want to run a server/client setup, please follow this code:
10+
#
11+
# - Server:
12+
#
13+
# ```bash
14+
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384
15+
# ```
16+
#
17+
# - Client:
18+
#
19+
# ```bash
20+
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
21+
# --header 'Content-Type: application/json' \
22+
# --header 'Authorization: Bearer token' \
23+
# --data '{
24+
# "model": "mistralai/Pixtral-12B-2409",
25+
# "messages": [
26+
# {
27+
# "role": "user",
28+
# "content": [
29+
# {"type" : "text", "text": "Describe this image in detail please."},
30+
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
31+
# {"type" : "text", "text": "and this one as well. Answer in French."},
32+
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
33+
# ]
34+
# }
35+
# ]
36+
# }'
37+
# ```
38+
#
39+
# Usage:
40+
# python demo.py simple
41+
# python demo.py advanced
42+
43+
44+
def run_simple_demo():
45+
model_name = "mistralai/Pixtral-12B-2409"
46+
sampling_params = SamplingParams(max_tokens=8192)
47+
48+
llm = LLM(model=model_name, tokenizer_mode="mistral")
49+
50+
prompt = "Describe this image in one sentence."
51+
image_url = "https://picsum.photos/id/237/200/300"
52+
53+
messages = [
54+
{
55+
"role":
56+
"user",
57+
"content": [
58+
{
59+
"type": "text",
60+
"text": prompt
61+
},
62+
{
63+
"type": "image_url",
64+
"image_url": {
65+
"url": image_url
66+
}
67+
},
68+
],
69+
},
70+
]
71+
outputs = llm.chat(messages, sampling_params=sampling_params)
72+
73+
print(outputs[0].outputs[0].text)
74+
75+
76+
def run_advanced_demo():
77+
model_name = "mistralai/Pixtral-12B-2409"
78+
max_img_per_msg = 5
79+
max_tokens_per_img = 4096
80+
81+
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
82+
llm = LLM(
83+
model=model_name,
84+
tokenizer_mode="mistral",
85+
limit_mm_per_prompt={"image": max_img_per_msg},
86+
max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
87+
)
88+
89+
prompt = "Describe the following image."
90+
91+
url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
92+
url_2 = "https://picsum.photos/seed/picsum/200/300"
93+
url_3 = "https://picsum.photos/id/32/512/512"
94+
95+
messages = [
96+
{
97+
"role":
98+
"user",
99+
"content": [
100+
{
101+
"type": "text",
102+
"text": prompt
103+
},
104+
{
105+
"type": "image_url",
106+
"image_url": {
107+
"url": url_1
108+
}
109+
},
110+
{
111+
"type": "image_url",
112+
"image_url": {
113+
"url": url_2
114+
}
115+
},
116+
],
117+
},
118+
{
119+
"role": "assistant",
120+
"content": "The images show nature.",
121+
},
122+
{
123+
"role": "user",
124+
"content": "More details please and answer only in French!.",
125+
},
126+
{
127+
"role": "user",
128+
"content": [
129+
{
130+
"type": "image_url",
131+
"image_url": {
132+
"url": url_3
133+
}
134+
},
135+
],
136+
},
137+
]
138+
139+
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
140+
print(outputs[0].outputs[0].text)
141+
142+
143+
def main():
144+
parser = argparse.ArgumentParser(
145+
description="Run a demo in simple or advanced mode.")
146+
147+
parser.add_argument(
148+
"mode",
149+
choices=["simple", "advanced"],
150+
help="Specify the demo mode: 'simple' or 'advanced'",
151+
)
152+
153+
args = parser.parse_args()
154+
155+
if args.mode == "simple":
156+
print("Running simple demo...")
157+
run_simple_demo()
158+
elif args.mode == "advanced":
159+
print("Running advanced demo...")
160+
run_advanced_demo()
161+
162+
163+
if __name__ == "__main__":
164+
main()

requirements-common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pyzmq
2525
msgspec
2626
gguf == 0.9.1
2727
importlib_metadata
28-
mistral_common >= 1.3.4
28+
mistral_common >= 1.4.0
2929
pyyaml
3030
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
3131
einops # Required for Qwen2-VL.

tests/models/test_pixtral.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
2+
3+
Run `pytest tests/models/test_mistral.py`.
4+
"""
5+
import pytest
6+
7+
from vllm.sampling_params import SamplingParams
8+
9+
pytestmark = pytest.mark.vlm
10+
11+
MODELS = ["mistralai/Pixtral-12B-2409"]
12+
13+
14+
@pytest.mark.skip(
15+
reason=
16+
"Model is too big, test passed on A100 locally but will OOM on CI machine."
17+
)
18+
@pytest.mark.parametrize("model", MODELS)
19+
@pytest.mark.parametrize("dtype", ["bfloat16"])
20+
@pytest.mark.parametrize("max_tokens", [64])
21+
@pytest.mark.parametrize("num_logprobs", [5])
22+
def test_models(
23+
vllm_runner,
24+
example_prompts,
25+
model: str,
26+
dtype: str,
27+
max_tokens: int,
28+
num_logprobs: int,
29+
) -> None:
30+
image_urls = [
31+
"https://picsum.photos/id/237/200/300",
32+
"https://picsum.photos/seed/picsum/200/300"
33+
]
34+
expected = [
35+
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
36+
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
37+
]
38+
prompt = "Describe the image in one short sentence."
39+
40+
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
41+
42+
with vllm_runner(model, dtype=dtype,
43+
tokenizer_mode="mistral") as vllm_model:
44+
45+
for i, image_url in enumerate(image_urls):
46+
messages = [
47+
{
48+
"role":
49+
"user",
50+
"content": [{
51+
"type": "text",
52+
"text": prompt
53+
}, {
54+
"type": "image_url",
55+
"image_url": {
56+
"url": image_url
57+
}
58+
}]
59+
},
60+
]
61+
62+
outputs = vllm_model.model.chat(messages,
63+
sampling_params=sampling_params)
64+
assert outputs[0].outputs[0].text == expected[i]

vllm/entrypoints/chat_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def _placeholder_str(self, modality: ModalityStr,
148148
return f"<|image_{current_count}|>"
149149
if model_type == "minicpmv":
150150
return "(<image>./</image>)"
151-
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
151+
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
152+
"pixtral"):
152153
# These models do not use image tokens in the prompt
153154
return None
154155
if model_type == "qwen":

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
9393
"UltravoxModel": ("ultravox", "UltravoxModel"),
9494
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
95+
"PixtralForConditionalGeneration": ("pixtral",
96+
"PixtralForConditionalGeneration"),
9597
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
9698
"Qwen2VLForConditionalGeneration"),
9799
}

0 commit comments

Comments
 (0)