Skip to content

Commit 7326545

Browse files
committed
Implement deterministic backward (thanks to Meituan)
1 parent 2c7d7b7 commit 7326545

File tree

8 files changed

+367
-50
lines changed

8 files changed

+367
-50
lines changed

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
8383

8484
```python
8585
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
86-
window_size=(-1, -1), alibi_slopes=None):
86+
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
8787
"""dropout_p should be set to 0.0 during evaluation
8888
If Q, K, V are already stacked into 1 tensor, this function will be faster than
8989
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
@@ -99,14 +99,16 @@ Arguments:
9999
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
100100
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
101101
the attention score of query i and key j.
102+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
103+
which is slightly slower and uses more memory. The forward pass is always deterministic.
102104
Return:
103105
out: (batch_size, seqlen, nheads, headdim).
104106
"""
105107
```
106108

107109
```python
108110
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
109-
window_size=(-1, -1), alibi_slopes=None):
111+
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
110112
"""dropout_p should be set to 0.0 during evaluation
111113
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
112114
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
@@ -128,6 +130,8 @@ Arguments:
128130
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
129131
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
130132
is added to the attention score of query i and key j.
133+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
134+
which is slightly slower and uses more memory. The forward pass is always deterministic.
131135
Return:
132136
out: (batch_size, seqlen, nheads, headdim).
133137
"""
@@ -269,10 +273,12 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
269273
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
270274
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
271275

272-
### 2.4: ALiBi (attention with linear bias)
276+
### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
273277

274278
Implement ALiBi (Press et el., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
275279

280+
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
281+
276282
## Performance
277283

278284
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).

csrc/flash_attn/flash_api.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ void set_params_dgrad(Flash_bwd_params &params,
150150
float p_dropout,
151151
float softmax_scale,
152152
int window_size_left,
153-
int window_size_right) {
153+
int window_size_right,
154+
bool deterministic) {
154155

155156
set_params_fprop(params,
156157
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
@@ -192,6 +193,8 @@ void set_params_dgrad(Flash_bwd_params &params,
192193

193194
// Softmax sum
194195
params.dsoftmax_sum = dsoftmax_sum_d;
196+
197+
params.deterministic = deterministic;
195198
}
196199

197200
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
@@ -618,8 +621,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
618621
params.alibi_slopes_ptr = nullptr;
619622
}
620623

621-
auto stream = at::cuda::getCurrentCUDAStream().stream();
622-
run_mha_fwd(params, stream);
624+
if (max_seqlen_k > 0) {
625+
auto stream = at::cuda::getCurrentCUDAStream().stream();
626+
run_mha_fwd(params, stream);
627+
} else {
628+
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
629+
out.zero_();
630+
softmax_lse.fill_(std::numeric_limits<float>::infinity());
631+
}
623632

624633
at::Tensor out_padded = out;
625634
if (head_size_og % 8 != 0) {
@@ -668,6 +677,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
668677
const bool is_causal,
669678
const int window_size_left,
670679
int window_size_right,
680+
const bool deterministic,
671681
c10::optional<at::Generator> gen_,
672682
c10::optional<at::Tensor> &rng_state) {
673683

@@ -783,7 +793,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
783793
at::Tensor dq_accum;
784794
at::Tensor dk_accum, dv_accum;
785795
if (loop) {
786-
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
796+
if (!deterministic) {
797+
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
798+
} else {
799+
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
800+
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
801+
}
787802
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
788803
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
789804
}
@@ -819,7 +834,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
819834
p_dropout,
820835
softmax_scale,
821836
window_size_left,
822-
window_size_right);
837+
window_size_right,
838+
deterministic);
839+
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
823840

824841
auto launch = &run_mha_bwd;
825842
// launch(params, stream, /*configure=*/true);
@@ -857,8 +874,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
857874
launch(params, stream, /*configure=*/false);
858875
} else {
859876
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
860-
dk.zero_();
861-
dv.zero_();
877+
dk_expanded.zero_();
878+
dv_expanded.zero_();
862879
softmax_d.zero_();
863880
}
864881

@@ -897,6 +914,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
897914
const bool is_causal,
898915
const int window_size_left,
899916
int window_size_right,
917+
const bool deterministic,
900918
c10::optional<at::Generator> gen_,
901919
c10::optional<at::Tensor> &rng_state) {
902920

@@ -1025,7 +1043,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
10251043
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
10261044
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
10271045
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
1028-
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1046+
if (!deterministic) {
1047+
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1048+
} else {
1049+
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
1050+
dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1051+
}
10291052
}
10301053

10311054
at::Tensor dk_expanded, dv_expanded;
@@ -1064,7 +1087,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
10641087
p_dropout,
10651088
softmax_scale,
10661089
window_size_left,
1067-
window_size_right);
1090+
window_size_right,
1091+
deterministic);
1092+
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
10681093

10691094
auto launch = &run_mha_bwd;
10701095
// launch(params, stream, /*configure=*/true);
@@ -1098,7 +1123,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
10981123
params.alibi_slopes_ptr = nullptr;
10991124
}
11001125

1101-
launch(params, stream, /*configure=*/false);
1126+
if (max_seqlen_q > 0) {
1127+
launch(params, stream, /*configure=*/false);
1128+
} else {
1129+
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1130+
dk_expanded.zero_();
1131+
dv_expanded.zero_();
1132+
softmax_d.zero_();
1133+
}
11021134

11031135
// For MQA/GQA we need to sum dK and dV across the groups
11041136
if (num_heads_k != num_heads) {

csrc/flash_attn/src/flash.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ struct Flash_bwd_params : public Flash_fwd_params {
172172

173173
// The pointer to the softmax d sum.
174174
void *__restrict__ dsoftmax_sum;
175+
176+
bool deterministic;
177+
index_t dq_accum_split_stride;
175178
};
176179

177180
////////////////////////////////////////////////////////////////////////////////////////////////////

csrc/flash_attn/src/flash_bwd_kernel.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ inline __device__ void clear_dKVaccum(const Params &params) {
230230
// Convert dQ from dQaccum (in float) to fp16/bf16.
231231
// This is used in the case where we want to parallelize the backward across seqlen_k.
232232
template<typename Kernel_traits, typename Params>
233-
inline __device__ void convert_dQ(const Params &params) {
233+
inline __device__ void convert_dQ(const Params &params, const int nsplits) {
234234
using Element = typename Kernel_traits::Element;
235235
using ElementAccum = typename Kernel_traits::ElementAccum;
236236
using index_t = typename Kernel_traits::index_t;
@@ -285,11 +285,15 @@ inline __device__ void convert_dQ(const Params &params) {
285285
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
286286

287287
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
288-
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
289-
#pragma unroll
290-
for (int i = 0; i < size(acc_dq); ++i) {
291-
acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout;
288+
clear(acc_dq);
289+
for (int s = 0; s < nsplits; ++s) {
290+
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
291+
#pragma unroll
292+
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
293+
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
292294
}
295+
#pragma unroll
296+
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
293297
// Convert acc_dq from fp32 to fp16
294298
Tensor rdQ = flash::convert_type<Element>(acc_dq);
295299
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
@@ -466,7 +470,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
466470
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
467471
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
468472
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
469-
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
473+
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
474+
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
475+
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
470476
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
471477
+ (m_block_max - 1) * kBlockM;
472478
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
@@ -715,7 +721,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
715721
tdKsQt.data() = tdKsQt.data() + size(sQ);
716722
}
717723

718-
if (!Is_first && !Seq_parallel) { __syncthreads(); }
724+
if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }
719725

720726
if (Kernel_traits::Is_V_in_regs) {
721727
// Clear the smem tiles to account for predicated off loads
@@ -1604,13 +1610,15 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
16041610
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
16051611
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
16061612

1607-
const int n_block = blockIdx.x;
16081613
// The block index for the batch.
16091614
const int bidb = blockIdx.y;
16101615
// The block index for the head.
16111616
const int bidh = blockIdx.z;
16121617

1613-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
1618+
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
1619+
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
1620+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
1621+
}
16141622
}
16151623

16161624
////////////////////////////////////////////////////////////////////////////////////////////////////

csrc/flash_attn/src/flash_bwd_launch_template.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params pa
3535
}
3636

3737
template<typename Kernel_traits>
38-
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) {
39-
flash::convert_dQ<Kernel_traits>(params);
38+
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) {
39+
flash::convert_dQ<Kernel_traits>(params, nsplits);
4040
}
4141

4242
template<typename Kernel_traits>
@@ -49,9 +49,18 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
4949
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
5050
dim3 grid_m(num_m_block, params.b, params.h);
5151
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
52-
dim3 grid_n(num_n_block, params.b, params.h);
52+
int gridDimx = num_n_block;
53+
if (params.deterministic) {
54+
auto dprops = at::cuda::getCurrentDeviceProperties();
55+
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
56+
}
57+
dim3 grid_n(gridDimx, params.b, params.h);
5358

54-
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
59+
if (!params.deterministic) {
60+
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
61+
} else {
62+
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
63+
}
5564
C10_CUDA_KERNEL_LAUNCH_CHECK();
5665

5766
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
@@ -69,6 +78,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
6978
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
7079
// If Is_local, set Is_causal to false
7180
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
81+
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
7282
if (smem_size_dq_dk_dv >= 48 * 1024) {
7383
C10_CUDA_CHECK(cudaFuncSetAttribute(
7484
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
@@ -86,7 +96,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
8696
C10_CUDA_CHECK(cudaFuncSetAttribute(
8797
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
8898
}
89-
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
99+
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
90100
C10_CUDA_KERNEL_LAUNCH_CHECK();
91101
}
92102

csrc/flash_attn/src/flash_fwd_launch_template.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
5252
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
5353
// If Is_local, set Is_causal to false
5454
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
55+
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
5556
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
5657
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
5758
if (smem_size >= 48 * 1024) {

0 commit comments

Comments
 (0)