@@ -541,6 +541,7 @@ def __call__(
541541 cute .ceil_div (cute .size (mK .shape [0 ]), self .cta_tiler [0 ]),
542542 cute .size (mQ .shape [2 ]), # num_heads = num_query_heads
543543 cute .size (mK .shape [3 ]),
544+ 1 , # num_splits
544545 cute .size (mK .shape [0 ]),
545546 mQ .shape [1 ],
546547 mV .shape [1 ],
@@ -927,12 +928,13 @@ def kernel(
927928 self .tile_n * self .cluster_shape_mnk [0 ], # careful, this case is not very well-tested
928929 self .is_causal ,
929930 self .is_local ,
931+ False , # is_split_kv
930932 None ,
931933 None ,
932934 qhead_per_kvhead_packgqa = 1 ,
933935 )
934936 SeqlenInfoCls = partial (
935- SeqlenInfoQK ,
937+ SeqlenInfoQK . create ,
936938 seqlen_q_static = mQ .shape [0 ],
937939 seqlen_k_static = mK .shape [0 ],
938940 mCuSeqlensQ = None ,
@@ -1159,7 +1161,7 @@ def load(
11591161 tile_scheduler = TileSchedulerCls ()
11601162 work_tile = tile_scheduler .initial_work_tile_info ()
11611163 while work_tile .is_valid_tile :
1162- n_block , head_idx , batch_idx = work_tile .tile_idx
1164+ n_block , head_idx , batch_idx , _ = work_tile .tile_idx
11631165 seqlen = SeqlenInfoCls (batch_idx )
11641166 m_block_min , m_block_max = block_info .get_m_block_min_max (
11651167 seqlen , n_block // self .cluster_shape_mnk [0 ]
@@ -1415,7 +1417,7 @@ def mma(
14151417 tile_scheduler = TileSchedulerCls ()
14161418 work_tile = tile_scheduler .initial_work_tile_info ()
14171419 while work_tile .is_valid_tile :
1418- n_block , head_idx , batch_idx = work_tile .tile_idx
1420+ n_block , head_idx , batch_idx , _ = work_tile .tile_idx
14191421 seqlen = SeqlenInfoCls (batch_idx ) # must be seqlen_k
14201422 m_block_min , m_block_max = block_info .get_m_block_min_max (
14211423 seqlen , n_block // self .cluster_shape_mnk [0 ]
@@ -1723,7 +1725,7 @@ def compute_loop(
17231725 tile_scheduler = TileSchedulerCls ()
17241726 work_tile = tile_scheduler .initial_work_tile_info ()
17251727 while work_tile .is_valid_tile :
1726- n_block , head_idx , batch_idx = work_tile .tile_idx
1728+ n_block , head_idx , batch_idx , _ = work_tile .tile_idx
17271729 seqlen = SeqlenInfoCls (batch_idx )
17281730 m_block_min , m_block_max = block_info .get_m_block_min_max (
17291731 seqlen , n_block // self .cluster_shape_mnk [0 ]
@@ -1981,7 +1983,7 @@ def dQacc_reduce(
19811983 pipeline .PipelineUserType .Producer , self .sdQaccum_stage
19821984 )
19831985 while work_tile .is_valid_tile :
1984- n_block , head_idx , batch_idx = work_tile .tile_idx
1986+ n_block , head_idx , batch_idx , _ = work_tile .tile_idx
19851987 seqlen = SeqlenInfoCls (batch_idx )
19861988 m_block_min , m_block_max = block_info .get_m_block_min_max (
19871989 seqlen , n_block // self .cluster_shape_mnk [0 ]
0 commit comments