Skip to content

Commit 9d3b08b

Browse files
committed
Fix input tensors
Disable splitk Signed-off-by: kaixih <[email protected]>
1 parent 11b5c5c commit 9d3b08b

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

csrc/attention/mla/cutlass_mla_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options(
119119
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
120120
static_cast<ElementAcc*>(nullptr), stride_LSE},
121121
hw_info,
122-
-1, // split_kv
122+
1, // split_kv
123123
nullptr, // is_var_split_kv
124124
};
125125
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute

tests/kernels/test_cutlass_mla_decode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
7575
pack_factor = 128 // block_size
7676
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
7777

78-
q = torch.randn(bs, h_q, d)
78+
# Amplify input values to ensure test coverage of edge cases where CUTLASS
79+
# kernel errors occur with split_k settings.
80+
q = torch.randn(bs, h_q, d) * 100
7981
block_table = torch.randint(0,
8082
bs * block_num, (bs, block_num),
8183
dtype=torch.int32)

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def _forward_decode(
8686
device=q_nope.device)
8787

8888
# Run MLA
89+
# Clone q_nope and q_pe to make sure strides computation is correct.
90+
q_nope = q_nope.clone()
91+
q_pe = q_pe.clone()
8992
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
9093
attn_metadata.decode.seq_lens,
9194
attn_metadata.decode.block_table, self.scale)

0 commit comments

Comments
 (0)