1313#include " flash_fwd_kernel.h"
1414#include " cuda_utils.h"
1515
16- template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask>
16+ template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, bool Is_equal_seq_qk >
1717__global__ void flash_fwd_kernel (Flash_fwd_params params) {
18- flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask>(params);
18+ flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask, Is_equal_seq_qk >(params);
1919}
2020
2121template <typename Kernel_traits, bool Is_dropout, bool Is_causal>
@@ -35,23 +35,26 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
3535 const bool is_even_K = params.d == Kernel_traits::kHeadDim ;
3636 const bool return_softmax = params.p_ptr != nullptr ;
3737 const bool is_attn_mask = params.attn_mask_ptr != nullptr ;
38+ const bool is_equal_qk = (params.seqlen_q == params.seqlen_k ) && (Is_causal) && (!is_attn_mask);
3839 BOOL_SWITCH (is_even_N, IsEvenNConst, [&] {
3940 BOOL_SWITCH (is_even_K, IsEvenKConst, [&] {
4041 BOOL_SWITCH (return_softmax, ReturnSoftmaxConst, [&] {
4142 BOOL_SWITCH (is_attn_mask, Is_attn_mask, [&] {
42- // Will only return softmax if dropout, to reduce compilation time.
43- auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal>;
44- // auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
45- if (smem_size >= 48 * 1024 ) {
46- C10_CUDA_CHECK (cudaFuncSetAttribute (
47- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
48- }
49- int ctas_per_sm;
50- cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor (
51- &ctas_per_sm, kernel, Kernel_traits::kNThreads , smem_size);
52- // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
53- kernel<<<grid, Kernel_traits::kNThreads , smem_size, stream>>>(params);
54- C10_CUDA_KERNEL_LAUNCH_CHECK ();
43+ BOOL_SWITCH (is_equal_qk, Is_equal_seq_qk, [&] {
44+ // Will only return softmax if dropout, to reduce compilation time.
45+ auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal, Is_equal_seq_qk>;
46+ // auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
47+ if (smem_size >= 48 * 1024 ) {
48+ C10_CUDA_CHECK (cudaFuncSetAttribute (
49+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
50+ }
51+ int ctas_per_sm;
52+ cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor (
53+ &ctas_per_sm, kernel, Kernel_traits::kNThreads , smem_size);
54+ // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
55+ kernel<<<grid, Kernel_traits::kNThreads , smem_size, stream>>>(params);
56+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
57+ });
5558 });
5659 });
5760 });
0 commit comments