|
7 | 7 | import random
|
8 | 8 | import time
|
9 | 9 | from functools import cache
|
10 |
| -from typing import Any, Optional |
| 10 | +from typing import Any, Optional, Union |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 | import uvloop
|
|
20 | 20 | from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
21 | 21 | from vllm.entrypoints.openai.api_server import (
|
22 | 22 | build_async_engine_client_from_engine_args)
|
23 |
| -from vllm.inputs import TextPrompt |
| 23 | +from vllm.inputs import TextPrompt, TokensPrompt |
24 | 24 | from vllm.lora.request import LoRARequest
|
25 | 25 | from vllm.lora.utils import get_adapter_absolute_path
|
26 | 26 | from vllm.multimodal import MultiModalDataDict
|
@@ -178,10 +178,13 @@ def run_vllm(
|
178 | 178 | "Please ensure that max_model_len is greater than the sum of"
|
179 | 179 | " prompt_len and expected_output_len for all requests.")
|
180 | 180 | # Add the requests to the engine.
|
181 |
| - prompts: list[TextPrompt] = [] |
| 181 | + prompts: list[Union[TextPrompt, TokensPrompt]] = [] |
182 | 182 | sampling_params: list[SamplingParams] = []
|
183 | 183 | for request in requests:
|
184 | 184 | prompts.append(
|
| 185 | + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], |
| 186 | + multi_modal_data=request.multi_modal_data) |
| 187 | + if "prompt_token_ids" in request.prompt else \ |
185 | 188 | TextPrompt(prompt=request.prompt,
|
186 | 189 | multi_modal_data=request.multi_modal_data))
|
187 | 190 | sampling_params.append(
|
@@ -242,11 +245,14 @@ async def run_vllm_async(
|
242 | 245 | " prompt_len and expected_output_len for all requests.")
|
243 | 246 |
|
244 | 247 | # Add the requests to the engine.
|
245 |
| - prompts: list[TextPrompt] = [] |
| 248 | + prompts: list[Union[TextPrompt, TokensPrompt]] = [] |
246 | 249 | sampling_params: list[SamplingParams] = []
|
247 | 250 | lora_requests: list[Optional[LoRARequest]] = []
|
248 | 251 | for request in requests:
|
249 | 252 | prompts.append(
|
| 253 | + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], |
| 254 | + multi_modal_data=request.multi_modal_data) |
| 255 | + if "prompt_token_ids" in request.prompt else \ |
250 | 256 | TextPrompt(prompt=request.prompt,
|
251 | 257 | multi_modal_data=request.multi_modal_data))
|
252 | 258 | sampling_params.append(
|
@@ -393,24 +399,29 @@ def main(args: argparse.Namespace):
|
393 | 399 | random.randint(0, vocab_size - 1)
|
394 | 400 | for _ in range(args.input_len)
|
395 | 401 | ]
|
396 |
| - # As tokenizer may add additional tokens like BOS, we need to try |
397 |
| - # different lengths to get the desired input length. |
398 |
| - for _ in range(5): # Max attempts to correct |
399 |
| - candidate_prompt = request_tokenizer.decode(candidate_ids) |
400 |
| - tokenized_len = len(request_tokenizer.encode(candidate_prompt)) |
401 |
| - |
402 |
| - if tokenized_len == args.input_len: |
403 |
| - break |
404 |
| - |
405 |
| - # Adjust length based on difference |
406 |
| - diff = args.input_len - tokenized_len |
407 |
| - if diff > 0: |
408 |
| - candidate_ids.extend([ |
409 |
| - random.randint(100, vocab_size - 100) |
410 |
| - for _ in range(diff) |
411 |
| - ]) |
412 |
| - else: |
413 |
| - candidate_ids = candidate_ids[:diff] |
| 402 | + |
| 403 | + candidate_prompt = {"prompt_token_ids": candidate_ids} |
| 404 | + |
| 405 | + if not args.skip_tokenizer_init: |
| 406 | + # As tokenizer may add additional tokens like BOS, we need |
| 407 | + # to try different lengths to get the desired input length. |
| 408 | + for _ in range(5): # Max attempts to correct |
| 409 | + candidate_prompt = request_tokenizer.decode(candidate_ids) |
| 410 | + tokenized_len = len( |
| 411 | + request_tokenizer.encode(candidate_prompt)) |
| 412 | + |
| 413 | + if tokenized_len == args.input_len: |
| 414 | + break |
| 415 | + |
| 416 | + # Adjust length based on difference |
| 417 | + diff = args.input_len - tokenized_len |
| 418 | + if diff > 0: |
| 419 | + candidate_ids.extend([ |
| 420 | + random.randint(100, vocab_size - 100) |
| 421 | + for _ in range(diff) |
| 422 | + ]) |
| 423 | + else: |
| 424 | + candidate_ids = candidate_ids[:diff] |
414 | 425 | requests.append(
|
415 | 426 | SampleRequest(prompt=candidate_prompt,
|
416 | 427 | prompt_len=args.input_len,
|
|
0 commit comments