3
3
from abc import ABC , abstractmethod
4
4
from collections .abc import Mapping
5
5
from dataclasses import dataclass , field
6
- from typing import Generic , TypeVar
6
+ from typing import Generic , TypeVar , cast
7
7
8
8
import numpy as np
9
9
import numpy .typing as npt
13
13
from vllm .inputs import DummyData
14
14
from vllm .logger import init_logger
15
15
16
- from .inputs import MultiModalDataDict , MultiModalInputs
16
+ from .inputs import (MultiModalDataDict , MultiModalEncDecInputs ,
17
+ MultiModalInputs )
17
18
from .processing import BaseMultiModalProcessor , BaseProcessingInfo
18
19
19
20
logger = init_logger (__name__ )
@@ -142,14 +143,10 @@ def _get_dummy_mm_inputs(
142
143
hf_processor_mm_kwargs = processor_inputs .hf_processor_mm_kwargs ,
143
144
)
144
145
145
- def get_dummy_data (
146
+ def get_and_validate_mm_inputs (
146
147
self ,
147
148
seq_len : int ,
148
- is_encoder_data : bool = False ,
149
- ) -> DummyData :
150
- # Avoid circular import
151
- from vllm .sequence import SequenceData
152
-
149
+ ) -> tuple [MultiModalInputs , Mapping [str , int ]]:
153
150
mm_counts = self .get_mm_limits ()
154
151
155
152
info = self .processing_info
@@ -165,11 +162,6 @@ def get_dummy_data(
165
162
166
163
mm_inputs = self ._get_dummy_mm_inputs (seq_len , mm_counts )
167
164
placeholders_by_modality = mm_inputs ["mm_placeholders" ]
168
- # For encoder-decoder models, use encoder prompt token ids instead of
169
- # decoder prompt to construct dummy seq_data for encoder profiling.
170
- prompt_token_ids = (
171
- mm_inputs ["prompt_token_ids" ] if not is_encoder_data else
172
- mm_inputs ["encoder_prompt_token_ids" ]) # type: ignore
173
165
174
166
total_placeholders_by_modality = {
175
167
modality : sum (item ["length" ] for item in placeholders )
@@ -185,28 +177,60 @@ def get_dummy_data(
185
177
f"{ total_placeholders_by_modality } placeholder tokens, which "
186
178
f"is not the expected { expected_placeholders_by_modality } "
187
179
"tokens." )
180
+ return mm_inputs , total_placeholders_by_modality
181
+
182
+ def get_encoder_dummy_data (
183
+ self ,
184
+ seq_len : int ,
185
+ ) -> DummyData :
186
+ # Avoid circular import
187
+ from vllm .sequence import SequenceData
188
+
189
+ mm_inputs , _ = self .get_and_validate_mm_inputs (seq_len )
190
+ mm_inputs = cast (MultiModalEncDecInputs , mm_inputs )
191
+
192
+ # For encoder-decoder models, use encoder prompt token ids instead of
193
+ # decoder prompt to construct dummy seq_data for encoder profiling.
194
+ encoder_prompt_token_ids = mm_inputs ["encoder_prompt_token_ids" ]
195
+
196
+ total_len = len (encoder_prompt_token_ids )
197
+ num_tokens_to_pad = max (total_len , seq_len ) - total_len
198
+ encoder_prompt_token_ids .extend ([0 ] * num_tokens_to_pad )
199
+
200
+ return DummyData (
201
+ seq_data = SequenceData .from_seqs (encoder_prompt_token_ids ),
202
+ multi_modal_data = None ,
203
+ multi_modal_placeholders = None ,
204
+ )
205
+
206
+ def get_decoder_dummy_data (
207
+ self ,
208
+ seq_len : int ,
209
+ ) -> DummyData :
210
+ # Avoid circular import
211
+ from vllm .sequence import SequenceData
212
+
213
+ (mm_inputs , total_placeholders_by_modality
214
+ ) = self .get_and_validate_mm_inputs (seq_len )
188
215
216
+ prompt_token_ids = mm_inputs ["prompt_token_ids" ]
189
217
total_len = len (prompt_token_ids )
190
218
191
219
# V0 does not support chunked prefill.
192
- if (total_len > seq_len and not envs .VLLM_USE_V1 ) or is_encoder_data :
193
- if total_len > seq_len and not is_encoder_data :
194
- logger .warning (
195
- "The context length (%d) of the model is too short "
196
- "to hold the multi-modal embeddings in the worst case "
197
- "(%d tokens in total, out of which %s are reserved for "
198
- "multi-modal embeddings). This may cause certain "
199
- "multi-modal inputs to fail during inference, even when "
200
- "the input text is short. To avoid this, you should "
201
- "increase `max_model_len`, reduce `max_num_seqs`, "
202
- "and/or reduce `mm_counts`." , seq_len , total_len ,
203
- total_placeholders_by_modality )
204
-
205
- num_tokens_to_pad = max (total_len , seq_len ) - total_len
206
- prompt_token_ids .extend ([0 ] * num_tokens_to_pad )
220
+ if total_len > seq_len and not envs .VLLM_USE_V1 :
221
+ logger .warning (
222
+ "The context length (%d) of the model is too short "
223
+ "to hold the multi-modal embeddings in the worst case "
224
+ "(%d tokens in total, out of which %s are reserved for "
225
+ "multi-modal embeddings). This may cause certain "
226
+ "multi-modal inputs to fail during inference, even when "
227
+ "the input text is short. To avoid this, you should "
228
+ "increase `max_model_len`, reduce `max_num_seqs`, "
229
+ "and/or reduce `mm_counts`." , seq_len , total_len ,
230
+ total_placeholders_by_modality )
207
231
208
232
return DummyData (
209
- seq_data = SequenceData .from_seqs ( prompt_token_ids ),
233
+ seq_data = SequenceData .from_prompt_token_counts (( 0 , seq_len ) ),
210
234
multi_modal_data = None ,
211
235
multi_modal_placeholders = None ,
212
236
)
@@ -216,5 +240,5 @@ def get_dummy_data(
216
240
return DummyData (
217
241
seq_data = SequenceData .from_seqs (prompt_token_ids ),
218
242
multi_modal_data = mm_inputs ["mm_kwargs" ],
219
- multi_modal_placeholders = placeholders_by_modality ,
243
+ multi_modal_placeholders = mm_inputs [ "mm_placeholders" ] ,
220
244
)
0 commit comments