Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 5a4eae3

Browse files
eric-haibin-linleezu
authored andcommitted
[BUGFIX] Fix bugs in BERT export script (#944)
* Fix export script * add doc * Fix lint * Fix lint
1 parent 66609ec commit 5a4eae3

File tree

2 files changed

+23
-17
lines changed

2 files changed

+23
-17
lines changed

scripts/bert/export.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@
3939

4040
import mxnet as mx
4141
import gluonnlp as nlp
42-
from gluonnlp.model import get_model
43-
from model.classification import BERTClassifier, BERTRegression
42+
from gluonnlp.model import get_model, BERTClassifier
4443
from model.qa import BertForQA
4544

4645
nlp.utils.check_version('0.8.1')
@@ -84,7 +83,7 @@
8483

8584
parser.add_argument('--seq_length',
8685
type=int,
87-
default=384,
86+
default=64,
8887
help='The maximum total input sequence length after WordPiece tokenization.'
8988
'Sequences longer than this needs to be truncated, and sequences shorter '
9089
'than this needs to be padded. Default is 384')
@@ -131,8 +130,7 @@
131130
pretrained=False,
132131
use_pooler=True,
133132
use_decoder=False,
134-
use_classifier=False,
135-
seq_length=args.seq_length)
133+
use_classifier=False)
136134
net = BERTClassifier(bert, num_classes=2, dropout=args.dropout)
137135
elif args.task == 'regression':
138136
bert, _ = get_model(
@@ -141,18 +139,16 @@
141139
pretrained=False,
142140
use_pooler=True,
143141
use_decoder=False,
144-
use_classifier=False,
145-
seq_length=args.seq_length)
146-
net = BERTRegression(bert, dropout=args.dropout)
142+
use_classifier=False)
143+
net = BERTClassifier(bert, num_classes=1, dropout=args.dropout)
147144
elif args.task == 'question_answering':
148145
bert, _ = get_model(
149146
name=args.model_name,
150147
dataset_name=args.dataset_name,
151148
pretrained=False,
152149
use_pooler=False,
153150
use_decoder=False,
154-
use_classifier=False,
155-
seq_length=args.seq_length)
151+
use_classifier=False)
156152
net = BertForQA(bert)
157153
else:
158154
raise ValueError('unknown task: %s'%args.task)
@@ -187,24 +183,34 @@ def export(batch, prefix):
187183
assert os.path.isfile(prefix + '-symbol.json')
188184
assert os.path.isfile(prefix + '-0000.params')
189185

190-
def infer(batch, prefix):
186+
def infer(prefix):
191187
"""Evaluate the model on a mini-batch."""
192188
log.info('Start inference ... ')
193189

194190
# import with SymbolBlock. Alternatively, you can use Module.load APIs.
195191
imported_net = mx.gluon.nn.SymbolBlock.imports(prefix + '-symbol.json',
196192
['data0', 'data1', 'data2'],
197193
prefix + '-0000.params')
198-
tic = time.time()
194+
195+
# exported model should be length-agnostic. Using a different seq_length should work
196+
inputs = mx.nd.arange(test_batch_size * (seq_length + 10))
197+
inputs = inputs.reshape(shape=(test_batch_size, seq_length + 10))
198+
token_types = mx.nd.zeros_like(inputs)
199+
valid_length = mx.nd.arange(test_batch_size)
200+
199201
# run forward inference
200-
inputs, token_types, valid_length = batch
202+
imported_net(inputs, token_types, valid_length)
203+
mx.nd.waitall()
204+
205+
# benchmark speed after warmup
206+
tic = time.time()
201207
num_trials = 10
202208
for _ in range(num_trials):
203209
imported_net(inputs, token_types, valid_length)
204210
mx.nd.waitall()
205211
toc = time.time()
206-
log.info('Inference time cost={:.2f} s, Thoughput={:.2f} samples/s'
207-
.format(toc - tic, num_trials / (toc - tic)))
212+
log.info('Batch size={}, Thoughput={:.2f} batches/s'
213+
.format(test_batch_size, num_trials / (toc - tic)))
208214

209215

210216
###############################################################################
@@ -213,4 +219,4 @@ def infer(batch, prefix):
213219
if __name__ == '__main__':
214220
prefix = os.path.join(args.output_dir, args.task)
215221
export(batch, prefix)
216-
infer(batch, prefix)
222+
infer(prefix)

scripts/tests/test_scripts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def test_finetune_train(early_stop, bert_model, dataset, dtype):
317317
@pytest.mark.integration
318318
@pytest.mark.parametrize('task', ['classification', 'regression', 'question_answering'])
319319
def test_export(task):
320-
process = subprocess.check_call([sys.executable, './scripts/bert/export/export.py',
320+
process = subprocess.check_call([sys.executable, './scripts/bert/export.py',
321321
'--task', task])
322322

323323
@pytest.mark.serial

0 commit comments

Comments
 (0)