-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Description
As a Windows user, I tried to compile this and found the problem was on these two files "flash_fwd_launch_template.h
" and "flash_bwd_launch_template.h
". below "./flash-attention/csrc/flash_attn/src
". While the template tried to reference the variable"Headdim", it caused error C2975. I think this might be the reason why we always get compile errors on the Windows system. Below is how I solve this problem:
First, in the file "flash_bwd_launch_template.h", you can find many functions like "run_mha_bwd_hdimXX", also the constant declaration "Headdim == XX
", and some templates like this: run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure)
, the thing I did is change all the "Headdim
" in these templates in the function. Take an example, if the function called run_mha_bwd_hdim128
and has a constant declaration
"Headdim == 128
", you have to change Headdim as 128 in the templates, which likes run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure)
, and I did the same thing to the functions "run_mha_fwd_hdimXX
" and also the templates.
Second, another error is from the "flash_fwd_launch_template.h
", line 107, also the problem of referencing the constant "kBlockM
" in the below if-else statement, and I rewrote it to
if constexpr(Kernel_traits::kHeadDim % 128 == 0){
dim3 grid_combine((params.b * params.h * params.seqlen_q + 4 - 1) / 4);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}else if constexpr(Kernel_traits::kHeadDim % 64 == 0){
dim3 grid_combine((params.b * params.h * params.seqlen_q + 8 - 1) / 8);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}else{
dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
Third, for the function"run_mha_fwd_splitkv_dispatch
" in "flash_fwd_launch_template.h
", line 194, you also have to change "kBlockM
" in the template as 64. And then you can try to compile it.
These solutions looked stupid but really solved my problem, I successfully compiled flash_attn_2 on Windows, and I still need to take some time to test it on other computers.
I put the files I rewrote: link.
I think there might be a better solution, but for me, it at least works.
Oh, I didn't use Ninja and compiled it from source code, might someone can try to compile it with Ninja?
EDIT: I used
- python 3.11
- Pytorch 2.2+cu121 Nightly
- CUDA 12.2
- Anaconda
- Windows 11 22H2