Skip to content

Commit 62f55d0

Browse files
authored
[FasterGeneration] MBart supports dy2sta (#3356)
1 parent 85d7ac6 commit 62f55d0

File tree

6 files changed

+264
-11
lines changed

6 files changed

+264
-11
lines changed

faster_generation/samples/mbart_sample.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
import paddle
1515
from paddlenlp.transformers import MBartForConditionalGeneration, MBartTokenizer
1616

17-
model_name = "mbart-large-50-one-to-many-mmt"
17+
model_name = "mbart-large-50-many-to-many-mmt"
1818

19-
tokenizer = MBartTokenizer.from_pretrained(model_name)
20-
model = MBartForConditionalGeneration.from_pretrained(model_name,
21-
src_lang="en_XX")
19+
tokenizer = MBartTokenizer.from_pretrained(model_name, src_lang="en_XX")
20+
model = MBartForConditionalGeneration.from_pretrained(model_name)
2221
model.eval()
2322

2423

@@ -41,7 +40,7 @@ def postprocess_response(seq, bos_idx, eos_idx):
4140

4241
inputs = "PaddleNLP is a powerful NLP library with Awesome pre-trained models and easy-to-use interface, supporting wide-range of NLP tasks from research to industrial applications."
4342
input_ids = tokenizer(inputs)["input_ids"]
44-
input_ids = paddle.to_tensor(input_ids, dtype='int64').unsqueeze(0)
43+
input_ids = paddle.to_tensor(input_ids, dtype='int32').unsqueeze(0)
4544

4645
outputs, _ = model.generate(input_ids=input_ids,
4746
forced_bos_token_id=bos_id,
@@ -53,5 +52,6 @@ def postprocess_response(seq, bos_idx, eos_idx):
5352
result = postprocess_response(outputs[0].numpy().tolist(), bos_id, eos_id)
5453

5554
print("Model input:", inputs)
55+
5656
print("Result:", result)
5757
# PaddleNLP是一个强大的NLP库,具有超乎寻常的预训练模型和易于使用的接口,支持从研究到工业应用的广泛的NLP任务。
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import argparse
17+
import paddle
18+
from pprint import pprint
19+
from paddlenlp.transformers import MBartForConditionalGeneration, MBartTokenizer
20+
from paddlenlp.ops import FasterMBART
21+
from paddlenlp.utils.log import logger
22+
23+
24+
def parse_args():
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument("--model_name_or_path",
27+
default="mbart-large-50-many-to-many-mmt",
28+
type=str,
29+
help="The model name to specify the bart to use. ")
30+
parser.add_argument("--inference_model_dir",
31+
default="./infer_model/",
32+
type=str,
33+
help="Path to save inference model of bart. ")
34+
parser.add_argument(
35+
"--topk",
36+
default=4,
37+
type=int,
38+
help="The number of candidate to procedure top_k sampling. ")
39+
parser.add_argument(
40+
"--topp",
41+
default=1.0,
42+
type=float,
43+
help="The probability threshold to procedure top_p sampling. ")
44+
parser.add_argument("--max_out_len",
45+
default=64,
46+
type=int,
47+
help="Maximum output length. ")
48+
parser.add_argument("--temperature",
49+
default=1.0,
50+
type=float,
51+
help="The temperature to set. ")
52+
parser.add_argument("--num_return_sequences",
53+
default=1,
54+
type=int,
55+
help="The number of returned sequences. ")
56+
parser.add_argument("--use_fp16_decoding",
57+
action="store_true",
58+
help="Whether to use fp16 decoding to predict. ")
59+
parser.add_argument("--decoding_strategy",
60+
default="beam_search",
61+
choices=["sampling", "beam_search"],
62+
type=str,
63+
help="The main strategy to decode. ")
64+
parser.add_argument(
65+
"--num_beams",
66+
default=5,
67+
type=int,
68+
help="The number of candidate to procedure beam search. ")
69+
parser.add_argument("--diversity_rate",
70+
default=0.0,
71+
type=float,
72+
help="The diversity rate to procedure beam search. ")
73+
parser.add_argument("--repetition_penalty",
74+
default=1.0,
75+
type=float,
76+
help="The repetition_penalty to set. ")
77+
parser.add_argument("--length_penalty",
78+
default=0.0,
79+
type=float,
80+
help="The length penalty to decode. ")
81+
parser.add_argument("--early_stopping",
82+
action="store_true",
83+
help="Whether to do early stopping. ")
84+
85+
args = parser.parse_args()
86+
return args
87+
88+
89+
def do_predict(args):
90+
place = "gpu"
91+
place = paddle.set_device(place)
92+
93+
model = MBartForConditionalGeneration.from_pretrained(
94+
args.model_name_or_path, src_lang="en_XX")
95+
tokenizer = MBartTokenizer.from_pretrained(args.model_name_or_path)
96+
97+
bos_id = tokenizer.lang_code_to_id["zh_CN"]
98+
eos_id = model.mbart.config["eos_token_id"]
99+
100+
# For opening faster_encoder
101+
model.eval()
102+
103+
faster_mbart = FasterMBART(model=model,
104+
use_fp16_decoding=args.use_fp16_decoding)
105+
# Set evaluate mode
106+
faster_mbart.eval()
107+
108+
# Convert dygraph model to static graph model
109+
faster_mbart = paddle.jit.to_static(
110+
faster_mbart,
111+
input_spec=[
112+
# input_ids
113+
paddle.static.InputSpec(shape=[None, None], dtype="int32"),
114+
# encoder_output
115+
None,
116+
# seq_len
117+
None,
118+
bos_id, # forced_bos_token_id
119+
args.num_beams, # num_beams.
120+
args.topk, # top_k
121+
args.topp, # top_p
122+
args.decoding_strategy, # decode_strategy
123+
tokenizer.bos_token_id, # bos_token_id
124+
tokenizer.eos_token_id, # eos_token_id
125+
tokenizer.pad_token_id, # pad_token_id
126+
model.mbart.
127+
config["decoder_start_token_id"], # decoder_start_token_id
128+
args.max_out_len, # max_length
129+
args.diversity_rate, # diversity_rate
130+
args.length_penalty, # length_penalty
131+
args.temperature, # temperature
132+
args.num_return_sequences, # num_return_sequences
133+
args.early_stopping, # early_stopping
134+
tokenizer.eos_token_id, #forced_eos_token_id
135+
])
136+
137+
# Save converted static graph model
138+
paddle.jit.save(faster_mbart, os.path.join(args.inference_model_dir,
139+
"mbart"))
140+
logger.info("MBART has been saved to {}.".format(args.inference_model_dir))
141+
142+
143+
if __name__ == "__main__":
144+
args = parse_args()
145+
pprint(args)
146+
147+
do_predict(args)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
import numpy as np
18+
from pprint import pprint
19+
20+
import paddle
21+
import paddle.inference as paddle_infer
22+
23+
from paddlenlp.transformers import MBartTokenizer
24+
from paddlenlp.ops.ext_utils import load
25+
26+
27+
def setup_args():
28+
"""Setup arguments."""
29+
parser = argparse.ArgumentParser()
30+
parser.add_argument("--inference_model_dir",
31+
default="./infer_model/",
32+
type=str,
33+
help="Path to save inference model of BART. ")
34+
35+
args = parser.parse_args()
36+
37+
return args
38+
39+
40+
def postprocess_response(tokenizer, seq, bos_idx, eos_idx):
41+
"""Post-process the decoded sequence."""
42+
eos_pos = len(seq) - 1
43+
for i, idx in enumerate(seq):
44+
if idx == eos_idx:
45+
eos_pos = i
46+
break
47+
seq = [
48+
idx for idx in seq[:eos_pos + 1] if idx != bos_idx and idx != eos_idx
49+
]
50+
res = tokenizer.convert_ids_to_string(seq)
51+
return res
52+
53+
54+
def infer(args):
55+
model_name = "mbart-large-50-many-to-many-mmt"
56+
tokenizer = MBartTokenizer.from_pretrained(model_name)
57+
58+
bos_id = tokenizer.lang_code_to_id["zh_CN"]
59+
eos_id = tokenizer.eos_token_id
60+
61+
inputs = "PaddleNLP is a powerful NLP library with Awesome pre-trained models and easy-to-use interface, supporting wide-range of NLP tasks from research to industrial applications."
62+
input_ids = tokenizer(inputs)["input_ids"]
63+
input_ids = np.asarray(input_ids, dtype="int32").reshape(1, -1)
64+
65+
# Load FasterTransformer lib.
66+
load("FasterTransformer", verbose=True)
67+
68+
config = paddle_infer.Config(
69+
os.path.join(args.inference_model_dir, "mbart.pdmodel"),
70+
os.path.join(args.inference_model_dir, "mbart.pdiparams"))
71+
72+
config.enable_use_gpu(100, 0)
73+
config.disable_glog_info()
74+
predictor = paddle_infer.create_predictor(config)
75+
76+
input_names = predictor.get_input_names()
77+
input_handle = predictor.get_input_handle(input_names[0])
78+
input_handle.copy_from_cpu(input_ids.astype("int32"))
79+
80+
predictor.run()
81+
82+
output_names = predictor.get_output_names()
83+
output_handle = predictor.get_output_handle(output_names[0])
84+
output_data = output_handle.copy_to_cpu()
85+
86+
result = postprocess_response(
87+
tokenizer,
88+
output_data.transpose([1, 2, 0]).tolist()[0][0], bos_id, eos_id)
89+
print("Model input:", inputs)
90+
print("Result:", result)
91+
92+
93+
if __name__ == "__main__":
94+
args = setup_args()
95+
pprint(args)
96+
97+
infer(args)

paddlenlp/ops/faster_transformer/transformer/decoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2515,7 +2515,8 @@ def __init__(self,
25152515
self.pos_emb = [model.decoder.decoder_embed_positions.weight]
25162516
self.word_emb = [model.decoder.embed_tokens.weight]
25172517

2518-
self.linear_weight = [model.lm_head_weight.t()]
2518+
setattr(self, "lm_head_weight_", model.lm_head_weight.t())
2519+
self.linear_weight = [getattr(self, "lm_head_weight_")]
25192520
self.linear_bias = [model.final_logits_bias]
25202521

25212522
def forward(self,

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,8 +1379,13 @@ def forward(self,
13791379

13801380

13811381
class FasterMBART(MBartPretrainedModel):
1382+
enable_faster_encoder_func = enable_faster_encoder
13821383

1383-
def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
1384+
def __init__(self,
1385+
model,
1386+
decoding_lib=None,
1387+
use_fp16_decoding=False,
1388+
enable_faster_encoder=False):
13841389
super(FasterMBART, self).__init__()
13851390
self.use_fp16_decoding = use_fp16_decoding
13861391
self._model = model
@@ -1393,13 +1398,18 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
13931398
self.encoder = model.mbart.get_encoder()
13941399
self.decoder = model.mbart.get_decoder()
13951400
self.pad_token_id = model.mbart.config['pad_token_id']
1401+
self.enable_faster_encoder = enable_faster_encoder
13961402

13971403
self.decoding = InferMBartDecoding(
13981404
model=self._model,
13991405
decoding_lib=decoding_lib,
14001406
use_fp16_decoding=use_fp16_decoding,
14011407
hidden_act=model.mbart.config['activation_function'])
14021408

1409+
if self.enable_faster_encoder:
1410+
# Must use `enable_faster_encoder` in `__init__` when dygraph to static graph.
1411+
self.encoder = FasterMBART.enable_faster_encoder_func(self.encoder)
1412+
14031413
def get_encoder(self):
14041414
return self.encoder
14051415

@@ -1439,11 +1449,9 @@ def forward(self,
14391449

14401450
#(gongenlei) Not enable_faster_encoder temporarily
14411451
if encoder_output is None:
1442-
self.encoder = enable_faster_encoder(self.encoder)
14431452
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
14441453
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
14451454
input_ids, model_kwargs)["encoder_output"]
1446-
self.encoder = disable_faster_encoder(self.encoder)
14471455
batch_size = paddle.shape(encoder_output)[0]
14481456
if seq_len is None:
14491457
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."

paddlenlp/transformers/mbart/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def forward(self, input_ids_shape, past_key_values_length=0):
203203
positions = paddle.arange(past_key_values_length,
204204
past_key_values_length + seq_len,
205205
dtype="int64")
206-
return super().forward(positions + self.offset)
206+
return Embedding.forward(self, positions + self.offset)
207207

208208

209209
class MBartEncoder(MBartPretrainedModel):
@@ -270,7 +270,7 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs):
270270
if input_ids is None:
271271
raise ValueError("Input_ids cannot be None.")
272272
inputs_embeds = self.d_model**0.5 * self.embed_tokens(input_ids)
273-
inputs_embed_pos = self.encoder_embed_positions(input_ids.shape)
273+
inputs_embed_pos = self.encoder_embed_positions(paddle.shape(input_ids))
274274
hidden_states = inputs_embeds + inputs_embed_pos
275275
hidden_states = self.encoder_layernorm_embedding(hidden_states)
276276
encoder_input = self.encoder_dropout(hidden_states)

0 commit comments

Comments
 (0)