Skip to content

Commit e724e25

Browse files
authored
[Cute,Fwd,Sm100] Implement SplitKV (#1940)
* Implement split KV * Remove modal bench harness * Fixes
1 parent 6c9eef9 commit e724e25

13 files changed

+755
-523
lines changed

flash_attn/cute/block_info.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ class BlockInfo:
1515
tile_n: cutlass.Constexpr[int]
1616
is_causal: cutlass.Constexpr[bool]
1717
is_local: cutlass.Constexpr[bool] = False
18+
is_split_kv: cutlass.Constexpr[bool] = False
1819
window_size_left: Optional[Int32] = None
1920
window_size_right: Optional[Int32] = None
2021
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
2122

2223
@cute.jit
23-
def get_n_block_min_max(self, seqlen_info: SeqlenInfoQK, m_block: Int32) -> Tuple[Int32, Int32]:
24+
def get_n_block_min_max(
25+
self,
26+
seqlen_info: SeqlenInfoQK,
27+
m_block: Int32,
28+
split_idx: cutlass.Int32 = 0,
29+
num_splits: cutlass.Int32 = 1,
30+
) -> Tuple[Int32, Int32]:
2431
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
2532
if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
2633
m_idx_max = (m_block + 1) * self.tile_m
@@ -37,6 +44,14 @@ def get_n_block_min_max(self, seqlen_info: SeqlenInfoQK, m_block: Int32) -> Tupl
3744
n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
3845
n_idx_left = n_idx - self.window_size_left
3946
n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
47+
if cutlass.const_expr(self.is_split_kv):
48+
num_n_blocks_per_split = (
49+
cutlass.Int32(0)
50+
if n_block_max <= n_block_min
51+
else (n_block_max - n_block_min + num_splits - 1) // num_splits
52+
)
53+
n_block_min = n_block_min + split_idx * num_n_blocks_per_split
54+
n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)
4055
return n_block_min, n_block_max
4156

4257
@cute.jit

flash_attn/cute/flash_bwd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def __call__(
405405
num_block=cute.ceil_div(mK.shape[1], self.n_block_size),
406406
num_head=num_head,
407407
num_batch=num_batch,
408+
num_splits=1,
408409
seqlen_k=0,
409410
headdim=mK.shape[2],
410411
headdim_v=mV.shape[2],
@@ -505,10 +506,10 @@ def kernel(
505506
tile_scheduler = TileScheduler.create(tile_sched_params)
506507
work_tile = tile_scheduler.initial_work_tile_info()
507508

508-
n_block, head_idx, batch_idx = work_tile.tile_idx
509+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
509510

510511
if work_tile.is_valid_tile:
511-
seqlen = SeqlenInfoQK(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK)
512+
seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK)
512513

513514
m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
514515
m_block_min = 0

flash_attn/cute/flash_bwd_postprocess.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def __call__(
242242
num_block=cute.ceil_div(mdQ.shape[1], self.tile_m),
243243
num_head=num_head,
244244
num_batch=num_batch,
245+
num_splits=1,
245246
seqlen_k=0,
246247
headdim=mdQ.shape[2],
247248
headdim_v=0,
@@ -317,14 +318,14 @@ def kernel(
317318
tile_scheduler = TileScheduler.create(tile_sched_params)
318319
work_tile = tile_scheduler.initial_work_tile_info()
319320

320-
m_block, num_head, batch_size = work_tile.tile_idx
321+
m_block, num_head, batch_size, _ = work_tile.tile_idx
321322

322323
if work_tile.is_valid_tile:
323324
# ///////////////////////////////////////////////////////////////////////////////
324325
# Get the appropriate tiles for this thread block.
325326
# ///////////////////////////////////////////////////////////////////////////////
326327

327-
seqlen = SeqlenInfoQK(
328+
seqlen = SeqlenInfoQK.create(
328329
batch_size,
329330
mdQ.shape[1],
330331
0,

flash_attn/cute/flash_bwd_preprocess.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __call__(
160160
num_block=cute.ceil_div(mO.shape[1], self.m_block_size),
161161
num_head=num_head,
162162
num_batch=num_batch,
163+
num_splits=1,
163164
seqlen_k=0,
164165
headdim=0,
165166
headdim_v=mO.shape[2],
@@ -212,13 +213,13 @@ def kernel(
212213

213214
tile_scheduler = TileScheduler.create(tile_sched_params)
214215
work_tile = tile_scheduler.initial_work_tile_info()
215-
m_block, num_head, batch_size = work_tile.tile_idx
216+
m_block, num_head, batch_size, _ = work_tile.tile_idx
216217

217218
if work_tile.is_valid_tile:
218219
# ///////////////////////////////////////////////////////////////////////////////
219220
# Get the appropriate tiles for this thread block.
220221
# ///////////////////////////////////////////////////////////////////////////////
221-
seqlen = SeqlenInfoQK(
222+
seqlen = SeqlenInfoQK.create(
222223
batch_size,
223224
mO.shape[1],
224225
0,

flash_attn/cute/flash_bwd_sm100.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def __call__(
541541
cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]),
542542
cute.size(mQ.shape[2]), # num_heads = num_query_heads
543543
cute.size(mK.shape[3]),
544+
1, # num_splits
544545
cute.size(mK.shape[0]),
545546
mQ.shape[1],
546547
mV.shape[1],
@@ -927,12 +928,13 @@ def kernel(
927928
self.tile_n * self.cluster_shape_mnk[0], # careful, this case is not very well-tested
928929
self.is_causal,
929930
self.is_local,
931+
False, # is_split_kv
930932
None,
931933
None,
932934
qhead_per_kvhead_packgqa=1,
933935
)
934936
SeqlenInfoCls = partial(
935-
SeqlenInfoQK,
937+
SeqlenInfoQK.create,
936938
seqlen_q_static=mQ.shape[0],
937939
seqlen_k_static=mK.shape[0],
938940
mCuSeqlensQ=None,
@@ -1159,7 +1161,7 @@ def load(
11591161
tile_scheduler = TileSchedulerCls()
11601162
work_tile = tile_scheduler.initial_work_tile_info()
11611163
while work_tile.is_valid_tile:
1162-
n_block, head_idx, batch_idx = work_tile.tile_idx
1164+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
11631165
seqlen = SeqlenInfoCls(batch_idx)
11641166
m_block_min, m_block_max = block_info.get_m_block_min_max(
11651167
seqlen, n_block // self.cluster_shape_mnk[0]
@@ -1415,7 +1417,7 @@ def mma(
14151417
tile_scheduler = TileSchedulerCls()
14161418
work_tile = tile_scheduler.initial_work_tile_info()
14171419
while work_tile.is_valid_tile:
1418-
n_block, head_idx, batch_idx = work_tile.tile_idx
1420+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
14191421
seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k
14201422
m_block_min, m_block_max = block_info.get_m_block_min_max(
14211423
seqlen, n_block // self.cluster_shape_mnk[0]
@@ -1723,7 +1725,7 @@ def compute_loop(
17231725
tile_scheduler = TileSchedulerCls()
17241726
work_tile = tile_scheduler.initial_work_tile_info()
17251727
while work_tile.is_valid_tile:
1726-
n_block, head_idx, batch_idx = work_tile.tile_idx
1728+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
17271729
seqlen = SeqlenInfoCls(batch_idx)
17281730
m_block_min, m_block_max = block_info.get_m_block_min_max(
17291731
seqlen, n_block // self.cluster_shape_mnk[0]
@@ -1981,7 +1983,7 @@ def dQacc_reduce(
19811983
pipeline.PipelineUserType.Producer, self.sdQaccum_stage
19821984
)
19831985
while work_tile.is_valid_tile:
1984-
n_block, head_idx, batch_idx = work_tile.tile_idx
1986+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
19851987
seqlen = SeqlenInfoCls(batch_idx)
19861988
m_block_min, m_block_max = block_info.get_m_block_min_max(
19871989
seqlen, n_block // self.cluster_shape_mnk[0]

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def __call__(
397397
cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),
398398
cute.size(mK.shape[2]),
399399
cute.size(mK.shape[3]),
400+
1, # num_splits
400401
cute.size(mK.shape[0]),
401402
mQ.shape[1],
402403
mV.shape[1],
@@ -551,12 +552,13 @@ def kernel(
551552
self.tile_n,
552553
self.is_causal,
553554
self.is_local,
555+
False, # is_split_kv
554556
None,
555557
None,
556558
qhead_per_kvhead_packgqa=1,
557559
)
558560
SeqlenInfoCls = partial(
559-
SeqlenInfoQK,
561+
SeqlenInfoQK.create,
560562
seqlen_q_static=mQ.shape[0],
561563
seqlen_k_static=mK.shape[0],
562564
mCuSeqlensQ=None,
@@ -678,7 +680,7 @@ def load(
678680
tile_scheduler = TileSchedulerCls()
679681
work_tile = tile_scheduler.initial_work_tile_info()
680682
while work_tile.is_valid_tile:
681-
n_block, head_idx, batch_idx = work_tile.tile_idx
683+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
682684
seqlen = SeqlenInfoCls(batch_idx)
683685
mK_cur = mK[None, None, head_idx, batch_idx]
684686
gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
@@ -932,7 +934,7 @@ def mma(
932934
tile_scheduler = TileSchedulerCls()
933935
work_tile = tile_scheduler.initial_work_tile_info()
934936
while work_tile.is_valid_tile:
935-
n_block, head_idx, batch_idx = work_tile.tile_idx
937+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
936938
seqlen = SeqlenInfoCls(batch_idx)
937939
mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k)
938940
mask_fn = partial(
@@ -1208,7 +1210,7 @@ def dQaccum_store(
12081210
tile_scheduler = TileSchedulerCls()
12091211
work_tile = tile_scheduler.initial_work_tile_info()
12101212
while work_tile.is_valid_tile:
1211-
n_block, head_idx, batch_idx = work_tile.tile_idx
1213+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
12121214
seqlen = SeqlenInfoCls(batch_idx)
12131215
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
12141216
gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))

flash_attn/cute/flash_fwd.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -759,11 +759,12 @@ def kernel(
759759
self.tile_n,
760760
self.is_causal,
761761
self.is_local,
762+
False, # is_split_kv
762763
window_size_left,
763764
window_size_right,
764765
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
765766
)
766-
seqlen = SeqlenInfoQK(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0])
767+
seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0])
767768
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
768769
# TODO: return early if n_block_max == 0
769770
# if self.is_causal:
@@ -1459,6 +1460,7 @@ def __call__(
14591460
cute.size(mQ.shape[3])
14601461
if const_expr(mCuSeqlensQ is None)
14611462
else cute.size(mCuSeqlensQ.shape[0] - 1),
1463+
1, # num_splits
14621464
cute.size(mK.shape[0]),
14631465
mQ.shape[1],
14641466
mV.shape[1],
@@ -1652,12 +1654,13 @@ def kernel(
16521654
self.tile_n,
16531655
self.is_causal,
16541656
self.is_local,
1657+
False, # is_split_kv
16551658
window_size_left,
16561659
window_size_right,
16571660
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
16581661
)
16591662
SeqlenInfoCls = partial(
1660-
SeqlenInfoQK,
1663+
SeqlenInfoQK.create,
16611664
seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
16621665
seqlen_k_static=mK.shape[0],
16631666
mCuSeqlensQ=mCuSeqlensQ,
@@ -1764,7 +1767,7 @@ def load(
17641767
work_tile = tile_scheduler.initial_work_tile_info()
17651768
while work_tile.is_valid_tile:
17661769
# if work_tile.is_valid_tile:
1767-
m_block, head_idx, batch_idx = work_tile.tile_idx
1770+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
17681771
seqlen = SeqlenInfoCls(batch_idx)
17691772
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
17701773
head_idx_kv = (
@@ -2106,7 +2109,7 @@ def mma(
21062109
# if work_tile.is_valid_tile:
21072110

21082111
# shape: (atom_v_m * rest_m)
2109-
m_block, head_idx, batch_idx = work_tile.tile_idx
2112+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
21102113
seqlen = SeqlenInfoCls(batch_idx)
21112114
mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k)
21122115
mask_fn = partial(

flash_attn/cute/flash_fwd_combine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ class SharedStorage:
255255
# Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
256256
seqlen = mO_partial.shape[0]
257257
num_head = mO_partial.shape[3]
258-
batch_size = mO_partial.shape[4]
258+
batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1)
259259

260260
# Create FastDivmod objects for efficient division
261261
seqlen_divmod = FastDivmod.create(seqlen)
@@ -341,7 +341,7 @@ def kernel(
341341
else mLSE_partial.shape[1]
342342
)
343343
# Handle variable length sequences using SeqlenInfo
344-
seqlen_info = SeqlenInfo(
344+
seqlen_info = SeqlenInfo.create(
345345
batch_idx=batch_idx,
346346
seqlen_static=mO_partial.shape[0],
347347
cu_seqlens=cu_seqlens,

0 commit comments

Comments
 (0)