-
Notifications
You must be signed in to change notification settings - Fork 3.2k
fix unexcepted answer in EAGLE mode #9252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @zyksir, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses an issue in the EAGLE mode where incorrect filtering of batch data led to "unexpected answers." The primary fix involves refining the filter_batch logic in eagle_utils.py to correctly handle different filtering scenarios (e.g., after verification versus after draft extension) by introducing a new parameter. Additionally, a filtering step is added in eagle_worker.py to remove finished requests before further processing, ensuring data consistency and preventing erroneous outputs.
Highlights
- eagle_utils.py filter_batch modification: The filter_batch method in eagle_utils.py has been updated to accept a has_been_filtered boolean flag. This allows for different filtering behaviors: slicing ([:len(new_indices)]) when the batch has already been partially filtered (e.g., during verification), and direct indexing ([new_indices]) when it has not (e.g., after draft extend). This prevents loss of useful information or retention of bad request data.
- eagle_worker.py batch filtering: A new filtering mechanism has been introduced in eagle_worker.py within the forward_draft_extend function. This ensures that finished requests are properly identified and removed from the batch before the filter_batch method is called, specifically for cases where the batch has not yet been filtered by unfinished_index. This resolves issues where invalid JSON schemas or finished requests could lead to incorrect topk_p states.
- Bug Fix for EAGLE Mode: This PR directly addresses and fixes issue #8671, which reported "unexpected answers" in EAGLE mode. The changes ensure that the topk_p and related tensors are correctly managed and filtered throughout the speculative decoding process, leading to more accurate and reliable model outputs.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an issue with unexpected answers in EAGLE mode by correctly filtering finished requests after the draft_extend step. The main change involves adding a new filtering logic in forward_draft_extend and modifying EagleDraftInput.filter_batch to support it with a new has_been_filtered flag. The fix appears correct and targeted at the described problem. I have one suggestion to make the code for identifying finished requests more concise.
| has_finished, unfinished_req_index = False, [] | ||
| for i, req in enumerate(batch.reqs): | ||
| if req.finished(): | ||
| has_finished = True | ||
| else: | ||
| unfinished_req_index.append(i) | ||
| if has_finished: | ||
| unfinished_index_device = torch.tensor( | ||
| unfinished_req_index, dtype=torch.int64, device=batch.spec_info.topk_p.device | ||
| ) | ||
| batch.spec_info.filter_batch(unfinished_index_device, has_been_filtered=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for finding and filtering finished requests can be made more concise and Pythonic using a list comprehension. This improves readability and reduces the number of lines.
unfinished_req_index = [
i for i, req in enumerate(batch.reqs) if not req.finished()
]
if len(unfinished_req_index) < len(batch.reqs):
unfinished_index_device = torch.tensor(
unfinished_req_index, dtype=torch.int64, device=batch.spec_info.topk_p.device
)
batch.spec_info.filter_batch(unfinished_index_device, has_been_filtered=False)
Motivation
This PR is to fix #8671
Modifications
In the
filter_batch, previous we hadself.topk_p = self.topk_p[: len(new_indices)], This line should never be used to filter since it makes no meaning iflen(new_indices) !=len(self.topk_p)self.topk_p = self.topk_p[new_indices]. Since after verification,len(topk_p)might not be equal tolen(new_indices).self.topk_p = self.topk_p[: len(new_indices)]will cause the useful info to be filtered, and keep the info that belongs to the bad request.My Modification is to add one filter after extend. This should minimize the modification.
Accuracy Tests
The test script can be found in #8671


Before the modification:
After the modification:
Benchmarking and Profiling
Checklist