Skip to content

Commit 0e63494

Browse files
authored
Add fp8 support to reshape_and_cache_flash (#6667)
1 parent ee81258 commit 0e63494

File tree

8 files changed

+98
-43
lines changed

8 files changed

+98
-43
lines changed

csrc/cache.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
2525
torch::Tensor& key_cache,
2626
torch::Tensor& value_cache,
2727
torch::Tensor& slot_mapping,
28-
const std::string& kv_cache_dtype);
28+
const std::string& kv_cache_dtype,
29+
const double k_scale, const double v_scale);
2930

3031
// Just for unittest
3132
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,

csrc/cache_kernels.cu

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
203203
}
204204
}
205205

206-
template <typename scalar_t>
206+
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
207207
__global__ void reshape_and_cache_flash_kernel(
208208
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
209209
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
210-
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
210+
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
211211
// head_size]
212-
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
212+
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
213213
// head_size]
214214
const int64_t* __restrict__ slot_mapping, // [num_tokens]
215215
const int block_stride, const int key_stride, const int value_stride,
216-
const int num_heads, const int head_size, const int block_size) {
216+
const int num_heads, const int head_size, const int block_size,
217+
const float k_scale, const float v_scale) {
217218
const int64_t token_idx = blockIdx.x;
218219
const int64_t slot_idx = slot_mapping[token_idx];
219220
// NOTE: slot_idx can be -1 if the token is padded
@@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
228229
const int64_t src_value_idx = token_idx * value_stride + i;
229230
const int head_idx = i / head_size;
230231
const int head_offset = i % head_size;
231-
const int64_t tgt_value_idx = block_idx * block_stride +
232-
block_offset * num_heads * head_size +
233-
head_idx * head_size + head_offset;
234-
k_cache[tgt_value_idx] = key[src_key_idx];
235-
v_cache[tgt_value_idx] = value[src_value_idx];
232+
const int64_t tgt_key_value_idx = block_idx * block_stride +
233+
block_offset * num_heads * head_size +
234+
head_idx * head_size + head_offset;
235+
scalar_t tgt_key = key[src_key_idx];
236+
scalar_t tgt_value = value[src_value_idx];
237+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
238+
key_cache[tgt_key_value_idx] = tgt_key;
239+
value_cache[tgt_key_value_idx] = tgt_value;
240+
} else {
241+
key_cache[tgt_key_value_idx] =
242+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
243+
value_cache[tgt_key_value_idx] =
244+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
245+
}
236246
}
237247
}
238248
} // namespace vllm
@@ -278,40 +288,45 @@ void reshape_and_cache(
278288
CALL_RESHAPE_AND_CACHE)
279289
}
280290

291+
// KV_T is the stored data type of kv-cache.
292+
// CACHE_T is the data type of key and value tensors.
293+
// KV_DTYPE is the real data type of kv-cache.
294+
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
295+
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
296+
<<<grid, block, 0, stream>>>( \
297+
reinterpret_cast<KV_T*>(key.data_ptr()), \
298+
reinterpret_cast<KV_T*>(value.data_ptr()), \
299+
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
300+
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
301+
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
302+
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
303+
281304
void reshape_and_cache_flash(
282-
torch::Tensor& key, // [num_tokens, num_heads, head_size]
283-
torch::Tensor& value, // [num_tokens, num_heads, head_size]
284-
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
285-
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
305+
torch::Tensor& key, // [num_tokens, num_heads, head_size]
306+
torch::Tensor& value, // [num_tokens, num_heads, head_size]
307+
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
308+
torch::Tensor&
309+
value_cache, // [num_blocks, block_size, num_heads, head_size]
286310
torch::Tensor& slot_mapping, // [num_tokens]
287-
const std::string& kv_cache_dtype) {
288-
// FIXME: only support auto datatype, does not support fp8
289-
if (kv_cache_dtype != "auto") {
290-
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
291-
}
311+
const std::string& kv_cache_dtype, const double k_scale,
312+
const double v_scale) {
292313
int num_tokens = key.size(0);
293314
int num_heads = key.size(1);
294315
int head_size = key.size(2);
295-
int block_size = k_cache.size(1);
316+
int block_size = key_cache.size(1);
296317

297318
int key_stride = key.stride(0);
298319
int value_stride = value.stride(0);
299-
int block_stride = k_cache.stride(0);
300-
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
320+
int block_stride = key_cache.stride(0);
321+
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
301322

302323
dim3 grid(num_tokens);
303324
dim3 block(std::min(num_heads * head_size, 512));
304325
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
305326
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
306-
VLLM_DISPATCH_FLOATING_TYPES(
307-
key.scalar_type(), "reshape_and_cache_flash", [&] {
308-
vllm::reshape_and_cache_flash_kernel<scalar_t>
309-
<<<grid, block, 0, stream>>>(
310-
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
311-
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
312-
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
313-
value_stride, num_heads, head_size, block_size);
314-
});
327+
328+
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
329+
CALL_RESHAPE_AND_CACHE_FLASH);
315330
}
316331

317332
namespace vllm {

csrc/torch_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
248248
" Tensor! key_cache,"
249249
" Tensor! value_cache,"
250250
" Tensor slot_mapping,"
251-
" str kv_cache_dtype) -> ()");
251+
" str kv_cache_dtype,"
252+
" float k_scale, float v_scale) -> ()");
252253
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
253254
&reshape_and_cache_flash);
254255

tests/kernels/test_cache.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,6 @@ def test_reshape_and_cache_flash(
215215
device: str,
216216
kv_cache_dtype: str,
217217
) -> None:
218-
if kv_cache_dtype == "fp8":
219-
pytest.skip()
220218
random.seed(seed)
221219
torch.random.manual_seed(seed)
222220
torch.cuda.manual_seed(seed)
@@ -248,15 +246,33 @@ def test_reshape_and_cache_flash(
248246
dtype,
249247
device=device,
250248
)
251-
key_cache, value_cache = key_caches[0], value_caches[0]
249+
key_cache, value_cache = key_caches[0].contiguous(
250+
), value_caches[0].contiguous()
251+
del key_caches
252+
del value_caches
252253

253254
# Clone the KV caches.
254-
cloned_key_cache = key_cache.clone()
255-
cloned_value_cache = value_cache.clone()
255+
if kv_cache_dtype == "fp8":
256+
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
257+
ops.convert_fp8(cloned_key_cache, key_cache)
258+
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
259+
ops.convert_fp8(cloned_value_cache, value_cache)
260+
else:
261+
cloned_key_cache = key_cache.clone()
262+
cloned_value_cache = value_cache.clone()
263+
264+
# Using default kv_scale
265+
k_scale = v_scale = 1.0
256266

257267
# Call the reshape_and_cache kernel.
258268
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
259-
slot_mapping, kv_cache_dtype)
269+
slot_mapping, kv_cache_dtype, k_scale, v_scale)
270+
271+
if kv_cache_dtype == "fp8":
272+
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
273+
ops.convert_fp8(result_key_cache, key_cache)
274+
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
275+
ops.convert_fp8(result_value_cache, value_cache)
260276

261277
# Run the reference implementation.
262278
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
@@ -269,8 +285,18 @@ def test_reshape_and_cache_flash(
269285
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
270286
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
271287

272-
assert torch.allclose(key_cache, cloned_key_cache)
273-
assert torch.allclose(value_cache, cloned_value_cache)
288+
if kv_cache_dtype == "fp8":
289+
assert torch.allclose(result_key_cache,
290+
cloned_key_cache,
291+
atol=0.001,
292+
rtol=0.1)
293+
assert torch.allclose(result_value_cache,
294+
cloned_value_cache,
295+
atol=0.001,
296+
rtol=0.1)
297+
else:
298+
assert torch.allclose(key_cache, cloned_key_cache)
299+
assert torch.allclose(value_cache, cloned_value_cache)
274300

275301

276302
@pytest.mark.parametrize("direction", COPYING_DIRECTION)

vllm/_custom_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,13 @@ def reshape_and_cache_flash(
426426
value_cache: torch.Tensor,
427427
slot_mapping: torch.Tensor,
428428
kv_cache_dtype: str,
429+
k_scale: float,
430+
v_scale: float,
429431
) -> None:
430432
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
431433
value_cache, slot_mapping,
432-
kv_cache_dtype)
434+
kv_cache_dtype, k_scale,
435+
v_scale)
433436

434437

435438
def copy_blocks(key_caches: List[torch.Tensor],

vllm/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ def forward(
478478
value_cache,
479479
attn_metadata.slot_mapping.flatten(),
480480
self.kv_cache_dtype,
481+
k_scale,
482+
v_scale,
481483
)
482484

483485
num_prefill_tokens = attn_metadata.num_prefill_tokens

vllm/attention/backends/flashinfer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ def forward(
489489
kv_cache[:, 1],
490490
attn_metadata.slot_mapping.flatten(),
491491
self.kv_cache_dtype,
492+
k_scale,
493+
v_scale,
492494
)
493495

494496
query = query.contiguous(

vllm/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,6 @@ def create_kv_caches_with_random_flash(
491491
seed: int = 0,
492492
device: Optional[str] = "cuda",
493493
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
494-
assert cache_dtype != "fp8"
495494
torch.random.manual_seed(seed)
496495
if torch.cuda.is_available():
497496
torch.cuda.manual_seed(seed)
@@ -507,7 +506,13 @@ def create_kv_caches_with_random_flash(
507506
key_value_cache = torch.empty(size=key_value_cache_shape,
508507
dtype=torch_dtype,
509508
device=device)
510-
key_value_cache.uniform_(-scale, scale)
509+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
510+
key_value_cache.uniform_(-scale, scale)
511+
elif cache_dtype == 'fp8':
512+
_generate_random_fp8(key_value_cache, -scale, scale)
513+
else:
514+
raise ValueError(
515+
f"Does not support key cache of type {cache_dtype}")
511516
key_caches.append(key_value_cache[:, 0])
512517
value_caches.append(key_value_cache[:, 1])
513518
return key_caches, value_caches

0 commit comments

Comments
 (0)