|
36 | 36 |
|
37 | 37 | logger = logging.get_logger(__name__)
|
38 | 38 |
|
39 |
| -_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B" |
| 39 | +_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj" |
40 | 40 | _CONFIG_FOR_DOC = "GPTJConfig"
|
41 | 41 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
42 | 42 |
|
| 43 | +_CHECKPOINT_FOR_QA = "ydshieh/tiny-random-gptj-for-question-answering" |
| 44 | +_QA_EXPECTED_OUTPUT = "' was Jim Henson?Jim Henson was a n'" |
| 45 | +_QA_EXPECTED_LOSS = 3.13 |
| 46 | + |
| 47 | +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ydshieh/tiny-random-gptj-for-sequence-classification" |
| 48 | +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" |
| 49 | +_SEQ_CLASS_EXPECTED_LOSS = 0.76 |
| 50 | + |
| 51 | + |
43 | 52 | GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
44 | 53 | "EleutherAI/gpt-j-6B",
|
45 | 54 | # See all GPT-J models at https://huggingface.co/models?filter=gptj
|
@@ -892,9 +901,11 @@ def __init__(self, config):
|
892 | 901 | @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
893 | 902 | @add_code_sample_docstrings(
|
894 | 903 | processor_class=_TOKENIZER_FOR_DOC,
|
895 |
| - checkpoint=_CHECKPOINT_FOR_DOC, |
| 904 | + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, |
896 | 905 | output_type=SequenceClassifierOutputWithPast,
|
897 | 906 | config_class=_CONFIG_FOR_DOC,
|
| 907 | + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, |
| 908 | + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, |
898 | 909 | )
|
899 | 910 | def forward(
|
900 | 911 | self,
|
@@ -1017,9 +1028,11 @@ def __init__(self, config):
|
1017 | 1028 | @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1018 | 1029 | @add_code_sample_docstrings(
|
1019 | 1030 | processor_class=_TOKENIZER_FOR_DOC,
|
1020 |
| - checkpoint=_CHECKPOINT_FOR_DOC, |
| 1031 | + checkpoint=_CHECKPOINT_FOR_QA, |
1021 | 1032 | output_type=QuestionAnsweringModelOutput,
|
1022 | 1033 | config_class=_CONFIG_FOR_DOC,
|
| 1034 | + expected_output=_QA_EXPECTED_OUTPUT, |
| 1035 | + expected_loss=_QA_EXPECTED_LOSS, |
1023 | 1036 | )
|
1024 | 1037 | def forward(
|
1025 | 1038 | self,
|
|
0 commit comments