@@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
203
203
}
204
204
}
205
205
206
- template <typename scalar_t >
206
+ template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt >
207
207
__global__ void reshape_and_cache_flash_kernel (
208
208
const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
209
209
const scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
210
- scalar_t * __restrict__ k_cache, // [num_blocks, block_size, num_heads,
210
+ cache_t * __restrict__ key_cache, // [num_blocks, block_size, num_heads,
211
211
// head_size]
212
- scalar_t * __restrict__ v_cache, // [num_blocks, block_size, num_heads,
212
+ cache_t * __restrict__ value_cache, // [num_blocks, block_size, num_heads,
213
213
// head_size]
214
214
const int64_t * __restrict__ slot_mapping, // [num_tokens]
215
215
const int block_stride, const int key_stride, const int value_stride,
216
- const int num_heads, const int head_size, const int block_size) {
216
+ const int num_heads, const int head_size, const int block_size,
217
+ const float k_scale, const float v_scale) {
217
218
const int64_t token_idx = blockIdx .x ;
218
219
const int64_t slot_idx = slot_mapping[token_idx];
219
220
// NOTE: slot_idx can be -1 if the token is padded
@@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
228
229
const int64_t src_value_idx = token_idx * value_stride + i;
229
230
const int head_idx = i / head_size;
230
231
const int head_offset = i % head_size;
231
- const int64_t tgt_value_idx = block_idx * block_stride +
232
- block_offset * num_heads * head_size +
233
- head_idx * head_size + head_offset;
234
- k_cache[tgt_value_idx] = key[src_key_idx];
235
- v_cache[tgt_value_idx] = value[src_value_idx];
232
+ const int64_t tgt_key_value_idx = block_idx * block_stride +
233
+ block_offset * num_heads * head_size +
234
+ head_idx * head_size + head_offset;
235
+ scalar_t tgt_key = key[src_key_idx];
236
+ scalar_t tgt_value = value[src_value_idx];
237
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
238
+ key_cache[tgt_key_value_idx] = tgt_key;
239
+ value_cache[tgt_key_value_idx] = tgt_value;
240
+ } else {
241
+ key_cache[tgt_key_value_idx] =
242
+ fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, k_scale);
243
+ value_cache[tgt_key_value_idx] =
244
+ fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, v_scale);
245
+ }
236
246
}
237
247
}
238
248
} // namespace vllm
@@ -278,40 +288,45 @@ void reshape_and_cache(
278
288
CALL_RESHAPE_AND_CACHE)
279
289
}
280
290
291
+ // KV_T is the stored data type of kv-cache.
292
+ // CACHE_T is the data type of key and value tensors.
293
+ // KV_DTYPE is the real data type of kv-cache.
294
+ #define CALL_RESHAPE_AND_CACHE_FLASH (KV_T, CACHE_T, KV_DTYPE ) \
295
+ vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
296
+ <<<grid, block, 0 , stream>>> ( \
297
+ reinterpret_cast <KV_T*>(key.data_ptr()), \
298
+ reinterpret_cast <KV_T*>(value.data_ptr()), \
299
+ reinterpret_cast <CACHE_T*>(key_cache.data_ptr()), \
300
+ reinterpret_cast <CACHE_T*>(value_cache.data_ptr()), \
301
+ slot_mapping.data_ptr<int64_t >(), block_stride, key_stride, \
302
+ value_stride, num_heads, head_size, block_size, k_scale, v_scale);
303
+
281
304
void reshape_and_cache_flash (
282
- torch::Tensor& key, // [num_tokens, num_heads, head_size]
283
- torch::Tensor& value, // [num_tokens, num_heads, head_size]
284
- torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
285
- torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
305
+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
306
+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
307
+ torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
308
+ torch::Tensor&
309
+ value_cache, // [num_blocks, block_size, num_heads, head_size]
286
310
torch::Tensor& slot_mapping, // [num_tokens]
287
- const std::string& kv_cache_dtype) {
288
- // FIXME: only support auto datatype, does not support fp8
289
- if (kv_cache_dtype != " auto" ) {
290
- TORCH_CHECK (false , " Unsupported data type of kv cache: " , kv_cache_dtype);
291
- }
311
+ const std::string& kv_cache_dtype, const double k_scale,
312
+ const double v_scale) {
292
313
int num_tokens = key.size (0 );
293
314
int num_heads = key.size (1 );
294
315
int head_size = key.size (2 );
295
- int block_size = k_cache .size (1 );
316
+ int block_size = key_cache .size (1 );
296
317
297
318
int key_stride = key.stride (0 );
298
319
int value_stride = value.stride (0 );
299
- int block_stride = k_cache .stride (0 );
300
- TORCH_CHECK (k_cache .stride (0 ) == v_cache .stride (0 ));
320
+ int block_stride = key_cache .stride (0 );
321
+ TORCH_CHECK (key_cache .stride (0 ) == value_cache .stride (0 ));
301
322
302
323
dim3 grid (num_tokens);
303
324
dim3 block (std::min (num_heads * head_size, 512 ));
304
325
const at::cuda::OptionalCUDAGuard device_guard (device_of (key));
305
326
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
306
- VLLM_DISPATCH_FLOATING_TYPES (
307
- key.scalar_type (), " reshape_and_cache_flash" , [&] {
308
- vllm::reshape_and_cache_flash_kernel<scalar_t >
309
- <<<grid, block, 0 , stream>>> (
310
- key.data_ptr <scalar_t >(), value.data_ptr <scalar_t >(),
311
- k_cache.data_ptr <scalar_t >(), v_cache.data_ptr <scalar_t >(),
312
- slot_mapping.data_ptr <int64_t >(), block_stride, key_stride,
313
- value_stride, num_heads, head_size, block_size);
314
- });
327
+
328
+ DISPATCH_BY_KV_CACHE_DTYPE (key.dtype (), kv_cache_dtype,
329
+ CALL_RESHAPE_AND_CACHE_FLASH);
315
330
}
316
331
317
332
namespace vllm {
0 commit comments