Skip to content

Commit f2b93bc

Browse files
KMFODAydshieh
authored andcommitted
Doctest longformer (huggingface#16441)
* Add initial doctring changes * make fixup * Add TF doc changes * fix seq classifier output * fix quality errors * t * swithc head to random init * Fix expected outputs * Update src/transformers/models/longformer/modeling_longformer.py Co-authored-by: Yih-Dar <[email protected]> Co-authored-by: Yih-Dar <[email protected]>
1 parent b5b6b3e commit f2b93bc

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

src/transformers/models/longformer/modeling_longformer.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,23 +1782,31 @@ def forward(
17821782
17831783
Returns:
17841784
1785-
Examples:
1785+
Mask filling example:
17861786
17871787
```python
1788-
>>> import torch
1789-
>>> from transformers import LongformerForMaskedLM, LongformerTokenizer
1788+
>>> from transformers import LongformerTokenizer, LongformerForMaskedLM
17901789
1791-
>>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
17921790
>>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
1791+
>>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
1792+
```
17931793
1794-
>>> SAMPLE_TEXT = " ".join(["Hello world! "] * 1000) # long input document
1795-
>>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
1794+
Let's try a very long input.
17961795
1797-
>>> attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM
1798-
>>> # check `LongformerModel.forward` for more details how to set *attention_mask*
1799-
>>> outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
1800-
>>> loss = outputs.loss
1801-
>>> prediction_logits = outputs.logits
1796+
```python
1797+
>>> TXT = (
1798+
... "My friends are <mask> but they eat too many carbs."
1799+
... + " That's why I decide not to eat with them." * 300
1800+
... )
1801+
>>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
1802+
>>> logits = model(input_ids).logits
1803+
1804+
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
1805+
>>> probs = logits[0, masked_index].softmax(dim=0)
1806+
>>> values, predictions = probs.topk(5)
1807+
1808+
>>> tokenizer.decode(predictions).split()
1809+
['healthy', 'skinny', 'thin', 'good', 'vegetarian']
18021810
```"""
18031811
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
18041812

@@ -1860,9 +1868,11 @@ def __init__(self, config):
18601868
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
18611869
@add_code_sample_docstrings(
18621870
processor_class=_TOKENIZER_FOR_DOC,
1863-
checkpoint=_CHECKPOINT_FOR_DOC,
1871+
checkpoint="jpelhaw/longformer-base-plagiarism-detection",
18641872
output_type=LongformerSequenceClassifierOutput,
18651873
config_class=_CONFIG_FOR_DOC,
1874+
expected_output="'ORIGINAL'",
1875+
expected_loss=5.44,
18661876
)
18671877
def forward(
18681878
self,
@@ -2127,9 +2137,11 @@ def __init__(self, config):
21272137
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
21282138
@add_code_sample_docstrings(
21292139
processor_class=_TOKENIZER_FOR_DOC,
2130-
checkpoint=_CHECKPOINT_FOR_DOC,
2140+
checkpoint="brad1141/Longformer-finetuned-norm",
21312141
output_type=LongformerTokenClassifierOutput,
21322142
config_class=_CONFIG_FOR_DOC,
2143+
expected_output="['Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence']",
2144+
expected_loss=0.63,
21332145
)
21342146
def forward(
21352147
self,

src/transformers/models/longformer/modeling_tf_longformer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,10 +2102,12 @@ def get_prefix_bias_name(self):
21022102
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
21032103
@add_code_sample_docstrings(
21042104
processor_class=_TOKENIZER_FOR_DOC,
2105-
checkpoint=_CHECKPOINT_FOR_DOC,
2105+
checkpoint="allenai/longformer-base-4096",
21062106
output_type=TFLongformerMaskedLMOutput,
21072107
config_class=_CONFIG_FOR_DOC,
21082108
mask="<mask>",
2109+
expected_output="' Paris'",
2110+
expected_loss=0.44,
21092111
)
21102112
def call(
21112113
self,
@@ -2198,6 +2200,8 @@ def __init__(self, config, *inputs, **kwargs):
21982200
checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
21992201
output_type=TFLongformerQuestionAnsweringModelOutput,
22002202
config_class=_CONFIG_FOR_DOC,
2203+
expected_output="' puppet'",
2204+
expected_loss=0.96,
22012205
)
22022206
def call(
22032207
self,
@@ -2344,9 +2348,11 @@ def __init__(self, config, *inputs, **kwargs):
23442348
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
23452349
@add_code_sample_docstrings(
23462350
processor_class=_TOKENIZER_FOR_DOC,
2347-
checkpoint=_CHECKPOINT_FOR_DOC,
2351+
checkpoint="hf-internal-testing/tiny-random-longformer",
23482352
output_type=TFLongformerSequenceClassifierOutput,
23492353
config_class=_CONFIG_FOR_DOC,
2354+
expected_output="'LABEL_1'",
2355+
expected_loss=0.69,
23502356
)
23512357
def call(
23522358
self,
@@ -2582,9 +2588,11 @@ def __init__(self, config, *inputs, **kwargs):
25822588
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
25832589
@add_code_sample_docstrings(
25842590
processor_class=_TOKENIZER_FOR_DOC,
2585-
checkpoint=_CHECKPOINT_FOR_DOC,
2591+
checkpoint="hf-internal-testing/tiny-random-longformer",
25862592
output_type=TFLongformerTokenClassifierOutput,
25872593
config_class=_CONFIG_FOR_DOC,
2594+
expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1']",
2595+
expected_loss=0.59,
25882596
)
25892597
def call(
25902598
self,

utils/documentation_tests.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ src/transformers/models/glpn/modeling_glpn.py
3131
src/transformers/models/gpt2/modeling_gpt2.py
3232
src/transformers/models/gptj/modeling_gptj.py
3333
src/transformers/models/hubert/modeling_hubert.py
34+
src/transformers/models/longformer/modeling_longformer.py
35+
src/transformers/models/longformer/modeling_tf_longformer.py
3436
src/transformers/models/marian/modeling_marian.py
3537
src/transformers/models/mbart/modeling_mbart.py
3638
src/transformers/models/mobilebert/modeling_mobilebert.py

0 commit comments

Comments
 (0)