Skip to content

Conversation

ispobock
Copy link
Collaborator

@ispobock ispobock commented Oct 5, 2025

Motivation

follow-up of #10846

Accept Length Test

start server with this line commented:

python3 -m sglang.launch_server --model openai/gpt-oss-20b --attention-backend triton --speculative-algorithm EAGLE3 --speculative-draft-model-path zhuyksir/EAGLE3-gpt-oss-20b-bf16 --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 3

run requests:

# requests to construct prefix tree
python3 -m sglang.test.send_one --prompt "You xxx" --max-new-tokens 3
python3 -m sglang.test.send_one --prompt "You are xxx" --max-new-tokens 3
python3 -m sglang.test.send_one --prompt "You are a xxx" --max-new-tokens 3
python3 -m sglang.test.send_one --prompt "You are a helpful xxx" --max-new-tokens 3
python3 -m sglang.test.send_one --prompt "You are a helpful, xxx" --max-new-tokens 3
python3 -m sglang.test.send_one --prompt "You are a helpful, respectful xxx" --max-new-tokens 3
python3 -m sglang.test.send_one --prompt "You are a helpful, respectful and xxx" --max-new-tokens 3
python3 -m sglang.test.send_one --prompt "You are a helpful, respectful and honest xxx" --max-new-tokens 3
python3 -m sglang.test.send_one --prompt "You are a helpful, respectful and honest assistant xxx" --max-new-tokens 3

# test request
python3 -m sglang.test.send_one --prompt "You are a helpful, respectful and honest assistant." --max-new-tokens 20

This PR w/ radix cache:

acc_length=1.33
speed=217.98 token/s

main w/ radix cache:

acc_length=1.18
speed=189.46 token/s

main w/o radix cache:

acc_length=1.33
speed=219.72 token/s

Copy link
Contributor

Summary of Changes

Hello @ispobock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces several enhancements and fixes to the SWARadixCache and its associated memory management components. The primary goal is to ensure correct and robust operation when the cache is utilized with the EAGLE algorithm. Key changes include refining how token keys are processed for EAGLE's bigram structure, enabling state backup and restoration for memory allocators, and improving type consistency across KV cache pools to prevent potential issues.

Highlights

  • EAGLE-specific Radix Cache Logic: Introduced an is_eagle flag and a _convert_to_bigram_key function to SWARadixCache to correctly adapt key handling for the EAGLE algorithm's bigram token representation.
  • State Management for Allocator: Implemented backup_state and restore_state methods in SWATokenToKVPoolAllocator to enable proper state persistence and recovery for the attention allocators.
  • Improved KV Cache Type Handling: Added a dtype parameter to the SWATokenToKVPool initialization, ensuring that the data type is correctly propagated to the underlying SWA and full KV pools.
  • Refined Cache Operations for EAGLE: Adjusted cache_finished_req and cache_unfinished_req in SWARadixCache to accurately manage token and KV index lengths when operating with EAGLE's bigram keys, including specific handling for chunked prefill scenarios.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for the EAGLE speculative decoding algorithm within the SWARadixCache. The changes primarily adapt the cache to handle bigram keys used by EAGLE, which involves modifying key conversion, prefix matching, and insertion logic. Supporting changes include plumbing the is_eagle flag and dtype through various components. My review identifies a potential bug in the insert method when handling None values for EAGLE, and a maintainability concern regarding in-place modification of RadixKey objects which violates its type hint. Overall, the changes are logical and address the caching issue for EAGLE.

Comment on lines 421 to 422
if value is None:
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic can lead to a bug when is_eagle is true. In that case, key.token_ids becomes a list of tuples (bigrams) after self.key_convert_fn. torch.tensor on a list of tuples will create a 2D tensor, while the rest of the code expects value to be a 1D tensor of KV cache indices. Although this code path (value is None) might not be triggered in production, it's a latent bug that can affect tests or future use cases.

I suggest raising an error for this case or creating a dummy 1D tensor if it's needed for tests.

Suggested change
if value is None:
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
if value is None:
if self.is_eagle:
# This path is not expected in production for EAGLE.
# The value should be a 1D tensor of indices, but creating it from bigram keys is ambiguous.
raise NotImplementedError("insert with value=None is not supported for EAGLE mode.")
value = torch.tensor(key.token_ids, dtype=torch.int64)

The last node create a new child if the prefix is shorter
than the last node's value.
"""
key.token_ids = self.key_convert_fn(key.token_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line modifies the input key object in-place, which can be an unexpected side effect for callers. Additionally, when is_eagle is true, key.token_ids is converted from List[int] to List[Tuple[int, int]], which violates the type hint in the RadixKey class definition (List[int]). This makes the code harder to understand and maintain.

While creating a new RadixKey object might have performance implications, it would be safer. A less disruptive change would be to update the type hint for RadixKey.token_ids to List[Union[int, Tuple[int, int]]] and add a comment here explaining the in-place modification.

Copy link
Collaborator

@hanming-lu hanming-lu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the high level makes sense. Finding it not straightforward to fully understand the +1 and -1 logics and their implications. Was there a doc for the radix cache changes?


token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have more comments on the reason behind this -1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. (len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)
So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.

Comment on lines 458 to 465
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)

old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wishing for more comments here 🙏 The +1 and -1 logics and their implications are not straightforward to fully understand :(

Copy link
Collaborator Author

@ispobock ispobock Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the chunked prefill case, the chunked kv should be cached in the Radix cache. But in EAGLE case, the last token will not be inserted into the tree due to the shorter length of bigram key. But we still add it to req.prefix_indices (ref), since the kv is still in the sequence. Here we do old_prefix_len - 1 to just make sure the additional kv should be freed correctly, or we will get the memory leak.

@ispobock ispobock merged commit 24bc3fb into sgl-project:main Oct 7, 2025
61 of 64 checks passed
ch-tiger1 pushed a commit to ch-tiger1/sglang that referenced this pull request Oct 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants