Skip to content

Commit 6ac28f4

Browse files
authored
bugfix: fix prefill kernels' lse result for empty kv-cache (#440)
Thank @hnyls2002 for spotting this bug.
1 parent c93f647 commit 6ac28f4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

include/flashinfer/attention/prefill.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,7 +1558,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
15581558
// normalize d
15591559
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
15601560

1561-
const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size);
1561+
const uint32_t num_kv_chunks = ceil_div(max(kv_len, 1), kv_chunk_size);
15621562

15631563
// write back
15641564
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
@@ -1872,7 +1872,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
18721872
// normalize d
18731873
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
18741874

1875-
const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size);
1875+
const uint32_t num_kv_chunks = ceil_div(max(kv_len, 1), kv_chunk_size);
18761876

18771877
// write_back
18781878
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(

0 commit comments

Comments
 (0)