Skip to content

Commit 9d76117

Browse files
ispobockhanming-lu
authored andcommitted
EAGLE cache fix for SWARadixCache (sgl-project#11231)
Co-authored-by: Hanming Lu <[email protected]>
1 parent 3a7a9c0 commit 9d76117

File tree

8 files changed

+248
-31
lines changed

8 files changed

+248
-31
lines changed

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ def init_memory_pool_and_cache(self):
777777
sliding_window_size=self.sliding_window_size,
778778
page_size=self.page_size,
779779
disable=server_args.disable_radix_cache,
780+
is_eagle=self.spec_algorithm.is_eagle(),
780781
)
781782
elif server_args.enable_lmcache:
782783
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (

python/sglang/srt/mem_cache/allocator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,15 @@ def free_swa(self, free_index: torch.Tensor):
274274
self.full_to_swa_index_mapping[free_index] = 0
275275

276276
def backup_state(self):
277-
raise NotImplementedError
277+
return [
278+
self.full_attn_allocator.backup_state(),
279+
self.swa_attn_allocator.backup_state(),
280+
]
278281

279282
def restore_state(self, state):
280-
raise NotImplementedError
283+
assert len(state) == 2
284+
self.full_attn_allocator.restore_state(state[0])
285+
self.swa_attn_allocator.restore_state(state[1])
281286

282287
def clear(self):
283288
self.swa_attn_allocator.clear()

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ def __init__(
749749
self,
750750
size: int,
751751
size_swa: int,
752+
dtype: torch.dtype,
752753
swa_attention_layer_ids: List[int],
753754
full_attention_layer_ids: List[int],
754755
enable_kvcache_transpose: bool,
@@ -757,6 +758,7 @@ def __init__(
757758
):
758759
self.size = size
759760
self.size_swa = size_swa
761+
self.dtype = dtype
760762
self.swa_layer_nums = len(swa_attention_layer_ids)
761763
self.full_layer_nums = len(full_attention_layer_ids)
762764
kwargs["page_size"] = 1
@@ -766,11 +768,13 @@ def __init__(
766768

767769
self.swa_kv_pool = token_to_kv_pool_class(
768770
size=size_swa,
771+
dtype=dtype,
769772
layer_num=self.swa_layer_nums,
770773
**kwargs,
771774
)
772775
self.full_kv_pool = token_to_kv_pool_class(
773776
size=size,
777+
dtype=dtype,
774778
layer_num=self.full_layer_nums,
775779
**kwargs,
776780
)

python/sglang/srt/mem_cache/radix_cache.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ def cache_finished_req(self, req: Req):
326326

327327
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
328328
all_token_len = len(token_ids)
329+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
330+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
329331
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
330332
kv_indices = self.req_to_token_pool.req_to_token[
331333
req.req_pool_idx, :all_token_len
@@ -349,7 +351,8 @@ def cache_finished_req(self, req: Req):
349351

350352
old_prefix_len = len(req.prefix_indices)
351353
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
352-
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
354+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
355+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
353356
old_prefix_len -= 1
354357

355358
# Radix Cache takes one ref in memory pool
@@ -370,7 +373,8 @@ def cache_unfinished_req(self, req: Req, chunked=False):
370373

371374
token_ids = req.fill_ids
372375
all_token_len = len(token_ids)
373-
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
376+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
377+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
374378
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
375379
kv_indices = self.req_to_token_pool.req_to_token[
376380
req.req_pool_idx, :all_token_len
@@ -393,7 +397,8 @@ def cache_unfinished_req(self, req: Req, chunked=False):
393397

394398
old_prefix_len = len(req.prefix_indices)
395399
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
396-
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
400+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
401+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
397402
old_prefix_len -= 1
398403

399404
# Radix Cache takes one ref in memory pool

python/sglang/srt/mem_cache/swa_radix_cache.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
3333
from sglang.srt.mem_cache.radix_cache import (
3434
RadixKey,
35+
_convert_to_bigram_key,
3536
_key_match_page_size1,
3637
_key_match_paged,
3738
get_child_key,
@@ -327,12 +328,14 @@ def __init__(
327328
sliding_window_size: int,
328329
page_size: int,
329330
disable: bool = False,
331+
is_eagle: bool = False,
330332
):
331333
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
332334
self.req_to_token_pool = req_to_token_pool
333335
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
334336
self.page_size = page_size
335337
self.disable = disable
338+
self.is_eagle = is_eagle
336339

337340
if self.token_to_kv_pool_allocator:
338341
self.device = self.token_to_kv_pool_allocator.device
@@ -346,6 +349,11 @@ def __init__(
346349
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
347350
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
348351

352+
if is_eagle:
353+
self.key_convert_fn = _convert_to_bigram_key
354+
else:
355+
self.key_convert_fn = lambda key: key
356+
349357
self.sliding_window_size = sliding_window_size
350358
self.reset()
351359

@@ -376,6 +384,8 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
376384
The last node create a new child if the prefix is shorter
377385
than the last node's value.
378386
"""
387+
key.token_ids = self.key_convert_fn(key.token_ids)
388+
379389
if self.disable or len(key) == 0:
380390
return MatchResult(
381391
device_indices=torch.empty(
@@ -406,8 +416,15 @@ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
406416
if self.disable:
407417
return 0
408418

419+
key.token_ids = self.key_convert_fn(key.token_ids)
420+
409421
if value is None:
410422
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
423+
424+
if self.is_eagle:
425+
# Make sure the value len equal to the EAGLE bigram key len
426+
value = value[: len(key)]
427+
411428
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
412429

413430
def cache_finished_req(self, req: Req) -> None:
@@ -422,25 +439,41 @@ def cache_finished_req(self, req: Req) -> None:
422439
return
423440

424441
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
442+
all_token_len = len(token_ids)
443+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
444+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
445+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
425446
kv_indices = self.req_to_token_pool.req_to_token[
426-
req.req_pool_idx, : len(token_ids)
447+
req.req_pool_idx, :all_token_len
427448
]
428449

429450
if self.page_size != 1:
430-
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
451+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
431452
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
432453
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
433454
else:
434-
page_aligned_len = len(kv_indices)
455+
page_aligned_len = actual_kv_len
435456
page_aligned_kv_indices = kv_indices.clone()
457+
if self.is_eagle:
458+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
459+
460+
page_aligned_token_len = (
461+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
462+
)
463+
464+
old_prefix_len = len(req.prefix_indices)
465+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
466+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
467+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
468+
old_prefix_len -= 1
436469

437470
# Radix Cache takes one ref in memory pool
438471
# insert the token_ids and kv_indices into the radix tree
439472
# Note: the insert function already frees the overlapped kv_indices
440473
new_prefix_len = self.insert(
441-
RadixKey(token_ids[:page_aligned_len], req.extra_key),
474+
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
442475
page_aligned_kv_indices,
443-
len(req.prefix_indices),
476+
old_prefix_len,
444477
)
445478

446479
# Remove req slot release the cache lock
@@ -459,39 +492,56 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
459492
return
460493

461494
token_ids = req.fill_ids
495+
all_token_len = len(token_ids)
496+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
497+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
498+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
462499
kv_indices = self.req_to_token_pool.req_to_token[
463-
req.req_pool_idx, : len(token_ids)
500+
req.req_pool_idx, :all_token_len
464501
]
465502

466503
if self.page_size != 1:
467-
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
504+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
468505
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
469506
else:
470-
page_aligned_len = len(kv_indices)
507+
page_aligned_len = actual_kv_len
471508
page_aligned_kv_indices = kv_indices.clone()
472-
page_aligned_token_ids = token_ids[:page_aligned_len]
509+
510+
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
511+
page_aligned_token_len = (
512+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
513+
)
514+
page_aligned_token_ids = token_ids[:page_aligned_token_len]
515+
516+
old_prefix_len = len(req.prefix_indices)
517+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
518+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
519+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
520+
old_prefix_len -= 1
473521

474522
# Radix Cache takes one ref in memory pool
475523
# Note: the insert function already frees the overlapped kv_indices
476524
new_prefix_len = self.insert(
477525
RadixKey(page_aligned_token_ids, req.extra_key),
478526
page_aligned_kv_indices,
479-
len(req.prefix_indices),
527+
old_prefix_len,
480528
)
481529

482530
# The prefix indices could be updated, reuse it
483531
new_indices, new_last_node, _, _ = self.match_prefix(
484532
RadixKey(page_aligned_token_ids, req.extra_key)
485533
)
486-
assert len(req.prefix_indices) <= len(
534+
assert old_prefix_len <= len(
487535
new_indices
488536
), f"{req.prefix_indices=}, {new_indices=}"
489537
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
490538
self.req_to_token_pool.write(
491-
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
492-
new_indices[len(req.prefix_indices) :],
539+
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
540+
new_indices[old_prefix_len:],
493541
)
494542

543+
req.last_matched_prefix_len = len(new_indices)
544+
495545
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
496546
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
497547

@@ -501,7 +551,13 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
501551
[new_indices, kv_indices[len(new_indices) :]]
502552
)
503553
else:
504-
req.prefix_indices = new_indices
554+
if self.is_eagle:
555+
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
556+
req.prefix_indices = torch.cat(
557+
[new_indices, kv_indices[actual_kv_len:]]
558+
)
559+
else:
560+
req.prefix_indices = new_indices
505561
req.last_node = new_last_node
506562
req.swa_uuid_for_lock = swa_uuid_for_lock
507563

python/sglang/srt/models/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727

2828
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
2929
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
30-
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
30+
return (
31+
_is_cuda
32+
and hasattr(forward_batch.token_to_kv_pool, "dtype")
33+
and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
34+
)
3135

3236

3337
def create_fused_set_kv_buffer_arg(

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class TestFile:
113113
TestFile("test_srt_engine.py", 261),
114114
TestFile("test_srt_endpoint.py", 130),
115115
TestFile("test_start_profile.py", 60),
116+
TestFile("test_swa_unittest.py", 1),
116117
TestFile("test_torch_compile.py", 76),
117118
TestFile("test_torch_compile_moe.py", 172),
118119
TestFile("test_torch_native_attention_backend.py", 123),

0 commit comments

Comments
 (0)