Skip to content

Commit 2db9044

Browse files
authored
[Bugfix] Fix auto dtype casting for BatchFeature (#19316)
Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 6fa718a commit 2db9044

File tree

7 files changed

+85
-57
lines changed

7 files changed

+85
-57
lines changed

tests/v1/engine/test_async_llm.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.inputs import PromptType
1616
from vllm.platforms import current_platform
1717
from vllm.sampling_params import RequestOutputKind
18+
from vllm.utils import set_default_torch_num_threads
1819
from vllm.v1.engine.async_llm import AsyncLLM
1920
from vllm.v1.metrics.loggers import LoggingStatLogger
2021

@@ -107,7 +108,8 @@ async def test_load(
107108
with monkeypatch.context() as m, ExitStack() as after:
108109
m.setenv("VLLM_USE_V1", "1")
109110

110-
engine = AsyncLLM.from_engine_args(engine_args)
111+
with set_default_torch_num_threads(1):
112+
engine = AsyncLLM.from_engine_args(engine_args)
111113
after.callback(engine.shutdown)
112114

113115
NUM_REQUESTS = 100
@@ -154,7 +156,8 @@ async def test_abort(
154156
with monkeypatch.context() as m, ExitStack() as after:
155157
m.setenv("VLLM_USE_V1", "1")
156158

157-
engine = AsyncLLM.from_engine_args(engine_args)
159+
with set_default_torch_num_threads(1):
160+
engine = AsyncLLM.from_engine_args(engine_args)
158161
after.callback(engine.shutdown)
159162

160163
NUM_REQUESTS = 100
@@ -226,7 +229,8 @@ async def test_finished_flag(
226229
with monkeypatch.context() as m, ExitStack() as after:
227230
m.setenv("VLLM_USE_V1", "1")
228231

229-
engine = AsyncLLM.from_engine_args(engine_args)
232+
with set_default_torch_num_threads(1):
233+
engine = AsyncLLM.from_engine_args(engine_args)
230234
after.callback(engine.shutdown)
231235

232236
sampling_params = SamplingParams(
@@ -260,7 +264,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
260264
with monkeypatch.context() as m, ExitStack() as after:
261265
m.setenv("VLLM_USE_V1", "1")
262266

263-
engine = AsyncLLM.from_engine_args(engine_args)
267+
with set_default_torch_num_threads(1):
268+
engine = AsyncLLM.from_engine_args(engine_args)
264269
after.callback(engine.shutdown)
265270

266271
NUM_REQUESTS = 100
@@ -322,10 +327,11 @@ async def test_customize_loggers(monkeypatch):
322327
with monkeypatch.context() as m, ExitStack() as after:
323328
m.setenv("VLLM_USE_V1", "1")
324329

325-
engine = AsyncLLM.from_engine_args(
326-
TEXT_ENGINE_ARGS,
327-
stat_loggers=[MockLoggingStatLogger],
328-
)
330+
with set_default_torch_num_threads(1):
331+
engine = AsyncLLM.from_engine_args(
332+
TEXT_ENGINE_ARGS,
333+
stat_loggers=[MockLoggingStatLogger],
334+
)
329335
after.callback(engine.shutdown)
330336

331337
await engine.do_log_stats()
@@ -340,7 +346,8 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
340346
with monkeypatch.context() as m, ExitStack() as after:
341347
m.setenv("VLLM_USE_V1", "1")
342348

343-
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
349+
with set_default_torch_num_threads(1):
350+
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
344351
after.callback(engine.shutdown)
345352

346353
sampling_params = SamplingParams(max_tokens=100,

tests/v1/engine/test_engine_core.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm import SamplingParams
1313
from vllm.engine.arg_utils import EngineArgs
1414
from vllm.platforms import current_platform
15+
from vllm.utils import set_default_torch_num_threads
1516
from vllm.v1.engine import EngineCoreRequest
1617
from vllm.v1.engine.core import EngineCore
1718
from vllm.v1.executor.abstract import Executor, UniProcExecutor
@@ -56,9 +57,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
5657
vllm_config = engine_args.create_engine_config()
5758
executor_class = Executor.get_class(vllm_config)
5859

59-
engine_core = EngineCore(vllm_config=vllm_config,
60-
executor_class=executor_class,
61-
log_stats=True)
60+
with set_default_torch_num_threads(1):
61+
engine_core = EngineCore(vllm_config=vllm_config,
62+
executor_class=executor_class,
63+
log_stats=True)
6264
"""Test basic request lifecycle."""
6365

6466
# First request.
@@ -190,9 +192,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
190192
vllm_config = engine_args.create_engine_config()
191193
executor_class = Executor.get_class(vllm_config)
192194

193-
engine_core = EngineCore(vllm_config=vllm_config,
194-
executor_class=executor_class,
195-
log_stats=True)
195+
with set_default_torch_num_threads(1):
196+
engine_core = EngineCore(vllm_config=vllm_config,
197+
executor_class=executor_class,
198+
log_stats=True)
196199
"""Test basic request lifecycle."""
197200
# First request.
198201
request: EngineCoreRequest = make_request()
@@ -286,9 +289,10 @@ def shutdown(self):
286289
enforce_eager=True,
287290
)
288291
vllm_config = engine_args.create_engine_config()
289-
engine_core = EngineCore(vllm_config=vllm_config,
290-
log_stats=False,
291-
executor_class=DummyExecutor)
292+
with set_default_torch_num_threads(1):
293+
engine_core = EngineCore(vllm_config=vllm_config,
294+
log_stats=False,
295+
executor_class=DummyExecutor)
292296
assert engine_core.batch_queue is not None
293297

294298
# Add two requests in a row. Each request have 12 prompt tokens.

tests/v1/engine/test_engine_core_client.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.engine.arg_utils import EngineArgs
2020
from vllm.platforms import current_platform
2121
from vllm.usage.usage_lib import UsageContext
22+
from vllm.utils import set_default_torch_num_threads
2223
from vllm.v1.engine import EngineCoreRequest
2324
from vllm.v1.engine.core import EngineCore
2425
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
@@ -138,13 +139,15 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
138139
vllm_config = engine_args.create_engine_config(
139140
UsageContext.UNKNOWN_CONTEXT)
140141
executor_class = Executor.get_class(vllm_config)
141-
client = EngineCoreClient.make_client(
142-
multiprocess_mode=multiprocessing_mode,
143-
asyncio_mode=False,
144-
vllm_config=vllm_config,
145-
executor_class=executor_class,
146-
log_stats=False,
147-
)
142+
143+
with set_default_torch_num_threads(1):
144+
client = EngineCoreClient.make_client(
145+
multiprocess_mode=multiprocessing_mode,
146+
asyncio_mode=False,
147+
vllm_config=vllm_config,
148+
executor_class=executor_class,
149+
log_stats=False,
150+
)
148151

149152
MAX_TOKENS = 20
150153
params = SamplingParams(max_tokens=MAX_TOKENS)
@@ -223,13 +226,15 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
223226
vllm_config = engine_args.create_engine_config(
224227
usage_context=UsageContext.UNKNOWN_CONTEXT)
225228
executor_class = Executor.get_class(vllm_config)
226-
client = EngineCoreClient.make_client(
227-
multiprocess_mode=True,
228-
asyncio_mode=True,
229-
vllm_config=vllm_config,
230-
executor_class=executor_class,
231-
log_stats=True,
232-
)
229+
230+
with set_default_torch_num_threads(1):
231+
client = EngineCoreClient.make_client(
232+
multiprocess_mode=True,
233+
asyncio_mode=True,
234+
vllm_config=vllm_config,
235+
executor_class=executor_class,
236+
log_stats=True,
237+
)
233238

234239
try:
235240
MAX_TOKENS = 20
@@ -312,13 +317,14 @@ def test_kv_cache_events(
312317
UsageContext.UNKNOWN_CONTEXT)
313318

314319
executor_class = Executor.get_class(vllm_config)
315-
client = EngineCoreClient.make_client(
316-
multiprocess_mode=multiprocessing_mode,
317-
asyncio_mode=False,
318-
vllm_config=vllm_config,
319-
executor_class=executor_class,
320-
log_stats=False,
321-
)
320+
with set_default_torch_num_threads(1):
321+
client = EngineCoreClient.make_client(
322+
multiprocess_mode=multiprocessing_mode,
323+
asyncio_mode=False,
324+
vllm_config=vllm_config,
325+
executor_class=executor_class,
326+
log_stats=False,
327+
)
322328
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
323329
subscriber = MockSubscriber(endpoint,
324330
topic=publisher_config.topic,
@@ -394,13 +400,14 @@ async def test_kv_cache_events_dp(
394400
UsageContext.UNKNOWN_CONTEXT)
395401

396402
executor_class = Executor.get_class(vllm_config)
397-
client = EngineCoreClient.make_client(
398-
multiprocess_mode=multiprocessing_mode,
399-
asyncio_mode=True,
400-
vllm_config=vllm_config,
401-
executor_class=executor_class,
402-
log_stats=False,
403-
)
403+
with set_default_torch_num_threads(1):
404+
client = EngineCoreClient.make_client(
405+
multiprocess_mode=multiprocessing_mode,
406+
asyncio_mode=True,
407+
vllm_config=vllm_config,
408+
executor_class=executor_class,
409+
log_stats=False,
410+
)
404411
await asyncio.sleep(1)
405412

406413
# Build endpoints for all DP ranks

vllm/inputs/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,12 @@ def maybe_cast_dtype(x):
168168
try:
169169
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
170170
# this emulates output.to(dtype=self.model_config.dtype)
171-
cast_output = json_map_leaves(maybe_cast_dtype, output)
172171
if isinstance(output, BatchFeature):
172+
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
173173
return BatchFeature(cast_output)
174174

175+
cast_output = json_map_leaves(maybe_cast_dtype, output)
176+
175177
logger.warning_once(
176178
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
177179
"Make sure to match the behaviour of `ProcessorMixin` when "

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -965,9 +965,9 @@ def _process_image_input(
965965
grid_thw_list = grid_thw.tolist()
966966

967967
if image_input["type"] == "image_embeds":
968-
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
968+
image_embeds = image_input["image_embeds"]
969969
else:
970-
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
970+
pixel_values = image_input["pixel_values"]
971971
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
972972

973973
# Split concatenated embeddings for each image item.
@@ -985,10 +985,9 @@ def _process_video_input(
985985
grid_thw_list = grid_thw.tolist()
986986

987987
if video_input["type"] == "video_embeds":
988-
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
988+
video_embeds = video_input["video_embeds"]
989989
else:
990-
pixel_values_videos = video_input["pixel_values_videos"].type(
991-
self.visual.dtype)
990+
pixel_values_videos = video_input["pixel_values_videos"]
992991
video_embeds = self.visual(pixel_values_videos,
993992
grid_thw=grid_thw_list)
994993

vllm/model_executor/models/qwen2_vl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,9 +1208,9 @@ def _process_image_input(
12081208
assert grid_thw.ndim == 2
12091209

12101210
if image_input["type"] == "image_embeds":
1211-
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
1211+
image_embeds = image_input["image_embeds"]
12121212
else:
1213-
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1213+
pixel_values = image_input["pixel_values"]
12141214
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
12151215

12161216
# Split concatenated embeddings for each image item.
@@ -1226,10 +1226,9 @@ def _process_video_input(
12261226
assert grid_thw.ndim == 2
12271227

12281228
if video_input["type"] == "video_embeds":
1229-
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
1229+
video_embeds = video_input["video_embeds"]
12301230
else:
1231-
pixel_values_videos = video_input["pixel_values_videos"].type(
1232-
self.visual.dtype)
1231+
pixel_values_videos = video_input["pixel_values_videos"]
12331232
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
12341233

12351234
# Split concatenated embeddings for each video item.

vllm/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@
190190
torch.int64: np.int64,
191191
}
192192

193+
194+
@contextlib.contextmanager
195+
def set_default_torch_num_threads(num_threads: int):
196+
"""Sets the default number of threads for PyTorch to the given value."""
197+
old_num_threads = torch.get_num_threads()
198+
torch.set_num_threads(num_threads)
199+
yield
200+
torch.set_num_threads(old_num_threads)
201+
202+
193203
P = ParamSpec('P')
194204
T = TypeVar("T")
195205
U = TypeVar("U")

0 commit comments

Comments
 (0)