@@ -403,7 +403,6 @@ def model_specific_adjustment(self):
403
403
is_hopper_with_cuda_12_3 ()
404
404
and is_no_spec_infer_or_topk_one (server_args )
405
405
and is_fa3_default_architecture (self .model_config .hf_config )
406
- and (not server_args .enable_hierarchical_cache )
407
406
):
408
407
server_args .attention_backend = "fa3"
409
408
elif _is_hip :
@@ -416,9 +415,7 @@ def model_specific_adjustment(self):
416
415
)
417
416
else :
418
417
# MLA architecture
419
- if is_hopper_with_cuda_12_3 () and (
420
- not server_args .enable_hierarchical_cache
421
- ):
418
+ if is_hopper_with_cuda_12_3 ():
422
419
server_args .attention_backend = "fa3"
423
420
elif is_sm100_supported ():
424
421
server_args .attention_backend = "flashinfer"
@@ -506,6 +503,27 @@ def model_specific_adjustment(self):
506
503
if self .model_config .context_len > 8192 :
507
504
self .mem_fraction_static *= 0.85
508
505
506
+ if (
507
+ server_args .enable_hierarchical_cache
508
+ and server_args .hicache_io_backend == "kernel"
509
+ ):
510
+ # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
511
+ if server_args .decode_attention_backend is None :
512
+ if not self .use_mla_backend :
513
+ server_args .decode_attention_backend = (
514
+ "flashinfer" if is_flashinfer_available () else "triton"
515
+ )
516
+ else :
517
+ server_args .decode_attention_backend = (
518
+ "flashinfer" if is_sm100_supported () else "triton"
519
+ )
520
+ elif server_args .decode_attention_backend == "fa3" :
521
+ server_args .hicache_io_backend = "direct"
522
+ logger .warning (
523
+ "FlashAttention3 decode backend is not compatible with hierarchical cache. "
524
+ f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
525
+ )
526
+
509
527
def init_torch_distributed (self ):
510
528
logger .info ("Init torch distributed begin." )
511
529
0 commit comments