Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
Expand Down Expand Up @@ -171,10 +171,13 @@ def run_vllm(
llm = LLM(**dataclasses.asdict(engine_args))

# Add the requests to the engine.
prompts: List[TextPrompt] = []
prompts: List[TextPrompt | TokensPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
Expand Down Expand Up @@ -229,11 +232,14 @@ async def run_vllm_async(
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[TextPrompt] = []
prompts: List[TextPrompt | TokensPrompt] = []
sampling_params: List[SamplingParams] = []
lora_requests: List[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
Expand Down Expand Up @@ -362,24 +368,29 @@ def main(args: argparse.Namespace):
random.randint(0, vocab_size - 1)
for _ in range(args.input_len)
]
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(request_tokenizer.encode(candidate_prompt))

if tokenized_len == args.input_len:
break

# Adjust length based on difference
diff = args.input_len - tokenized_len
if diff > 0:
candidate_ids.extend([
random.randint(100, vocab_size - 100)
for _ in range(diff)
])
else:
candidate_ids = candidate_ids[:diff]

candidate_prompt = {"prompt_token_ids": candidate_ids}

if not args.skip_tokenizer_init:
# As tokenizer may add additional tokens like BOS, we need
# to try different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(
request_tokenizer.encode(candidate_prompt))

if tokenized_len == args.input_len:
break

# Adjust length based on difference
diff = args.input_len - tokenized_len
if diff > 0:
candidate_ids.extend([
random.randint(100, vocab_size - 100)
for _ in range(diff)
])
else:
candidate_ids = candidate_ids[:diff]
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,
Expand Down