Skip to content

Commit 0c1d8f9

Browse files
committed
XLA min len, forced eos, and forced bos
1 parent 99c8226 commit 0c1d8f9

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

src/transformers/generation_tf_logits_process.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,18 @@ def __init__(self, min_length: int, eos_token_id: int):
215215
self.min_length = min_length
216216
self.eos_token_id = eos_token_id
217217

218-
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
219-
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
220-
# generate is not XLA - compileable anyways
221-
if cur_len < self.min_length:
222-
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
223-
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
218+
def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
219+
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
220+
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
221+
return scores
224222

223+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
224+
# applies eos token masking if the first argument is true
225+
scores = tf.cond(
226+
tf.less(cur_len, self.min_length),
227+
lambda: self._apply_eos_token_mask(scores),
228+
lambda: tf.identity(scores),
229+
)
225230
return scores
226231

227232

tests/generation/test_generation_tf_logits_process.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020

21+
from parameterized import parameterized
2122
from transformers import is_tf_available
2223
from transformers.testing_utils import require_tf
2324

@@ -47,12 +48,15 @@ def _get_uniform_logits(self, batch_size: int, length: int):
4748
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
4849
return scores
4950

50-
def test_min_length_dist_processor(self):
51+
@parameterized.expand([(False,), (True,)])
52+
def test_min_length_dist_processor(self, use_xla):
5153
vocab_size = 20
5254
batch_size = 4
5355
eos_token_id = 0
5456

5557
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
58+
if use_xla:
59+
min_dist_processor = tf.function(min_dist_processor, jit_compile=True)
5660

5761
# check that min length is applied at length 5
5862
cur_len = 5
@@ -256,12 +260,15 @@ def test_no_bad_words_dist_processor(self):
256260
[[True, True, False, True, True], [True, True, True, False, True]],
257261
)
258262

259-
def test_forced_bos_token_logits_processor(self):
263+
@parameterized.expand([(False,), (True,)])
264+
def test_forced_bos_token_logits_processor(self, use_xla):
260265
vocab_size = 20
261266
batch_size = 4
262267
bos_token_id = 0
263268

264269
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
270+
if use_xla:
271+
logits_processor = tf.function(logits_processor, jit_compile=True)
265272

266273
# check that all scores are -inf except the bos_token_id score
267274
cur_len = 1
@@ -280,13 +287,16 @@ def test_forced_bos_token_logits_processor(self):
280287
scores = logits_processor(input_ids, scores, cur_len)
281288
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
282289

283-
def test_forced_eos_token_logits_processor(self):
290+
@parameterized.expand([(False,), (True,)])
291+
def test_forced_eos_token_logits_processor(self, use_xla):
284292
vocab_size = 20
285293
batch_size = 4
286294
eos_token_id = 0
287295
max_length = 5
288296

289297
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
298+
if use_xla:
299+
logits_processor = tf.function(logits_processor, jit_compile=True)
290300

291301
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
292302
cur_len = 4

0 commit comments

Comments
 (0)