Skip to content

Conversation

guilhermeleobas
Copy link

Enable torch.compile support for FlashAttention and improve testing

  • Add support for torch.compile to recognize FlashAttention forward/backward functions.
  • Update tests to use torch._subclasses.fake_tensor.fake_check to validate the FakeTensor implementation.
  • Creates a file exposing the flags used to build FlashAttention at runtime (flash_attn_config.py)

@guilhermeleobas guilhermeleobas marked this pull request as ready for review July 22, 2025 20:52
@guilhermeleobas
Copy link
Author

cc @zou3519 @anijain2305

}

return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
return { dq.clone(), dk.clone(), dv.clone(), softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here? Adding additional clones is generally bad for performance

hopper/build.sh Outdated
Comment on lines 5 to 8
# Flash Attention Minimal Build Script for PHI-1 Reproducer
# Uses subshell to automatically clean up environment variables

# Run in subshell - variables are automatically cleaned up when it exits
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have more context for this file? I don't see a mention of PHI-1 elsewhere

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove this file once the PR gets approved. I only keep it as it makes it easier to build and test flash attention.

Comment on lines 61 to 63
if not DISABLE_FAKE_CHECK:
flash_attn_func = run_fake_check(flash_attn_func)
flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func)
Copy link

@zou3519 zou3519 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit jank but I'm fine with it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tridao, are you ok with running fake_check as part of the tests? I can invert the flag to be opt-in instead of opt-out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would that slow down the tests much? Rn it takes 30-50mins to run all the tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it can double the time to run the tests, as fake_check would run the function twice to compare the fake version with the actual implementation. I can change this flag to be opt-in instead. So, to run fake_check, one should enable the flag.

fake_check(fn, args, kwargs)
return fn(*args, **kwargs)
return wrapper

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a sense of how comprehensive the existing tests in this file are? Are they good at exercising a variety of inputs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess @tridao would be a better person to answer this one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do test a lot of input shapes and different attn options (~100k tests iirc)

Copy link

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, modulo the prints

@lantudou
Copy link

lantudou commented Aug 5, 2025

@guilhermeleobas Anything update? I tried your commits and found there are still some shape errors in the backward function, here is my test code.
`
import torch
from flash_attn_interface import flash_attn_func, _flash_attn_forward
from torch import nn
class EfficienctMultiHeadAttention(nn.Module):

def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True):
    """
    初始化。

    参数:
        embed_size (int): 输入和输出的特征维度。
        num_heads (int): 注意力头的数量。
        dropout (float): 应用于注意力权重的 dropout 比率。
        use_flash_attn (bool): 是否尝试使用 FlashAttention (如果可用)。
    """
    super().__init__()
    assert embed_size % num_heads == 0, "embed_size 必须能被 num_heads 整除"

    self.embed_size = embed_size
    self.num_heads = num_heads
    self.head_dim = embed_size // num_heads
    self.use_flash_attn = use_flash_attn and (flash_attn_func is not None)

    # 使用一个线性层同时生成 Q, K, V,效率更高
    self.qkv_proj = nn.Linear(embed_size, 3 * embed_size)
    # 最终输出的线性层
    self.out_proj = nn.Linear(embed_size, embed_size)
    self.dropout = dropout

def forward(self, x, attention_mask=None):
    """
    前向传播。

    参数:
        x (torch.Tensor): 输入张量,形状为 (N, seq_length, embed_size)
        attention_mask (torch.Tensor, optional): 注意力掩码。
            对于 SDPA,布尔掩码中 True 的位置表示 *不* 被关注。
            对于 FlashAttention,通常使用 causal 参数,而非 mask。

    返回:
        torch.Tensor: 输出张量,形状为 (N, seq_length, embed_size)
    """
    N, seq_length, _ = x.shape

    # 1. 投影并切分 Q, K, V
    # (N, seq_length, embed_size) -> (N, seq_length, 3 * embed_size)
    qkv = self.qkv_proj(x)
    # (N, seq_length, 3 * embed_size) -> 3 x (N, seq_length, embed_size)
    q, k, v = qkv.chunk(3, dim=-1)

    # 2. 重塑 Q, K, V 以便多头计算
    # (N, seq_length, embed_size) -> (N, seq_length, num_heads, head_dim)
    q = q.view(N, seq_length, self.num_heads, self.head_dim)
    k = k.view(N, seq_length, self.num_heads, self.head_dim)
    v = v.view(N, seq_length, self.num_heads, self.head_dim)

    # --- 使用 FlashAttention 的路径 ---
    if self.use_flash_attn and attention_mask is None:
        # flash_attn_func 需要的形状是 (batch, seqlen, nheads, headdim)
        # 这与我们当前的形状完全匹配
        # flash_attn_func 的 dropout 在内部处理
        #print("正在使用 FlashAttention...")
        out = flash_attn_func(
            q, k, v
        )
    # 3. 合并多头的输出
    # (N, seq_length, num_heads, head_dim) -> (N, seq_length, embed_size)
    out = out.reshape(N, seq_length, self.embed_size)

    # 4. 通过最后的线性层
    out = self.out_proj(out)

    return out

batch_size = 16

sequence_length = 256
embedding_dim = 2048

test = EfficienctMultiHeadAttention(embedding_dim, num_heads=16).cuda().bfloat16()
input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16()
test = torch.compile(test, mode='max-autotune')

out = test(input_tensor)
loss = out.sum()
loss.backward()
`

@guilhermeleobas
Copy link
Author

Hi @lantudou, thanks for the reproducer. Could you try it again?

@guilhermeleobas
Copy link
Author

@tridao, could you take a look one more time once you have some cycles to spare?

@m3rcuriel
Copy link

I think the torch custom ops should be registered to a flash_attn_3 namespace to match the TORCH_LIBRARY registration in C. Would these conflict with FA2?

@guilhermeleobas
Copy link
Author

I think the torch custom ops should be registered to a flash_attn_3 namespace to match the TORCH_LIBRARY registration in C. Would these conflict with FA2?

Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3?

@m3rcuriel
Copy link

...

Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3?

idk what the future plan is but at least right now transformers thinks it can have both at the same time

@guilhermeleobas
Copy link
Author

idk what the future plan is but at least right now transformers thinks it can have both at the same time

Got it. I changed the namespace to flash_attn_3.

@Tomcli
Copy link

Tomcli commented Aug 20, 2025

Thank you for the great work @guilhermeleobas. Just wondering have you tried to run training on the export FX graph? It seems like this implementation still missing register_autograd in order to run the training loop.

@guilhermeleobas
Copy link
Author

guilhermeleobas commented Aug 21, 2025

Thanks for the feedback, @Tomcli. As for the second question. Probably it won't work. I based my implementation on what is implemented for FA2, which doesn't use register_autograd as well. I can do this in a follow-up PR if needed.

Edit: added in this one.

@guilhermeleobas
Copy link
Author

Hi @Tomcli, could you try the last commit, please?

@OutofAi
Copy link

OutofAi commented Aug 30, 2025

thanks for fixing it @guilhermeleobas guilhermeleobas, the compile now works, but the compiled artifacts still breaks me for when using AOTInductor packaging.

@guilhermeleobas
Copy link
Author

Thanks for trying this PR @OutofAi. Do you have a reproducer for this error?

@Turakar
Copy link

Turakar commented Sep 2, 2025

I just want to share that for my workflow, based only on torch.compile() and not torch.export(), this PR works. Thanks a lot!

@guilhermeleobas
Copy link
Author

guilhermeleobas commented Sep 2, 2025

Hi @OutofAi, I believe I fixed the torch.export bug you're seeing. Could you try again, please?

@yijianggit
Copy link

Hi @guilhermeleobas

I tried torch.compile with flash_attn_with_kvcache, and received:

File "/opt/python/3.10/lib/python3.10/site-packages/flash_attn_interface.py", line 217, in _flash_attn_forward_fake
    raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
torch._dynamo.exc.TorchRuntimeError: Failed running call_function flash_attn_3._flash_attn_forward.default(*(FakeTensor(..., device='cuda:0', size=(8, 1, 8, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(8, 4096, 8, 128), dtype=torch.bfloat16,
           grad_fn=<Error>), FakeTensor(..., device='cuda:0', size=(8, 4096, 8, 128), dtype=torch.bfloat16,
           grad_fn=<Error>), None, None, None, None, None, None, None, None, FakeTensor(..., device='cuda:0', size=(8,), dtype=torch.int32), None, None, None, None, None, None, None, None, None, None, None, 0.08838834764831845), **{'causal': True, 'window_size_left': -1, 'window_size_right': -1, 'attention_chunk': 0, 'softcap': 0.0, 'rotary_interleaved': True, 'scheduler_metadata': None, 'num_splits': 0, 'pack_gqa': None, 'sm_margin': 0}):
tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got num_splits=0


  File "/opt/python/3.10/lib/python3.10/site-packages/flash_attn_interface.py", line 1031, in flash_attn_with_kvcache
    out, softmax_lse, *rest = _flash_attn_forward(
  File "/opt/python/3.10/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 641, in __call__
    return self._opoverload(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I wonder if flash_attn_with_kvcache is supported in this commit?

window_size_right: int = -1,
softcap: float = 0.0,
deterministic: bool = False,
deterministic: bool= False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks malformatted

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@OutofAi
Copy link

OutofAi commented Sep 3, 2025

@guilhermeleobas just saw you messages, seems to be working now with your new changes, gave it a try with my sort of unity test and it works now, thanks, for reference this the code

import os, json, torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.export import Dim

import flash_attn_interface

# ---------- Norm + RoPE ----------
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [..., D]
        scale = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return (x * scale) * self.weight

def precompute_freqs_cis(head_dim: int, seq_len: int, device, dtype, theta: float = 10000.0):
    """
    Returns (cos, sin) with shapes:
      cos: [S, head_dim//2], sin: [S, head_dim//2]
    """
    half = head_dim // 2
    inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device, dtype=dtype) / half))
    t = torch.arange(seq_len, device=device, dtype=dtype)
    freqs = torch.einsum('s,f->sf', t, inv_freq)                         # [S, half]
    return torch.cos(freqs), torch.sin(freqs)                            # (cos, sin)

def _apply_rope_4d(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """
    x:   [B, S, N, Hd]
    cos: [1, S, 1, Hd//2]
    sin: [1, S, 1, Hd//2]
    Applies rotary on pairs in the last dim.
    """
    Hd = x.shape[-1]
    x_pair = x.view(*x.shape[:-1], Hd // 2, 2)                           # [..., half, 2]
    x_rot  = torch.stack((-x_pair[..., 1], x_pair[..., 0]), dim=-1)      # rotate_half
    out_pair = x_pair * cos[..., None] + x_rot * sin[..., None]
    return out_pair.view(*x.shape[:-1], Hd)

def rope_apply(x: torch.Tensor, freqs, num_heads: int) -> torch.Tensor:
    """
    x: [B, S, D]; freqs: (cos, sin) where cos/sin are [S, Hd//2]; returns [B, S, D]
    """
    cos, sin = freqs
    B, S, D = x.shape
    N = num_heads
    Hd = D // N
    x4 = rearrange(x, "b s (n d) -> b s n d", n=N)
    cos = cos[None, :, None, :]  # [1, S, 1, Hd//2]
    sin = sin[None, :, None, :]  # [1, S, 1, Hd//2]
    x4 = _apply_rope_4d(x4, cos, sin)
    return rearrange(x4, "b s n d -> b s (n d)", n=N)


def flash_attention(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int,
    dropout_p: float = 0.0, causal: bool = False, softmax_scale: float | None = None
) -> torch.Tensor:
    """
    q,k,v: [B, S, D] where D = num_heads * head_dim
    returns: [B, S, D]
    """
    # rearrange to [B, S, N, Hd]
    q4 = rearrange(q, "b s (n d) -> b s n d", n=num_heads).contiguous()
    k4 = rearrange(k, "b s (n d) -> b s n d", n=num_heads).contiguous()
    v4 = rearrange(v, "b s (n d) -> b s n d", n=num_heads).contiguous()

    # FlashAttention call expects [B, S, N, Hd]
    x4 = flash_attn_interface.flash_attn_func(
        q4, k4, v4,
        softmax_scale=softmax_scale,
        causal=causal,
    )  # -> [B, S, N, Hd]

    # back to [B, S, D]
    x = rearrange(x4, "b s n d -> b s (n d)", n=num_heads).contiguous()
    return x


# ---------- Modules ----------
class AttentionModule(nn.Module):
    def __init__(self, num_heads: int):
        super().__init__()
        self.num_heads = num_heads

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        return flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)

class SelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.q = nn.Linear(dim, dim, bias=False)
        self.k = nn.Linear(dim, dim, bias=False)
        self.v = nn.Linear(dim, dim, bias=False)
        self.o = nn.Linear(dim, dim, bias=False)

        self.norm_q = RMSNorm(dim, eps=eps)
        self.norm_k = RMSNorm(dim, eps=eps)

        self.attn = AttentionModule(self.num_heads)

    def forward(self, x: torch.Tensor, freqs) -> torch.Tensor:
        q = self.norm_q(self.q(x))
        k = self.norm_k(self.k(x))
        v = self.v(x)

        # RoPE on q,k
        q = rope_apply(q, freqs, self.num_heads)
        k = rope_apply(k, freqs, self.num_heads)

        x = self.attn(q, k, v)
        return self.o(x)

# ---------- Export + AOTInductor packaging ----------
INDUCTOR_CONFIGS = {
    "triton.cudagraphs": False,
    "max_autotune": True,
    "epilogue_fusion": True,
}

BATCH = Dim("batch", min=1, max=8)
SEQ   = Dim("seq",   min=2, max=4096)

def compile_and_export(model: nn.Module, x: torch.Tensor, freqs, package_name="self_attn_flash.pt2"):
    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)

    dynamic_shapes = {
        # x: [B, S, D]
        "x": {0: BATCH, 1: SEQ},
        # freqs is a tuple (cos, sin), each [S, Hd//2]
        # Map only the leading dimension to the SAME SEQ Dim so exporter knows they match.
        "freqs": (
            {0: SEQ},   # cos
            {0: SEQ},   # sin
        ),
    }

    exported = torch.export.export(
        mod=model,
        args=(),
        kwargs={"x": x, "freqs": freqs},
        dynamic_shapes=dynamic_shapes,
    )

    pkg_path = os.path.join(os.getcwd(), package_name)
    torch._inductor.aoti_compile_and_package(  # type: ignore[attr-defined]
        exported,
        package_path=pkg_path,
        inductor_configs=INDUCTOR_CONFIGS,
    )

    ui = [
        s.arg.name
        for s in exported.graph_signature.input_specs
        if getattr(s, "kind", None) is not None and s.kind.name == "USER_INPUT"
    ]
    with open(package_name + ".inputs.json", "w") as f:
        json.dump(ui, f)

    print(f"[ok] Exported + packaged → {pkg_path}")
    print(f"[ok] Input order JSON   → {package_name}.inputs.json")

# ---------- Demo / Dummy data ----------
if __name__ == "__main__":
    assert torch.cuda.is_available(), "This demo expects CUDA for FlashAttention."
    device = "cuda"
    dtype = torch.bfloat16  # Flash-Attn loves bf16/fp16

    torch.set_float32_matmul_precision("high")

    B, S, D, H = 2, 128, 256, 8
    x = torch.randn(B, S, D, device=device, dtype=dtype)

    # Precompute RoPE freqs for the chosen S and head_dim
    Hd = D // H
    cos, sin = precompute_freqs_cis(head_dim=Hd, seq_len=S, device=device, dtype=dtype)
    freqs = (cos, sin)

    model = SelfAttention(dim=D, num_heads=H).to(device=device, dtype=dtype)

    with torch.inference_mode():
        y = model(x, freqs)
    print("Forward OK:", tuple(y.shape))  # (B, S, D)

    compile_and_export(model, x, freqs, package_name="self_attn_flash.pt2")

@guilhermeleobas
Copy link
Author

guilhermeleobas commented Sep 3, 2025

Hi @yijianggit,

I wonder if flash_attn_with_kvcache is supported in this commit?

No, it is not. Implementing flash_attn call with num_splits < 0 requires one to implement the num_splits_heuristic and it is not a trivial task. I'll give it a shot. But in the meantime, could you share a reproducer for this error?

The heuristic for determining the number of splits depends on many variables, so replicating the exact C++ logic would be quite involved. While it’s technically feasible, a more practical approach might be to call the C++ code directly.

Thanks for the repro @OutofAi

@yijianggit
Copy link

Hi @guilhermeleobas,

Here is the code snippet for reproduce:

from flash_attn_interface import flash_attn_func, flash_attn_with_kvcache
import torch

flash_attn_func = torch.compile(flash_attn_func)
flash_attn_with_kvcache = torch.compile(flash_attn_with_kvcache)

q = torch.randn((3, 5, 2, 8), dtype=torch.bfloat16, device='cuda')
k = torch.randn((3, 5, 2, 8), dtype=torch.bfloat16, device='cuda')
v = torch.randn((3, 5, 2, 8), dtype=torch.bfloat16, device='cuda')
o = flash_attn_func(q, k, v, causal=True)

mask = torch.zeros_like(k)
mask[:, 0:2, :, :] = 1
q1 = q[:, 0:2, :, :]
k1 = (k * mask)
v1 = (v * mask)
o1 = flash_attn_with_kvcache(q1, k1, v1, cache_seqlens=2, causal=True)
assert torch.allclose(o[:, 0:2, :, :], o1)

mask = torch.zeros_like(k)
mask[:, 0:3, :, :] = 1
q2 = q[:, 2:3, :, :]
k2 = (k * mask)
v2 = (v * mask)
o2 = flash_attn_with_kvcache(q2, k2, v2, cache_seqlens=3, causal=True)
assert torch.allclose(o[:, 2:3, :, :], o2)

Please let me know if there are things I can help support this. Thanks!

@guilhermeleobas
Copy link
Author

Hi @tridao, could you take a look at this PR again? It seems to be working for most use cases.

@haohaibo
Copy link

haohaibo commented Sep 15, 2025

@guilhermeleobas This change does not work properly with float8_dynamic_activation_float8_weight(granularity=PerRow())

My code snippet

    from torchao.quantization import quantize_
    from torchao.quantization import float8_dynamic_activation_float8_weight
    from torchao.quantization.quant_api import PerRow
    quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()))
    quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()))

The output error

...
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0] failed while attempting to run meta for aten._scaled_mm.default
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0] Traceback (most recent call last):
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]     r = func(*args, **kwargs)
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 756, in __call__
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]     return self._op(*args, **kwargs)
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 6185, in meta_scaled_mm
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]     torch._check(
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]   File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 1660, in _check
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]     _check_with(RuntimeError, cond, message)
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]   File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 1642, in _check_with
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0]     raise error_type(message_evaluated)
E0915 14:32:19.279000 175002 torch/_subclasses/fake_tensor.py:2431] [10/0] RuntimeError: For non-tensorwise scaling, scale tensors must be 2D, but got scale_a.dim()=1 and scale_b.dim()=2
  0%|                                                                                                           | 0/10 [00:09<?, ?it/s]
Traceback (most recent call last):
  File "/work/image/inference.py", line 248, in <module>
    main()
  File "/work/image/inference.py", line 238, in main
    image = generate_image(pipe, **config)
  File "/work/image/inference.py", line 140, in generate_image
    images = pipe(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/work/image/diffusers/pipelines/image/pipeline.py", line 480, in __call__
    model_out = self.transformer.forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/work/image/diffusers/models/transformers/transformer.py", line 596, in forward
    x_shard = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/work/image/diffusers/models/transformers/transformer.py", line 193, in forward
    x_shard = self.attn_forward(
  File "/work/image/diffusers/models/transformers/transformer.py", line 204, in attn_forward
    self.attention(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/work/image/diffusers/models/transformers/transformer.py", line 115, in forward
    output = flash_attn_varlen_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1213, in __call__
    result = self._inner_convert(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
  ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that #1791 is landed can you also update the stable file to keep them in sync?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@guilhermeleobas
Copy link
Author

@haohaibo, thanks for the repro but your code is missing some things. What is pipe?

@varunneal
Copy link

varunneal commented Sep 28, 2025

Unfortunately the flash_api_stable.cpp code does not compile for me on any nightly version. I wrote a fork here that is both ABI compatible and includes the torch.compile compatibility, but it uses a modified flash_attn_interface than the pre-ABI compatible code.

You can view the diff here.

EDIT: I've also packaged/built my fork and uploaded it on Huggingface for each Torch version here: https://huggingface.co/varunneal/flash-attention-3, just in case it is useful for anyone.

@janeyx99
Copy link
Contributor

@varunneal when you say flash_api_stable.cpp does not compile for you—do you mean on main or on this branch?

If on main that would be concerning and I’d like to find out more about ur setup so we could fix it.

@varunneal
Copy link

@janeyx99 Just on this branch

@StrongerXi
Copy link

Having this would be quite helpful, what's blocking it from being merged? @guilhermeleobas.

@guilhermeleobas
Copy link
Author

On it. I'm on PTO this week and can work on this next week.

Unfortunately the flash_api_stable.cpp code does not compile for me on any nightly version. I wrote a fork here that is both ABI compatible and includes the torch.compile compatibility, but it uses a modified flash_attn_interface than the pre-ABI compatible code.

You can view the diff here.

Thanks for testing it. I was working on a different branch without the latest changes, so the failures slipped by me and I didn’t notice.

@guilhermeleobas
Copy link
Author

@varunneal could you give this PR another try? Building flash_api_stable should work now.


is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal

if arch < 90:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi. I do not think flash attention 3 rely on hopper. It works fine on older gpu like a100. Can you remove this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.