Skip to content

Commit 34956ca

Browse files
authored
add unified transformer more output & loss (#3459)
1 parent ed266e7 commit 34956ca

File tree

3 files changed

+180
-95
lines changed

3 files changed

+180
-95
lines changed

paddlenlp/transformers/unified_transformer/modeling.py

Lines changed: 90 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from paddle.nn import TransformerEncoder
2020

2121
from .. import PretrainedModel, register_base_model
22+
from ..model_outputs import CausalLMOutputWithCrossAttentions
2223

2324
__all__ = [
2425
"UnifiedTransformerPretrainedModel",
@@ -343,7 +344,10 @@ def forward(self,
343344
attention_mask=None,
344345
use_cache=False,
345346
cache=None,
346-
role_ids=None):
347+
role_ids=None,
348+
output_attentions=False,
349+
output_hidden_states=False,
350+
return_dict=False):
347351
r"""
348352
The UnifiedTransformerModel forward method, overrides the special
349353
:meth:`__call__` method.
@@ -392,17 +396,25 @@ def forward(self,
392396
Indices of role ids indicated different roles.
393397
It's data type should be `int64` and has a shape of
394398
[batch_size, sequence_length]. Defaults to None.
399+
output_attentions (bool, optional):
400+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
401+
tensors for more detail. Defaults to `False`.
402+
output_hidden_states (bool, optional):
403+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
404+
more detail. Defaults to `False`.
405+
return_dict (bool, optional):
406+
Whether to return a :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` object.
407+
If `False`, the output will be a tuple of tensors. Defaults to `False`.
395408
396409
Returns:
397-
Tensor|tuple: If `use_cache` is False, it is a tensor
398-
representing the output of :class:`UnifiedTransformerModel`, with
410+
An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` if
411+
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
412+
to ordered and not None (depending on the input arguments) fields of
413+
:class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions`.
414+
Especially, When `return_dict=output_hidden_states=output_attentions=False` and `cache=None`,
415+
returns a tensor representing the output of :class:`UnifiedTransformerModel`, with
399416
shape [batch_size, sequence_length, hidden_size]. The data type is
400-
float32 or float64. Otherwise, it is a tuple, besides the output of
401-
:class:`UnifiedTransformerModel`, the tuple also includes the new
402-
cache which is same as input `cache` but `incremental_cache` in it
403-
has an incremental length.
404-
See :meth:`paddle.nn.MultiHeadAttention.gen_cache` method and
405-
:meth:`paddle.nn.MultiHeadAttention.forward` method for more details.
417+
float32 or float64.
406418
407419
Example:
408420
.. code-block::
@@ -429,16 +441,18 @@ def forward(self,
429441
token_type_ids,
430442
position_ids,
431443
role_ids=role_ids)
432-
if use_cache:
433-
if cache is None:
434-
cache = self.encoder.gen_cache(embedding_output)
435-
sequence_output, cache = self.encoder(embedding_output,
436-
attention_mask, cache)
437-
return sequence_output, cache
438-
else:
439-
sequence_output = self.encoder(embedding_output, attention_mask)
444+
if use_cache and cache is None:
445+
cache = self.encoder.gen_cache(embedding_output)
440446

441-
return sequence_output
447+
sequence_output = self.encoder(
448+
embedding_output,
449+
attention_mask,
450+
cache,
451+
output_attentions=output_attentions,
452+
output_hidden_states=output_hidden_states,
453+
return_dict=return_dict,
454+
)
455+
return sequence_output
442456

443457

444458
class UnifiedTransformerLMHead(nn.Layer):
@@ -502,7 +516,11 @@ def forward(self,
502516
masked_positions=None,
503517
use_cache=False,
504518
cache=None,
505-
role_ids=None):
519+
role_ids=None,
520+
labels=None,
521+
output_attentions=False,
522+
output_hidden_states=False,
523+
return_dict=False):
506524
r"""
507525
The UnifiedTransformerLMHeadModel forward method, overrides the special
508526
:meth:`__call__` method.
@@ -522,17 +540,26 @@ def forward(self,
522540
See :class:`UnifiedTransformerModel`.
523541
role_ids: (Tensor, optional):
524542
See :class:`UnifiedTransformerModel`.
543+
labels: (Tensor, optional):
544+
Labels for computing the left-to-right language modeling loss. Indices should be in
545+
`[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
546+
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., vocab_size]`
547+
output_attentions (bool, optional):
548+
See :class: `UnifiedTransformerModel`
549+
output_hidden_states (bool, optional):
550+
See :class: `UnifiedTransformerModel`
551+
return_dict (bool, optional):
552+
See :class: `UnifiedTransformerModel`
525553
526554
Returns:
527-
Tensor|tuple: If `use_cache` is False, it is a tensor
528-
representing the output of :class:`UnifiedTransformerLMHeadModel`,
555+
An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions` if
556+
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
557+
to ordered and not None (depending on the input arguments) fields of
558+
:class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions`.
559+
Especially, When `return_dict=output_hidden_states=output_attentions=False` and `cache=labels=None`,
560+
returns a tensor representing the output of :class:`UnifiedTransformerLMHeadModel`,
529561
with shape [batch_size, sequence_length, vocab_size]. The data type
530-
is float32 or float64. Otherwise, it is a tuple, besides the output
531-
of :class:`UnifiedTransformerLMHeadModel`, the tuple also includes
532-
the new cache which is same as input `cache` but `incremental_cache`
533-
in it has an incremental length.
534-
See :meth:`paddle.nn.MultiHeadAttention.gen_cache` method and
535-
:meth:`paddle.nn.MultiHeadAttention.forward` method for more details.
562+
is float32 or float64.
536563
537564
Example:
538565
.. code-block::
@@ -551,20 +578,43 @@ def forward(self,
551578
logits = model(**inputs)
552579
"""
553580

554-
outputs = self.unified_transformer(input_ids,
555-
token_type_ids,
556-
position_ids,
557-
attention_mask,
558-
use_cache,
559-
cache,
560-
role_ids=role_ids)
561-
sequence_output = outputs[0] if use_cache else outputs
581+
outputs = self.unified_transformer(
582+
input_ids,
583+
token_type_ids,
584+
position_ids,
585+
attention_mask,
586+
use_cache,
587+
cache,
588+
role_ids=role_ids,
589+
output_attentions=output_attentions,
590+
output_hidden_states=output_hidden_states,
591+
return_dict=return_dict,
592+
)
593+
sequence_output = outputs if isinstance(outputs,
594+
type(input_ids)) else outputs[0]
562595
logits = self.lm_head(sequence_output, masked_positions)
563-
if use_cache:
564-
cache = outputs[1]
565-
return logits, cache
566-
else:
567-
return logits
596+
597+
lm_loss = None
598+
if labels is not None:
599+
loss_fct = nn.CrossEntropyLoss()
600+
lm_loss = loss_fct(logits.reshape((-1, logits.shape[-1])),
601+
labels.reshape([-1]))
602+
if not return_dict:
603+
if isinstance(outputs, type(input_ids)):
604+
return (lm_loss, logits) if lm_loss is not None else logits
605+
else:
606+
outputs = (logits, ) + outputs[1:]
607+
return ((lm_loss, ) +
608+
outputs) if lm_loss is not None else outputs
609+
610+
return CausalLMOutputWithCrossAttentions(
611+
loss=lm_loss,
612+
logits=logits,
613+
past_key_values=outputs.past_key_values,
614+
hidden_states=outputs.hidden_states,
615+
attentions=outputs.attentions,
616+
cross_attentions=outputs.cross_attentions,
617+
)
568618

569619
def prepare_faster_entry(self, kwargs):
570620
from paddlenlp.ops import FasterUnifiedTransformer

tests/transformers/test_generation_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def test_greedy_generate(self):
433433
)
434434
pretrained_model = self.all_generative_model_classes[model_class][
435435
0](**config)
436+
paddle.seed(128)
436437
model = model_class(pretrained_model)
437438
model.eval()
438439

@@ -446,15 +447,13 @@ def test_greedy_generate(self):
446447
output_generate[0].tolist())
447448

448449
def test_sample_generate(self):
449-
random.seed(128)
450-
np.random.seed(128)
451-
paddle.seed(128)
452450

453451
for model_class in self.all_generative_model_classes.keys():
454452
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(
455453
)
456454
pretrained_model = self.all_generative_model_classes[model_class][
457455
0](**config)
456+
paddle.seed(128)
458457
model = model_class(pretrained_model)
459458
model.eval()
460459

0 commit comments

Comments
 (0)