Skip to content

Commit f6f6e31

Browse files
jsnflyelusenji
authored andcommitted
Allow passing encoder_ouputs as tuple to EncoderDecoder Models (huggingface#16814)
* Add passing encoder_outputs as tuple to existing test * Add check for tuple * Add check for tuple also for speech and vision Co-authored-by: jsnfly <[email protected]>
1 parent a241aa2 commit f6f6e31

File tree

4 files changed

+25
-3
lines changed

4 files changed

+25
-3
lines changed

src/transformers/models/encoder_decoder/modeling_encoder_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.nn import CrossEntropyLoss
2323

2424
from ...configuration_utils import PretrainedConfig
25-
from ...modeling_outputs import Seq2SeqLMOutput
25+
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
2626
from ...modeling_utils import PreTrainedModel
2727
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
2828
from ..auto.configuration_auto import AutoConfig
@@ -494,6 +494,8 @@ def forward(
494494
return_dict=return_dict,
495495
**kwargs_encoder,
496496
)
497+
elif isinstance(encoder_outputs, tuple):
498+
encoder_outputs = BaseModelOutput(*encoder_outputs)
497499

498500
encoder_hidden_states = encoder_outputs[0]
499501

src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.nn import CrossEntropyLoss
2323

2424
from ...configuration_utils import PretrainedConfig
25-
from ...modeling_outputs import Seq2SeqLMOutput
25+
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
2626
from ...modeling_utils import PreTrainedModel
2727
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
2828
from ..auto.configuration_auto import AutoConfig
@@ -514,6 +514,8 @@ def forward(
514514
return_dict=return_dict,
515515
**kwargs_encoder,
516516
)
517+
elif isinstance(encoder_outputs, tuple):
518+
encoder_outputs = BaseModelOutput(*encoder_outputs)
517519

518520
encoder_hidden_states = encoder_outputs[0]
519521

src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.nn import CrossEntropyLoss
2323

2424
from ...configuration_utils import PretrainedConfig
25-
from ...modeling_outputs import Seq2SeqLMOutput
25+
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
2626
from ...modeling_utils import PreTrainedModel
2727
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
2828
from ..auto.configuration_auto import AutoConfig
@@ -466,6 +466,8 @@ def forward(
466466
return_dict=return_dict,
467467
**kwargs_encoder,
468468
)
469+
elif isinstance(encoder_outputs, tuple):
470+
encoder_outputs = BaseModelOutput(*encoder_outputs)
469471

470472
encoder_hidden_states = encoder_outputs[0]
471473

tests/encoder_decoder/test_modeling_encoder_decoder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,22 @@ def check_encoder_decoder_model(
142142
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
143143
)
144144

145+
# Test passing encoder_outputs as tuple.
146+
encoder_outputs = (encoder_hidden_states,)
147+
outputs_encoder_decoder = enc_dec_model(
148+
encoder_outputs=encoder_outputs,
149+
decoder_input_ids=decoder_input_ids,
150+
attention_mask=attention_mask,
151+
decoder_attention_mask=decoder_attention_mask,
152+
)
153+
154+
self.assertEqual(
155+
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
156+
)
157+
self.assertEqual(
158+
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
159+
)
160+
145161
def check_encoder_decoder_model_from_pretrained_using_model_paths(
146162
self,
147163
config,

0 commit comments

Comments
 (0)