Skip to content

Commit 63afb45

Browse files
alex-jw-brooksLeiWang1999
authored andcommitted
[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (vllm-project#9131)
Signed-off-by: Alex-Brooks <[email protected]> Signed-off-by: LeiWang1999 <[email protected]>
1 parent 2c8129a commit 63afb45

21 files changed

+443
-121
lines changed

examples/offline_inference_vision_language.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def run_phi3v(question: str, modality: str):
105105
trust_remote_code=True,
106106
max_model_len=4096,
107107
max_num_seqs=2,
108+
# Note - mm_processor_kwargs can also be passed to generate/chat calls
108109
mm_processor_kwargs={"num_crops": 16},
109110
)
110111
stop_token_ids = None

tests/multimodal/test_processor_kwargs.py

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ def mm_model_cls():
7474
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
7575
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
7676
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
77-
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
77+
"pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
7878
}
7979

8080

81-
### Test for default processor logic & mm_processor_kwargs wrapping
81+
### Tests for default processor logic & mm_processor_kwargs wrapping
8282
def test_default_processor_is_a_noop():
8383
"""Ensure that by default, there is no processor override."""
8484
dummy_registry = InputRegistry()
@@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
8989
assert proc_inputs is proc_outputs
9090

9191

92-
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
93-
def test_processor_default_kwargs(use_processor_mock, num_crops):
94-
"""Ensure input processors can use processor kwargs."""
95-
dummy_registry = InputRegistry()
92+
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
93+
"""Get the init / inference kwargs and expected num_crops for this test."""
9694
# If we have a value for num_crops, pass the override value and make
9795
# sure we get that value as a return-value from out mock processor,
9896
# otherwise fall back to the default value
99-
mm_processor_kwargs = None if num_crops is None else {
100-
"num_crops": num_crops
97+
init_kwargs = None if init_num_crops is None else {
98+
"num_crops": init_num_crops
10199
}
102-
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
103-
ctx = build_model_context(DUMMY_MODEL_ID,
104-
mm_processor_kwargs=mm_processor_kwargs)
105-
processor = dummy_registry.create_input_processor(ctx.model_config)
100+
inference_kwargs = None if inference_num_crops is None else {
101+
"num_crops": inference_num_crops
102+
}
103+
if inference_num_crops is not None:
104+
expected_seq_count = inference_num_crops
105+
elif init_num_crops is not None:
106+
expected_seq_count = init_num_crops
107+
else:
108+
expected_seq_count = DEFAULT_NUM_CROPS
109+
return init_kwargs, inference_kwargs, expected_seq_count
110+
111+
112+
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
113+
(None, None),
114+
(NUM_CROPS_OVERRIDE, None),
115+
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
116+
])
117+
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
118+
inference_num_crops):
119+
"""Ensure input processors can use processor kwargs."""
120+
dummy_registry = InputRegistry()
121+
122+
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
123+
init_num_crops, inference_num_crops)
106124

107-
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
108-
assert num_crops_val == expected_num_crops
125+
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
126+
processor = dummy_registry.create_input_processor(ctx.model_config)
127+
num_crops_val = processor(
128+
LLMInputs(prompt_token_ids=[],
129+
prompt="",
130+
mm_processor_kwargs=inference_kwargs))
131+
assert num_crops_val == expected_seq_count
109132

110133

111134
@pytest.mark.parametrize(
@@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
124147
mm_processor_kwargs):
125148
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
126149
dummy_registry = InputRegistry()
150+
# Should filter out the init time kwargs
127151
ctx = build_model_context(DUMMY_MODEL_ID,
128152
mm_processor_kwargs=mm_processor_kwargs)
129153

130154
processor = dummy_registry.create_input_processor(ctx.model_config)
131-
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
155+
# Should filter out the inference time kwargs
156+
num_crops_val = processor(
157+
LLMInputs(prompt_token_ids=[],
158+
prompt="",
159+
mm_processor_kwargs=mm_processor_kwargs))
132160
assert num_crops_val == DEFAULT_NUM_CROPS
133161

134162

@@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
271299
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
272300

273301

274-
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
275-
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
302+
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
303+
(None, None),
304+
(NUM_CROPS_OVERRIDE, None),
305+
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
306+
])
307+
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
308+
inference_num_crops):
276309
"""Ensure custom mappers can use processor kwargs."""
277-
mm_processor_kwargs = None if num_crops is None else {
278-
"num_crops": num_crops
279-
}
280-
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
310+
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
311+
init_num_crops, inference_num_crops)
312+
281313
ctx = build_model_context(MULTIMODAL_MODEL_ID,
282314
trust_remote_code=True,
283-
mm_processor_kwargs=mm_processor_kwargs,
315+
mm_processor_kwargs=init_kwargs,
284316
limit_mm_per_prompt={"image": 1})
285317

286318
mm_registry = MultiModalRegistry()
287319
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
288-
# Patch the image registry for phi3v with our lambda that is compatible
289-
# with overrides, then ensure that calling the method correctly echos
290-
# our num_crops value back from the mm_processor_kwargs.
291320
image = image_assets[0].pil_image
292321
mm_inputs = {"image": image}
293322

294-
with patch.object(
295-
mm_registry._get_plugin("image"),
296-
"_default_input_mapper",
297-
{mm_model_cls(): custom_mapper},
298-
):
299-
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
323+
# Patch the image registry for phi3v with our lambda that is compatible
324+
# with overrides, then ensure that calling the method correctly echos
325+
# our num_crops value back from the mm_processor_kwargs.
326+
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
327+
mm_model_cls())
328+
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
329+
inference_kwargs)
300330

301331
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
302332

@@ -316,24 +346,24 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
316346
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
317347
mm_processor_kwargs):
318348
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
349+
# Should filter out the init time kwargs
319350
ctx = build_model_context(MULTIMODAL_MODEL_ID,
320351
trust_remote_code=True,
321352
mm_processor_kwargs=mm_processor_kwargs,
322353
limit_mm_per_prompt={"image": 1})
323354

324355
mm_registry = MultiModalRegistry()
325356
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
326-
# Patch the image registry for phi3v with our lambda that is compatible
327-
# with overrides, then ensure that calling the method correctly echos
328-
# our num_crops value back from the mm_processor_kwargs.
329357
image = image_assets[0].pil_image
330358
mm_inputs = {"image": image}
331359

332-
with patch.object(
333-
mm_registry._get_plugin("image"),
334-
"_default_input_mapper",
335-
{mm_model_cls(): custom_mapper},
336-
):
337-
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
360+
# Patch the image registry for phi3v with our lambda that is compatible
361+
# with overrides, then ensure that calling the method correctly echos
362+
# our num_crops value back from the mm_processor_kwargs.
363+
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
364+
mm_model_cls())
365+
# Should filter out the inference time kwargs
366+
mapped_inputs = mm_registry.map_input(
367+
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
338368

339369
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1

tests/test_inputs.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from vllm.inputs import zip_enc_dec_prompts
56
from vllm.inputs.parse import parse_and_batch_prompt
67

78
STRING_INPUTS = [
@@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]):
5152
def test_parse_single_batch_string_slice(inputs_slice: slice):
5253
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
5354
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
55+
56+
57+
# yapf: disable
58+
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
59+
(None, [{}, {}]),
60+
({}, [{}, {}]),
61+
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
62+
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
63+
])
64+
# yapf: enable
65+
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
66+
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
67+
encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
68+
decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
69+
zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
70+
mm_processor_kwargs)
71+
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
72+
for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
73+
expected_mm_kwargs,
74+
zipped_prompts):
75+
assert isinstance(zipped, dict)
76+
assert len(zipped.keys()) == 3
77+
assert zipped['encoder_prompt'] == enc
78+
assert zipped['decoder_prompt'] == dec
79+
assert zipped['mm_processor_kwargs'] == exp_kwargs

tests/test_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
10-
get_open_port, merge_async_iterators)
10+
get_open_port, merge_async_iterators, supports_kw)
1111

1212
from .utils import error_on_warning
1313

@@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config):
236236
with pytest.raises(ValueError):
237237
parser_with_config.parse_args(
238238
['serve', '--config', './data/test_config.yaml'])
239+
240+
241+
# yapf: enable
242+
@pytest.mark.parametrize(
243+
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
244+
[
245+
# Tests for positional argument support
246+
(lambda foo: None, "foo", True, True, False),
247+
(lambda foo: None, "foo", False, True, True),
248+
# Tests for positional or keyword / keyword only
249+
(lambda foo=100: None, "foo", True, True, False),
250+
(lambda *, foo: None, "foo", False, True, True),
251+
# Tests to make sure the names of variadic params are NOT supported
252+
(lambda *args: None, "args", False, True, False),
253+
(lambda **kwargs: None, "kwargs", False, True, False),
254+
# Tests for if we allow var kwargs to add support
255+
(lambda foo: None, "something_else", False, True, False),
256+
(lambda foo, **kwargs: None, "something_else", False, True, True),
257+
(lambda foo, **kwargs: None, "kwargs", True, True, False),
258+
(lambda foo, **kwargs: None, "foo", True, True, False),
259+
])
260+
# yapf: disable
261+
def test_supports_kw(callable,kw_name,requires_kw_only,
262+
allow_var_kwargs,is_supported):
263+
assert supports_kw(
264+
callable=callable,
265+
kw_name=kw_name,
266+
requires_kw_only=requires_kw_only,
267+
allow_var_kwargs=allow_var_kwargs
268+
) == is_supported

vllm/core/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,7 @@ def schedule(
13091309
# `multi_modal_data` will be None.
13101310
multi_modal_data=seq_group.multi_modal_data
13111311
if scheduler_outputs.num_prefill_groups > 0 else None,
1312+
mm_processor_kwargs=seq_group.mm_processor_kwargs,
13121313
prompt_adapter_request=seq_group.prompt_adapter_request,
13131314
)
13141315
else:

vllm/engine/llm_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,13 @@ def add_request(
811811
)
812812
processed_inputs = self.input_processor(preprocessed_inputs)
813813

814+
# This is a bit of a hack - copy the mm_processor_kwargs that were
815+
# used in the input processor to the processed output, since these
816+
# kwargs are presumed to be immutable and the values should be aligned
817+
# between the input processor (here) and the input mapper.
818+
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
819+
"mm_processor_kwargs")
820+
814821
self._add_processed_request(
815822
request_id=request_id,
816823
processed_inputs=processed_inputs,

vllm/entrypoints/llm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def chat(
472472
add_generation_prompt: bool = True,
473473
continue_final_message: bool = False,
474474
tools: Optional[List[Dict[str, Any]]] = None,
475+
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
475476
) -> List[RequestOutput]:
476477
"""
477478
Generate responses for a chat conversation.
@@ -501,6 +502,8 @@ def chat(
501502
continue_final_message: If True, continues the final message in
502503
the conversation instead of starting a new one. Cannot be `True`
503504
if `add_generation_prompt` is also `True`.
505+
mm_processor_kwargs: Multimodal processor kwarg overrides for this
506+
chat request. Only used for offline requests.
504507
505508
Returns:
506509
A list of ``RequestOutput`` objects containing the generated
@@ -522,6 +525,9 @@ def chat(
522525
tokenizer = self.get_tokenizer()
523526
model_config = self.llm_engine.get_model_config()
524527

528+
# NOTE: _parse_chat_message_content_parts() currently doesn't
529+
# handle mm_processor_kwargs, since there is no implementation in
530+
# the chat message parsing for it.
525531
conversation, mm_data = parse_chat_messages(
526532
msgs, model_config, tokenizer)
527533

@@ -554,6 +560,9 @@ def chat(
554560
if mm_data is not None:
555561
prompt["multi_modal_data"] = mm_data
556562

563+
if mm_processor_kwargs is not None:
564+
prompt["mm_processor_kwargs"] = mm_processor_kwargs
565+
557566
prompts.append(prompt)
558567

559568
return self.generate(

0 commit comments

Comments
 (0)