-
Couldn't load subscription status.
- Fork 2.1k
Closed
Description
Hi, @tridao
I'm glad you used the semaphore solution to solve the problem of backward deterministic computing (which used for dq or dk and dv while GQA). I found that the code has been developed(commit), but it not worked now. In code
void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) {
VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
// BOOL_SWITCH(params.deterministic, Deterministic, [&] {
// run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
// });
});
});
}From the kernel code, deterministic are ready, I am a little confused about why this feature is not enabled. I can provide some help if needed.
Metadata
Metadata
Assignees
Labels
No labels