Skip to content

Commit 4a39017

Browse files
njhillLeiWang1999
authored andcommitted
[Benchmark] Add --async-engine option to benchmark_throughput.py (vllm-project#7964)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent 5bd62ff commit 4a39017

File tree

3 files changed

+143
-19
lines changed

3 files changed

+143
-19
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
from typing import List, Optional, Tuple
77

88
import torch
9+
import uvloop
910
from tqdm import tqdm
1011
from transformers import (AutoModelForCausalLM, AutoTokenizer,
1112
PreTrainedTokenizerBase)
1213

13-
from vllm.engine.arg_utils import EngineArgs
14+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
15+
from vllm.entrypoints.openai.api_server import (
16+
build_async_engine_client_from_engine_args)
1417
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
15-
from vllm.utils import FlexibleArgumentParser
18+
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
1619

1720

1821
def sample_requests(
@@ -135,6 +138,93 @@ def run_vllm(
135138
return end - start
136139

137140

141+
async def run_vllm_async(
142+
requests: List[Tuple[str, int, int]],
143+
model: str,
144+
tokenizer: str,
145+
quantization: Optional[str],
146+
tensor_parallel_size: int,
147+
seed: int,
148+
n: int,
149+
use_beam_search: bool,
150+
trust_remote_code: bool,
151+
dtype: str,
152+
max_model_len: Optional[int],
153+
enforce_eager: bool,
154+
kv_cache_dtype: str,
155+
quantization_param_path: Optional[str],
156+
device: str,
157+
enable_prefix_caching: bool,
158+
enable_chunked_prefill: bool,
159+
max_num_batched_tokens: int,
160+
distributed_executor_backend: Optional[str],
161+
gpu_memory_utilization: float = 0.9,
162+
num_scheduler_steps: int = 1,
163+
use_v2_block_manager: bool = False,
164+
download_dir: Optional[str] = None,
165+
load_format: str = EngineArgs.load_format,
166+
disable_async_output_proc: bool = False,
167+
disable_frontend_multiprocessing: bool = False,
168+
) -> float:
169+
from vllm import SamplingParams
170+
engine_args = AsyncEngineArgs(
171+
model=model,
172+
tokenizer=tokenizer,
173+
quantization=quantization,
174+
tensor_parallel_size=tensor_parallel_size,
175+
seed=seed,
176+
trust_remote_code=trust_remote_code,
177+
dtype=dtype,
178+
max_model_len=max_model_len,
179+
gpu_memory_utilization=gpu_memory_utilization,
180+
enforce_eager=enforce_eager,
181+
kv_cache_dtype=kv_cache_dtype,
182+
quantization_param_path=quantization_param_path,
183+
device=device,
184+
enable_prefix_caching=enable_prefix_caching,
185+
download_dir=download_dir,
186+
enable_chunked_prefill=enable_chunked_prefill,
187+
max_num_batched_tokens=max_num_batched_tokens,
188+
distributed_executor_backend=distributed_executor_backend,
189+
load_format=load_format,
190+
num_scheduler_steps=num_scheduler_steps,
191+
use_v2_block_manager=use_v2_block_manager,
192+
disable_async_output_proc=disable_async_output_proc,
193+
worker_use_ray=False,
194+
engine_use_ray=False,
195+
disable_log_requests=True,
196+
)
197+
198+
async with build_async_engine_client_from_engine_args(
199+
engine_args, disable_frontend_multiprocessing) as llm:
200+
201+
# Add the requests to the engine.
202+
prompts: List[str] = []
203+
sampling_params: List[SamplingParams] = []
204+
for prompt, _, output_len in requests:
205+
prompts.append(prompt)
206+
sampling_params.append(
207+
SamplingParams(
208+
n=n,
209+
temperature=0.0 if use_beam_search else 1.0,
210+
top_p=1.0,
211+
use_beam_search=use_beam_search,
212+
ignore_eos=True,
213+
max_tokens=output_len,
214+
))
215+
216+
generators = []
217+
start = time.perf_counter()
218+
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
219+
generator = llm.generate(prompt, sp, request_id=f"test{i}")
220+
generators.append(generator)
221+
all_gens = merge_async_iterators(*generators)
222+
async for i, res in all_gens:
223+
pass
224+
end = time.perf_counter()
225+
return end - start
226+
227+
138228
def run_hf(
139229
requests: List[Tuple[str, int, int]],
140230
model: str,
@@ -230,7 +320,7 @@ def main(args: argparse.Namespace):
230320
args.output_len)
231321

232322
if args.backend == "vllm":
233-
elapsed_time = run_vllm(
323+
run_args = [
234324
requests, args.model, args.tokenizer, args.quantization,
235325
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
236326
args.trust_remote_code, args.dtype, args.max_model_len,
@@ -240,7 +330,14 @@ def main(args: argparse.Namespace):
240330
args.max_num_batched_tokens, args.distributed_executor_backend,
241331
args.gpu_memory_utilization, args.num_scheduler_steps,
242332
args.use_v2_block_manager, args.download_dir, args.load_format,
243-
args.disable_async_output_proc)
333+
args.disable_async_output_proc
334+
]
335+
336+
if args.async_engine:
337+
run_args.append(args.disable_frontend_multiprocessing)
338+
elapsed_time = uvloop.run(run_vllm_async(*run_args))
339+
else:
340+
elapsed_time = run_vllm(*run_args)
244341
elif args.backend == "hf":
245342
assert args.tensor_parallel_size == 1
246343
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -426,6 +523,14 @@ def main(args: argparse.Namespace):
426523
action='store_true',
427524
default=False,
428525
help="Disable async output processor for vLLM backend.")
526+
parser.add_argument("--async-engine",
527+
action='store_true',
528+
default=False,
529+
help="Use vLLM async engine rather than LLM class.")
530+
parser.add_argument("--disable-frontend-multiprocessing",
531+
action='store_true',
532+
default=False,
533+
help="Disable decoupled async engine frontend.")
429534
args = parser.parse_args()
430535
if args.tokenizer is None:
431536
args.tokenizer = args.model

vllm/entrypoints/openai/api_server.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767

6868

6969
def model_is_embedding(model_name: str, trust_remote_code: bool,
70-
quantization: str) -> bool:
70+
quantization: Optional[str]) -> bool:
7171
return ModelConfig(model=model_name,
7272
tokenizer=model_name,
7373
tokenizer_mode="auto",
@@ -96,13 +96,6 @@ async def _force_log():
9696
@asynccontextmanager
9797
async def build_async_engine_client(
9898
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
99-
"""
100-
Create AsyncEngineClient, either:
101-
- in-process using the AsyncLLMEngine Directly
102-
- multiprocess using AsyncLLMEngine RPC
103-
104-
Returns the Client or None if the creation failed.
105-
"""
10699

107100
# Context manager to handle async_engine_client lifecycle
108101
# Ensures everything is shutdown and cleaned up on error/exit
@@ -112,14 +105,37 @@ async def build_async_engine_client(
112105
# Backend itself still global for the silly lil' health handler
113106
global async_engine_client
114107

108+
async with build_async_engine_client_from_engine_args(
109+
engine_args, args.disable_frontend_multiprocessing) as engine:
110+
111+
async_engine_client = engine # type: ignore[assignment]
112+
yield engine
113+
114+
115+
@asynccontextmanager
116+
async def build_async_engine_client_from_engine_args(
117+
engine_args: AsyncEngineArgs,
118+
disable_frontend_multiprocessing: bool = False,
119+
) -> AsyncIterator[Optional[AsyncEngineClient]]:
120+
"""
121+
Create AsyncEngineClient, either:
122+
- in-process using the AsyncLLMEngine Directly
123+
- multiprocess using AsyncLLMEngine RPC
124+
125+
Returns the Client or None if the creation failed.
126+
"""
127+
115128
# If manually triggered or embedding model, use AsyncLLMEngine in process.
116129
# TODO: support embedding model via RPC.
117-
if (model_is_embedding(args.model, args.trust_remote_code,
118-
args.quantization)
119-
or args.disable_frontend_multiprocessing):
120-
async_engine_client = AsyncLLMEngine.from_engine_args(
130+
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
131+
engine_args.quantization)
132+
or disable_frontend_multiprocessing):
133+
engine_client = AsyncLLMEngine.from_engine_args(
121134
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
122-
yield async_engine_client
135+
try:
136+
yield engine_client
137+
finally:
138+
engine_client.shutdown_background_loop()
123139
return
124140

125141
# Otherwise, use the multiprocessing AsyncLLMEngine.
@@ -148,7 +164,6 @@ async def build_async_engine_client(
148164
# NOTE: Actually, this is not true yet. We still need to support
149165
# embedding models via RPC (see TODO above)
150166
rpc_client = AsyncEngineRPCClient(rpc_path)
151-
async_engine_client = rpc_client # type: ignore
152167

153168
# Start RPCServer in separate process (holds the AsyncLLMEngine).
154169
context = multiprocessing.get_context("spawn")
@@ -174,7 +189,7 @@ async def build_async_engine_client(
174189
yield None
175190
return
176191

177-
yield async_engine_client
192+
yield rpc_client # type: ignore[misc]
178193
finally:
179194
# Ensure rpc server process was terminated
180195
rpc_server_process.terminate()

vllm/entrypoints/openai/rpc/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import cloudpickle
88
import zmq
99
import zmq.asyncio
10+
from zmq import Frame # type: ignore[attr-defined]
1011
from zmq.asyncio import Socket
1112

1213
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
@@ -214,6 +215,7 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
214215

215216
# Await the data from the Server.
216217
frame = await socket.recv(copy=False)
218+
assert isinstance(frame, Frame)
217219
data = pickle.loads(frame.buffer)
218220

219221
if isinstance(data, Exception):
@@ -247,6 +249,7 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
247249
f"{self._data_timeout} ms")
248250

249251
frame = await socket.recv(copy=False)
252+
assert isinstance(frame, Frame)
250253
return pickle.loads(frame.buffer)
251254

252255
# Make a new socket connection.
@@ -395,6 +398,7 @@ async def generate(
395398
# Stream back the results from the RPC Server.
396399
while not finished:
397400
message = await socket.recv(copy=False)
401+
assert isinstance(message, Frame)
398402
request_output = pickle.loads(message.buffer)
399403

400404
if isinstance(request_output, Exception):

0 commit comments

Comments
 (0)