Skip to content

Might be a solution to get built/compiles Flash Attention 2 on Windows #595

@Akatsuki030

Description

@Akatsuki030

As a Windows user, I tried to compile this and found the problem was on these two files "flash_fwd_launch_template.h" and "flash_bwd_launch_template.h". below "./flash-attention/csrc/flash_attn/src". While the template tried to reference the variable"Headdim", it caused error C2975. I think this might be the reason why we always get compile errors on the Windows system. Below is how I solve this problem:

First, in the file "flash_bwd_launch_template.h", you can find many functions like "run_mha_bwd_hdimXX", also the constant declaration "Headdim == XX", and some templates like this: run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure), the thing I did is change all the "Headdim" in these templates in the function. Take an example, if the function called run_mha_bwd_hdim128 and has a constant declaration
"Headdim == 128", you have to change Headdim as 128 in the templates, which likes run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure), and I did the same thing to the functions "run_mha_fwd_hdimXX" and also the templates.

Second, another error is from the "flash_fwd_launch_template.h", line 107, also the problem of referencing the constant "kBlockM" in the below if-else statement, and I rewrote it to

		if constexpr(Kernel_traits::kHeadDim % 128 == 0){
			dim3 grid_combine((params.b * params.h * params.seqlen_q + 4 - 1) / 4);
			BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
				if (params.num_splits <= 2) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 4) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 8) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 16) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 32) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 64) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 128) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				}
				C10_CUDA_KERNEL_LAUNCH_CHECK();
			});
		}else if constexpr(Kernel_traits::kHeadDim % 64 == 0){
			dim3 grid_combine((params.b * params.h * params.seqlen_q + 8 - 1) / 8);
			BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
				if (params.num_splits <= 2) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 4) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 8) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 16) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 32) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 64) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 128) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 8, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				}
				C10_CUDA_KERNEL_LAUNCH_CHECK();
			});
		}else{
			dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
			BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
				if (params.num_splits <= 2) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 4) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 8) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 16) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 32) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 64) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				} else if (params.num_splits <= 128) {
					flash_fwd_splitkv_combine_kernel<Kernel_traits, 16, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
				}
				C10_CUDA_KERNEL_LAUNCH_CHECK();
			});
		}

Third, for the function"run_mha_fwd_splitkv_dispatch" in "flash_fwd_launch_template.h", line 194, you also have to change "kBlockM" in the template as 64. And then you can try to compile it.
These solutions looked stupid but really solved my problem, I successfully compiled flash_attn_2 on Windows, and I still need to take some time to test it on other computers.
I put the files I rewrote: link.
I think there might be a better solution, but for me, it at least works.
Oh, I didn't use Ninja and compiled it from source code, might someone can try to compile it with Ninja?
EDIT: I used

  • python 3.11
  • Pytorch 2.2+cu121 Nightly
  • CUDA 12.2
  • Anaconda
  • Windows 11 22H2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions