Skip to content

Commit e97ec4f

Browse files
zhaoyang-starzhaoyangzhuohan123
authored
Support FP8-E5M2 KV Cache (vllm-project#2279)
Co-authored-by: zhaoyang <[email protected]> Co-authored-by: Zhuohan Li <[email protected]>
1 parent 71e87f2 commit e97ec4f

26 files changed

+912
-196
lines changed

benchmarks/benchmark_latency.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
2424
trust_remote_code=args.trust_remote_code,
2525
dtype=args.dtype,
2626
enforce_eager=args.enforce_eager,
27+
kv_cache_dtype=args.kv_cache_dtype,
2728
)
2829

2930
sampling_params = SamplingParams(
@@ -117,6 +118,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
117118
parser.add_argument('--enforce-eager',
118119
action='store_true',
119120
help='enforce eager mode and disable CUDA graph')
121+
parser.add_argument(
122+
"--kv-cache-dtype",
123+
type=str,
124+
choices=['auto', 'fp8_e5m2'],
125+
default='auto',
126+
help=
127+
'Data type for kv cache storage. If "auto", will use model data type.')
120128
parser.add_argument(
121129
'--profile',
122130
action='store_true',

benchmarks/benchmark_throughput.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def run_vllm(
7171
dtype: str,
7272
max_model_len: Optional[int],
7373
enforce_eager: bool,
74+
kv_cache_dtype: str,
7475
) -> float:
7576
from vllm import LLM, SamplingParams
7677
llm = LLM(
@@ -83,6 +84,7 @@ def run_vllm(
8384
dtype=dtype,
8485
max_model_len=max_model_len,
8586
enforce_eager=enforce_eager,
87+
kv_cache_dtype=kv_cache_dtype,
8688
)
8789

8890
# Add the requests to the engine.
@@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
206208
args.quantization, args.tensor_parallel_size,
207209
args.seed, args.n, args.use_beam_search,
208210
args.trust_remote_code, args.dtype,
209-
args.max_model_len, args.enforce_eager)
211+
args.max_model_len, args.enforce_eager,
212+
args.kv_cache_dtype)
210213
elif args.backend == "hf":
211214
assert args.tensor_parallel_size == 1
212215
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -284,6 +287,13 @@ def main(args: argparse.Namespace):
284287
parser.add_argument("--enforce-eager",
285288
action="store_true",
286289
help="enforce eager execution")
290+
parser.add_argument(
291+
"--kv-cache-dtype",
292+
type=str,
293+
choices=["auto", "fp8_e5m2"],
294+
default="auto",
295+
help=
296+
'Data type for kv cache storage. If "auto", will use model data type.')
287297
args = parser.parse_args()
288298
if args.tokenizer is None:
289299
args.tokenizer = args.model

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from typing import Optional
12
import argparse
23
import random
34
import time
45

56
import torch
67

8+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
79
from vllm._C import ops
810

911
NUM_BLOCKS = 1024
@@ -23,6 +25,7 @@ def main(
2325
dtype: torch.dtype,
2426
seed: int,
2527
do_profile: bool,
28+
kv_cache_dtype: Optional[str] = None,
2629
) -> None:
2730
random.seed(seed)
2831
torch.random.manual_seed(seed)
@@ -59,15 +62,10 @@ def main(
5962
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
6063

6164
# Create the KV cache.
62-
x = 16 // torch.tensor([], dtype=dtype).element_size()
63-
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
64-
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
65-
key_cache.uniform_(-scale, scale)
66-
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
67-
value_cache = torch.empty(size=value_cache_shape,
68-
dtype=dtype,
69-
device="cuda")
70-
value_cache.uniform_(-scale, scale)
65+
key_caches, value_caches = create_kv_caches_with_random(
66+
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
67+
dtype)
68+
key_cache, value_cache = key_caches[0], value_caches[0]
7169

7270
# Prepare for the paged attention kernel.
7371
output = torch.empty_like(query)
@@ -106,6 +104,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
106104
block_size,
107105
max_context_len,
108106
alibi_slopes,
107+
kv_cache_dtype,
109108
)
110109
elif version == "v2":
111110
ops.paged_attention_v2(
@@ -123,6 +122,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
123122
block_size,
124123
max_context_len,
125124
alibi_slopes,
125+
kv_cache_dtype,
126126
)
127127
else:
128128
raise ValueError(f"Invalid version: {version}")
@@ -168,16 +168,18 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
168168
default="half")
169169
parser.add_argument("--seed", type=int, default=0)
170170
parser.add_argument("--profile", action="store_true")
171+
parser.add_argument(
172+
"--kv-cache-dtype",
173+
type=str,
174+
choices=["auto", "fp8_e5m2"],
175+
default="auto",
176+
help=
177+
'Data type for kv cache storage. If "auto", will use model data type.')
171178
args = parser.parse_args()
172179
print(args)
173180

174181
if args.num_query_heads % args.num_kv_heads != 0:
175182
raise ValueError("num_query_heads must be divisible by num_kv_heads")
176-
dtype_to_torch_dtype = {
177-
"half": torch.half,
178-
"bfloat16": torch.bfloat16,
179-
"float": torch.float,
180-
}
181183
main(
182184
version=args.version,
183185
num_seqs=args.batch_size,
@@ -187,7 +189,8 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
187189
head_size=args.head_size,
188190
block_size=args.block_size,
189191
use_alibi=args.use_alibi,
190-
dtype=dtype_to_torch_dtype[args.dtype],
192+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
191193
seed=args.seed,
192194
do_profile=args.profile,
195+
kv_cache_dtype=args.kv_cache_dtype,
193196
)

csrc/attention/attention_dtypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
#include "dtype_float16.cuh"
55
#include "dtype_float32.cuh"
66
#include "dtype_bfloat16.cuh"
7+
#include "dtype_fp8_e5m2.cuh"

0 commit comments

Comments
 (0)