File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -113,7 +113,8 @@ def _get_decode_wrapper(self):
113
113
self .runner .parallel_config ))
114
114
num_kv_heads = self .runner .model_config .get_num_kv_heads (
115
115
self .runner .parallel_config )
116
- use_tensor_cores = num_qo_heads // num_kv_heads >= 4
116
+ use_tensor_cores = (num_qo_heads // num_kv_heads ) not in
117
+ (1 , 2 , 4 , 8 )
117
118
self ._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper (
118
119
self ._get_workspace_buffer (),
119
120
"NHD" ,
@@ -171,7 +172,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
171
172
self .runner .parallel_config ))
172
173
num_kv_heads = self .runner .model_config .get_num_kv_heads (
173
174
self .runner .parallel_config )
174
- use_tensor_cores = num_qo_heads // num_kv_heads >= 4
175
+ use_tensor_cores = (num_qo_heads // num_kv_heads ) not in
176
+ (1 , 2 , 4 , 8 )
175
177
self ._graph_decode_wrapper = \
176
178
CUDAGraphBatchDecodeWithPagedKVCacheWrapper (
177
179
self ._graph_decode_workspace_buffer , _indptr_buffer ,
You can’t perform that action at this time.
0 commit comments