-
-
Notifications
You must be signed in to change notification settings - Fork 9.2k
[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention #17139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
2d7dba5
9f733ff
b888e6b
853b95b
4102f56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -283,7 +283,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( | |
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] | ||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] | ||
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] | ||
int max_ctx_blocks, const float* k_scale, const float* v_scale) { | ||
int max_ctx_blocks, const float* k_scale, const float* v_scale, | ||
const float* __restrict__ fp8_out_scale_ptr) { | ||
// clang-format on | ||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; | ||
const auto warpid = threadIdx.x / WARP_SIZE; | ||
|
@@ -797,7 +798,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( | |
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] | ||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] | ||
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] | ||
int max_ctx_blocks, const float* k_scale, const float* v_scale) { | ||
int max_ctx_blocks, const float* k_scale, const float* v_scale, | ||
const float* __restrict__ fp8_out_scale_ptr) { | ||
// clang-format on | ||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; | ||
const auto warpid = threadIdx.x / WARP_SIZE; | ||
|
@@ -1239,6 +1241,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( | |
|
||
// final write to tmp_out after vout accumulation | ||
if (warpid == 0) { | ||
const float out_scale = | ||
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; | ||
_B16x4 vout[QHLOOP][VHELOOP]; | ||
// iterate across heads | ||
for (int qh = 0; qh < QHLOOP; qh++) { | ||
|
@@ -1287,7 +1291,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | |
// max_num_partitions, head_size] | ||
const int* __restrict__ context_lens, // [num_seqs] | ||
const int* __restrict__ query_start_loc_ptr, // [num_seqs] | ||
const int max_num_partitions) { | ||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { | ||
const auto num_heads = gridDim.x; | ||
const auto head_idx = blockIdx.x; | ||
const auto seq_idx = blockIdx.y; | ||
|
@@ -1465,8 +1469,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | |
|
||
const float inv_global_exp_sum = | ||
__fdividef(1.0f, shared_global_exp_sum + 1e-6f); | ||
const float out_scale = | ||
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; | ||
acc *= inv_global_exp_sum; | ||
|
||
acc *= out_scale; | ||
const int64_t query_start_off = static_cast<int64_t>( | ||
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); | ||
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + | ||
|
@@ -1506,7 +1512,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( | |
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] | ||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] | ||
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] | ||
int max_ctx_blocks, const float* k_scale, const float* v_scale) { | ||
int max_ctx_blocks, const float* k_scale, const float* v_scale, | ||
const float* __restrict__ fp8_out_scale_ptr) { | ||
UNREACHABLE_CODE | ||
} | ||
|
||
|
@@ -1533,7 +1540,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( | |
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] | ||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] | ||
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] | ||
int max_ctx_blocks, const float* k_scale, const float* v_scale) { | ||
int max_ctx_blocks, const float* k_scale, const float* v_scale, | ||
const float* __restrict__ fp8_out_scale_ptr) { | ||
UNREACHABLE_CODE | ||
} | ||
|
||
|
@@ -1548,7 +1556,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | |
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] | ||
const int* __restrict__ context_lens, // [num_seqs] | ||
const int* __restrict__ query_start_loc_ptr, // [num_seqs] | ||
const int max_num_partitions) { | ||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { | ||
UNREACHABLE_CODE | ||
} | ||
// clang-format on | ||
|
@@ -1564,7 +1572,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | |
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ | ||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ | ||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ | ||
max_ctx_blocks, k_scale_ptr, v_scale_ptr); | ||
max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr); | ||
|
||
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ | ||
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \ | ||
|
@@ -1575,14 +1583,15 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | |
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ | ||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ | ||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ | ||
max_ctx_blocks, k_scale_ptr, v_scale_ptr); | ||
max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr); | ||
|
||
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ | ||
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \ | ||
PARTITION_SIZE, NPAR_LOOPS> \ | ||
<<<reduce_grid, reduce_block, 0, stream>>>( \ | ||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ | ||
context_lens_ptr, query_start_loc_ptr, max_num_partitions); | ||
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ | ||
fp8_out_scale_ptr); | ||
|
||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE, | ||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, | ||
|
@@ -1594,7 +1603,7 @@ void paged_attention_custom_launcher( | |
torch::Tensor& block_tables, torch::Tensor& context_lens, | ||
const std::optional<torch::Tensor>& query_start_loc, int max_context_len, | ||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale, | ||
torch::Tensor& v_scale) { | ||
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale) { | ||
int num_seqs = block_tables.size(0); | ||
int num_heads = query.size(1); | ||
int head_size = query.size(2); | ||
|
@@ -1626,6 +1635,11 @@ void paged_attention_custom_launcher( | |
int* context_lens_ptr = context_lens.data_ptr<int>(); | ||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr()); | ||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr()); | ||
// NOTE: fp8_out_scale is optional. | ||
const auto fp8_out_scale_ptr = | ||
fp8_out_scale | ||
? static_cast<const float*>(fp8_out_scale.value().data_ptr()) | ||
: nullptr; | ||
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is ensured at https://github.com/vllm-project/vllm/pull/17139/files#diff-79b8261aa73f07cc7450e48c8e14150576656f19ccfb42ba972860092c1f5949R1779-R1786
No, it should be the same type as query, it is used in the internal calculations |
||
|
||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); | ||
|
@@ -1736,33 +1750,54 @@ void paged_attention_custom_launcher( | |
} | ||
} | ||
|
||
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \ | ||
ALIBI_ENABLED) \ | ||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ | ||
PSIZE, ALIBI_ENABLED>( \ | ||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ | ||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ | ||
max_context_len, alibi_slopes, k_scale, v_scale); | ||
|
||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ | ||
PSIZE) \ | ||
if (alibi_slopes) { \ | ||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \ | ||
} else { \ | ||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \ | ||
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ | ||
PSIZE, ALIBI_ENABLED) \ | ||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ | ||
PSIZE, ALIBI_ENABLED>( \ | ||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ | ||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ | ||
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); | ||
|
||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ | ||
OUTT, PSIZE) \ | ||
if (alibi_slopes) { \ | ||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ | ||
true); \ | ||
} else { \ | ||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ | ||
false); \ | ||
} | ||
|
||
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ | ||
switch (block_size) { \ | ||
case 16: \ | ||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \ | ||
break; \ | ||
case 32: \ | ||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \ | ||
break; \ | ||
default: \ | ||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ | ||
break; \ | ||
#if defined(__HIPCC__) && defined(__gfx90a__) | ||
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ | ||
if (fp8_out_scale) { \ | ||
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ | ||
} else { \ | ||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ | ||
256); \ | ||
} | ||
#else | ||
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ | ||
if (fp8_out_scale) { \ | ||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ | ||
uint8_t, 256); \ | ||
} else { \ | ||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ | ||
256); \ | ||
} | ||
#endif | ||
|
||
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ | ||
switch (block_size) { \ | ||
case 16: \ | ||
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ | ||
break; \ | ||
case 32: \ | ||
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ | ||
break; \ | ||
default: \ | ||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ | ||
break; \ | ||
} | ||
|
||
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ | ||
|
@@ -1795,7 +1830,8 @@ void paged_attention( | |
int64_t block_size, int64_t max_context_len, | ||
const std::optional<torch::Tensor>& alibi_slopes, | ||
const std::string& kv_cache_dtype, torch::Tensor& k_scale, | ||
torch::Tensor& v_scale) { | ||
torch::Tensor& v_scale, | ||
const c10::optional<torch::Tensor>& fp8_out_scale) { | ||
// clang-format on | ||
const int head_size = query.size(2); | ||
if (kv_cache_dtype == "auto") { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering where out_scale is used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is actually used in the reduction kernel launched after either of the attention kernels.
The dereferencing here is indeed not needed, but it'll get optimized out. I'll make a note to clean it up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you just remove it in this PR?