-
Notifications
You must be signed in to change notification settings - Fork 3.4k
EAGLE cache fix for SWARadixCache #11231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
2bdd5a6
0a877e1
7ee7c3f
8d452b8
42ddcdc
d8c07bb
fe9b29d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -32,6 +32,7 @@ | |||||||||||||||||
| from sglang.srt.mem_cache.memory_pool import ReqToTokenPool | ||||||||||||||||||
| from sglang.srt.mem_cache.radix_cache import ( | ||||||||||||||||||
| RadixKey, | ||||||||||||||||||
| _convert_to_bigram_key, | ||||||||||||||||||
| _key_match_page_size1, | ||||||||||||||||||
| _key_match_paged, | ||||||||||||||||||
| get_child_key, | ||||||||||||||||||
|
|
@@ -327,12 +328,14 @@ def __init__( | |||||||||||||||||
| sliding_window_size: int, | ||||||||||||||||||
| page_size: int, | ||||||||||||||||||
| disable: bool = False, | ||||||||||||||||||
| is_eagle: bool = False, | ||||||||||||||||||
| ): | ||||||||||||||||||
| assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) | ||||||||||||||||||
| self.req_to_token_pool = req_to_token_pool | ||||||||||||||||||
| self.token_to_kv_pool_allocator = token_to_kv_pool_allocator | ||||||||||||||||||
| self.page_size = page_size | ||||||||||||||||||
| self.disable = disable | ||||||||||||||||||
| self.is_eagle = is_eagle | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.token_to_kv_pool_allocator: | ||||||||||||||||||
| self.device = self.token_to_kv_pool_allocator.device | ||||||||||||||||||
|
|
@@ -346,6 +349,11 @@ def __init__( | |||||||||||||||||
| self.key_match_fn = partial(_key_match_paged, page_size=page_size) | ||||||||||||||||||
| self.get_child_key_fn = partial(get_child_key, page_size=page_size) | ||||||||||||||||||
|
|
||||||||||||||||||
| if is_eagle: | ||||||||||||||||||
| self.key_convert_fn = _convert_to_bigram_key | ||||||||||||||||||
| else: | ||||||||||||||||||
| self.key_convert_fn = lambda key: key | ||||||||||||||||||
|
|
||||||||||||||||||
| self.sliding_window_size = sliding_window_size | ||||||||||||||||||
| self.reset() | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -376,6 +384,8 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: | |||||||||||||||||
| The last node create a new child if the prefix is shorter | ||||||||||||||||||
| than the last node's value. | ||||||||||||||||||
| """ | ||||||||||||||||||
| key.token_ids = self.key_convert_fn(key.token_ids) | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.disable or len(key) == 0: | ||||||||||||||||||
| return MatchResult( | ||||||||||||||||||
| device_indices=torch.empty( | ||||||||||||||||||
|
|
@@ -406,8 +416,15 @@ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int: | |||||||||||||||||
| if self.disable: | ||||||||||||||||||
| return 0 | ||||||||||||||||||
|
|
||||||||||||||||||
| key.token_ids = self.key_convert_fn(key.token_ids) | ||||||||||||||||||
|
|
||||||||||||||||||
| if value is None: | ||||||||||||||||||
| value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) | ||||||||||||||||||
|
Comment on lines
421
to
422
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic can lead to a bug when I suggest raising an error for this case or creating a dummy 1D tensor if it's needed for tests.
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| if self.is_eagle: | ||||||||||||||||||
| # Make sure the value len equal to the EAGLE bigram key len | ||||||||||||||||||
| value = value[: len(key)] | ||||||||||||||||||
|
|
||||||||||||||||||
| return self._insert_helper(self.root_node, key, value, prev_prefix_len) | ||||||||||||||||||
|
|
||||||||||||||||||
| def cache_finished_req(self, req: Req) -> None: | ||||||||||||||||||
|
|
@@ -422,25 +439,38 @@ def cache_finished_req(self, req: Req) -> None: | |||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| token_ids = (req.origin_input_ids + req.output_ids)[:-1] | ||||||||||||||||||
| all_token_len = len(token_ids) | ||||||||||||||||||
| actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have more comments on the reason behind this -1?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we convert the key to bigram key, e.g. |
||||||||||||||||||
| kv_indices = self.req_to_token_pool.req_to_token[ | ||||||||||||||||||
| req.req_pool_idx, : len(token_ids) | ||||||||||||||||||
| req.req_pool_idx, :all_token_len | ||||||||||||||||||
| ] | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.page_size != 1: | ||||||||||||||||||
| page_aligned_len = len(kv_indices) // self.page_size * self.page_size | ||||||||||||||||||
| page_aligned_len = actual_kv_len // self.page_size * self.page_size | ||||||||||||||||||
| page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() | ||||||||||||||||||
| self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) | ||||||||||||||||||
| else: | ||||||||||||||||||
| page_aligned_len = len(kv_indices) | ||||||||||||||||||
| page_aligned_len = actual_kv_len | ||||||||||||||||||
| page_aligned_kv_indices = kv_indices.clone() | ||||||||||||||||||
| if self.is_eagle: | ||||||||||||||||||
| self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) | ||||||||||||||||||
|
|
||||||||||||||||||
| page_aligned_token_len = ( | ||||||||||||||||||
| page_aligned_len + 1 if self.is_eagle else page_aligned_len | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| old_prefix_len = len(req.prefix_indices) | ||||||||||||||||||
| if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: | ||||||||||||||||||
| # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) | ||||||||||||||||||
| old_prefix_len -= 1 | ||||||||||||||||||
|
Comment on lines
460
to
468
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wishing for more comments here 🙏 The +1 and -1 logics and their implications are not straightforward to fully understand :(
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the chunked prefill case, the chunked kv should be cached in the Radix cache. But in EAGLE case, the last token will not be inserted into the tree due to the shorter length of bigram key. But we still add it to |
||||||||||||||||||
|
|
||||||||||||||||||
| # Radix Cache takes one ref in memory pool | ||||||||||||||||||
| # insert the token_ids and kv_indices into the radix tree | ||||||||||||||||||
| # Note: the insert function already frees the overlapped kv_indices | ||||||||||||||||||
| new_prefix_len = self.insert( | ||||||||||||||||||
| RadixKey(token_ids[:page_aligned_len], req.extra_key), | ||||||||||||||||||
| RadixKey(token_ids[:page_aligned_token_len], req.extra_key), | ||||||||||||||||||
| page_aligned_kv_indices, | ||||||||||||||||||
| len(req.prefix_indices), | ||||||||||||||||||
| old_prefix_len, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Remove req slot release the cache lock | ||||||||||||||||||
|
|
@@ -459,39 +489,54 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: | |||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| token_ids = req.fill_ids | ||||||||||||||||||
| all_token_len = len(token_ids) | ||||||||||||||||||
| # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key | ||||||||||||||||||
| actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len | ||||||||||||||||||
| kv_indices = self.req_to_token_pool.req_to_token[ | ||||||||||||||||||
| req.req_pool_idx, : len(token_ids) | ||||||||||||||||||
| req.req_pool_idx, :all_token_len | ||||||||||||||||||
| ] | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.page_size != 1: | ||||||||||||||||||
| page_aligned_len = len(kv_indices) // self.page_size * self.page_size | ||||||||||||||||||
| page_aligned_len = actual_kv_len // self.page_size * self.page_size | ||||||||||||||||||
| page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() | ||||||||||||||||||
| else: | ||||||||||||||||||
| page_aligned_len = len(kv_indices) | ||||||||||||||||||
| page_aligned_len = actual_kv_len | ||||||||||||||||||
| page_aligned_kv_indices = kv_indices.clone() | ||||||||||||||||||
| page_aligned_token_ids = token_ids[:page_aligned_len] | ||||||||||||||||||
|
|
||||||||||||||||||
| # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1 | ||||||||||||||||||
| page_aligned_token_len = ( | ||||||||||||||||||
| page_aligned_len + 1 if self.is_eagle else page_aligned_len | ||||||||||||||||||
| ) | ||||||||||||||||||
| page_aligned_token_ids = token_ids[:page_aligned_token_len] | ||||||||||||||||||
|
|
||||||||||||||||||
| old_prefix_len = len(req.prefix_indices) | ||||||||||||||||||
| if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: | ||||||||||||||||||
| # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) | ||||||||||||||||||
| old_prefix_len -= 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| # Radix Cache takes one ref in memory pool | ||||||||||||||||||
| # Note: the insert function already frees the overlapped kv_indices | ||||||||||||||||||
| new_prefix_len = self.insert( | ||||||||||||||||||
| RadixKey(page_aligned_token_ids, req.extra_key), | ||||||||||||||||||
| page_aligned_kv_indices, | ||||||||||||||||||
| len(req.prefix_indices), | ||||||||||||||||||
| old_prefix_len, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # The prefix indices could be updated, reuse it | ||||||||||||||||||
| new_indices, new_last_node, _, _ = self.match_prefix( | ||||||||||||||||||
| RadixKey(page_aligned_token_ids, req.extra_key) | ||||||||||||||||||
| ) | ||||||||||||||||||
| assert len(req.prefix_indices) <= len( | ||||||||||||||||||
| assert old_prefix_len <= len( | ||||||||||||||||||
| new_indices | ||||||||||||||||||
| ), f"{req.prefix_indices=}, {new_indices=}" | ||||||||||||||||||
| assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}" | ||||||||||||||||||
| self.req_to_token_pool.write( | ||||||||||||||||||
| (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), | ||||||||||||||||||
| new_indices[len(req.prefix_indices) :], | ||||||||||||||||||
| (req.req_pool_idx, slice(old_prefix_len, len(new_indices))), | ||||||||||||||||||
| new_indices[old_prefix_len:], | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| req.last_matched_prefix_len = len(new_indices) | ||||||||||||||||||
|
|
||||||||||||||||||
| self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) | ||||||||||||||||||
| swa_uuid_for_lock = self.inc_lock_ref(new_last_node) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -501,7 +546,13 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: | |||||||||||||||||
| [new_indices, kv_indices[len(new_indices) :]] | ||||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| req.prefix_indices = new_indices | ||||||||||||||||||
| if self.is_eagle: | ||||||||||||||||||
| # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill | ||||||||||||||||||
| req.prefix_indices = torch.cat( | ||||||||||||||||||
| [new_indices, kv_indices[actual_kv_len:]] | ||||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| req.prefix_indices = new_indices | ||||||||||||||||||
| req.last_node = new_last_node | ||||||||||||||||||
| req.swa_uuid_for_lock = swa_uuid_for_lock | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line modifies the input
keyobject in-place, which can be an unexpected side effect for callers. Additionally, whenis_eagleis true,key.token_idsis converted fromList[int]toList[Tuple[int, int]], which violates the type hint in theRadixKeyclass definition (List[int]). This makes the code harder to understand and maintain.While creating a new
RadixKeyobject might have performance implications, it would be safer. A less disruptive change would be to update the type hint forRadixKey.token_idstoList[Union[int, Tuple[int, int]]]and add a comment here explaining the in-place modification.