Skip to content

Commit d2e83e4

Browse files
committed
rebase latest main branch
1 parent 9d55ef1 commit d2e83e4

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/attention/backends/flashinfer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def _get_decode_wrapper(self):
113113
self.runner.parallel_config))
114114
num_kv_heads = self.runner.model_config.get_num_kv_heads(
115115
self.runner.parallel_config)
116-
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
116+
use_tensor_cores = (num_qo_heads // num_kv_heads) not in
117+
(1, 2, 4, 8)
117118
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
118119
self._get_workspace_buffer(),
119120
"NHD",
@@ -171,7 +172,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
171172
self.runner.parallel_config))
172173
num_kv_heads = self.runner.model_config.get_num_kv_heads(
173174
self.runner.parallel_config)
174-
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
175+
use_tensor_cores = (num_qo_heads // num_kv_heads) not in
176+
(1, 2, 4, 8)
175177
self._graph_decode_wrapper = \
176178
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
177179
self._graph_decode_workspace_buffer, _indptr_buffer,

0 commit comments

Comments
 (0)