1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15- import os
16-
1715import argparse
18- import numpy as np
1916from functools import partial
2017
2118import paddle
22- from paddle import inference
23- from paddlenlp .data import Stack , Tuple , Pad , Vocab
24- from paddlenlp .transformers import ErnieTokenizer
25-
2619from utils import convert_example , parse_decode
2720
28- # yapf: disable
21+ from paddlenlp .data import Pad , Stack , Tuple , Vocab
22+ from paddlenlp .transformers import ErnieTokenizer
23+
2924parser = argparse .ArgumentParser (__doc__ )
30- parser .add_argument ("--model_file" , type = str , required = True , default = './static_graph_params.pdmodel' , help = "The path to model info in static graph." )
31- parser .add_argument ("--params_file" , type = str , required = True , default = './static_graph_params.pdiparams' , help = "The path to parameters in static graph." )
25+ parser .add_argument (
26+ "--model_file" ,
27+ type = str ,
28+ required = True ,
29+ default = "./static_graph_params.pdmodel" ,
30+ help = "The path to model info in static graph." ,
31+ )
32+ parser .add_argument (
33+ "--params_file" ,
34+ type = str ,
35+ required = True ,
36+ default = "./static_graph_params.pdiparams" ,
37+ help = "The path to parameters in static graph." ,
38+ )
3239parser .add_argument ("--batch_size" , type = int , default = 2 , help = "The number of sequences contained in a mini-batch." )
3340parser .add_argument ("--max_seq_len" , type = int , default = 64 , help = "Number of words of the longest seqence." )
34- parser .add_argument ("--device" , default = "gpu" , type = str , choices = ["cpu" , "gpu" ] ,help = "The device to select to train the model, is must be cpu/gpu." )
41+ parser .add_argument (
42+ "--device" ,
43+ default = "gpu" ,
44+ type = str ,
45+ choices = ["cpu" , "gpu" ],
46+ help = "The device to select to train the model, is must be cpu/gpu." ,
47+ )
3548parser .add_argument ("--pinyin_vocab_file_path" , type = str , default = "pinyin_vocab.txt" , help = "pinyin vocab file path" )
3649
3750args = parser .parse_args ()
38- # yapf: enable
3951
4052
4153class Predictor (object ):
@@ -51,6 +63,7 @@ def __init__(self, model_file, params_file, device, max_seq_length, tokenizer, p
5163 # such as enable_mkldnn, set_cpu_math_library_num_threads
5264 config .disable_gpu ()
5365 config .switch_use_feed_fetch_ops (False )
66+ config .delete_pass ("fused_multi_transformer_encoder_pass" )
5467 self .predictor = paddle .inference .create_predictor (config )
5568
5669 self .input_handles = [self .predictor .get_input_handle (name ) for name in self .predictor .get_input_names ()]
0 commit comments