|
16 | 16 | import unittest
|
17 | 17 |
|
18 | 18 | from transformers import GPT2Config, is_tf_available
|
19 |
| -from transformers.testing_utils import require_tf, slow |
| 19 | +from transformers.testing_utils import get_gpu_count, require_tf, slow |
20 | 20 |
|
21 | 21 | from ..test_configuration_common import ConfigTester
|
22 | 22 | from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
@@ -294,7 +294,7 @@ def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask
|
294 | 294 | result = model(inputs)
|
295 | 295 | self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
296 | 296 |
|
297 |
| - def create_and_check_gpt2_xla_generate(self, config, input_ids, *args): |
| 297 | + def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args): |
298 | 298 | config.eos_token_id = None
|
299 | 299 | config.max_length = 10
|
300 | 300 | model = TFGPT2LMHeadModel(config=config)
|
@@ -408,9 +408,9 @@ def test_gpt2_lm_head(self):
|
408 | 408 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
409 | 409 | self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
|
410 | 410 |
|
411 |
| - def test_gpt2_xla_generate(self): |
| 411 | + def test_gpt2_xla_generate_fast(self): |
412 | 412 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
413 |
| - self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs) |
| 413 | + self.model_tester.create_and_check_gpt2_xla_generate_fast(*config_and_inputs) |
414 | 414 |
|
415 | 415 | def test_gpt2_double_head(self):
|
416 | 416 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
@@ -536,41 +536,57 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
|
536 | 536 | self.assertListEqual(output_strings, expected_output_string)
|
537 | 537 |
|
538 | 538 | @slow
|
539 |
| - def test_lm_generate_gpt2(self): |
| 539 | + @unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU") |
| 540 | + # TODO: remove the skip when the XLA CPU softmax issue gets sorted |
| 541 | + def test_lm_generate_gpt2_greedy_xla(self): |
| 542 | + # TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix |
| 543 | + # the underlying problem) |
540 | 544 | model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
541 |
| - input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog |
| 545 | + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
542 | 546 |
|
543 |
| - # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog |
544 |
| - # fmt: off |
545 |
| - expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290] |
546 |
| - # fmt: on |
547 |
| - output_ids = model.generate(input_ids, do_sample=False) |
548 |
| - self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) |
| 547 | + tokenizer.pad_token = tokenizer.eos_token |
| 548 | + tokenizer.padding_side = "left" |
549 | 549 |
|
550 |
| - @slow |
551 |
| - def test_lm_generate_gpt2_xla_greedy(self): |
552 |
| - """This test gives the exact same results as the non-xla test above""" |
553 |
| - model = TFGPT2LMHeadModel.from_pretrained("gpt2") |
554 |
| - input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog |
| 550 | + sentences = ["The dog"] |
| 551 | + expected_output_strings = [ |
| 552 | + "The dog was found in a field near the intersection of West and West Streets.\n\nThe dog", |
| 553 | + ] |
| 554 | + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids |
555 | 555 |
|
556 |
| - # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog |
557 |
| - # fmt: off |
558 |
| - expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290] |
559 |
| - # fmt: on |
560 |
| - xla_generate = tf.function(model.generate, jit_compile=True) |
| 556 | + output_ids = model.generate(input_ids, do_sample=False) |
| 557 | + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| 558 | + self.assertListEqual(output_strings, expected_output_strings) |
561 | 559 |
|
| 560 | + xla_generate = tf.function(model.generate, jit_compile=True) |
562 | 561 | output_ids = xla_generate(input_ids, do_sample=False)
|
563 |
| - self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) |
| 562 | + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| 563 | + self.assertListEqual(output_strings, expected_output_strings) |
564 | 564 |
|
565 | 565 | @slow
|
566 |
| - def test_lm_generate_gpt2_xla_sample(self): |
| 566 | + @unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU") |
| 567 | + # TODO: remove the skip when the XLA CPU softmax issue gets sorted |
| 568 | + def test_lm_generate_gpt2_sample_xla(self): |
| 569 | + # NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same |
| 570 | + # output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible |
| 571 | + # and that we can seed both versions. |
567 | 572 | model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
568 |
| - input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog |
| 573 | + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
569 | 574 |
|
570 |
| - # fmt: off |
571 |
| - expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0] |
572 |
| - # fmt: on |
573 |
| - xla_generate = tf.function(model.generate, jit_compile=True) |
| 575 | + tokenizer.pad_token = tokenizer.eos_token |
| 576 | + tokenizer.padding_side = "left" |
| 577 | + |
| 578 | + sentence = ["The dog"] |
| 579 | + expected_output_string = [ |
| 580 | + "The dog must be well educated to do anything. If anything, this must be her best friend" |
| 581 | + ] |
| 582 | + expected_output_string_xla = ["The dog has been named in connection with the murder of a 20-year-old man in!"] |
| 583 | + input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids |
574 | 584 |
|
575 |
| - output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0]) |
576 |
| - self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) |
| 585 | + output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0]) |
| 586 | + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| 587 | + self.assertListEqual(output_strings, expected_output_string) |
| 588 | + |
| 589 | + xla_generate = tf.function(model.generate, jit_compile=True) |
| 590 | + output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0]) |
| 591 | + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| 592 | + self.assertListEqual(output_strings, expected_output_string_xla) |
0 commit comments