Skip to content

Commit 0598fa2

Browse files
Add template when seqlen_q equal to seqlen_k with casual mask (#23)
1 parent 0fa5933 commit 0598fa2

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

csrc/flash_attn/src/flash_fwd_kernel.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem(
118118

119119
////////////////////////////////////////////////////////////////////////////////////////////////////
120120

121-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, typename Params>
121+
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, bool Is_equal_seq_qk, typename Params>
122122
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
123123

124124
using Element = typename Kernel_traits::Element;
@@ -500,8 +500,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
500500
params.unscale_softmax);
501501
tPgMask.data() = tPgMask.data() + (-kBlockN);
502502
}
503-
504-
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
503+
if (Is_equal_seq_qk) {
504+
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
505+
} else {
506+
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
507+
}
505508

506509
Tensor rP = flash::convert_type<Element>(scores);
507510
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
@@ -609,7 +612,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
609612

610613
////////////////////////////////////////////////////////////////////////////////////////////////////
611614

612-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, typename Params>
615+
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, bool Is_equal_seq_qk, typename Params>
613616
inline __device__ void compute_attn(const Params &params) {
614617
const int m_block = blockIdx.x;
615618
// The block index for the batch.
@@ -625,7 +628,7 @@ inline __device__ void compute_attn(const Params &params) {
625628
// the attention matrix. This way, as long as we have the batch, head, and the location of
626629
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
627630

628-
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask>(params, bidb, bidh, m_block);
631+
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask, Is_equal_seq_qk>(params, bidb, bidh, m_block);
629632
}
630633

631634
////////////////////////////////////////////////////////////////////////////////////////////////////

csrc/flash_attn/src/flash_fwd_launch_template.h

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
#include "flash_fwd_kernel.h"
1414
#include "cuda_utils.h"
1515

16-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask>
16+
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, bool Is_equal_seq_qk>
1717
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
18-
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask>(params);
18+
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask, Is_equal_seq_qk>(params);
1919
}
2020

2121
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@@ -35,23 +35,26 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
3535
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
3636
const bool return_softmax = params.p_ptr != nullptr;
3737
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
38+
const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask);
3839
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
3940
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
4041
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
4142
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
42-
// Will only return softmax if dropout, to reduce compilation time.
43-
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal>;
44-
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
45-
if (smem_size >= 48 * 1024) {
46-
C10_CUDA_CHECK(cudaFuncSetAttribute(
47-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
48-
}
49-
int ctas_per_sm;
50-
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
51-
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
52-
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
53-
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
54-
C10_CUDA_KERNEL_LAUNCH_CHECK();
43+
BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] {
44+
// Will only return softmax if dropout, to reduce compilation time.
45+
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal, Is_equal_seq_qk>;
46+
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
47+
if (smem_size >= 48 * 1024) {
48+
C10_CUDA_CHECK(cudaFuncSetAttribute(
49+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
50+
}
51+
int ctas_per_sm;
52+
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
53+
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
54+
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
55+
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
56+
C10_CUDA_KERNEL_LAUNCH_CHECK();
57+
});
5558
});
5659
});
5760
});

0 commit comments

Comments
 (0)