Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion hopper/benchmark_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
# See more here: https://github.com/triton-lang/triton/blob/d9f10ebdc5da53f73eb852fde73d8d7d80b679d1/python/triton/testing.py#L487

import time
import random
import itertools
import torch
import triton
import torch.nn.functional as F

from triton.testing import do_bench, do_bench_cudagraph
Expand Down Expand Up @@ -59,9 +62,23 @@

print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}")

for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]:
# for seqlen in [s * 1024 for s in [8]]:
for var_len, seqlen in itertools.product([False, True], [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]):
cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int)

if var_len:
if not should_run_flashmla:
continue
for i in range(batch_size):
cache_seqlens[i] = max(random.normalvariate(seqlen, seqlen / 2), seqlen_q)

max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
seqlen = max_seqlen_pad
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")

num_splits = 0
q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device)
try:
Expand Down