Skip to content

Commit b33ab4e

Browse files
authored
Add global_attention_mask to gen_kwargs (#16485)
If global_attention_mask is found in the models inputs (used by certain models, like LED) in the prediction_step method of Seq2SeqTrainer, it is added to the gen_kwargs, which are passed to model.decode(). This allows us to properly set the global attention when decoding.
1 parent 9fd5e6b commit b33ab4e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/transformers/trainer_seq2seq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,11 @@ def prediction_step(
163163

164164
if "attention_mask" in inputs:
165165
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
166+
if "global_attention_mask" in inputs:
167+
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
166168

167169
# prepare generation inputs
168-
# some encoder-decoder models can have varying encder's and thus
170+
# some encoder-decoder models can have varying encoder's and thus
169171
# varying model input names
170172
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
171173
generation_inputs = inputs[self.model.encoder.main_input_name]

0 commit comments

Comments
 (0)