Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ def init_memory_pool_and_cache(self):
sliding_window_size=self.sliding_window_size,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(),
)
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
Expand Down
9 changes: 7 additions & 2 deletions python/sglang/srt/mem_cache/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,15 @@ def free_swa(self, free_index: torch.Tensor):
self.full_to_swa_index_mapping[free_index] = 0

def backup_state(self):
raise NotImplementedError
return [
self.full_attn_allocator.backup_state(),
self.swa_attn_allocator.backup_state(),
]

def restore_state(self, state):
raise NotImplementedError
assert len(state) == 2
self.full_attn_allocator.restore_state(state[0])
self.swa_attn_allocator.restore_state(state[1])

def clear(self):
self.swa_attn_allocator.clear()
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def __init__(
self,
size: int,
size_swa: int,
dtype: torch.dtype,
swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
Expand All @@ -755,6 +756,7 @@ def __init__(
):
self.size = size
self.size_swa = size_swa
self.dtype = dtype
self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids)
kwargs["page_size"] = 1
Expand All @@ -764,11 +766,13 @@ def __init__(

self.swa_kv_pool = token_to_kv_pool_class(
size=size_swa,
dtype=dtype,
layer_num=self.swa_layer_nums,
**kwargs,
)
self.full_kv_pool = token_to_kv_pool_class(
size=size,
dtype=dtype,
layer_num=self.full_layer_nums,
**kwargs,
)
Expand Down
79 changes: 65 additions & 14 deletions python/sglang/srt/mem_cache/swa_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import (
RadixKey,
_convert_to_bigram_key,
_key_match_page_size1,
_key_match_paged,
get_child_key,
Expand Down Expand Up @@ -327,12 +328,14 @@ def __init__(
sliding_window_size: int,
page_size: int,
disable: bool = False,
is_eagle: bool = False,
):
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
self.is_eagle = is_eagle

if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
Expand All @@ -346,6 +349,11 @@ def __init__(
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = partial(get_child_key, page_size=page_size)

if is_eagle:
self.key_convert_fn = _convert_to_bigram_key
else:
self.key_convert_fn = lambda key: key

self.sliding_window_size = sliding_window_size
self.reset()

Expand Down Expand Up @@ -376,6 +384,8 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
The last node create a new child if the prefix is shorter
than the last node's value.
"""
key.token_ids = self.key_convert_fn(key.token_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line modifies the input key object in-place, which can be an unexpected side effect for callers. Additionally, when is_eagle is true, key.token_ids is converted from List[int] to List[Tuple[int, int]], which violates the type hint in the RadixKey class definition (List[int]). This makes the code harder to understand and maintain.

While creating a new RadixKey object might have performance implications, it would be safer. A less disruptive change would be to update the type hint for RadixKey.token_ids to List[Union[int, Tuple[int, int]]] and add a comment here explaining the in-place modification.


if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
Expand Down Expand Up @@ -406,8 +416,15 @@ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
if self.disable:
return 0

key.token_ids = self.key_convert_fn(key.token_ids)

if value is None:
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
Comment on lines 421 to 422
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic can lead to a bug when is_eagle is true. In that case, key.token_ids becomes a list of tuples (bigrams) after self.key_convert_fn. torch.tensor on a list of tuples will create a 2D tensor, while the rest of the code expects value to be a 1D tensor of KV cache indices. Although this code path (value is None) might not be triggered in production, it's a latent bug that can affect tests or future use cases.

I suggest raising an error for this case or creating a dummy 1D tensor if it's needed for tests.

Suggested change
if value is None:
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
if value is None:
if self.is_eagle:
# This path is not expected in production for EAGLE.
# The value should be a 1D tensor of indices, but creating it from bigram keys is ambiguous.
raise NotImplementedError("insert with value=None is not supported for EAGLE mode.")
value = torch.tensor(key.token_ids, dtype=torch.int64)


if self.is_eagle:
# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]

return self._insert_helper(self.root_node, key, value, prev_prefix_len)

def cache_finished_req(self, req: Req) -> None:
Expand All @@ -422,25 +439,38 @@ def cache_finished_req(self, req: Req) -> None:
return

token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have more comments on the reason behind this -1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. (len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)
So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.

kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]

if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])

page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)

old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1
Comment on lines 460 to 468
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wishing for more comments here 🙏 The +1 and -1 logics and their implications are not straightforward to fully understand :(

Copy link
Collaborator Author

@ispobock ispobock Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the chunked prefill case, the chunked kv should be cached in the Radix cache. But in EAGLE case, the last token will not be inserted into the tree due to the shorter length of bigram key. But we still add it to req.prefix_indices (ref), since the kv is still in the sequence. Here we do old_prefix_len - 1 to just make sure the additional kv should be freed correctly, or we will get the memory leak.


# Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert(
RadixKey(token_ids[:page_aligned_len], req.extra_key),
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
page_aligned_kv_indices,
len(req.prefix_indices),
old_prefix_len,
)

# Remove req slot release the cache lock
Expand All @@ -459,39 +489,54 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
return

token_ids = req.fill_ids
all_token_len = len(token_ids)
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]

if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
page_aligned_token_ids = token_ids[:page_aligned_len]

# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
page_aligned_token_ids = token_ids[:page_aligned_token_len]

old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1

# Radix Cache takes one ref in memory pool
# Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert(
RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
len(req.prefix_indices),
old_prefix_len,
)

# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
)
assert len(req.prefix_indices) <= len(
assert old_prefix_len <= len(
new_indices
), f"{req.prefix_indices=}, {new_indices=}"
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
new_indices[old_prefix_len:],
)

req.last_matched_prefix_len = len(new_indices)

self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)

Expand All @@ -501,7 +546,13 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
if self.is_eagle:
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
req.prefix_indices = torch.cat(
[new_indices, kv_indices[actual_kv_len:]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
req.swa_uuid_for_lock = swa_uuid_for_lock

Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
return (
_is_cuda
and hasattr(forward_batch.token_to_kv_pool, "dtype")
and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
)


def create_fused_set_kv_buffer_arg(
Expand Down
Loading