-
Notifications
You must be signed in to change notification settings - Fork 30.9k
🚨🚨🚨 Optimize Top P Sampler and fix edge case #18984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
LysandreJik
merged 8 commits into
huggingface:main
from
ekagra-ranjan:fix/optimize-top-p-warper
Sep 15, 2022
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
f25b527
init PR
ekagra-ranjan 7dd916a
optimize top p and add edge case
ekagra-ranjan b1cff0a
styling
ekagra-ranjan 485303e
style
ekagra-ranjan 644b9c2
revert tf and flax test
ekagra-ranjan 9b2ba23
add edge case test for FLAX and TF
ekagra-ranjan 2457d47
update doc with smallest set sampling for top p
ekagra-ranjan 8854977
make style
ekagra-ranjan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't
<=always a bit dangerous with float values? I'm not sure we can assure 100% backward compatibility hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly worried about that we'll silently break someone's PyTorch generation code that uses
top_pby default hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten there is indeed a change at the edge case -- before, if
top_pwas0.8and the input was[0.5, 0.3, 0.1, 0.1], the first three tokens would pass this filter, despite the first two summing up to 0.8 (and thus satisfying the top P conditions, according to the original paper and our docstrings).The behavior in TF and FLAX satisfies the edge case above, while PT does not. In practice, the impact will be negligible (this change filters one additional token when the sum of the logits is exactly
top_p), although it can change seeded test cases.Alternatively, we can change our docstrings (and TF+FLAX's implementation) to ignore this edge case :D
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten I believe you are referring to the floating point precision in the context of
<=being dangerous with float value. The Top P sampler intends to pick minimum elements which have cumulative dist>= top_p. So either we use the equality while selecting the mask or ignore it and then shift the mask to right/left.The proposed PT implementation uses
<=but it can be implemented in the same manner as TF and FLAX which do not have the equality operator explicitly but will need to clone a tensor and shifting values to right/left. This however will not prevent the issue of floating point precision.E.g., if we take input as
[0.5, 0.3, 0.1, 0.1]andtop_pas0.8then according to this:transformers/src/transformers/generation_flax_logits_process.py
Lines 142 to 146 in 693ba2c
the
cumulative_probscould be [0.5, 0.79995, 0.1, 0.1] due to floating point precision which will lead to Top P sampler picking 1st three elements instead of 1st two even though there is no equality operator involved.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanations @gante and @ekagra-ranjan - this makes sense to me!
Given the very high usage of
generateandtop_pwe need to clearly mark this as a "breaking behavior bug fix" with 🚨🚨🚨 in the PR description and also make sure it's mentioned in our release notes (cc @LysandreJik )But good for merge then for me