Skip to content

Commit 609ef61

Browse files
authored
[Bugfix] Fix profiling OOM and decouple encoder multimodal profiling (#14361)
Signed-off-by: Isotr0py <[email protected]>
1 parent db84f5e commit 609ef61

File tree

3 files changed

+59
-33
lines changed

3 files changed

+59
-33
lines changed

tests/multimodal/test_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
873873
exc_ctx = pytest.raises(ValueError, match="this model only supports")
874874

875875
with exc_ctx:
876-
profiler.get_dummy_data(model_config.max_model_len)
876+
profiler.get_decoder_dummy_data(model_config.max_model_len)
877877

878878

879879
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])

vllm/inputs/registry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,10 @@ def dummy_data_for_profiling(
335335
tokenizer,
336336
disable_cache=True)
337337
profiler = MultiModalProfiler(processor)
338-
dummy_data = profiler.get_dummy_data(
339-
seq_len, is_encoder_data=is_encoder_data)
338+
dummy_data_factory = (profiler.get_encoder_dummy_data
339+
if is_encoder_data else
340+
profiler.get_decoder_dummy_data)
341+
dummy_data = dummy_data_factory(seq_len)
340342
else:
341343
model_cls, _ = get_model_architecture(model_config)
342344
if is_encoder_data:

vllm/multimodal/profiling.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from collections.abc import Mapping
55
from dataclasses import dataclass, field
6-
from typing import Generic, TypeVar
6+
from typing import Generic, TypeVar, cast
77

88
import numpy as np
99
import numpy.typing as npt
@@ -13,7 +13,8 @@
1313
from vllm.inputs import DummyData
1414
from vllm.logger import init_logger
1515

16-
from .inputs import MultiModalDataDict, MultiModalInputs
16+
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
17+
MultiModalInputs)
1718
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
1819

1920
logger = init_logger(__name__)
@@ -142,14 +143,10 @@ def _get_dummy_mm_inputs(
142143
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
143144
)
144145

145-
def get_dummy_data(
146+
def get_and_validate_mm_inputs(
146147
self,
147148
seq_len: int,
148-
is_encoder_data: bool = False,
149-
) -> DummyData:
150-
# Avoid circular import
151-
from vllm.sequence import SequenceData
152-
149+
) -> tuple[MultiModalInputs, Mapping[str, int]]:
153150
mm_counts = self.get_mm_limits()
154151

155152
info = self.processing_info
@@ -165,11 +162,6 @@ def get_dummy_data(
165162

166163
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
167164
placeholders_by_modality = mm_inputs["mm_placeholders"]
168-
# For encoder-decoder models, use encoder prompt token ids instead of
169-
# decoder prompt to construct dummy seq_data for encoder profiling.
170-
prompt_token_ids = (
171-
mm_inputs["prompt_token_ids"] if not is_encoder_data else
172-
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
173165

174166
total_placeholders_by_modality = {
175167
modality: sum(item["length"] for item in placeholders)
@@ -185,28 +177,60 @@ def get_dummy_data(
185177
f"{total_placeholders_by_modality} placeholder tokens, which "
186178
f"is not the expected {expected_placeholders_by_modality} "
187179
"tokens.")
180+
return mm_inputs, total_placeholders_by_modality
181+
182+
def get_encoder_dummy_data(
183+
self,
184+
seq_len: int,
185+
) -> DummyData:
186+
# Avoid circular import
187+
from vllm.sequence import SequenceData
188+
189+
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
190+
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
191+
192+
# For encoder-decoder models, use encoder prompt token ids instead of
193+
# decoder prompt to construct dummy seq_data for encoder profiling.
194+
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
195+
196+
total_len = len(encoder_prompt_token_ids)
197+
num_tokens_to_pad = max(total_len, seq_len) - total_len
198+
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
199+
200+
return DummyData(
201+
seq_data=SequenceData.from_seqs(encoder_prompt_token_ids),
202+
multi_modal_data=None,
203+
multi_modal_placeholders=None,
204+
)
205+
206+
def get_decoder_dummy_data(
207+
self,
208+
seq_len: int,
209+
) -> DummyData:
210+
# Avoid circular import
211+
from vllm.sequence import SequenceData
212+
213+
(mm_inputs, total_placeholders_by_modality
214+
) = self.get_and_validate_mm_inputs(seq_len)
188215

216+
prompt_token_ids = mm_inputs["prompt_token_ids"]
189217
total_len = len(prompt_token_ids)
190218

191219
# V0 does not support chunked prefill.
192-
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
193-
if total_len > seq_len and not is_encoder_data:
194-
logger.warning(
195-
"The context length (%d) of the model is too short "
196-
"to hold the multi-modal embeddings in the worst case "
197-
"(%d tokens in total, out of which %s are reserved for "
198-
"multi-modal embeddings). This may cause certain "
199-
"multi-modal inputs to fail during inference, even when "
200-
"the input text is short. To avoid this, you should "
201-
"increase `max_model_len`, reduce `max_num_seqs`, "
202-
"and/or reduce `mm_counts`.", seq_len, total_len,
203-
total_placeholders_by_modality)
204-
205-
num_tokens_to_pad = max(total_len, seq_len) - total_len
206-
prompt_token_ids.extend([0] * num_tokens_to_pad)
220+
if total_len > seq_len and not envs.VLLM_USE_V1:
221+
logger.warning(
222+
"The context length (%d) of the model is too short "
223+
"to hold the multi-modal embeddings in the worst case "
224+
"(%d tokens in total, out of which %s are reserved for "
225+
"multi-modal embeddings). This may cause certain "
226+
"multi-modal inputs to fail during inference, even when "
227+
"the input text is short. To avoid this, you should "
228+
"increase `max_model_len`, reduce `max_num_seqs`, "
229+
"and/or reduce `mm_counts`.", seq_len, total_len,
230+
total_placeholders_by_modality)
207231

208232
return DummyData(
209-
seq_data=SequenceData.from_seqs(prompt_token_ids),
233+
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
210234
multi_modal_data=None,
211235
multi_modal_placeholders=None,
212236
)
@@ -216,5 +240,5 @@ def get_dummy_data(
216240
return DummyData(
217241
seq_data=SequenceData.from_seqs(prompt_token_ids),
218242
multi_modal_data=mm_inputs["mm_kwargs"],
219-
multi_modal_placeholders=placeholders_by_modality,
243+
multi_modal_placeholders=mm_inputs["mm_placeholders"],
220244
)

0 commit comments

Comments
 (0)