2626
2727from ..model_utils import PretrainedModel , register_base_model
2828from ..nezha .modeling import ACT2FN
29+ from ..model_outputs import (
30+ BaseModelOutputWithPastAndCrossAttentions ,
31+ Seq2SeqModelOutput ,
32+ Seq2SeqLMOutput ,
33+ BaseModelOutput ,
34+ ModelOutput ,
35+ )
2936
3037__all__ = [
3138 'T5Model' , "T5PretrainedModel" , 'T5ForConditionalGeneration' ,
@@ -944,7 +951,8 @@ def forward(self,
944951 cache = None ,
945952 use_cache = False ,
946953 output_attentions = False ,
947- output_hidden_states = False ):
954+ output_hidden_states = False ,
955+ return_dict = False ):
948956 assert input_ids is not None , "input_ids can not be None"
949957 input_shape = input_ids .shape
950958 input_ids = input_ids .reshape (shape = [- 1 , input_shape [- 1 ]])
@@ -1051,13 +1059,22 @@ def forward(self,
10511059 if output_hidden_states :
10521060 all_hidden_states = all_hidden_states + (hidden_states , )
10531061
1054- return tuple (v for v in [
1055- hidden_states ,
1056- present_key_value_states ,
1057- all_hidden_states ,
1058- all_attentions ,
1059- all_cross_attentions ,
1060- ] if v is not None )
1062+ if not return_dict :
1063+ return tuple (v for v in [
1064+ hidden_states ,
1065+ present_key_value_states ,
1066+ all_hidden_states ,
1067+ all_attentions ,
1068+ all_cross_attentions ,
1069+ ] if v is not None )
1070+
1071+ return BaseModelOutputWithPastAndCrossAttentions (
1072+ last_hidden_state = hidden_states ,
1073+ past_key_values = present_key_value_states ,
1074+ hidden_states = all_hidden_states ,
1075+ attentions = all_attentions ,
1076+ cross_attentions = all_cross_attentions ,
1077+ )
10611078
10621079 def get_extended_attention_mask (self , attention_mask , input_shape ):
10631080 if attention_mask .ndim == 3 :
@@ -1293,7 +1310,8 @@ def forward(self,
12931310 cache = None ,
12941311 use_cache = True ,
12951312 output_attentions = False ,
1296- output_hidden_states = False ):
1313+ output_hidden_states = False ,
1314+ return_dict = False ):
12971315 r"""
12981316 The T5Model forward method, overrides the `__call__()` special method.
12991317
@@ -1343,8 +1361,16 @@ def forward(self,
13431361 output_hidden_states (bool, optional):
13441362 Whether or not to return the output of all hidden layers.
13451363 Defaults to `False`.
1364+ return_dict (bool, optional):
1365+ Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`. If `False`, the output
1366+ will be a tuple of tensors. Defaults to `False`.
1367+
13461368
13471369 Returns:
1370+ An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput` if `return_dict=True`.
1371+ Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
1372+ :class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.
1373+
13481374 tuple: Returns tuple (`last_hidden_state`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
13491375 `cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)
13501376
@@ -1419,8 +1445,10 @@ def forward(self,
14191445 input_ids = input_ids ,
14201446 attention_mask = attention_mask ,
14211447 output_attentions = output_attentions ,
1422- output_hidden_states = output_hidden_states )
1423-
1448+ output_hidden_states = output_hidden_states ,
1449+ return_dict = return_dict )
1450+ elif return_dict and not isinstance (encoder_output , BaseModelOutput ):
1451+ encoder_output = convert_encoder_output (encoder_output )
14241452 hidden_states = encoder_output [0 ]
14251453
14261454 # Decode
@@ -1432,9 +1460,22 @@ def forward(self,
14321460 encoder_attention_mask = attention_mask ,
14331461 use_cache = use_cache ,
14341462 output_attentions = output_attentions ,
1435- output_hidden_states = output_hidden_states )
1436-
1437- return decoder_outputs + encoder_output
1463+ output_hidden_states = output_hidden_states ,
1464+ return_dict = return_dict )
1465+
1466+ if not return_dict :
1467+ return decoder_outputs + encoder_output
1468+
1469+ return Seq2SeqModelOutput (
1470+ last_hidden_state = decoder_outputs .last_hidden_state ,
1471+ past_key_values = decoder_outputs .past_key_values ,
1472+ decoder_hidden_states = decoder_outputs .hidden_states ,
1473+ decoder_attentions = decoder_outputs .attentions ,
1474+ cross_attentions = decoder_outputs .cross_attentions ,
1475+ encoder_last_hidden_state = encoder_output .last_hidden_state ,
1476+ encoder_hidden_states = encoder_output .hidden_states ,
1477+ encoder_attentions = encoder_output .attentions ,
1478+ )
14381479
14391480
14401481class T5ForConditionalGeneration (T5PretrainedModel ):
@@ -1490,7 +1531,8 @@ def forward(self,
14901531 labels = None ,
14911532 use_cache = True ,
14921533 output_attentions = False ,
1493- output_hidden_states = False ):
1534+ output_hidden_states = False ,
1535+ return_dict = False ):
14941536 r"""
14951537
14961538 Args:
@@ -1518,8 +1560,15 @@ def forward(self,
15181560 See :class:`T5Model`.
15191561 output_hidden_states (bool, optional):
15201562 See :class:`T5Model`.
1563+ return_dict (bool, optional):
1564+ Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`. If `False`, the output
1565+ will be a tuple of tensors. Defaults to `False`.
15211566
15221567 Returns:
1568+ An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput` if `return_dict=True`.
1569+ Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
1570+ :class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`.
1571+
15231572 tuple: Returns tuple (`loss`, `logits`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
15241573 `cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)
15251574
@@ -1581,12 +1630,15 @@ def forward(self,
15811630 input_ids = input_ids ,
15821631 attention_mask = attention_mask ,
15831632 output_attentions = output_attentions ,
1584- output_hidden_states = output_hidden_states )
1585-
1586- if isinstance (encoder_output , (tuple , list )):
1587- hidden_states = encoder_output [0 ]
1633+ output_hidden_states = output_hidden_states ,
1634+ return_dict = return_dict )
15881635 else :
1589- hidden_states = encoder_output
1636+ if isinstance (encoder_output , paddle .Tensor ):
1637+ encoder_output = (encoder_output , )
1638+ if return_dict and not isinstance (encoder_output , BaseModelOutput ):
1639+ encoder_output = convert_encoder_output (encoder_output )
1640+
1641+ hidden_states = encoder_output [0 ]
15901642
15911643 if labels is not None and decoder_input_ids is None :
15921644 # get decoder inputs from shifting lm labels to the right
@@ -1610,7 +1662,8 @@ def forward(self,
16101662 encoder_attention_mask = attention_mask ,
16111663 use_cache = use_cache ,
16121664 output_attentions = output_attentions ,
1613- output_hidden_states = output_hidden_states )
1665+ output_hidden_states = output_hidden_states ,
1666+ return_dict = return_dict )
16141667
16151668 sequence_output = decoder_outputs [0 ]
16161669
@@ -1631,11 +1684,21 @@ def forward(self,
16311684 loss = loss_fct (lm_logits .reshape (shape = [- 1 , lm_logits .shape [- 1 ]]),
16321685 labels .flatten ())
16331686
1634- if not isinstance (encoder_output , (list , tuple )):
1635- encoder_output = (encoder_output , )
1636-
1637- output = (lm_logits , ) + decoder_outputs [1 :] + encoder_output
1638- return ((loss , ) + output ) if loss is not None else output
1687+ if not return_dict :
1688+ output = (lm_logits , ) + decoder_outputs [1 :] + encoder_output
1689+ return ((loss , ) + output ) if loss is not None else output
1690+
1691+ return Seq2SeqLMOutput (
1692+ loss = loss ,
1693+ logits = lm_logits ,
1694+ past_key_values = decoder_outputs .past_key_values ,
1695+ decoder_hidden_states = decoder_outputs .hidden_states ,
1696+ decoder_attentions = decoder_outputs .attentions ,
1697+ cross_attentions = decoder_outputs .cross_attentions ,
1698+ encoder_last_hidden_state = encoder_output .last_hidden_state ,
1699+ encoder_hidden_states = encoder_output .hidden_states ,
1700+ encoder_attentions = encoder_output .attentions ,
1701+ )
16391702
16401703 @staticmethod
16411704 def prepare_input_ids_for_generation (bos_token_id , encoder_output = None ):
@@ -1809,6 +1872,7 @@ def forward(
18091872 use_cache : Optional [bool ] = False ,
18101873 output_attentions : Optional [bool ] = False ,
18111874 output_hidden_states : Optional [bool ] = False ,
1875+ return_dict : Optional [bool ] = False ,
18121876 ):
18131877 encoder_outputs = self .encoder (
18141878 input_ids = input_ids ,
@@ -1819,9 +1883,25 @@ def forward(
18191883 use_cache = use_cache ,
18201884 output_attentions = output_attentions ,
18211885 output_hidden_states = output_hidden_states ,
1822- )
1886+ return_dict = return_dict )
18231887
18241888 return encoder_outputs
18251889
18261890
18271891T5EncoderModel .base_model_class = T5EncoderModel
1892+
1893+
1894+ def convert_encoder_output (encoder_output ):
1895+ """
1896+ Convert encoder_output from tuple to class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.
1897+
1898+ Args:
1899+ encoder_output (tuple or ModleOutput):
1900+ The output of the encoder, a tuple consists `last_hidden_state`, `hidden_states`(optional), `attentions`(optional).
1901+ The data type of `last_hidden_state` is float32 and its shape is [batch_size, sequence_length, hidden_size].
1902+ """
1903+ return BaseModelOutput (
1904+ last_hidden_state = encoder_output [0 ],
1905+ hidden_states = encoder_output [1 ] if len (encoder_output ) > 1 else None ,
1906+ attentions = encoder_output [2 ] if len (encoder_output ) > 2 else None ,
1907+ )
0 commit comments