Skip to content

Commit 78ab9e8

Browse files
author
yuwei
committed
feat: Add TORCH_CHECK for unsupported deterministic mode and related test script
1 parent 164b798 commit 78ab9e8

File tree

3 files changed

+376
-0
lines changed

3 files changed

+376
-0
lines changed

hopper/flash_api.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tenso
12581258
}
12591259
// This is what we will template on
12601260
bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
1261+
TORCH_CHECK(!(seqused_q_.has_value() && deterministic), "FlashAttention backward does not support 'seqused_q' parameter when deterministic is true.");
1262+
12611263
#ifdef FLASHATTENTION_DISABLE_VARLEN
12621264
TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
12631265
#endif
@@ -1274,6 +1276,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tenso
12741276
int const num_heads_k = k.size(-2);
12751277
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
12761278
TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8");
1279+
TORCH_CHECK((head_size_v < 256 && head_size < 256) || !deterministic, "FlashAttention backward only supports deterministic when head dimension less than 256");
12771280
int const max_headdim = get_max_headdim();
12781281
TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
12791282
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
@@ -1291,6 +1294,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tenso
12911294
int const head_size_v_rounded = head_size_rounded;
12921295
// Very important that these match the kernel configs
12931296
bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
1297+
TORCH_CHECK(
1298+
!deterministic || !is_local || !cu_seqlens_q_.has_value() || torch::equal(cu_seqlens_q_.value(), cu_seqlens_k_.value()),
1299+
"FlashAttention backward only supports deterministic when local is false"
1300+
);
12941301
int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
12951302
: (head_size_rounded <= 96 ? 64
12961303
: (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)

hopper/test_flash_attn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def test_flash_attn_output(
203203
attention_chunk=attention_chunk,
204204
softcap=softcap,
205205
pack_gqa=pack_gqa,
206+
deterministic=deterministic,
206207
num_splits=num_splits
207208
)
208209
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
@@ -222,6 +223,10 @@ def test_flash_attn_output(
222223
and not has_qv
223224
and not dv > 256
224225
and not attention_chunk != 0
226+
and not (
227+
deterministic == True
228+
and (dv >= 256 or dv == 64)
229+
)
225230
):
226231
g = torch.randn_like(out)
227232
do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
@@ -475,6 +480,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
475480
k_descale=k_descale, v_descale=v_descale,
476481
window_size=window_size,
477482
attention_chunk=attention_chunk,
483+
deterministic=deterministic,
478484
softcap=softcap,
479485
)
480486
out = output_pad_fn(out_unpad)
@@ -497,6 +503,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
497503
and not has_qv
498504
and not dv > 256
499505
and not attention_chunk != 0
506+
and not (
507+
deterministic == True
508+
and (dv >= 256 or local == True or seqused_k is None)
509+
)
500510
):
501511
g_unpad = torch.randn_like(out_unpad)
502512
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)

0 commit comments

Comments
 (0)