|
| 1 | +"""Benchmark online serving throughput. |
| 2 | +
|
| 3 | +On the server side, run one of the following commands: |
| 4 | + (CacheFlow backend) |
| 5 | + python -m cacheflow.entrypoints.simple_fastapi_frontend \ |
| 6 | + --disable-log-requests --model <your_model> |
| 7 | +
|
| 8 | + (TGI backend) |
| 9 | + ./launch_hf_server.sh <your_model> |
| 10 | +
|
| 11 | +On the client side, run: |
| 12 | + python benchmarks/benchmark_serving.py \ |
| 13 | + --backend <backend> \ |
| 14 | + --tokenizer <your_model> --dataset <target_dataset> \ |
| 15 | + --request-rate <request_rate> |
| 16 | +""" |
| 17 | +import argparse |
| 18 | +import asyncio |
| 19 | +import json |
| 20 | +import random |
| 21 | +import time |
| 22 | +from typing import AsyncGenerator, List, Tuple |
| 23 | + |
| 24 | +import aiohttp |
| 25 | +import numpy as np |
| 26 | +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase |
| 27 | + |
| 28 | +# (prompt len, output len, latency) |
| 29 | +REQUEST_LATENCY: List[Tuple[int, int, float]] = [] |
| 30 | + |
| 31 | + |
| 32 | +def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase: |
| 33 | + config = AutoConfig.from_pretrained(model_name) |
| 34 | + if config.model_type == "llama": |
| 35 | + # A workaround for potential protobuf errors. |
| 36 | + model_name = "hf-internal-testing/llama-tokenizer" |
| 37 | + return AutoTokenizer.from_pretrained(model_name) |
| 38 | + |
| 39 | + |
| 40 | +def sample_requests( |
| 41 | + dataset_path: str, |
| 42 | + num_requests: int, |
| 43 | + tokenizer: PreTrainedTokenizerBase, |
| 44 | +) -> List[Tuple[str, int, int]]: |
| 45 | + # Load the dataset. |
| 46 | + with open(dataset_path) as f: |
| 47 | + dataset = json.load(f) |
| 48 | + # Filter out the conversations with less than 2 turns. |
| 49 | + dataset = [ |
| 50 | + data for data in dataset |
| 51 | + if len(data["conversations"]) >= 2 |
| 52 | + ] |
| 53 | + # Only keep the first two turns of each conversation. |
| 54 | + dataset = [ |
| 55 | + (data["conversations"][0]["value"], data["conversations"][1]["value"]) |
| 56 | + for data in dataset |
| 57 | + ] |
| 58 | + |
| 59 | + # Tokenize the prompts and completions. |
| 60 | + prompts = [prompt for prompt, _ in dataset] |
| 61 | + prompt_token_ids = tokenizer(prompts).input_ids |
| 62 | + completions = [completion for _, completion in dataset] |
| 63 | + completion_token_ids = tokenizer(completions).input_ids |
| 64 | + tokenized_dataset = [] |
| 65 | + for i in range(len(dataset)): |
| 66 | + output_len = len(completion_token_ids[i]) |
| 67 | + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) |
| 68 | + |
| 69 | + # Filter out too long sequences. |
| 70 | + filtered_dataset: List[Tuple[str, int, int]] = [] |
| 71 | + for prompt, prompt_token_ids, output_len in tokenized_dataset: |
| 72 | + prompt_len = len(prompt_token_ids) |
| 73 | + if prompt_len < 4 or output_len < 4: |
| 74 | + # Prune too short sequences. |
| 75 | + # This is because TGI causes errors when the input or output length |
| 76 | + # is too short. |
| 77 | + continue |
| 78 | + if prompt_len > 1024 or prompt_len + output_len > 2048: |
| 79 | + # Prune too long sequences. |
| 80 | + continue |
| 81 | + filtered_dataset.append((prompt, prompt_len, output_len)) |
| 82 | + |
| 83 | + # Sample the requests. |
| 84 | + sampled_requests = random.sample(filtered_dataset, num_requests) |
| 85 | + return sampled_requests |
| 86 | + |
| 87 | + |
| 88 | +async def get_request( |
| 89 | + input_requests: List[Tuple[str, int, int]], |
| 90 | + request_rate: float, |
| 91 | +) -> AsyncGenerator[Tuple[str, int, int], None]: |
| 92 | + input_requests = iter(input_requests) |
| 93 | + for request in input_requests: |
| 94 | + yield request |
| 95 | + |
| 96 | + if request_rate == float("inf"): |
| 97 | + # If the request rate is infinity, then we don't need to wait. |
| 98 | + continue |
| 99 | + # Sample the request interval from the exponential distribution. |
| 100 | + interval = np.random.exponential(1.0 / request_rate) |
| 101 | + # The next request will be sent after the interval. |
| 102 | + await asyncio.sleep(interval) |
| 103 | + |
| 104 | + |
| 105 | +async def send_request( |
| 106 | + backend: str, |
| 107 | + api_url: str, |
| 108 | + prompt: str, |
| 109 | + prompt_len: int, |
| 110 | + output_len: int, |
| 111 | + best_of: int, |
| 112 | + use_beam_search: bool, |
| 113 | +) -> None: |
| 114 | + request_start_time = time.time() |
| 115 | + |
| 116 | + headers = {"User-Agent": "Benchmark Client"} |
| 117 | + if backend == "cacheflow": |
| 118 | + pload = { |
| 119 | + "prompt": prompt, |
| 120 | + "n": 1, |
| 121 | + "best_of": best_of, |
| 122 | + "use_beam_search": use_beam_search, |
| 123 | + "temperature": 0.0 if use_beam_search else 1.0, |
| 124 | + "top_p": 1.0, |
| 125 | + "max_tokens": output_len, |
| 126 | + "ignore_eos": True, |
| 127 | + "stream": False, |
| 128 | + } |
| 129 | + elif backend == "tgi": |
| 130 | + assert not use_beam_search |
| 131 | + params = { |
| 132 | + "best_of": best_of, |
| 133 | + "max_new_tokens": output_len, |
| 134 | + "do_sample": True, |
| 135 | + } |
| 136 | + pload = { |
| 137 | + "inputs": prompt, |
| 138 | + "parameters": params, |
| 139 | + } |
| 140 | + else: |
| 141 | + raise ValueError(f"Unknown backend: {backend}") |
| 142 | + |
| 143 | + timeout = aiohttp.ClientTimeout(total=3 * 3600) |
| 144 | + async with aiohttp.ClientSession(timeout=timeout) as session: |
| 145 | + while True: |
| 146 | + async with session.post(api_url, headers=headers, json=pload) as response: |
| 147 | + chunks = [] |
| 148 | + async for chunk, _ in response.content.iter_chunks(): |
| 149 | + chunks.append(chunk) |
| 150 | + output = b"".join(chunks).decode("utf-8") |
| 151 | + output = json.loads(output) |
| 152 | + |
| 153 | + # Re-send the request if it failed. |
| 154 | + if "error" not in output: |
| 155 | + break |
| 156 | + |
| 157 | + request_end_time = time.time() |
| 158 | + request_latency = request_end_time - request_start_time |
| 159 | + REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) |
| 160 | + |
| 161 | + |
| 162 | +async def benchmark( |
| 163 | + backend: str, |
| 164 | + api_url: str, |
| 165 | + input_requests: List[Tuple[str, int, int]], |
| 166 | + best_of: int, |
| 167 | + use_beam_search: bool, |
| 168 | + request_rate: float, |
| 169 | +) -> None: |
| 170 | + tasks: List[asyncio.Task] = [] |
| 171 | + async for request in get_request(input_requests, request_rate): |
| 172 | + prompt, prompt_len, output_len = request |
| 173 | + task = asyncio.create_task(send_request(backend, api_url, prompt, |
| 174 | + prompt_len, output_len, |
| 175 | + best_of, use_beam_search)) |
| 176 | + tasks.append(task) |
| 177 | + await asyncio.gather(*tasks) |
| 178 | + |
| 179 | + |
| 180 | +def main(args: argparse.Namespace): |
| 181 | + print(args) |
| 182 | + random.seed(args.seed) |
| 183 | + np.random.seed(args.seed) |
| 184 | + |
| 185 | + api_url = f"http://{args.host}:{args.port}/generate" |
| 186 | + tokenizer = get_tokenizer(args.tokenizer) |
| 187 | + input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) |
| 188 | + |
| 189 | + benchmark_start_time = time.time() |
| 190 | + asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of, |
| 191 | + args.use_beam_search, args.request_rate)) |
| 192 | + benchmark_end_time = time.time() |
| 193 | + benchmark_time = benchmark_end_time - benchmark_start_time |
| 194 | + print(f"Total time: {benchmark_time:.2f} s") |
| 195 | + print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") |
| 196 | + |
| 197 | + # Compute the latency statistics. |
| 198 | + avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) |
| 199 | + print(f"Average latency: {avg_latency:.2f} s") |
| 200 | + avg_per_token_latency = np.mean([ |
| 201 | + latency / (prompt_len + output_len) |
| 202 | + for prompt_len, output_len, latency in REQUEST_LATENCY |
| 203 | + ]) |
| 204 | + print(f"Average latency per token: {avg_per_token_latency:.2f} s") |
| 205 | + avg_per_output_token_latency = np.mean([ |
| 206 | + latency / output_len |
| 207 | + for _, output_len, latency in REQUEST_LATENCY |
| 208 | + ]) |
| 209 | + print("Average latency per output token: " |
| 210 | + f"{avg_per_output_token_latency:.2f} s") |
| 211 | + |
| 212 | + |
| 213 | +if __name__ == "__main__": |
| 214 | + parser = argparse.ArgumentParser( |
| 215 | + description="Benchmark the online serving throughput.") |
| 216 | + parser.add_argument("--backend", type=str, default="cacheflow", |
| 217 | + choices=["cacheflow", "tgi"]) |
| 218 | + parser.add_argument("--host", type=str, default="localhost") |
| 219 | + parser.add_argument("--port", type=int, default=8001) |
| 220 | + parser.add_argument("--dataset", type=str, required=True, |
| 221 | + help="Path to the dataset.") |
| 222 | + parser.add_argument("--tokenizer", type=str, required=True, |
| 223 | + help="Name or path of the tokenizer.") |
| 224 | + parser.add_argument("--best-of", type=int, default=1, |
| 225 | + help="Generates `best_of` sequences per prompt and " |
| 226 | + "returns the best one.") |
| 227 | + parser.add_argument("--use-beam-search", action="store_true") |
| 228 | + parser.add_argument("--num-prompts", type=int, default=1000, |
| 229 | + help="Number of prompts to process.") |
| 230 | + parser.add_argument("--request-rate", type=float, default=float("inf"), |
| 231 | + help="Number of requests per second. If this is inf, " |
| 232 | + "then all the requests are sent at time 0. " |
| 233 | + "Otherwise, we use Poisson process to synthesize " |
| 234 | + "the request arrival times.") |
| 235 | + parser.add_argument("--seed", type=int, default=0) |
| 236 | + args = parser.parse_args() |
| 237 | + main(args) |
0 commit comments