Skip to content

Fix/deterministic dk dv #1678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

yuWeiCute
Copy link

@yuWeiCute yuWeiCute commented May 26, 2025

  • Upon inspecting the dv_semaphore during debugging, it was found that some dv_semaphore values were not initialized to zero.The issue was resolved by changing torch::empty to torch::zeros, and the problem no longer occurred.

  • The semaphore is initialized with the shape [seq_len / kBlockN, batch_size, num_head_kv], but during accumulation, num_batch = 1, leading to a mismatch in the data dimensions.

fix issue:#1596

@defei-coder
Copy link

@tridao
I'm glad FA3 adopted the semaphore solution to solve the problem of backward deterministic computing (which used for dq any case and dk & dv while GQA, which similar to #722).
Our team("Meituan MLP FastKernel team") has found FA3 developed the deterministic solution code, but deterministic flag was not enabled now.(opened issue:#1596)
We tried to fixed this by zero dk_accum & dv_accum out and fixe the num_batch value in CollectiveEpilogueBwdGQA.
Could you help to review our PR, and give us some feedbacks?

@evanluyifan
Copy link

Thiis can run correctly on our GQA case.
Plz help to review, and waiting for ur comments. @tridao

@xTayEx
Copy link

xTayEx commented Jun 5, 2025

@yuWeiCute Hi, I have cloned this PR and given it a try. But I found that now in the test_flash_attn.py file, the call to flash_attn_3_cuda.bwd is commented, so actually the deterministic argument is not tested, see below.

            # import flash_attn_3_cuda
            # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd(
            #     g,
            #     q,
            #     k,
            #     v,
            #     out,
            #     lse,
            #     None,
            #     None,
            #     None,
            #     d ** (-0.5),
            #     causal,
            #     window_size[0], window_size[1],
            #     softcap,
            #     deterministic,
            #     0,  # sm_margin
            # )

https://github.com/yuWeiCute/flash-attention-hopper/blob/a9a3170fc98cbd22a4cc870937b390f3d483f1eb/hopper/test_flash_attn.py#L228-L245
Could you please provide a minimal example to try this PR?

@xTayEx
Copy link

xTayEx commented Jun 9, 2025

@tridao We here at Tencent tried this approach, proved to be deteministic & efficient at our cases. Could you please share when this PR might be merged? We are looking forward to the official GQA bwd deterministic version!

@lygztq
Copy link

lygztq commented Jul 3, 2025

Well done! looking forward to this new approach for deterministic bwd

@chrisHuxi
Copy link

@tridao hi, we from ByteDance has identified this as a valuable feature for future releases. Do we have an estimated timeline for when it might be merged? Thx.

@tridao
Copy link
Member

tridao commented Jul 4, 2025

Cool will review & merge this weekend

@yuWeiCute yuWeiCute force-pushed the fix/deterministic-dk-dv branch from 50c1d49 to 78ab9e8 Compare July 5, 2025 13:34
@yuWeiCute
Copy link
Author

Hi @tridao

Thanks for your support. I really appreciate your effort to review this code change.

I noticed deterministic mode still isn't supported in some cases, particularly when head dimension equals 256. To fix this, I've added a new commit with:

  1. Adds validation checks for unsupported cases
  2. Adds additional test cases

Let me know if you have any feedback on these updates.

@defei-coder
Copy link

@tridao Hi, tri. Any suggestions for this PR?

@evanluyifan
Copy link

Cool will review & merge this weekend

And updates?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants