Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,18 @@ def __init__(self, min_length: int, eos_token_id: int):
self.min_length = min_length
self.eos_token_id = eos_token_id

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
# generate is not XLA - compileable anyways
if cur_len < self.min_length:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
return scores

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# applies eos token masking if the first argument is true
scores = tf.cond(
tf.less(cur_len, self.min_length),
lambda: self._apply_eos_token_mask(scores),
lambda: tf.identity(scores),
Comment on lines +227 to +228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lambda: self._apply_eos_token_mask(scores),
lambda: tf.identity(scores),
self._apply_eos_token_mask(scores),
scores,

Would this work without the lambdas and identity call? I feel like it should but I'm not sure if I'm missing something obvious.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it doesn't. Super unintuitive, but tf.cond expects a callable, not the output of each branch 😬 (docs)

It fails if we remove the lambda.

)
return scores


Expand Down
16 changes: 13 additions & 3 deletions tests/generation/test_generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np

from parameterized import parameterized
from transformers import is_tf_available
from transformers.testing_utils import require_tf

Expand Down Expand Up @@ -47,12 +48,15 @@ def _get_uniform_logits(self, batch_size: int, length: int):
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
return scores

def test_min_length_dist_processor(self):
@parameterized.expand([(False,), (True,)])
def test_min_length_dist_processor(self, use_xla):
vocab_size = 20
batch_size = 4
eos_token_id = 0

min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
if use_xla:
min_dist_processor = tf.function(min_dist_processor, jit_compile=True)

# check that min length is applied at length 5
cur_len = 5
Expand Down Expand Up @@ -256,12 +260,15 @@ def test_no_bad_words_dist_processor(self):
[[True, True, False, True, True], [True, True, True, False, True]],
)

def test_forced_bos_token_logits_processor(self):
@parameterized.expand([(False,), (True,)])
def test_forced_bos_token_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
bos_token_id = 0

logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)

# check that all scores are -inf except the bos_token_id score
cur_len = 1
Expand All @@ -280,13 +287,16 @@ def test_forced_bos_token_logits_processor(self):
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))

def test_forced_eos_token_logits_processor(self):
@parameterized.expand([(False,), (True,)])
def test_forced_eos_token_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5

logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)

# check that all scores are -inf except the eos_token_id when max_length-1 is reached
cur_len = 4
Expand Down