Skip to content

Commit 60464e8

Browse files
DarkLight1337JC1DA
authored andcommitted
[Frontend] Use a proper chat template for VLM2Vec (vllm-project#9912)
Signed-off-by: Loc Huynh <[email protected]>
1 parent 58e5ad3 commit 60464e8

File tree

6 files changed

+78
-11
lines changed

6 files changed

+78
-11
lines changed

docs/source/models/vlm.rst

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,7 @@ To consume the server, you can use the OpenAI client like in the example below:
240240
)
241241
print("Chat completion output:", chat_response.choices[0].message.content)
242242
243-
244-
A full code example can be found in `examples/openai_api_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_api_client_for_multimodal.py>`_.
243+
A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client_for_multimodal.py>`_.
245244

246245
.. tip::
247246
There is no need to place image placeholders in the text content of the API request - they are already represented by the image content.
@@ -269,14 +268,19 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
269268
.. code-block:: bash
270269
271270
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
272-
--trust-remote-code --max-model-len 4096
271+
--trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja
273272
274273
.. important::
275274

276275
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
277276
to run this model in embedding mode instead of text generation mode.
278277

279-
Since this schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
278+
.. important::
279+
280+
VLM2Vec does not expect chat-based input. We use a `custom chat template <https://github.com/vllm-project/vllm/blob/main/examples/template_vlm2vec.jinja>`_
281+
to combine the text and images together.
282+
283+
Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
280284

281285
.. code-block:: python
282286
@@ -301,3 +305,5 @@ Since this schema is not defined by OpenAI client, we post a request to the serv
301305
response.raise_for_status()
302306
response_json = response.json()
303307
print("Embedding output:", response_json["data"][0]["embedding"])
308+
309+
A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import requests
2+
3+
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
4+
5+
response = requests.post(
6+
"http://localhost:8000/v1/embeddings",
7+
json={
8+
"model":
9+
"TIGER-Lab/VLM2Vec-Full",
10+
"messages": [{
11+
"role":
12+
"user",
13+
"content": [
14+
{
15+
"type": "image_url",
16+
"image_url": {
17+
"url": image_url
18+
}
19+
},
20+
{
21+
"type": "text",
22+
"text": "Represent the given image."
23+
},
24+
],
25+
}],
26+
"encoding_format":
27+
"float",
28+
},
29+
)
30+
response.raise_for_status()
31+
response_json = response.json()
32+
33+
print("Embedding output:", response_json["data"][0]["embedding"])

examples/template_vlm2vec.jinja

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{%- if messages | length > 1 -%}
2+
{{ raise_exception('Embedding models should only embed one message at a time') }}
3+
{%- endif -%}
4+
5+
{% set vars = namespace(parts=[], next_image_id=1) %}
6+
{%- for message in messages -%}
7+
{%- for content in message['content'] -%}
8+
{%- if content['type'] == 'text' -%}
9+
{%- set vars.parts = vars.parts + [content['text']] %}
10+
{%- elif content['type'] == 'image' -%}
11+
{%- set vars.parts = vars.parts + ['<|image_{i:d}|>'.format(i=vars.next_image_id)] %}
12+
{%- set vars.next_image_id = vars.next_image_id + 1 %}
13+
{%- endif -%}
14+
{%- endfor -%}
15+
{%- endfor -%}
16+
{{ vars.parts | join(' ') }}

tests/entrypoints/openai/test_vision_embedding.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66

77
from vllm.multimodal.utils import encode_image_base64, fetch_image
88

9-
from ...utils import RemoteOpenAIServer
9+
from ...utils import VLLM_PATH, RemoteOpenAIServer
1010

1111
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
1212
MAXIMUM_IMAGES = 2
1313

14+
vlm2vec_jinja_path = VLLM_PATH / "examples/template_vlm2vec.jinja"
15+
assert vlm2vec_jinja_path.exists()
16+
1417
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
1518
TEST_IMAGE_URLS = [
1619
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
@@ -35,6 +38,8 @@ def server():
3538
"--trust-remote-code",
3639
"--limit-mm-per-prompt",
3740
f"image={MAXIMUM_IMAGES}",
41+
"--chat-template",
42+
str(vlm2vec_jinja_path),
3843
]
3944

4045
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -90,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
9095
assert len(embeddings["data"]) == 1
9196
assert len(embeddings["data"][0]["embedding"]) == 3072
9297
assert embeddings["usage"]["completion_tokens"] == 0
93-
assert embeddings["usage"]["prompt_tokens"] == 771
94-
assert embeddings["usage"]["total_tokens"] == 771
98+
assert embeddings["usage"]["prompt_tokens"] == 762
99+
assert embeddings["usage"]["total_tokens"] == 762

vllm/entrypoints/chat_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
156156

157157
self._items: List[_T] = []
158158

159+
@property
160+
def model_config(self) -> ModelConfig:
161+
return self._model_config
162+
159163
@staticmethod
160164
@lru_cache(maxsize=None)
161165
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
@@ -491,10 +495,13 @@ def _parse_chat_message_content_parts(
491495
content: List[Union[str, Dict[str, str]]] = []
492496

493497
mm_parser = mm_tracker.create_parser()
494-
wrap_dicts = \
495-
mm_tracker._model_config.hf_config.model_type in \
496-
MODEL_KEEP_MULTI_MODAL_CONTENT or \
497-
(chat_template_text_format == "openai")
498+
model_config = mm_tracker.model_config
499+
500+
wrap_dicts = (chat_template_text_format == "openai"
501+
or (model_config.task == "embedding"
502+
and model_config.is_multimodal_model)
503+
or (model_config.hf_config.model_type
504+
in MODEL_KEEP_MULTI_MODAL_CONTENT))
498505

499506
for part in parts:
500507
parse_res = _parse_chat_message_content_part(

0 commit comments

Comments
 (0)