Skip to content

Conversation

ekagra-ranjan
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan commented Sep 12, 2022

What does this PR do?

This PR does the following:

  1. Fixes Top_P sampling samples an extra token when the cum sum of probabilities is exactly equal to top_p #18976
  2. Optimizes the Top P sampler Pytorch implementation by removing the need to clone an intermediate tensor and shifting things to right.
  3. Add edge case test to PT, TF, FLAX

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante @patrickvonplaten

@ekagra-ranjan ekagra-ranjan changed the title Optimize Top P Sampler Optimize Top P Sampler and fix edge case Sep 12, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 12, 2022

The documentation is not available anymore as the PR was closed or merged.

@ekagra-ranjan ekagra-ranjan changed the title Optimize Top P Sampler and fix edge case Optimize Top P Sampler and fix edge case (pytorch) Sep 12, 2022
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, will get my approval when the test change also gets added to TF and FLAX 👍

@ekagra-ranjan
Copy link
Contributor Author

ekagra-ranjan commented Sep 12, 2022

@gante the proposed PT implementation passes the edge case. I also added the edge case locally and verified that the existing FLAX implementation passes the edge case with no change required in its implementation.

However, the TF implementation passes the edge case when use_xla is True but fails when it is false in my local machine. Hence, I reverted the addition of edge case to TF and FLAX in my PR. It seems that the behavior changes when using xla for TF.

Can you please confirm if just replacing 0.7 with 0.8 in this test succeeds in your local machine?

top_p_warp = TFTopPLogitsWarper(0.7)

@ekagra-ranjan
Copy link
Contributor Author

ekagra-ranjan commented Sep 12, 2022

I was investigating on TF's behavior and found this:

This is the input distribution to the test:

dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))

The above goes to TFTopPLogitsWrapper which takes a cumsum here:

cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)

This cumulative_probs gets different value for use_xla as True or False in the unittest.

  1. When use_xla is True then cumulative_probs is [[0.5, 0.8, 0.90000004, 1.],
    [0.29999998, 0.59999996, 0.8499999 , 0.99999994]]
  2. When use_xla is False then cumulative_probs is [[0.5, 0.79999995, 0.9, 1. ],
    [0.3, 0.6, 0.85, 1. ]

This is causing an extra sample to get be sampled in the 1st batch when use_xla is False as 0.79999995 is < 0.8.

How should we proceed forward? This issue of changing behavior is not there in PT and FLAX so should we go ahead with just PT and FLAX for this PR and raise this as a separate TF issue in transformers repo?

@gante
Copy link
Member

gante commented Sep 12, 2022

@ekagra-ranjan we could add an if/else depending on whether use_xla is True or not, and set top_p to 0.8 or 0.79999995 accordingly.

However, since this edge case has such low impact in practice, it's okay if we take the simpler path and simply set top_p to 0.79999995. It won't test the edge case with XLA, but at least it is tested once (with eager execution, i.e. with use_xla=False).

P.S.: TF's softmax is known to have these minor numerical instabilities.

@ekagra-ranjan ekagra-ranjan changed the title Optimize Top P Sampler and fix edge case (pytorch) Optimize Top P Sampler and fix edge case Sep 12, 2022
@ekagra-ranjan ekagra-ranjan changed the title Optimize Top P Sampler and fix edge case Optimize Top P Sampler and fix edge case for Pytorch Sep 12, 2022
@ekagra-ranjan
Copy link
Contributor Author

@gante Thank you for your reviews! Edge case test for FLAX and TF have been added and are passing

@ekagra-ranjan ekagra-ranjan changed the title Optimize Top P Sampler and fix edge case for Pytorch Optimize Top P Sampler and fix edge case Sep 12, 2022
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍 Thanks for addressing all the comments!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't <= always a bit dangerous with float values? I'm not sure we can assure 100% backward compatibility here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly worried about that we'll silently break someone's PyTorch generation code that uses top_p by default here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten there is indeed a change at the edge case -- before, if top_p was 0.8 and the input was [0.5, 0.3, 0.1, 0.1], the first three tokens would pass this filter, despite the first two summing up to 0.8 (and thus satisfying the top P conditions, according to the original paper and our docstrings).

The behavior in TF and FLAX satisfies the edge case above, while PT does not. In practice, the impact will be negligible (this change filters one additional token when the sum of the logits is exactly top_p), although it can change seeded test cases.

Alternatively, we can change our docstrings (and TF+FLAX's implementation) to ignore this edge case :D

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten I believe you are referring to the floating point precision in the context of <= being dangerous with float value. The Top P sampler intends to pick minimum elements which have cumulative dist >= top_p. So either we use the equality while selecting the mask or ignore it and then shift the mask to right/left.

The proposed PT implementation uses <= but it can be implemented in the same manner as TF and FLAX which do not have the equality operator explicitly but will need to clone a tensor and shifting values to right/left. This however will not prevent the issue of floating point precision.

E.g., if we take input as [0.5, 0.3, 0.1, 0.1] and top_p as 0.8 then according to this:

score_mask = cumulative_probs < self.top_p
# include the token that is higher than top_p as well
score_mask = jnp.roll(score_mask, 1)
score_mask |= score_mask.at[:, 0].set(True)

the cumulative_probs could be [0.5, 0.79995, 0.1, 0.1] due to floating point precision which will lead to Top P sampler picking 1st three elements instead of 1st two even though there is no equality operator involved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanations @gante and @ekagra-ranjan - this makes sense to me!

Given the very high usage of generate and top_p we need to clearly mark this as a "breaking behavior bug fix" with 🚨🚨🚨 in the PR description and also make sure it's mentioned in our release notes (cc @LysandreJik )

But good for merge then for me

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just make sure users that have their generation pipeline give different results find this PR here as the explanation. It's a backwards breaking bug change for me that however might affect quite some generation pipelines. IMO it's ok to merge with a big warning - wdyt @LysandreJik @sgugger ?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a problem with fixing the behavior to match the documentation.

For the PR before merging, could you?

  1. run make style to fix the quality issue
  2. rebase on main, which should take care of the tests failures we see

@ekagra-ranjan ekagra-ranjan force-pushed the fix/optimize-top-p-warper branch from 6f25001 to 8854977 Compare September 14, 2022 20:52
@ekagra-ranjan ekagra-ranjan changed the title Optimize Top P Sampler and fix edge case 🚨🚨🚨 Optimize Top P Sampler and fix edge case Sep 14, 2022
@ekagra-ranjan ekagra-ranjan changed the title 🚨🚨🚨 Optimize Top P Sampler and fix edge case Optimize Top P Sampler and fix edge case Sep 14, 2022
@ekagra-ranjan ekagra-ranjan changed the title Optimize Top P Sampler and fix edge case 🚨🚨🚨 Optimize Top P Sampler and fix edge case Sep 14, 2022
@ekagra-ranjan
Copy link
Contributor Author

@sgugger Sure, done.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Will make sure this is very visible in the release notes, thank you for the 🚨

@LysandreJik LysandreJik merged commit 578e18e into huggingface:main Sep 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Top_P sampling samples an extra token when the cum sum of probabilities is exactly equal to top_p

6 participants