39
39
40
40
import mxnet as mx
41
41
import gluonnlp as nlp
42
- from gluonnlp .model import get_model
43
- from model .classification import BERTClassifier , BERTRegression
42
+ from gluonnlp .model import get_model , BERTClassifier
44
43
from model .qa import BertForQA
45
44
46
45
nlp .utils .check_version ('0.8.1' )
84
83
85
84
parser .add_argument ('--seq_length' ,
86
85
type = int ,
87
- default = 384 ,
86
+ default = 64 ,
88
87
help = 'The maximum total input sequence length after WordPiece tokenization.'
89
88
'Sequences longer than this needs to be truncated, and sequences shorter '
90
89
'than this needs to be padded. Default is 384' )
131
130
pretrained = False ,
132
131
use_pooler = True ,
133
132
use_decoder = False ,
134
- use_classifier = False ,
135
- seq_length = args .seq_length )
133
+ use_classifier = False )
136
134
net = BERTClassifier (bert , num_classes = 2 , dropout = args .dropout )
137
135
elif args .task == 'regression' :
138
136
bert , _ = get_model (
141
139
pretrained = False ,
142
140
use_pooler = True ,
143
141
use_decoder = False ,
144
- use_classifier = False ,
145
- seq_length = args .seq_length )
146
- net = BERTRegression (bert , dropout = args .dropout )
142
+ use_classifier = False )
143
+ net = BERTClassifier (bert , num_classes = 1 , dropout = args .dropout )
147
144
elif args .task == 'question_answering' :
148
145
bert , _ = get_model (
149
146
name = args .model_name ,
150
147
dataset_name = args .dataset_name ,
151
148
pretrained = False ,
152
149
use_pooler = False ,
153
150
use_decoder = False ,
154
- use_classifier = False ,
155
- seq_length = args .seq_length )
151
+ use_classifier = False )
156
152
net = BertForQA (bert )
157
153
else :
158
154
raise ValueError ('unknown task: %s' % args .task )
@@ -187,24 +183,34 @@ def export(batch, prefix):
187
183
assert os .path .isfile (prefix + '-symbol.json' )
188
184
assert os .path .isfile (prefix + '-0000.params' )
189
185
190
- def infer (batch , prefix ):
186
+ def infer (prefix ):
191
187
"""Evaluate the model on a mini-batch."""
192
188
log .info ('Start inference ... ' )
193
189
194
190
# import with SymbolBlock. Alternatively, you can use Module.load APIs.
195
191
imported_net = mx .gluon .nn .SymbolBlock .imports (prefix + '-symbol.json' ,
196
192
['data0' , 'data1' , 'data2' ],
197
193
prefix + '-0000.params' )
198
- tic = time .time ()
194
+
195
+ # exported model should be length-agnostic. Using a different seq_length should work
196
+ inputs = mx .nd .arange (test_batch_size * (seq_length + 10 ))
197
+ inputs = inputs .reshape (shape = (test_batch_size , seq_length + 10 ))
198
+ token_types = mx .nd .zeros_like (inputs )
199
+ valid_length = mx .nd .arange (test_batch_size )
200
+
199
201
# run forward inference
200
- inputs , token_types , valid_length = batch
202
+ imported_net (inputs , token_types , valid_length )
203
+ mx .nd .waitall ()
204
+
205
+ # benchmark speed after warmup
206
+ tic = time .time ()
201
207
num_trials = 10
202
208
for _ in range (num_trials ):
203
209
imported_net (inputs , token_types , valid_length )
204
210
mx .nd .waitall ()
205
211
toc = time .time ()
206
- log .info ('Inference time cost={:.2f} s , Thoughput={:.2f} samples /s'
207
- .format (toc - tic , num_trials / (toc - tic )))
212
+ log .info ('Batch size={} , Thoughput={:.2f} batches /s'
213
+ .format (test_batch_size , num_trials / (toc - tic )))
208
214
209
215
210
216
###############################################################################
@@ -213,4 +219,4 @@ def infer(batch, prefix):
213
219
if __name__ == '__main__' :
214
220
prefix = os .path .join (args .output_dir , args .task )
215
221
export (batch , prefix )
216
- infer (batch , prefix )
222
+ infer (prefix )
0 commit comments