Skip to content

Commit 28e7f4d

Browse files
authored
Merge pull request Dao-AILab#1155 from ipiszy/fix
Fix out-of-bound writes for var-seq-len zero-length KVs
2 parents bcd918f + 53537da commit 28e7f4d

File tree

4 files changed

+21
-8
lines changed

4 files changed

+21
-8
lines changed

hopper/epilogue_fwd_sm90_tma.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,10 @@ struct CollectiveEpilogueFwd {
285285
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
286286
// Clear_OOB_K must be false since we don't want to write zeros to gmem
287287
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
288-
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
288+
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM
289289
);
290290
static_assert(kBlockM <= NumMmaThreads);
291-
if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
291+
if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
292292
}
293293

294294
};

hopper/flash_fwd_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
123123
}
124124
int n_block_max = collective_mainloop.get_n_block_max(
125125
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
126-
if (Is_causal && n_block_max <= 0) {
126+
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {
127127
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
128128
scheduler.broadcast_next_work(work_tile_info);
129129
continue;
@@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
169169
}
170170
int n_block_max = collective_mainloop.get_n_block_max(
171171
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
172-
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
172+
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
173173
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
174174
continue;
175175
}

hopper/test_flash_attn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ def test_flash_attn_varlen_output(
236236
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
237237
)
238238

239-
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
240-
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
239+
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random", zero_lengths=False)
240+
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
241241
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
242242

243243
(
@@ -312,11 +312,16 @@ def test_flash_attn_varlen_output(
312312
dk_ref,
313313
dv_ref,
314314
) = torch.autograd.grad(out_ref, (q, k, v), g)
315+
zero_masking = rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1")
316+
dk_ref.masked_fill_(zero_masking, 0.0)
317+
dv_ref.masked_fill_(zero_masking, 0.0)
315318
(
316319
dq_pt,
317320
dk_pt,
318321
dv_pt,
319322
) = torch.autograd.grad(out_pt, (q, k, v), g)
323+
dk_pt.masked_fill_(zero_masking, 0.0)
324+
dv_pt.masked_fill_(zero_masking, 0.0)
320325
dq = dq_pad_fn(dq_unpad)
321326
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
322327
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")

tests/test_util.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,23 @@
55
from flash_attn.bert_padding import pad_input, unpad_input
66

77

8-
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
8+
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
99
assert mode in ["full", "random", "third"]
1010
if mode == "full":
1111
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
1212
elif mode == "random":
1313
lengths = torch.randint(
14-
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
14+
max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
1515
)
1616
elif mode == "third":
1717
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
18+
19+
if zero_lengths:
20+
# Generate zero-lengths every 5 batches and the last batch.
21+
for i in range(batch_size):
22+
if i % 5 == 0:
23+
lengths[i] = 0
24+
lengths[-1] = 0
1825
padding_mask = (
1926
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
2027
)
@@ -251,4 +258,5 @@ def attention_ref(
251258
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
252259
if query_padding_mask is not None:
253260
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
261+
output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"), 0.0)
254262
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)

0 commit comments

Comments
 (0)