Skip to content

Commit 6d90d76

Browse files
authored
TF: rework XLA generate tests (#16866)
1 parent 3b1bbef commit 6d90d76

File tree

2 files changed

+76
-54
lines changed

2 files changed

+76
-54
lines changed

tests/gpt2/test_modeling_tf_gpt2.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import unittest
1717

1818
from transformers import GPT2Config, is_tf_available
19-
from transformers.testing_utils import require_tf, slow
19+
from transformers.testing_utils import get_gpu_count, require_tf, slow
2020

2121
from ..test_configuration_common import ConfigTester
2222
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@@ -294,7 +294,7 @@ def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask
294294
result = model(inputs)
295295
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
296296

297-
def create_and_check_gpt2_xla_generate(self, config, input_ids, *args):
297+
def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args):
298298
config.eos_token_id = None
299299
config.max_length = 10
300300
model = TFGPT2LMHeadModel(config=config)
@@ -408,9 +408,9 @@ def test_gpt2_lm_head(self):
408408
config_and_inputs = self.model_tester.prepare_config_and_inputs()
409409
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
410410

411-
def test_gpt2_xla_generate(self):
411+
def test_gpt2_xla_generate_fast(self):
412412
config_and_inputs = self.model_tester.prepare_config_and_inputs()
413-
self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs)
413+
self.model_tester.create_and_check_gpt2_xla_generate_fast(*config_and_inputs)
414414

415415
def test_gpt2_double_head(self):
416416
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -536,41 +536,57 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
536536
self.assertListEqual(output_strings, expected_output_string)
537537

538538
@slow
539-
def test_lm_generate_gpt2(self):
539+
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
540+
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
541+
def test_lm_generate_gpt2_greedy_xla(self):
542+
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
543+
# the underlying problem)
540544
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
541-
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
545+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
542546

543-
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
544-
# fmt: off
545-
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
546-
# fmt: on
547-
output_ids = model.generate(input_ids, do_sample=False)
548-
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
547+
tokenizer.pad_token = tokenizer.eos_token
548+
tokenizer.padding_side = "left"
549549

550-
@slow
551-
def test_lm_generate_gpt2_xla_greedy(self):
552-
"""This test gives the exact same results as the non-xla test above"""
553-
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
554-
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
550+
sentences = ["The dog"]
551+
expected_output_strings = [
552+
"The dog was found in a field near the intersection of West and West Streets.\n\nThe dog",
553+
]
554+
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
555555

556-
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
557-
# fmt: off
558-
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
559-
# fmt: on
560-
xla_generate = tf.function(model.generate, jit_compile=True)
556+
output_ids = model.generate(input_ids, do_sample=False)
557+
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
558+
self.assertListEqual(output_strings, expected_output_strings)
561559

560+
xla_generate = tf.function(model.generate, jit_compile=True)
562561
output_ids = xla_generate(input_ids, do_sample=False)
563-
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
562+
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
563+
self.assertListEqual(output_strings, expected_output_strings)
564564

565565
@slow
566-
def test_lm_generate_gpt2_xla_sample(self):
566+
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
567+
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
568+
def test_lm_generate_gpt2_sample_xla(self):
569+
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
570+
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
571+
# and that we can seed both versions.
567572
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
568-
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
573+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
569574

570-
# fmt: off
571-
expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0]
572-
# fmt: on
573-
xla_generate = tf.function(model.generate, jit_compile=True)
575+
tokenizer.pad_token = tokenizer.eos_token
576+
tokenizer.padding_side = "left"
577+
578+
sentence = ["The dog"]
579+
expected_output_string = [
580+
"The dog must be well educated to do anything. If anything, this must be her best friend"
581+
]
582+
expected_output_string_xla = ["The dog has been named in connection with the murder of a 20-year-old man in!"]
583+
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
574584

575-
output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0])
576-
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
585+
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
586+
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
587+
self.assertListEqual(output_strings, expected_output_string)
588+
589+
xla_generate = tf.function(model.generate, jit_compile=True)
590+
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
591+
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
592+
self.assertListEqual(output_strings, expected_output_string_xla)

tests/t5/test_modeling_tf_t5.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import unittest
1717

1818
from transformers import T5Config, is_tf_available
19-
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
19+
from transformers.testing_utils import get_gpu_count, require_sentencepiece, require_tf, require_tokenizers, slow
2020
from transformers.utils import cached_property
2121

2222
from ..test_configuration_common import ConfigTester
@@ -227,7 +227,7 @@ def create_and_check_t5_decoder_model_past_large_inputs(
227227
# test that outputs are equal for slice
228228
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
229229

230-
def create_and_check_t5_xla_generate(self, config, input_ids, *args):
230+
def create_and_check_t5_xla_generate_fast(self, config, input_ids, *args):
231231
config.eos_token_id = None
232232
config.max_length = 10
233233
config.do_sample = False
@@ -297,9 +297,9 @@ def test_t5_decoder_model_past_large_inputs(self):
297297
config_and_inputs = self.model_tester.prepare_config_and_inputs()
298298
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
299299

300-
def test_t5_model_xla_generate(self):
300+
def test_t5_model_xla_generate_fast(self):
301301
config_and_inputs = self.model_tester.prepare_config_and_inputs()
302-
self.model_tester.create_and_check_t5_xla_generate(*config_and_inputs)
302+
self.model_tester.create_and_check_t5_xla_generate_fast(*config_and_inputs)
303303

304304
def test_model_common_attributes(self):
305305
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -481,12 +481,18 @@ def test_train_pipeline_custom_model(self):
481481
@require_tokenizers
482482
class TFT5GenerationIntegrationTests(unittest.TestCase):
483483
@slow
484+
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
485+
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
484486
def test_greedy_xla_generate_simple(self):
485487
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
486488
tokenizer = T5Tokenizer.from_pretrained("t5-small")
487489

488-
sentence = "Translate English to German: Today is a beautiful day."
489-
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
490+
# two examples with different lengths to confirm that attention masks are operational in XLA
491+
sentences = [
492+
"Translate English to German: Today is a beautiful day.",
493+
"Translate English to German: I have four cats, three dogs, two birds, and a horse.",
494+
]
495+
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
490496

491497
xla_generate = tf.function(model.generate, jit_compile=True)
492498

@@ -496,7 +502,10 @@ def test_greedy_xla_generate_simple(self):
496502
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
497503
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
498504

499-
expected_output_string = ["Heute ist ein schöner Tag."]
505+
expected_output_string = [
506+
"Heute ist ein schöner Tag.",
507+
"Ich habe vier Katzen, drei Hunde, zwei Vögel und ein Pferd.",
508+
]
500509

501510
self.assertListEqual(expected_output_string, output_strings)
502511
self.assertListEqual(expected_output_string, output_strings_xla)
@@ -525,31 +534,28 @@ def test_greedy_generate(self):
525534
self.assertListEqual(expected_output_string, output_strings)
526535

527536
@slow
537+
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
538+
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
528539
def test_sample_xla_generate_simple(self):
540+
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
541+
# output out of the same seed is far from guaranteed (unlike this example). We can, however, confirm that the
542+
# results are sensible and that we can seed both versions.
529543
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
530544
tokenizer = T5Tokenizer.from_pretrained("t5-small")
531545

532-
sentence = "Translate English to German: Today is a beautiful day."
546+
sentence = "Translate English to German: I have two bananas"
533547
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
534-
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
535-
# divergences in generate -- especially with sampling.
536-
expected_output_string = ["Heute ist ein schöner Tag."]
537-
expected_output_string_xla = ["Heute ist ein schöne Tage."]
538-
# However, notice that the first tokens are the same, for the same seed
539-
assert expected_output_string[0][:15] == expected_output_string_xla[0][:15]
548+
expected_output_string = ["Ich habe 2 Bananen"]
549+
expected_output_string_xla = ["Ich habe 2 Bananen"]
540550

541-
# forces the generation to happen on CPU, to avoid GPU-related quirks
542-
with tf.device(":/CPU:0"):
543-
# seed set -> deterministic sampling sequence -> deterministic generation
544-
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
551+
# seed set -> deterministic sampling sequence -> deterministic generation
552+
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
545553
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
546554
self.assertListEqual(expected_output_string, output_strings)
547555

548-
# forces the generation to happen on CPU, to avoid GPU-related quirks
549-
with tf.device(":/CPU:0"):
550-
xla_generate = tf.function(model.generate, jit_compile=True)
551-
# seed set -> deterministic sampling sequence -> deterministic generation
552-
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
556+
xla_generate = tf.function(model.generate, jit_compile=True)
557+
# seed set -> deterministic sampling sequence -> deterministic generation
558+
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
553559
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
554560
self.assertListEqual(expected_output_string_xla, output_strings_xla)
555561

0 commit comments

Comments
 (0)