Skip to content

Add torch.compile support to flash attention 3 #1769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

guilhermeleobas
Copy link

Enable torch.compile support for FlashAttention and improve testing

  • Add support for torch.compile to recognize FlashAttention forward/backward functions.
  • Update tests to use torch._subclasses.fake_tensor.fake_check to validate the FakeTensor implementation.
  • Creates a file exposing the flags used to build FlashAttention at runtime (flash_attn_config.py)

@guilhermeleobas guilhermeleobas marked this pull request as ready for review July 22, 2025 20:52
@guilhermeleobas
Copy link
Author

cc @zou3519 @anijain2305

@@ -1511,7 +1511,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tenso
softmax_d.zero_();
}

return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
return { dq.clone(), dk.clone(), dv.clone(), softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here? Adding additional clones is generally bad for performance

hopper/build.sh Outdated
Comment on lines 5 to 8
# Flash Attention Minimal Build Script for PHI-1 Reproducer
# Uses subshell to automatically clean up environment variables

# Run in subshell - variables are automatically cleaned up when it exits
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have more context for this file? I don't see a mention of PHI-1 elsewhere

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove this file once the PR gets approved. I only keep it as it makes it easier to build and test flash attention.

Comment on lines 61 to 63
if not DISABLE_FAKE_CHECK:
flash_attn_func = run_fake_check(flash_attn_func)
flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func)
Copy link

@zou3519 zou3519 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit jank but I'm fine with it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tridao, are you ok with running fake_check as part of the tests? I can invert the flag to be opt-in instead of opt-out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would that slow down the tests much? Rn it takes 30-50mins to run all the tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it can double the time to run the tests, as fake_check would run the function twice to compare the fake version with the actual implementation. I can change this flag to be opt-in instead. So, to run fake_check, one should enable the flag.

fake_check(fn, args, kwargs)
return fn(*args, **kwargs)
return wrapper

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a sense of how comprehensive the existing tests in this file are? Are they good at exercising a variety of inputs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess @tridao would be a better person to answer this one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do test a lot of input shapes and different attn options (~100k tests iirc)

Copy link

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, modulo the prints

@lantudou
Copy link

lantudou commented Aug 5, 2025

@guilhermeleobas Anything update? I tried your commits and found there are still some shape errors in the backward function, here is my test code.
`
import torch
from flash_attn_interface import flash_attn_func, _flash_attn_forward
from torch import nn
class EfficienctMultiHeadAttention(nn.Module):

def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True):
    """
    初始化。

    参数:
        embed_size (int): 输入和输出的特征维度。
        num_heads (int): 注意力头的数量。
        dropout (float): 应用于注意力权重的 dropout 比率。
        use_flash_attn (bool): 是否尝试使用 FlashAttention (如果可用)。
    """
    super().__init__()
    assert embed_size % num_heads == 0, "embed_size 必须能被 num_heads 整除"

    self.embed_size = embed_size
    self.num_heads = num_heads
    self.head_dim = embed_size // num_heads
    self.use_flash_attn = use_flash_attn and (flash_attn_func is not None)

    # 使用一个线性层同时生成 Q, K, V,效率更高
    self.qkv_proj = nn.Linear(embed_size, 3 * embed_size)
    # 最终输出的线性层
    self.out_proj = nn.Linear(embed_size, embed_size)
    self.dropout = dropout

def forward(self, x, attention_mask=None):
    """
    前向传播。

    参数:
        x (torch.Tensor): 输入张量,形状为 (N, seq_length, embed_size)
        attention_mask (torch.Tensor, optional): 注意力掩码。
            对于 SDPA,布尔掩码中 True 的位置表示 *不* 被关注。
            对于 FlashAttention,通常使用 causal 参数,而非 mask。

    返回:
        torch.Tensor: 输出张量,形状为 (N, seq_length, embed_size)
    """
    N, seq_length, _ = x.shape

    # 1. 投影并切分 Q, K, V
    # (N, seq_length, embed_size) -> (N, seq_length, 3 * embed_size)
    qkv = self.qkv_proj(x)
    # (N, seq_length, 3 * embed_size) -> 3 x (N, seq_length, embed_size)
    q, k, v = qkv.chunk(3, dim=-1)

    # 2. 重塑 Q, K, V 以便多头计算
    # (N, seq_length, embed_size) -> (N, seq_length, num_heads, head_dim)
    q = q.view(N, seq_length, self.num_heads, self.head_dim)
    k = k.view(N, seq_length, self.num_heads, self.head_dim)
    v = v.view(N, seq_length, self.num_heads, self.head_dim)

    # --- 使用 FlashAttention 的路径 ---
    if self.use_flash_attn and attention_mask is None:
        # flash_attn_func 需要的形状是 (batch, seqlen, nheads, headdim)
        # 这与我们当前的形状完全匹配
        # flash_attn_func 的 dropout 在内部处理
        #print("正在使用 FlashAttention...")
        out = flash_attn_func(
            q, k, v
        )
    # 3. 合并多头的输出
    # (N, seq_length, num_heads, head_dim) -> (N, seq_length, embed_size)
    out = out.reshape(N, seq_length, self.embed_size)

    # 4. 通过最后的线性层
    out = self.out_proj(out)

    return out

batch_size = 16

sequence_length = 256
embedding_dim = 2048

test = EfficienctMultiHeadAttention(embedding_dim, num_heads=16).cuda().bfloat16()
input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16()
test = torch.compile(test, mode='max-autotune')

out = test(input_tensor)
loss = out.sum()
loss.backward()
`

@guilhermeleobas
Copy link
Author

Hi @lantudou, thanks for the reproducer. Could you try it again?

@guilhermeleobas
Copy link
Author

@tridao, could you take a look one more time once you have some cycles to spare?

@m3rcuriel
Copy link

I think the torch custom ops should be registered to a flash_attn_3 namespace to match the TORCH_LIBRARY registration in C. Would these conflict with FA2?

@guilhermeleobas
Copy link
Author

I think the torch custom ops should be registered to a flash_attn_3 namespace to match the TORCH_LIBRARY registration in C. Would these conflict with FA2?

Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3?

@m3rcuriel
Copy link

...

Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3?

idk what the future plan is but at least right now transformers thinks it can have both at the same time

@guilhermeleobas
Copy link
Author

idk what the future plan is but at least right now transformers thinks it can have both at the same time

Got it. I changed the namespace to flash_attn_3.

@Tomcli
Copy link

Tomcli commented Aug 20, 2025

Thank you for the great work @guilhermeleobas. Just wondering have you tried to run training on the export FX graph? It seems like this implementation still missing register_autograd in order to run the training loop.

@guilhermeleobas
Copy link
Author

guilhermeleobas commented Aug 21, 2025

Thanks for the feedback, @Tomcli. As for the second question. Probably it won't work. I based my implementation on what is implemented for FA2, which doesn't use register_autograd as well. I can do this in a follow-up PR if needed.

Edit: added in this one.

@guilhermeleobas
Copy link
Author

Hi @Tomcli, could you try the last commit, please?

Comment on lines 314 to 316
is_varlen_q = bool(cu_seqlens_q)
is_varlen_k = bool(cu_seqlens_k)
is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k)
Copy link

@supercharleszhu supercharleszhu Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_varlen_q = bool(cu_seqlens_q)
is_varlen_k = bool(cu_seqlens_k)
is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k)
is_varlen_q = cu_seqlens_q is not None
is_varlen_k = cu_seqlens_k is not None
is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None

@guilhermeleobas guilhermeleobas force-pushed the guilhermeleobas/fa3-compile branch from 871a3cf to e52508f Compare August 22, 2025 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants