Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 48 additions & 45 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,55 +149,58 @@ def _flash_attn_fwd(
if not causal and not local:
n_block_size = 192

compile_key = (
dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None,
lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None,
window_size_left is not None, window_size_right is not None,
learnable_sink is not None,
m_block_size, n_block_size, num_threads,
compute_capability,
)
if compile_key not in _flash_attn_fwd.compile_cache:
if compute_capability == 9:
assert learnable_sink is None, "Sm90 doesn't support additive sink"
# fa_fwd = FlashAttentionForwardSm80(
fa_fwd = FlashAttentionForwardSm90(
dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
is_causal=causal,
is_local=local,
pack_gqa=False,
m_block_size=m_block_size,
n_block_size=n_block_size,
# num_stages=1,
num_stages=2,
num_threads=num_threads,
Q_in_regs=False,
)
elif compute_capability == 10:
fa_fwd = FlashAttentionForwardSm100(
head_dim,
head_dim_v,
is_causal=causal,
is_local=local,
qhead_per_kvhead=qhead_per_kvhead,
is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None,
with torch.cuda.device(q.device.index):
compile_key = (
dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None,
lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None,
window_size_left is not None, window_size_right is not None,
m_block_size, n_block_size, num_threads,
compute_capability,
)

if compile_key not in _flash_attn_fwd.compile_cache:
if compute_capability == 9:
# fa_fwd = FlashAttentionForwardSm80(
fa_fwd = FlashAttentionForwardSm90(
dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
is_causal=causal,
is_local=local,
pack_gqa=False,
m_block_size=m_block_size,
n_block_size=n_block_size,
# num_stages=1,
num_stages=2,
num_threads=num_threads,
Q_in_regs=False,
)
elif compute_capability == 10:
fa_fwd = FlashAttentionForwardSm100(
head_dim,
head_dim_v,
is_causal=causal,
is_local=local,
qhead_per_kvhead=qhead_per_kvhead,
is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None,
)
else:
raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x")
# TODO: check @can_implement
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream,
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor,
softcap, window_size_left, window_size_right,

)
else:
raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x")
# TODO: check @can_implement
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream,

_flash_attn_fwd.compile_cache[compile_key](
q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream,
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor,
softcap, window_size_left, window_size_right, additive_sink_tensor,
)
_flash_attn_fwd.compile_cache[compile_key](
q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream,
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor,
softcap, window_size_left, window_size_right, additive_sink_tensor,
)

return out, lse


Expand Down