-
Notifications
You must be signed in to change notification settings - Fork 3.1k
EAGLE cache fix for SWARadixCache #11231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Hanming Lu <[email protected]>
Summary of ChangesHello @ispobock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces several enhancements and fixes to the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fix for the EAGLE speculative decoding algorithm within the SWARadixCache
. The changes primarily adapt the cache to handle bigram keys used by EAGLE, which involves modifying key conversion, prefix matching, and insertion logic. Supporting changes include plumbing the is_eagle
flag and dtype
through various components. My review identifies a potential bug in the insert
method when handling None
values for EAGLE, and a maintainability concern regarding in-place modification of RadixKey
objects which violates its type hint. Overall, the changes are logical and address the caching issue for EAGLE.
if value is None: | ||
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic can lead to a bug when is_eagle
is true. In that case, key.token_ids
becomes a list of tuples (bigrams) after self.key_convert_fn
. torch.tensor
on a list of tuples will create a 2D tensor, while the rest of the code expects value
to be a 1D tensor of KV cache indices. Although this code path (value is None
) might not be triggered in production, it's a latent bug that can affect tests or future use cases.
I suggest raising an error for this case or creating a dummy 1D tensor if it's needed for tests.
if value is None: | |
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) | |
if value is None: | |
if self.is_eagle: | |
# This path is not expected in production for EAGLE. | |
# The value should be a 1D tensor of indices, but creating it from bigram keys is ambiguous. | |
raise NotImplementedError("insert with value=None is not supported for EAGLE mode.") | |
value = torch.tensor(key.token_ids, dtype=torch.int64) |
The last node create a new child if the prefix is shorter | ||
than the last node's value. | ||
""" | ||
key.token_ids = self.key_convert_fn(key.token_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line modifies the input key
object in-place, which can be an unexpected side effect for callers. Additionally, when is_eagle
is true, key.token_ids
is converted from List[int]
to List[Tuple[int, int]]
, which violates the type hint in the RadixKey
class definition (List[int]
). This makes the code harder to understand and maintain.
While creating a new RadixKey
object might have performance implications, it would be safer. A less disruptive change would be to update the type hint for RadixKey.token_ids
to List[Union[int, Tuple[int, int]]]
and add a comment here explaining the in-place modification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the high level makes sense. Finding it not straightforward to fully understand the +1 and -1 logics and their implications. Was there a doc for the radix cache changes?
|
||
token_ids = (req.origin_input_ids + req.output_ids)[:-1] | ||
all_token_len = len(token_ids) | ||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have more comments on the reason behind this -1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)]
, the length will -1. (len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1
)
So for the corresponding kv length should also -1. Then we get the actual_kv_len
, and use it to do later calculation and slicing.
page_aligned_token_len = ( | ||
page_aligned_len + 1 if self.is_eagle else page_aligned_len | ||
) | ||
|
||
old_prefix_len = len(req.prefix_indices) | ||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: | ||
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) | ||
old_prefix_len -= 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wishing for more comments here 🙏 The +1 and -1 logics and their implications are not straightforward to fully understand :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the chunked prefill case, the chunked kv should be cached in the Radix cache. But in EAGLE case, the last token will not be inserted into the tree due to the shorter length of bigram key. But we still add it to req.prefix_indices
(ref), since the kv is still in the sequence. Here we do old_prefix_len - 1
to just make sure the additional kv should be freed correctly, or we will get the memory leak.
Co-authored-by: Hanming Lu <[email protected]>
Motivation
follow-up of #10846
Accept Length Test
start server with this line commented:
run requests:
This PR w/ radix cache:
main w/ radix cache:
main w/o radix cache: