Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _setup_attributes(self):
- Configures pipeline stages for softmax, correction, and epilogue operations
"""

self.kv_stage = 4 if self.q_dtype.width == 8 else 3
self.kv_stage = 4 if (self.k_dtype.width == 8 or self.v_dtype.width == 8) else 3
self.acc_stage = 1
self.epi_stage = 2
# For hdim 192,128, we don't have enough smem to store all 3 stages of KV:
Expand Down Expand Up @@ -251,11 +251,13 @@ def __call__(
if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN):
raise RuntimeError("The layout of mV is not supported")

# check type consistency
if const_expr(self.q_dtype != self.k_dtype):
raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}")
if const_expr(self.q_dtype != self.v_dtype):
raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}")
# Allow mixed precision: Q(bf16/f16) with KV(bf16/f16/fp8)
if const_expr(self.q_dtype.width not in [16]):
raise TypeError(f"Q must be 16-bit precision, got {self.q_dtype}")
if const_expr(self.k_dtype.width not in [8, 16]):
raise TypeError(f"K must be 8-bit or 16-bit precision, got {self.k_dtype}")
if const_expr(self.v_dtype.width not in [8, 16]):
raise TypeError(f"V must be 8-bit or 16-bit precision, got {self.v_dtype}")
self._setup_attributes()
self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None
# This can be tuned
Expand All @@ -267,16 +269,19 @@ def __call__(
# the intermediate tensor p is from tmem & mK-major
p_source = tcgen05.OperandSource.TMEM
p_major_mode = tcgen05.OperandMajorMode.K
qk_dtype = self.q_dtype if const_expr(self.q_dtype.width >= self.k_dtype.width) else self.k_dtype
pv_dtype = self.q_dtype if const_expr(self.q_dtype.width >= self.v_dtype.width) else self.v_dtype

tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma(
self.q_dtype,
qk_dtype,
self.q_major_mode,
self.k_major_mode,
self.qk_acc_dtype,
cta_group,
self.mma_tiler_qk[:2],
)
tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma(
self.v_dtype,
pv_dtype,
p_major_mode,
self.v_major_mode,
self.pv_acc_dtype,
Expand Down
44 changes: 28 additions & 16 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def maybe_contiguous(x):
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
torch.float32: cutlass.Float32,
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
torch.float8_e5m2: cutlass.Float8E5M2,
}


Expand Down Expand Up @@ -111,8 +113,9 @@ def _flash_attn_fwd(
assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)"
assert seqused_q is None or seqused_q.shape == (batch_size,), "seqused_q must have shape (batch_size,)"
assert seqused_k is None or seqused_k.shape == (batch_size,), "seqused_k must have shape (batch_size,)"
assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
assert q.dtype in [torch.float16, torch.bfloat16], "Q must be float16 or bfloat16"
assert k.dtype in [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2], "K must be float16, bfloat16, or float8"
assert v.dtype in [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2], "V must be float16, bfloat16, or float8"
for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
if t is not None:
assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
Expand All @@ -123,7 +126,10 @@ def _flash_attn_fwd(
assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink)), "inputs must be on CUDA device"
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
assert head_dim <= 256, "head_dim must be less than or equal to 256"
alignment = 16 // q.element_size()
if k.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or v.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
alignment = 32 # FP8 requires 32-byte alignment
else:
alignment = 16 // q.element_size()
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
if softmax_scale is None:
Expand All @@ -142,17 +148,23 @@ def _flash_attn_fwd(
requires_grad = q.requires_grad or k.requires_grad or v.requires_grad
lse = torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad else None

dtype = torch2cute_dtype_map[q.dtype]
q_tensor, k_tensor, v_tensor, o_tensor = [
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1)
for t in (q, k, v, out)
]
lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if lse is not None else None
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, learnable_sink_tensor = [
from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
]
page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None
q_dtype = torch2cute_dtype_map[q.dtype]
k_dtype = torch2cute_dtype_map[k.dtype]
v_dtype = torch2cute_dtype_map[v.dtype]
o_dtype = torch2cute_dtype_map[out.dtype]
# Convert FP8 to BF16 for compatibility and create cute tensors
def to_cute(t, align=16, leading_dim=None):
if t is None: return None
leading_dim = t.ndim - 1 if leading_dim is None else leading_dim
t = t.to(torch.bfloat16) if t.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] else t
return from_dlpack(t.detach(), assumed_align=align).mark_layout_dynamic(leading_dim=leading_dim)

q_tensor, k_tensor, v_tensor, o_tensor = [to_cute(t) for t in (q, k, v, out)]
lse_tensor = to_cute(lse, 4, lse.ndim - 1) if lse is not None else None
cu_seqlens_q_tensor, cu_seqlens_k_tensor = [to_cute(t, 4, 0) for t in (cu_seqlens_q, cu_seqlens_k)]
seqused_q_tensor, seqused_k_tensor = [to_cute(t, 4, 0) for t in (seqused_q, seqused_k)]
learnable_sink_tensor = to_cute(learnable_sink, 4, 0)
page_table_tensor = to_cute(page_table, 4, 1)
if causal:
window_size_right = 0
local = window_size_left is not None or window_size_right is not None
Expand All @@ -174,7 +186,7 @@ def _flash_attn_fwd(
pack_gqa = False

compile_key = (
dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None,
q_dtype, k_dtype, v_dtype, o_dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None,
lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None,
page_table is not None,
window_size_left is not None, window_size_right is not None,
Expand All @@ -187,7 +199,7 @@ def _flash_attn_fwd(
assert page_table is None, "paged KV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
fa_fwd = FlashAttentionForwardSm90(
dtype,
q_dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
Expand Down
37 changes: 37 additions & 0 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,3 +1041,40 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype):
out_no_lse, lse_no_lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=False)
assert lse_no_lse is None, "LSE should be None when return_lse=False"
assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse"


@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)])
def test_flash_attn_mixed_precision_q_bf16_kv_fp8(seqlen_q, seqlen_k, d, causal):
"""Test Q(bfloat16) + KV(float8) + O(bfloat16) mixed precision."""
device = "cuda"
torch.random.manual_seed(66)
batch_size = 2
nheads = 8
nheads_kv = 8 # MHA for simplicity

# Generate Q in bfloat16
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.bfloat16, requires_grad=True)

# Generate K, V in float8_e4m3fn
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.bfloat16)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.bfloat16)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)

# Reference computation in bfloat16
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().to(torch.bfloat16).requires_grad_(True)
v_ref = v.detach().to(torch.bfloat16).requires_grad_(True)

out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal)
out_pt, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True)

# FlashAttention mixed precision computation
out, lse = flash_attn_func(q, k, v, causal=causal)

mult = 4 # Higher tolerance for FP8 mixed precision
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
mult_mean = 3
assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item()