Skip to content

Commit 19bc77f

Browse files
[Fix] Fix hicache backend (#8991)
1 parent 86497d9 commit 19bc77f

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -611,12 +611,7 @@ def init_memory_pool_and_cache(self):
611611
hicache_ratio=server_args.hicache_ratio,
612612
hicache_size=server_args.hicache_size,
613613
hicache_write_policy=server_args.hicache_write_policy,
614-
hicache_io_backend=(
615-
"direct"
616-
if server_args.attention_backend
617-
== "fa3" # hot fix for incompatibility
618-
else server_args.hicache_io_backend
619-
),
614+
hicache_io_backend=server_args.hicache_io_backend,
620615
hicache_mem_layout=server_args.hicache_mem_layout,
621616
hicache_storage_backend=server_args.hicache_storage_backend,
622617
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,

python/sglang/srt/model_executor/model_runner.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,6 @@ def model_specific_adjustment(self):
403403
is_hopper_with_cuda_12_3()
404404
and is_no_spec_infer_or_topk_one(server_args)
405405
and is_fa3_default_architecture(self.model_config.hf_config)
406-
and (not server_args.enable_hierarchical_cache)
407406
):
408407
server_args.attention_backend = "fa3"
409408
elif _is_hip:
@@ -416,9 +415,7 @@ def model_specific_adjustment(self):
416415
)
417416
else:
418417
# MLA architecture
419-
if is_hopper_with_cuda_12_3() and (
420-
not server_args.enable_hierarchical_cache
421-
):
418+
if is_hopper_with_cuda_12_3():
422419
server_args.attention_backend = "fa3"
423420
elif is_sm100_supported():
424421
server_args.attention_backend = "flashinfer"
@@ -506,6 +503,27 @@ def model_specific_adjustment(self):
506503
if self.model_config.context_len > 8192:
507504
self.mem_fraction_static *= 0.85
508505

506+
if (
507+
server_args.enable_hierarchical_cache
508+
and server_args.hicache_io_backend == "kernel"
509+
):
510+
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
511+
if server_args.decode_attention_backend is None:
512+
if not self.use_mla_backend:
513+
server_args.decode_attention_backend = (
514+
"flashinfer" if is_flashinfer_available() else "triton"
515+
)
516+
else:
517+
server_args.decode_attention_backend = (
518+
"flashinfer" if is_sm100_supported() else "triton"
519+
)
520+
elif server_args.decode_attention_backend == "fa3":
521+
server_args.hicache_io_backend = "direct"
522+
logger.warning(
523+
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
524+
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
525+
)
526+
509527
def init_torch_distributed(self):
510528
logger.info("Init torch distributed begin.")
511529

0 commit comments

Comments
 (0)