1919from paddle .nn import TransformerEncoder
2020
2121from .. import PretrainedModel , register_base_model
22+ from ..model_outputs import CausalLMOutputWithCrossAttentions
2223
2324__all__ = [
2425 "UnifiedTransformerPretrainedModel" ,
@@ -343,7 +344,10 @@ def forward(self,
343344 attention_mask = None ,
344345 use_cache = False ,
345346 cache = None ,
346- role_ids = None ):
347+ role_ids = None ,
348+ output_attentions = False ,
349+ output_hidden_states = False ,
350+ return_dict = False ):
347351 r"""
348352 The UnifiedTransformerModel forward method, overrides the special
349353 :meth:`__call__` method.
@@ -392,17 +396,25 @@ def forward(self,
392396 Indices of role ids indicated different roles.
393397 It's data type should be `int64` and has a shape of
394398 [batch_size, sequence_length]. Defaults to None.
399+ output_attentions (bool, optional):
400+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
401+ tensors for more detail. Defaults to `False`.
402+ output_hidden_states (bool, optional):
403+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
404+ more detail. Defaults to `False`.
405+ return_dict (bool, optional):
406+ Whether to return a :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` object.
407+ If `False`, the output will be a tuple of tensors. Defaults to `False`.
395408
396409 Returns:
397- Tensor|tuple: If `use_cache` is False, it is a tensor
398- representing the output of :class:`UnifiedTransformerModel`, with
410+ An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` if
411+ `return_dict=True`. Otherwise it returns a tuple of tensors corresponding
412+ to ordered and not None (depending on the input arguments) fields of
413+ :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions`.
414+ Especially, When `return_dict=output_hidden_states=output_attentions=False` and `cache=None`,
415+ returns a tensor representing the output of :class:`UnifiedTransformerModel`, with
399416 shape [batch_size, sequence_length, hidden_size]. The data type is
400- float32 or float64. Otherwise, it is a tuple, besides the output of
401- :class:`UnifiedTransformerModel`, the tuple also includes the new
402- cache which is same as input `cache` but `incremental_cache` in it
403- has an incremental length.
404- See :meth:`paddle.nn.MultiHeadAttention.gen_cache` method and
405- :meth:`paddle.nn.MultiHeadAttention.forward` method for more details.
417+ float32 or float64.
406418
407419 Example:
408420 .. code-block::
@@ -429,16 +441,18 @@ def forward(self,
429441 token_type_ids ,
430442 position_ids ,
431443 role_ids = role_ids )
432- if use_cache :
433- if cache is None :
434- cache = self .encoder .gen_cache (embedding_output )
435- sequence_output , cache = self .encoder (embedding_output ,
436- attention_mask , cache )
437- return sequence_output , cache
438- else :
439- sequence_output = self .encoder (embedding_output , attention_mask )
444+ if use_cache and cache is None :
445+ cache = self .encoder .gen_cache (embedding_output )
440446
441- return sequence_output
447+ sequence_output = self .encoder (
448+ embedding_output ,
449+ attention_mask ,
450+ cache ,
451+ output_attentions = output_attentions ,
452+ output_hidden_states = output_hidden_states ,
453+ return_dict = return_dict ,
454+ )
455+ return sequence_output
442456
443457
444458class UnifiedTransformerLMHead (nn .Layer ):
@@ -502,7 +516,11 @@ def forward(self,
502516 masked_positions = None ,
503517 use_cache = False ,
504518 cache = None ,
505- role_ids = None ):
519+ role_ids = None ,
520+ labels = None ,
521+ output_attentions = False ,
522+ output_hidden_states = False ,
523+ return_dict = False ):
506524 r"""
507525 The UnifiedTransformerLMHeadModel forward method, overrides the special
508526 :meth:`__call__` method.
@@ -522,17 +540,26 @@ def forward(self,
522540 See :class:`UnifiedTransformerModel`.
523541 role_ids: (Tensor, optional):
524542 See :class:`UnifiedTransformerModel`.
543+ labels: (Tensor, optional):
544+ Labels for computing the left-to-right language modeling loss. Indices should be in
545+ `[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
546+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., vocab_size]`
547+ output_attentions (bool, optional):
548+ See :class: `UnifiedTransformerModel`
549+ output_hidden_states (bool, optional):
550+ See :class: `UnifiedTransformerModel`
551+ return_dict (bool, optional):
552+ See :class: `UnifiedTransformerModel`
525553
526554 Returns:
527- Tensor|tuple: If `use_cache` is False, it is a tensor
528- representing the output of :class:`UnifiedTransformerLMHeadModel`,
555+ An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions` if
556+ `return_dict=True`. Otherwise it returns a tuple of tensors corresponding
557+ to ordered and not None (depending on the input arguments) fields of
558+ :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions`.
559+ Especially, When `return_dict=output_hidden_states=output_attentions=False` and `cache=labels=None`,
560+ returns a tensor representing the output of :class:`UnifiedTransformerLMHeadModel`,
529561 with shape [batch_size, sequence_length, vocab_size]. The data type
530- is float32 or float64. Otherwise, it is a tuple, besides the output
531- of :class:`UnifiedTransformerLMHeadModel`, the tuple also includes
532- the new cache which is same as input `cache` but `incremental_cache`
533- in it has an incremental length.
534- See :meth:`paddle.nn.MultiHeadAttention.gen_cache` method and
535- :meth:`paddle.nn.MultiHeadAttention.forward` method for more details.
562+ is float32 or float64.
536563
537564 Example:
538565 .. code-block::
@@ -551,20 +578,43 @@ def forward(self,
551578 logits = model(**inputs)
552579 """
553580
554- outputs = self .unified_transformer (input_ids ,
555- token_type_ids ,
556- position_ids ,
557- attention_mask ,
558- use_cache ,
559- cache ,
560- role_ids = role_ids )
561- sequence_output = outputs [0 ] if use_cache else outputs
581+ outputs = self .unified_transformer (
582+ input_ids ,
583+ token_type_ids ,
584+ position_ids ,
585+ attention_mask ,
586+ use_cache ,
587+ cache ,
588+ role_ids = role_ids ,
589+ output_attentions = output_attentions ,
590+ output_hidden_states = output_hidden_states ,
591+ return_dict = return_dict ,
592+ )
593+ sequence_output = outputs if isinstance (outputs ,
594+ type (input_ids )) else outputs [0 ]
562595 logits = self .lm_head (sequence_output , masked_positions )
563- if use_cache :
564- cache = outputs [1 ]
565- return logits , cache
566- else :
567- return logits
596+
597+ lm_loss = None
598+ if labels is not None :
599+ loss_fct = nn .CrossEntropyLoss ()
600+ lm_loss = loss_fct (logits .reshape ((- 1 , logits .shape [- 1 ])),
601+ labels .reshape ([- 1 ]))
602+ if not return_dict :
603+ if isinstance (outputs , type (input_ids )):
604+ return (lm_loss , logits ) if lm_loss is not None else logits
605+ else :
606+ outputs = (logits , ) + outputs [1 :]
607+ return ((lm_loss , ) +
608+ outputs ) if lm_loss is not None else outputs
609+
610+ return CausalLMOutputWithCrossAttentions (
611+ loss = lm_loss ,
612+ logits = logits ,
613+ past_key_values = outputs .past_key_values ,
614+ hidden_states = outputs .hidden_states ,
615+ attentions = outputs .attentions ,
616+ cross_attentions = outputs .cross_attentions ,
617+ )
568618
569619 def prepare_faster_entry (self , kwargs ):
570620 from paddlenlp .ops import FasterUnifiedTransformer
0 commit comments