Skip to content
Merged
91 changes: 91 additions & 0 deletions benchmarks/bench_blackwell_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch
from triton.testing import do_bench

import flashinfer


def bench_fmha_blackwell(
batch_size,
qkv_len,
num_heads,
head_dim,
causal,
dtype,
):
q = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
k = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
v = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)

qo_segment_offsets = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)
kv_segment_offsets = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)

o, lse = flashinfer.prefill.fmha_varlen(
q, k, v, qo_segment_offsets, kv_segment_offsets, causal=causal
)

ms = do_bench(
lambda: flashinfer.prefill.fmha_varlen(
q,
k,
v,
qo_segment_offsets,
kv_segment_offsets,
causal=causal,
)
)

def flops(ms):
if causal:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9
else:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9

print(
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s"
)


if __name__ == "__main__":
bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 128, False, torch.bfloat16)

bench_fmha_blackwell(128, 512, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16)
101 changes: 101 additions & 0 deletions csrc/fmha_cutlass_sm100.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/cutlass_utils.cuh>

#include "pytorch_extension_utils.h"

#define DISPATCH_mask_mode(mask_mode, MASK_MODE, ...) \
[&]() -> bool { \
if (mask_mode == MaskMode::kNone) { \
constexpr MaskMode MASK_MODE = MaskMode::kNone; \
return __VA_ARGS__(); \
} else if (mask_mode == MaskMode::kCausal) { \
constexpr MaskMode MASK_MODE = MaskMode::kCausal; \
return __VA_ARGS__(); \
} \
return false; \
}()

#define DISPATCH_head_dim(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, ...) \
[&]() -> bool { \
if (head_dim_qk == 192 && head_dim_vo == 128) { \
constexpr int HEAD_DIM_QK = 192; \
constexpr int HEAD_DIM_VO = 128; \
return __VA_ARGS__(); \
} else if (head_dim_qk == 128 && head_dim_vo == 128) { \
constexpr int HEAD_DIM_QK = 128; \
constexpr int HEAD_DIM_VO = 128; \
return __VA_ARGS__(); \
} \
return false; \
}()

#define DISPATCH_DTYPE_IN_OUT(in_dtype, out_dtype, c_type_in, c_type_out, ...) \
[&]() -> bool { \
if (in_dtype == out_dtype) { \
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(in_dtype, c_type_in, [&] { \
using c_type_out = c_type_in; \
return __VA_ARGS__(); \
}); \
} \
return false; \
}()

#define DISPATCH_context(DTypeIn, DTypeOut, HEAD_DIM_QK, HEAD_DIM_VO, MaskMode, ...) \
{ \
DISPATCH_mask_mode(mask_mode, MaskMode, [&] { \
return DISPATCH_DTYPE_IN_OUT(scalar_type_in, scalar_type_out, DTypeIn, DTypeOut, [&] { \
return DISPATCH_head_dim(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, \
[&] { return __VA_ARGS__(); }); \
}); \
}); \
}

using namespace flashinfer;

void FMHACutlassSM100Run(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_lens,
at::Tensor kv_lens, at::Tensor qo_segment_offsets,
at::Tensor kv_segment_offsets, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t batch_size,
int64_t total_qo_len, int64_t total_kv_len, int64_t max_qo_len,
int64_t max_kv_len) {
CHECK(q.scalar_type() == k.scalar_type());
auto scalar_type_in = q.scalar_type();
auto scalar_type_out = o.scalar_type();
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
DISPATCH_context(DTypeIn, DTypeOut, HEAD_DIM_QK, HEAD_DIM_VO, MASK_MODE, [&] {
using cutlass_type_in = cutlass_dtype_t<DTypeIn>;
using cutlass_type_out = cutlass_dtype_t<DTypeOut>;
using TILE_Q = _256;
using TILE_KV = _128;
using D_QK = cute::Int<HEAD_DIM_QK>;
using D_VO = cute::Int<HEAD_DIM_VO>;
using TileShapeQK = Shape<TILE_Q, TILE_KV, D_QK>;
using TileShapePV = Shape<TILE_Q, D_VO, TILE_KV>;
using CutlassMaskMode =
typename std::conditional<MASK_MODE == MaskMode::kCausal, CausalMask, ResidualMask>::type;
run_fmha_fwd<cutlass_type_in, cutlass_type_out, TileShapeQK, TileShapePV, CutlassMaskMode>(
q, k, v, qo_lens, kv_lens, qo_segment_offsets, kv_segment_offsets, o, maybe_lse,
mask_mode_code, sm_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, batch_size,
total_qo_len, total_kv_len, max_qo_len, max_kv_len);

return true;
});
}
27 changes: 27 additions & 0 deletions csrc/fmha_cutlass_sm100_pybind.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2023-2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pytorch_extension_utils.h"

void FMHACutlassSM100Run(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_lens,
at::Tensor kv_lens, at::Tensor qo_segment_offsets,
at::Tensor kv_segment_offsets, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t batch_size,
int64_t total_qo_len, int64_t total_kv_len, int64_t max_qo_len,
int64_t max_kv_len);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("run", FMHACutlassSM100Run); }
1 change: 1 addition & 0 deletions flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .attention import (
gen_customize_single_prefill_module as gen_customize_single_prefill_module,
)
from .attention import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module
from .attention import gen_pod_module as gen_pod_module
from .attention import gen_sampling_tvm_binding as gen_sampling_tvm_binding
from .attention import gen_single_decode_module as gen_single_decode_module
Expand Down
1 change: 1 addition & 0 deletions flashinfer/jit/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .pytorch import (
gen_customize_single_prefill_module as gen_customize_single_prefill_module,
)
from .pytorch import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module
from .pytorch import gen_pod_module as gen_pod_module
from .pytorch import gen_single_decode_module as gen_single_decode_module
from .pytorch import gen_single_prefill_module as gen_single_prefill_module
Expand Down
62 changes: 61 additions & 1 deletion flashinfer/jit/attention/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jinja2
import torch

from ..core import load_cuda_ops, logger, sm90a_nvcc_flags
from ..core import load_cuda_ops, logger, sm90a_nvcc_flags, sm100a_nvcc_flags
from ..env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR
from ..utils import (
dtype_map,
Expand Down Expand Up @@ -1340,3 +1340,63 @@ def gen_customize_batch_prefill_module(
)
else:
raise ValueError(f"Invalid backend: {backend}")


def get_fmha_cutlass_sm100a_uri(
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
dtype_idx: torch.dtype,
head_dim_qk: int,
head_dim_vo: int,
pos_encoding_mode: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
) -> str:
# NOTE(Zihao): use different uri after when support customize attention
return "fmha_cutlass_sm100a"
# return (
# f"fmha_cutlass_sm100a_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
# f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
# f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
# f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
# f"head_dim_qk_{head_dim_qk}_"
# f"head_dim_vo_{head_dim_vo}_"
# f"posenc_{pos_encoding_mode}_"
# f"use_swa_{use_sliding_window}_"
# f"use_logits_cap_{use_logits_soft_cap}"
# )


def gen_fmha_cutlass_sm100a_module(
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
dtype_idx: torch.dtype,
head_dim_qk: int,
head_dim_vo: int,
pos_encoding_mode: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
):
uri = get_fmha_cutlass_sm100a_uri(
dtype_q,
dtype_kv,
dtype_o,
dtype_idx,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
)

source_paths = [
FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100.cu",
FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100_pybind.cu",
]
return load_cuda_ops(
uri,
source_paths,
extra_cuda_cflags=sm100a_nvcc_flags,
)
2 changes: 2 additions & 0 deletions flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def remove_unwanted_pytorch_nvcc_flags():
remove_unwanted_pytorch_nvcc_flags()

sm90a_nvcc_flags = ["-gencode", "arch=compute_90a,code=sm_90a"]
sm100a_nvcc_flags = ["-gencode", "arch=compute_100a,code=sm_100a"]


def load_cuda_ops(
Expand Down Expand Up @@ -113,6 +114,7 @@ def load_cuda_ops(
"-lineinfo",
"--ptxas-options=-v",
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
"-DCUTLASS_DEBUG_TRACE_LEVEL=2",
]
else:
# non debug mode
Expand Down
Loading