Skip to content

Commit 67d1236

Browse files
maleksan85rootAleksandr Malyshev
authored andcommitted
[BUGFIX] Skip tokenization support for throughput benchmark (vllm-project#12712)
Signed-off-by: root <[email protected]> Signed-off-by: Aleksandr Malyshev <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: Aleksandr Malyshev <[email protected]>
1 parent 0095256 commit 67d1236

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import random
88
import time
99
from functools import cache
10-
from typing import Any, Optional
10+
from typing import Any, Optional, Union
1111

1212
import torch
1313
import uvloop
@@ -20,7 +20,7 @@
2020
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
2121
from vllm.entrypoints.openai.api_server import (
2222
build_async_engine_client_from_engine_args)
23-
from vllm.inputs import TextPrompt
23+
from vllm.inputs import TextPrompt, TokensPrompt
2424
from vllm.lora.request import LoRARequest
2525
from vllm.lora.utils import get_adapter_absolute_path
2626
from vllm.multimodal import MultiModalDataDict
@@ -178,10 +178,13 @@ def run_vllm(
178178
"Please ensure that max_model_len is greater than the sum of"
179179
" prompt_len and expected_output_len for all requests.")
180180
# Add the requests to the engine.
181-
prompts: list[TextPrompt] = []
181+
prompts: list[Union[TextPrompt, TokensPrompt]] = []
182182
sampling_params: list[SamplingParams] = []
183183
for request in requests:
184184
prompts.append(
185+
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
186+
multi_modal_data=request.multi_modal_data)
187+
if "prompt_token_ids" in request.prompt else \
185188
TextPrompt(prompt=request.prompt,
186189
multi_modal_data=request.multi_modal_data))
187190
sampling_params.append(
@@ -242,11 +245,14 @@ async def run_vllm_async(
242245
" prompt_len and expected_output_len for all requests.")
243246

244247
# Add the requests to the engine.
245-
prompts: list[TextPrompt] = []
248+
prompts: list[Union[TextPrompt, TokensPrompt]] = []
246249
sampling_params: list[SamplingParams] = []
247250
lora_requests: list[Optional[LoRARequest]] = []
248251
for request in requests:
249252
prompts.append(
253+
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
254+
multi_modal_data=request.multi_modal_data)
255+
if "prompt_token_ids" in request.prompt else \
250256
TextPrompt(prompt=request.prompt,
251257
multi_modal_data=request.multi_modal_data))
252258
sampling_params.append(
@@ -393,24 +399,29 @@ def main(args: argparse.Namespace):
393399
random.randint(0, vocab_size - 1)
394400
for _ in range(args.input_len)
395401
]
396-
# As tokenizer may add additional tokens like BOS, we need to try
397-
# different lengths to get the desired input length.
398-
for _ in range(5): # Max attempts to correct
399-
candidate_prompt = request_tokenizer.decode(candidate_ids)
400-
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
401-
402-
if tokenized_len == args.input_len:
403-
break
404-
405-
# Adjust length based on difference
406-
diff = args.input_len - tokenized_len
407-
if diff > 0:
408-
candidate_ids.extend([
409-
random.randint(100, vocab_size - 100)
410-
for _ in range(diff)
411-
])
412-
else:
413-
candidate_ids = candidate_ids[:diff]
402+
403+
candidate_prompt = {"prompt_token_ids": candidate_ids}
404+
405+
if not args.skip_tokenizer_init:
406+
# As tokenizer may add additional tokens like BOS, we need
407+
# to try different lengths to get the desired input length.
408+
for _ in range(5): # Max attempts to correct
409+
candidate_prompt = request_tokenizer.decode(candidate_ids)
410+
tokenized_len = len(
411+
request_tokenizer.encode(candidate_prompt))
412+
413+
if tokenized_len == args.input_len:
414+
break
415+
416+
# Adjust length based on difference
417+
diff = args.input_len - tokenized_len
418+
if diff > 0:
419+
candidate_ids.extend([
420+
random.randint(100, vocab_size - 100)
421+
for _ in range(diff)
422+
])
423+
else:
424+
candidate_ids = candidate_ids[:diff]
414425
requests.append(
415426
SampleRequest(prompt=candidate_prompt,
416427
prompt_len=args.input_len,

vllm/engine/arg_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
276276
parser.add_argument(
277277
'--skip-tokenizer-init',
278278
action='store_true',
279-
help='Skip initialization of tokenizer and detokenizer.')
279+
help='Skip initialization of tokenizer and detokenizer. '
280+
'Expects valid prompt_token_ids and None for prompt from '
281+
'the input. The generated output will contain token ids.')
280282
parser.add_argument(
281283
'--revision',
282284
type=nullable_str,

0 commit comments

Comments
 (0)