Skip to content

Commit 311490a

Browse files
authored
Add script for benchmarking serving throughput (#145)
1 parent da5ddcd commit 311490a

10 files changed

+421
-415
lines changed

benchmarks/benchmark_async_llm_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def main(args: argparse.Namespace):
1010
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
1111
for i in range(args.n_threads)]
1212

13+
api_url = f"http://{args.host}:{args.port}/generate"
1314
headers = {"User-Agent": "CacheFlow Benchmark Client"}
1415
ploads = [{
1516
"prompt": p,
@@ -19,8 +20,8 @@ def main(args: argparse.Namespace):
1920
} for p in prompts]
2021

2122
def send_request(results, i):
22-
response = requests.post(args.api_url, headers=headers,
23-
json=ploads[i], stream=True)
23+
response = requests.post(api_url, headers=headers, json=ploads[i],
24+
stream=True)
2425
results[i] = response
2526

2627
# use args.n_threads to prompt the backend
@@ -50,7 +51,8 @@ def send_request(results, i):
5051

5152
if __name__ == "__main__":
5253
parser = argparse.ArgumentParser()
53-
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
54+
parser.add_argument("--host", type=str, default="localhost")
55+
parser.add_argument("--port", type=int, default=8001)
5456
parser.add_argument("--max-tokens", type=int, default=128)
5557
parser.add_argument("--n-threads", type=int, default=128)
5658
args = parser.parse_args()

benchmarks/benchmark_latency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Benchmark the latency of processing a single batch of requests."""
12
import argparse
23
import time
34

benchmarks/benchmark_serving.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""Benchmark online serving throughput.
2+
3+
On the server side, run one of the following commands:
4+
(CacheFlow backend)
5+
python -m cacheflow.entrypoints.simple_fastapi_frontend \
6+
--disable-log-requests --model <your_model>
7+
8+
(TGI backend)
9+
./launch_hf_server.sh <your_model>
10+
11+
On the client side, run:
12+
python benchmarks/benchmark_serving.py \
13+
--backend <backend> \
14+
--tokenizer <your_model> --dataset <target_dataset> \
15+
--request-rate <request_rate>
16+
"""
17+
import argparse
18+
import asyncio
19+
import json
20+
import random
21+
import time
22+
from typing import AsyncGenerator, List, Tuple
23+
24+
import aiohttp
25+
import numpy as np
26+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
27+
28+
# (prompt len, output len, latency)
29+
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
30+
31+
32+
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
33+
config = AutoConfig.from_pretrained(model_name)
34+
if config.model_type == "llama":
35+
# A workaround for potential protobuf errors.
36+
model_name = "hf-internal-testing/llama-tokenizer"
37+
return AutoTokenizer.from_pretrained(model_name)
38+
39+
40+
def sample_requests(
41+
dataset_path: str,
42+
num_requests: int,
43+
tokenizer: PreTrainedTokenizerBase,
44+
) -> List[Tuple[str, int, int]]:
45+
# Load the dataset.
46+
with open(dataset_path) as f:
47+
dataset = json.load(f)
48+
# Filter out the conversations with less than 2 turns.
49+
dataset = [
50+
data for data in dataset
51+
if len(data["conversations"]) >= 2
52+
]
53+
# Only keep the first two turns of each conversation.
54+
dataset = [
55+
(data["conversations"][0]["value"], data["conversations"][1]["value"])
56+
for data in dataset
57+
]
58+
59+
# Tokenize the prompts and completions.
60+
prompts = [prompt for prompt, _ in dataset]
61+
prompt_token_ids = tokenizer(prompts).input_ids
62+
completions = [completion for _, completion in dataset]
63+
completion_token_ids = tokenizer(completions).input_ids
64+
tokenized_dataset = []
65+
for i in range(len(dataset)):
66+
output_len = len(completion_token_ids[i])
67+
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
68+
69+
# Filter out too long sequences.
70+
filtered_dataset: List[Tuple[str, int, int]] = []
71+
for prompt, prompt_token_ids, output_len in tokenized_dataset:
72+
prompt_len = len(prompt_token_ids)
73+
if prompt_len < 4 or output_len < 4:
74+
# Prune too short sequences.
75+
# This is because TGI causes errors when the input or output length
76+
# is too short.
77+
continue
78+
if prompt_len > 1024 or prompt_len + output_len > 2048:
79+
# Prune too long sequences.
80+
continue
81+
filtered_dataset.append((prompt, prompt_len, output_len))
82+
83+
# Sample the requests.
84+
sampled_requests = random.sample(filtered_dataset, num_requests)
85+
return sampled_requests
86+
87+
88+
async def get_request(
89+
input_requests: List[Tuple[str, int, int]],
90+
request_rate: float,
91+
) -> AsyncGenerator[Tuple[str, int, int], None]:
92+
input_requests = iter(input_requests)
93+
for request in input_requests:
94+
yield request
95+
96+
if request_rate == float("inf"):
97+
# If the request rate is infinity, then we don't need to wait.
98+
continue
99+
# Sample the request interval from the exponential distribution.
100+
interval = np.random.exponential(1.0 / request_rate)
101+
# The next request will be sent after the interval.
102+
await asyncio.sleep(interval)
103+
104+
105+
async def send_request(
106+
backend: str,
107+
api_url: str,
108+
prompt: str,
109+
prompt_len: int,
110+
output_len: int,
111+
best_of: int,
112+
use_beam_search: bool,
113+
) -> None:
114+
request_start_time = time.time()
115+
116+
headers = {"User-Agent": "Benchmark Client"}
117+
if backend == "cacheflow":
118+
pload = {
119+
"prompt": prompt,
120+
"n": 1,
121+
"best_of": best_of,
122+
"use_beam_search": use_beam_search,
123+
"temperature": 0.0 if use_beam_search else 1.0,
124+
"top_p": 1.0,
125+
"max_tokens": output_len,
126+
"ignore_eos": True,
127+
"stream": False,
128+
}
129+
elif backend == "tgi":
130+
assert not use_beam_search
131+
params = {
132+
"best_of": best_of,
133+
"max_new_tokens": output_len,
134+
"do_sample": True,
135+
}
136+
pload = {
137+
"inputs": prompt,
138+
"parameters": params,
139+
}
140+
else:
141+
raise ValueError(f"Unknown backend: {backend}")
142+
143+
timeout = aiohttp.ClientTimeout(total=3 * 3600)
144+
async with aiohttp.ClientSession(timeout=timeout) as session:
145+
while True:
146+
async with session.post(api_url, headers=headers, json=pload) as response:
147+
chunks = []
148+
async for chunk, _ in response.content.iter_chunks():
149+
chunks.append(chunk)
150+
output = b"".join(chunks).decode("utf-8")
151+
output = json.loads(output)
152+
153+
# Re-send the request if it failed.
154+
if "error" not in output:
155+
break
156+
157+
request_end_time = time.time()
158+
request_latency = request_end_time - request_start_time
159+
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
160+
161+
162+
async def benchmark(
163+
backend: str,
164+
api_url: str,
165+
input_requests: List[Tuple[str, int, int]],
166+
best_of: int,
167+
use_beam_search: bool,
168+
request_rate: float,
169+
) -> None:
170+
tasks: List[asyncio.Task] = []
171+
async for request in get_request(input_requests, request_rate):
172+
prompt, prompt_len, output_len = request
173+
task = asyncio.create_task(send_request(backend, api_url, prompt,
174+
prompt_len, output_len,
175+
best_of, use_beam_search))
176+
tasks.append(task)
177+
await asyncio.gather(*tasks)
178+
179+
180+
def main(args: argparse.Namespace):
181+
print(args)
182+
random.seed(args.seed)
183+
np.random.seed(args.seed)
184+
185+
api_url = f"http://{args.host}:{args.port}/generate"
186+
tokenizer = get_tokenizer(args.tokenizer)
187+
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
188+
189+
benchmark_start_time = time.time()
190+
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
191+
args.use_beam_search, args.request_rate))
192+
benchmark_end_time = time.time()
193+
benchmark_time = benchmark_end_time - benchmark_start_time
194+
print(f"Total time: {benchmark_time:.2f} s")
195+
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
196+
197+
# Compute the latency statistics.
198+
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
199+
print(f"Average latency: {avg_latency:.2f} s")
200+
avg_per_token_latency = np.mean([
201+
latency / (prompt_len + output_len)
202+
for prompt_len, output_len, latency in REQUEST_LATENCY
203+
])
204+
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
205+
avg_per_output_token_latency = np.mean([
206+
latency / output_len
207+
for _, output_len, latency in REQUEST_LATENCY
208+
])
209+
print("Average latency per output token: "
210+
f"{avg_per_output_token_latency:.2f} s")
211+
212+
213+
if __name__ == "__main__":
214+
parser = argparse.ArgumentParser(
215+
description="Benchmark the online serving throughput.")
216+
parser.add_argument("--backend", type=str, default="cacheflow",
217+
choices=["cacheflow", "tgi"])
218+
parser.add_argument("--host", type=str, default="localhost")
219+
parser.add_argument("--port", type=int, default=8001)
220+
parser.add_argument("--dataset", type=str, required=True,
221+
help="Path to the dataset.")
222+
parser.add_argument("--tokenizer", type=str, required=True,
223+
help="Name or path of the tokenizer.")
224+
parser.add_argument("--best-of", type=int, default=1,
225+
help="Generates `best_of` sequences per prompt and "
226+
"returns the best one.")
227+
parser.add_argument("--use-beam-search", action="store_true")
228+
parser.add_argument("--num-prompts", type=int, default=1000,
229+
help="Number of prompts to process.")
230+
parser.add_argument("--request-rate", type=float, default=float("inf"),
231+
help="Number of requests per second. If this is inf, "
232+
"then all the requests are sent at time 0. "
233+
"Otherwise, we use Poisson process to synthesize "
234+
"the request arrival times.")
235+
parser.add_argument("--seed", type=int, default=0)
236+
args = parser.parse_args()
237+
main(args)

0 commit comments

Comments
 (0)