18
18
19
19
import numpy as np
20
20
21
+ from parameterized import parameterized
21
22
from transformers import is_tf_available
22
23
from transformers .testing_utils import require_tf
23
24
@@ -47,12 +48,15 @@ def _get_uniform_logits(self, batch_size: int, length: int):
47
48
scores = tf .ones ((batch_size , length ), dtype = tf .float32 ) / length
48
49
return scores
49
50
50
- def test_min_length_dist_processor (self ):
51
+ @parameterized .expand ([(False ,), (True ,)])
52
+ def test_min_length_dist_processor (self , use_xla ):
51
53
vocab_size = 20
52
54
batch_size = 4
53
55
eos_token_id = 0
54
56
55
57
min_dist_processor = TFMinLengthLogitsProcessor (min_length = 10 , eos_token_id = eos_token_id )
58
+ if use_xla :
59
+ min_dist_processor = tf .function (min_dist_processor , jit_compile = True )
56
60
57
61
# check that min length is applied at length 5
58
62
cur_len = 5
@@ -256,12 +260,15 @@ def test_no_bad_words_dist_processor(self):
256
260
[[True , True , False , True , True ], [True , True , True , False , True ]],
257
261
)
258
262
259
- def test_forced_bos_token_logits_processor (self ):
263
+ @parameterized .expand ([(False ,), (True ,)])
264
+ def test_forced_bos_token_logits_processor (self , use_xla ):
260
265
vocab_size = 20
261
266
batch_size = 4
262
267
bos_token_id = 0
263
268
264
269
logits_processor = TFForcedBOSTokenLogitsProcessor (bos_token_id = bos_token_id )
270
+ if use_xla :
271
+ logits_processor = tf .function (logits_processor , jit_compile = True )
265
272
266
273
# check that all scores are -inf except the bos_token_id score
267
274
cur_len = 1
@@ -280,13 +287,16 @@ def test_forced_bos_token_logits_processor(self):
280
287
scores = logits_processor (input_ids , scores , cur_len )
281
288
self .assertFalse (tf .math .reduce_any (tf .math .is_inf ((scores ))))
282
289
283
- def test_forced_eos_token_logits_processor (self ):
290
+ @parameterized .expand ([(False ,), (True ,)])
291
+ def test_forced_eos_token_logits_processor (self , use_xla ):
284
292
vocab_size = 20
285
293
batch_size = 4
286
294
eos_token_id = 0
287
295
max_length = 5
288
296
289
297
logits_processor = TFForcedEOSTokenLogitsProcessor (max_length = max_length , eos_token_id = eos_token_id )
298
+ if use_xla :
299
+ logits_processor = tf .function (logits_processor , jit_compile = True )
290
300
291
301
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
292
302
cur_len = 4
0 commit comments