-
Notifications
You must be signed in to change notification settings - Fork 30.9k
🚨🚨🚨 Optimize Top P Sampler and fix edge case #18984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨🚨🚨 Optimize Top P Sampler and fix edge case #18984
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, will get my approval when the test change also gets added to TF and FLAX 👍
@gante the proposed PT implementation passes the edge case. I also added the edge case locally and verified that the existing FLAX implementation passes the edge case with no change required in its implementation. However, the TF implementation passes the edge case when Can you please confirm if just replacing 0.7 with 0.8 in this test succeeds in your local machine?
|
I was investigating on TF's behavior and found this: This is the input distribution to the test:
The above goes to TFTopPLogitsWrapper which takes a cumsum here:
This
This is causing an extra sample to get be sampled in the 1st batch when How should we proceed forward? This issue of changing behavior is not there in PT and FLAX so should we go ahead with just PT and FLAX for this PR and raise this as a separate TF issue in transformers repo? |
@ekagra-ranjan we could add an if/else depending on whether However, since this edge case has such low impact in practice, it's okay if we take the simpler path and simply set P.S.: TF's softmax is known to have these minor numerical instabilities. |
@gante Thank you for your reviews! Edge case test for FLAX and TF have been added and are passing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍 Thanks for addressing all the comments!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't <=
always a bit dangerous with float values? I'm not sure we can assure 100% backward compatibility here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly worried about that we'll silently break someone's PyTorch generation code that uses top_p
by default here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten there is indeed a change at the edge case -- before, if top_p
was 0.8
and the input was [0.5, 0.3, 0.1, 0.1]
, the first three tokens would pass this filter, despite the first two summing up to 0.8 (and thus satisfying the top P conditions, according to the original paper and our docstrings).
The behavior in TF and FLAX satisfies the edge case above, while PT does not. In practice, the impact will be negligible (this change filters one additional token when the sum of the logits is exactly top_p
), although it can change seeded test cases.
Alternatively, we can change our docstrings (and TF+FLAX's implementation) to ignore this edge case :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten I believe you are referring to the floating point precision in the context of <=
being dangerous with float value. The Top P sampler intends to pick minimum elements which have cumulative dist >= top_p
. So either we use the equality while selecting the mask or ignore it and then shift the mask to right/left.
The proposed PT implementation uses <=
but it can be implemented in the same manner as TF and FLAX which do not have the equality operator explicitly but will need to clone a tensor and shifting values to right/left. This however will not prevent the issue of floating point precision.
E.g., if we take input as [0.5, 0.3, 0.1, 0.1]
and top_p
as 0.8
then according to this:
transformers/src/transformers/generation_flax_logits_process.py
Lines 142 to 146 in 693ba2c
score_mask = cumulative_probs < self.top_p | |
# include the token that is higher than top_p as well | |
score_mask = jnp.roll(score_mask, 1) | |
score_mask |= score_mask.at[:, 0].set(True) |
the cumulative_probs
could be [0.5, 0.79995, 0.1, 0.1] due to floating point precision which will lead to Top P sampler picking 1st three elements instead of 1st two even though there is no equality operator involved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanations @gante and @ekagra-ranjan - this makes sense to me!
Given the very high usage of generate
and top_p
we need to clearly mark this as a "breaking behavior bug fix" with 🚨🚨🚨 in the PR description and also make sure it's mentioned in our release notes (cc @LysandreJik )
But good for merge then for me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just make sure users that have their generation pipeline give different results find this PR here as the explanation. It's a backwards breaking bug change for me that however might affect quite some generation pipelines. IMO it's ok to merge with a big warning - wdyt @LysandreJik @sgugger ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a problem with fixing the behavior to match the documentation.
For the PR before merging, could you?
- run
make style
to fix the quality issue - rebase on main, which should take care of the tests failures we see
6f25001
to
8854977
Compare
@sgugger Sure, done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Will make sure this is very visible in the release notes, thank you for the 🚨
What does this PR do?
This PR does the following:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante @patrickvonplaten