Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/generation_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):

Args:
top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
for generation.
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Expand Down
15 changes: 6 additions & 9 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ class TopPLogitsWarper(LogitsWarper):

Args:
top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
for generation.
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Expand All @@ -191,17 +191,14 @@ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't <= always a bit dangerous with float values? I'm not sure we can assure 100% backward compatibility here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly worried about that we'll silently break someone's PyTorch generation code that uses top_p by default here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten there is indeed a change at the edge case -- before, if top_p was 0.8 and the input was [0.5, 0.3, 0.1, 0.1], the first three tokens would pass this filter, despite the first two summing up to 0.8 (and thus satisfying the top P conditions, according to the original paper and our docstrings).

The behavior in TF and FLAX satisfies the edge case above, while PT does not. In practice, the impact will be negligible (this change filters one additional token when the sum of the logits is exactly top_p), although it can change seeded test cases.

Alternatively, we can change our docstrings (and TF+FLAX's implementation) to ignore this edge case :D

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten I believe you are referring to the floating point precision in the context of <= being dangerous with float value. The Top P sampler intends to pick minimum elements which have cumulative dist >= top_p. So either we use the equality while selecting the mask or ignore it and then shift the mask to right/left.

The proposed PT implementation uses <= but it can be implemented in the same manner as TF and FLAX which do not have the equality operator explicitly but will need to clone a tensor and shifting values to right/left. This however will not prevent the issue of floating point precision.

E.g., if we take input as [0.5, 0.3, 0.1, 0.1] and top_p as 0.8 then according to this:

score_mask = cumulative_probs < self.top_p
# include the token that is higher than top_p as well
score_mask = jnp.roll(score_mask, 1)
score_mask |= score_mask.at[:, 0].set(True)

the cumulative_probs could be [0.5, 0.79995, 0.1, 0.1] due to floating point precision which will lead to Top P sampler picking 1st three elements instead of 1st two even though there is no equality operator involved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanations @gante and @ekagra-ranjan - this makes sense to me!

Given the very high usage of generate and top_p we need to clearly mark this as a "breaking behavior bug fix" with 🚨🚨🚨 in the PR description and also make sure it's mentioned in our release notes (cc @LysandreJik )

But good for merge then for me

if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):

Args:
top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
for generation.
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,8 +990,8 @@ def generate(
top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
`top_p` or higher are kept for generation.
typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
The amount of probability mass from the original distribution to be considered in typical decoding. If
set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
Expand Down
4 changes: 2 additions & 2 deletions tests/generation/test_generation_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def test_top_p_dist_warper(self):
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))

top_p_warp = FlaxTopPLogitsWarper(0.7)
top_p_warp = FlaxTopPLogitsWarper(0.8)
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))

# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
Expand Down
4 changes: 2 additions & 2 deletions tests/generation/test_generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def test_top_p_dist_warper(self):
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
)

top_p_warp = TopPLogitsWarper(0.7)
top_p_warp = TopPLogitsWarper(0.8)
filtered_dist = torch.exp(top_p_warp(input_ids, dist))

# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
Expand Down
7 changes: 5 additions & 2 deletions tests/generation/test_generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,15 @@ def test_top_p_dist_warper(self, use_xla):
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))

top_p_warp = TFTopPLogitsWarper(0.7)
# top_p should have been 0.8 to test the edge case of top_p being exactly equal to sum of some token prob
# However, due to the numerical instability of softmax in TF we choose this as the edge case
# top_p as 0.8 passes when use_xla is True and fails when False. Refer PR #18984.
top_p_warp = TFTopPLogitsWarper(0.79999995)
if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True)
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))

# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
Expand Down