Skip to content

Commit 7d5174e

Browse files
wiio12elusenji
authored andcommitted
Add doc about attention_mask on gpt2 (huggingface#16829)
* Add doc about `attention_mask` on gpt2 Add a simple sentence describing how `attention_mask` needs to be constructed when ``past_key_values` is used. * Add doc about attention_mask on gpt2_tf * clean up style * remove empty line white spaces * remove whitespace in empty line
1 parent 0b3508e commit 7d5174e

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,10 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
565565
- 1 for tokens that are **not masked**,
566566
- 0 for tokens that are **masked**.
567567
568+
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
569+
`past_key_values`. In other words, the `attention_mask` always has to have the length:
570+
`len(past_key_values) + len(input_ids)`
571+
568572
[What are attention masks?](../glossary#attention-mask)
569573
token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
570574
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,

src/transformers/models/gpt2/modeling_tf_gpt2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,10 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
655655
- 1 for tokens that are **not masked**,
656656
- 0 for tokens that are **masked**.
657657
658+
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
659+
`past_key_values`. In other words, the `attention_mask` always has to have the length:
660+
`len(past_key_values) + len(input_ids)`
661+
658662
[What are attention masks?](../glossary#attention-mask)
659663
token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
660664
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,

0 commit comments

Comments
 (0)