-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix/deterministic dk dv #1678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix/deterministic dk dv #1678
Conversation
@tridao |
Thiis can run correctly on our GQA case. |
@yuWeiCute Hi, I have cloned this PR and given it a try. But I found that now in the # import flash_attn_3_cuda
# dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd(
# g,
# q,
# k,
# v,
# out,
# lse,
# None,
# None,
# None,
# d ** (-0.5),
# causal,
# window_size[0], window_size[1],
# softcap,
# deterministic,
# 0, # sm_margin
# ) https://github.com/yuWeiCute/flash-attention-hopper/blob/a9a3170fc98cbd22a4cc870937b390f3d483f1eb/hopper/test_flash_attn.py#L228-L245 |
@tridao We here at Tencent tried this approach, proved to be deteministic & efficient at our cases. Could you please share when this PR might be merged? We are looking forward to the official GQA bwd deterministic version! |
Well done! looking forward to this new approach for deterministic bwd |
@tridao hi, we from ByteDance has identified this as a valuable feature for future releases. Do we have an estimated timeline for when it might be merged? Thx. |
Cool will review & merge this weekend |
50c1d49
to
78ab9e8
Compare
Hi @tridao Thanks for your support. I really appreciate your effort to review this code change. I noticed deterministic mode still isn't supported in some cases, particularly when head dimension equals 256. To fix this, I've added a new commit with:
Let me know if you have any feedback on these updates. |
@tridao Hi, tri. Any suggestions for this PR? |
And updates? |
Upon inspecting the dv_semaphore during debugging, it was found that some dv_semaphore values were not initialized to zero.The issue was resolved by changing torch::empty to torch::zeros, and the problem no longer occurred.
The semaphore is initialized with the shape [seq_len / kBlockN, batch_size, num_head_kv], but during accumulation, num_batch = 1, leading to a mismatch in the data dimensions.
fix issue:#1596