Skip to content

Conversation

GD06
Copy link
Contributor

@GD06 GD06 commented Jul 21, 2025

The deterministic flag is dropped in the BWD launch template. This PR adds back the deterministic flag, and turn off "Slice_dQKV_Mma" when the deterministic mode is ON for head_dim=256.

NOTE: This could lead to register pressures, and slow down kernel performance for head_dim=256 when running with the deterministic mode.

@GD06
Copy link
Contributor Author

GD06 commented Jul 21, 2025

Hi @tridao @sgrigory , may I know if anyone can help with reviewing this PR?

We found that flash-attn v3 built from commits this year cannot have deterministic results even when the deterministic mode is ON. Then we deep dived into the issue a bit, and found that the flag was dropped in the BWD template. This PR aims to fix the issue. But please let me know if there are issues other than slowing down the cases with head_dim=256.

@@ -607,7 +609,7 @@ struct CollectiveMainloopBwdSm90 {
seqlen_info, n_block, bidb, params.window_size_left,
params.window_size_right, 0 /*sink_token_length*/);
// It's possible to have m_block_max <= m_block_min. Exit early
if constexpr (Is_causal || Is_local || Varlen) {
if constexpr ((Is_causal || Is_local || Varlen) && !Deterministic) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tridao @sgrigory , this line of change can fix the issue of backward kernels hang with the deterministic mode.

I think we are still having some issues on the head_dim=256 with the deterministic mode, but most of other cases I tested look fine.

@sgrigory
Copy link
Contributor

To check that the output is deterministic, maybe add this test to test_flash_attn.py?

def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):

@lisheng-spaghetti
Copy link

there is another PR #1678 which has the same fix but with more unit test coverage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants