Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 5a776bf

Browse files
[Feature] Allow custom dropout, number of layers/units for BERT (#950)
1 parent ab9e353 commit 5a776bf

File tree

4 files changed

+110
-29
lines changed

4 files changed

+110
-29
lines changed

scripts/bert/pretraining_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def get_model_loss(ctx, model, pretrained, dataset_name, vocab, dtype,
7070
"""
7171
# model
7272
model, vocabulary = nlp.model.get_model(model, dataset_name=dataset_name, vocab=vocab,
73-
pretrained=pretrained, ctx=ctx)
73+
pretrained=pretrained, ctx=ctx,
74+
hparam_allow_override=True)
7475

7576
if not pretrained:
7677
model.initialize(init=mx.init.Normal(0.02), ctx=ctx)

scripts/text_generation/model/gpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def _get_gpt2_model(model_name=None, dataset_name=None, vocab=None, pretrained=T
421421
-------
422422
GPT2Model, gluonnlp.vocab.Vocab
423423
"""
424-
predefined_args = gpt2_hparams[model_name]
424+
predefined_args = gpt2_hparams[model_name].copy()
425425
mutable_args = ['dropout']
426426
mutable_args = frozenset(mutable_args)
427427
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \

src/gluonnlp/model/bert.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,8 @@ def hybrid_forward(self, F, inputs, valid_length=None):
815815

816816
def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
817817
root=os.path.join(get_home_dir(), 'models'), use_pooler=True, use_decoder=True,
818-
use_classifier=True, pretrained_allow_missing=False, **kwargs):
818+
use_classifier=True, pretrained_allow_missing=False,
819+
hparam_allow_override=False, **kwargs):
819820
"""Generic BERT BASE model.
820821
821822
The number of layers (L) is 12, number of units (H) is 768, and the
@@ -873,6 +874,9 @@ def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
873874
If pretrained_allow_missing=True, this will be ignored and the
874875
parameters will be left uninitialized. Otherwise AssertionError is
875876
raised.
877+
hparam_allow_override : bool, default False
878+
If set to True, pre-defined hyper-parameters of the model
879+
(e.g. the number of layers, hidden units) can be overriden.
876880
877881
The pretrained parameters for dataset_name
878882
'openwebtext_book_corpus_wiki_en_uncased' were obtained by running the
@@ -905,13 +909,15 @@ def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
905909
return get_bert_model(model_name='bert_12_768_12', vocab=vocab, dataset_name=dataset_name,
906910
pretrained=pretrained, ctx=ctx, use_pooler=use_pooler,
907911
use_decoder=use_decoder, use_classifier=use_classifier, root=root,
908-
pretrained_allow_missing=pretrained_allow_missing, **kwargs)
912+
pretrained_allow_missing=pretrained_allow_missing,
913+
hparam_allow_override=hparam_allow_override, **kwargs)
909914

910915

911916
def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), use_pooler=True,
912917
use_decoder=True, use_classifier=True,
913918
root=os.path.join(get_home_dir(), 'models'),
914-
pretrained_allow_missing=False, **kwargs):
919+
pretrained_allow_missing=False,
920+
hparam_allow_override=False, **kwargs):
915921
"""Generic BERT LARGE model.
916922
917923
The number of layers (L) is 24, number of units (H) is 1024, and the
@@ -952,6 +958,9 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
952958
If pretrained_allow_missing=True, this will be ignored and the
953959
parameters will be left uninitialized. Otherwise AssertionError is
954960
raised.
961+
hparam_allow_override : bool, default False
962+
If set to True, pre-defined hyper-parameters of the model
963+
(e.g. the number of layers, hidden units) can be overriden.
955964
956965
Returns
957966
-------
@@ -960,12 +969,14 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
960969
return get_bert_model(model_name='bert_24_1024_16', vocab=vocab, dataset_name=dataset_name,
961970
pretrained=pretrained, ctx=ctx, use_pooler=use_pooler,
962971
use_decoder=use_decoder, use_classifier=use_classifier, root=root,
963-
pretrained_allow_missing=pretrained_allow_missing, **kwargs)
972+
pretrained_allow_missing=pretrained_allow_missing,
973+
hparam_allow_override=hparam_allow_override, **kwargs)
964974

965975

966976
def roberta_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
967977
use_decoder=True,
968-
root=os.path.join(get_home_dir(), 'models'), **kwargs):
978+
root=os.path.join(get_home_dir(), 'models'),
979+
hparam_allow_override=False, **kwargs):
969980
"""Generic RoBERTa BASE model.
970981
971982
The number of layers (L) is 12, number of units (H) is 768, and the
@@ -990,19 +1001,24 @@ def roberta_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu
9901001
MXNET_HOME defaults to '~/.mxnet'.
9911002
use_decoder : bool, default True
9921003
Whether to include the decoder for masked language model prediction.
1004+
hparam_allow_override : bool, default False
1005+
If set to True, pre-defined hyper-parameters of the model
1006+
(e.g. the number of layers, hidden units) can be overriden.
9931007
9941008
Returns
9951009
-------
9961010
RoBERTaModel, gluonnlp.vocab.Vocab
9971011
"""
9981012
return get_roberta_model(model_name='roberta_12_768_12', vocab=vocab, dataset_name=dataset_name,
9991013
pretrained=pretrained, ctx=ctx,
1000-
use_decoder=use_decoder, root=root, **kwargs)
1014+
use_decoder=use_decoder, root=root,
1015+
hparam_allow_override=hparam_allow_override, **kwargs)
10011016

10021017

10031018
def roberta_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
10041019
use_decoder=True,
1005-
root=os.path.join(get_home_dir(), 'models'), **kwargs):
1020+
root=os.path.join(get_home_dir(), 'models'),
1021+
hparam_allow_override=False, **kwargs):
10061022
"""Generic RoBERTa LARGE model.
10071023
10081024
The number of layers (L) is 24, number of units (H) is 1024, and the
@@ -1027,19 +1043,22 @@ def roberta_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cp
10271043
MXNET_HOME defaults to '~/.mxnet'.
10281044
use_decoder : bool, default True
10291045
Whether to include the decoder for masked language model prediction.
1046+
hparam_allow_override : bool, default False
1047+
If set to True, pre-defined hyper-parameters of the model
1048+
(e.g. the number of layers, hidden units) can be overriden.
10301049
10311050
Returns
10321051
-------
10331052
RoBERTaModel, gluonnlp.vocab.Vocab
10341053
"""
10351054
return get_roberta_model(model_name='roberta_24_1024_16', vocab=vocab,
10361055
dataset_name=dataset_name, pretrained=pretrained, ctx=ctx,
1037-
use_decoder=use_decoder,
1038-
root=root, **kwargs)
1056+
use_decoder=use_decoder, root=root,
1057+
hparam_allow_override=hparam_allow_override, **kwargs)
10391058

10401059
def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
10411060
root=os.path.join(get_home_dir(), 'models'), use_pooler=True, use_decoder=True,
1042-
use_classifier=True, **kwargs):
1061+
use_classifier=True, hparam_allow_override=False, **kwargs):
10431062
"""Baidu ERNIE model.
10441063
10451064
Reference:
@@ -1073,6 +1092,9 @@ def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
10731092
Whether to include the decoder for masked language model prediction.
10741093
use_classifier : bool, default True
10751094
Whether to include the classifier for next sentence classification.
1095+
hparam_allow_override : bool, default False
1096+
If set to True, pre-defined hyper-parameters of the model
1097+
(e.g. the number of layers, hidden units) can be overriden.
10761098
10771099
Returns
10781100
-------
@@ -1081,13 +1103,14 @@ def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
10811103
return get_bert_model(model_name='ernie_12_768_12', vocab=vocab, dataset_name=dataset_name,
10821104
pretrained=pretrained, ctx=ctx, use_pooler=use_pooler,
10831105
use_decoder=use_decoder, use_classifier=use_classifier, root=root,
1084-
pretrained_allow_missing=False, **kwargs)
1106+
pretrained_allow_missing=False,
1107+
hparam_allow_override=hparam_allow_override, **kwargs)
10851108

10861109

10871110
def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
1088-
use_decoder=True, output_attention=False,
1089-
output_all_encodings=False, root=os.path.join(get_home_dir(), 'models'),
1090-
**kwargs):
1111+
use_decoder=True, output_attention=False, output_all_encodings=False,
1112+
root=os.path.join(get_home_dir(), 'models'), ignore_extra=False,
1113+
hparam_allow_override=False, **kwargs):
10911114
"""Any RoBERTa pretrained model.
10921115
10931116
Parameters
@@ -1121,17 +1144,25 @@ def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained
11211144
Whether to include attention weights of each encoding cell to the output.
11221145
output_all_encodings : bool, default False
11231146
Whether to output encodings of all encoder cells.
1147+
ignore_extra : bool, default False
1148+
Whether to silently ignore parameters from the file that are not
1149+
present in this Block.
1150+
hparam_allow_override : bool, default False
1151+
If set to True, pre-defined hyper-parameters of the model
1152+
(e.g. the number of layers, hidden units) can be overriden.
11241153
11251154
Returns
11261155
-------
11271156
RoBERTaModel, gluonnlp.vocab.Vocab
11281157
"""
1129-
predefined_args = bert_hparams[model_name]
1130-
mutable_args = ['use_residual', 'dropout', 'word_embed']
1131-
mutable_args = frozenset(mutable_args)
1132-
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
1133-
'Cannot override predefined model settings.'
1158+
predefined_args = bert_hparams[model_name].copy()
1159+
if not hparam_allow_override:
1160+
mutable_args = ['use_residual', 'dropout', 'word_embed']
1161+
mutable_args = frozenset(mutable_args)
1162+
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
1163+
'Cannot override predefined model settings.'
11341164
predefined_args.update(kwargs)
1165+
11351166
# encoder
11361167
encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'],
11371168
num_layers=predefined_args['num_layers'],
@@ -1156,7 +1187,7 @@ def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained
11561187
word_embed=predefined_args['word_embed'],
11571188
use_decoder=use_decoder)
11581189
if pretrained:
1159-
ignore_extra = not use_decoder
1190+
ignore_extra = ignore_extra or not use_decoder
11601191
_load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=ignore_extra,
11611192
allow_missing=False)
11621193
return net, bert_vocab
@@ -1165,7 +1196,8 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
11651196
use_pooler=True, use_decoder=True, use_classifier=True, output_attention=False,
11661197
output_all_encodings=False, use_token_type_embed=True,
11671198
root=os.path.join(get_home_dir(), 'models'),
1168-
pretrained_allow_missing=False, **kwargs):
1199+
pretrained_allow_missing=False, ignore_extra=False,
1200+
hparam_allow_override=False, **kwargs):
11691201
"""Any BERT pretrained model.
11701202
11711203
Parameters
@@ -1228,16 +1260,23 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
12281260
If pretrained_allow_missing=True, this will be ignored and the
12291261
parameters will be left uninitialized. Otherwise AssertionError is
12301262
raised.
1263+
ignore_extra : bool, default False
1264+
Whether to silently ignore parameters from the file that are not
1265+
present in this Block.
1266+
hparam_allow_override : bool, default False
1267+
If set to True, pre-defined hyper-parameters of the model
1268+
(e.g. the number of layers, hidden units) can be overriden.
12311269
12321270
Returns
12331271
-------
12341272
BERTModel, gluonnlp.vocab.BERTVocab
12351273
"""
1236-
predefined_args = bert_hparams[model_name]
1237-
mutable_args = ['use_residual', 'dropout', 'word_embed']
1238-
mutable_args = frozenset(mutable_args)
1239-
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
1240-
'Cannot override predefined model settings.'
1274+
predefined_args = bert_hparams[model_name].copy()
1275+
if not hparam_allow_override:
1276+
mutable_args = ['use_residual', 'dropout', 'word_embed']
1277+
mutable_args = frozenset(mutable_args)
1278+
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
1279+
'Cannot override predefined model settings.'
12411280
predefined_args.update(kwargs)
12421281
# encoder
12431282
encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'],
@@ -1267,7 +1306,7 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
12671306
use_classifier=use_classifier,
12681307
use_token_type_embed=use_token_type_embed)
12691308
if pretrained:
1270-
ignore_extra = not (use_pooler and use_decoder and use_classifier)
1309+
ignore_extra = ignore_extra or not (use_pooler and use_decoder and use_classifier)
12711310
_load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=ignore_extra,
12721311
allow_missing=pretrained_allow_missing)
12731312
return net, bert_vocab

tests/unittest/test_models.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,47 @@ def test_pretrained_bert_models(disable_missing_parameters):
223223
del model
224224
mx.nd.waitall()
225225

226+
@pytest.mark.serial
227+
@pytest.mark.remote_required
228+
@pytest.mark.parametrize('hparam_allow_override', [False, True])
229+
def test_pretrained_bert_models_override(hparam_allow_override):
230+
models = ['bert_12_768_12', 'bert_24_1024_16',
231+
'roberta_12_768_12', 'roberta_24_1024_16']
232+
pretrained = {
233+
'bert_12_768_12': ['book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased'],
234+
'bert_24_1024_16': ['book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased'],
235+
'roberta_12_768_12': ['openwebtext_ccnews_stories_books_cased'],
236+
'roberta_24_1024_16': ['openwebtext_ccnews_stories_books_cased']
237+
}
238+
ones = mx.nd.ones((2, 10))
239+
valid_length = mx.nd.ones((2,))
240+
positions = mx.nd.zeros((2, 3))
241+
for model_name in models:
242+
pretrained_datasets = pretrained.get(model_name)
243+
for dataset in pretrained_datasets:
244+
eprint('testing forward for %s on %s' % (model_name, dataset))
245+
246+
if hparam_allow_override:
247+
model, vocab = nlp.model.get_model(model_name, dataset_name=dataset,
248+
pretrained=True,
249+
root='tests/data/model/',
250+
hparam_allow_override=hparam_allow_override,
251+
ignore_extra=True,
252+
num_layers=6)
253+
else:
254+
with pytest.raises(AssertionError):
255+
model, vocab = nlp.model.get_model(model_name, dataset_name=dataset,
256+
pretrained=True,
257+
root='tests/data/model/',
258+
num_layers=6)
259+
continue
260+
if 'roberta' in model_name:
261+
output = model(ones, valid_length, positions)
262+
else:
263+
output = model(ones, ones, valid_length, positions)
264+
output[0].wait_to_read()
265+
del model
266+
mx.nd.waitall()
226267

227268
@pytest.mark.serial
228269
@pytest.mark.remote_required

0 commit comments

Comments
 (0)