Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/bert/pretraining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def get_model_loss(ctx, model, pretrained, dataset_name, vocab, dtype,
"""
# model
model, vocabulary = nlp.model.get_model(model, dataset_name=dataset_name, vocab=vocab,
pretrained=pretrained, ctx=ctx)
pretrained=pretrained, ctx=ctx,
hparam_allow_override=True)

if not pretrained:
model.initialize(init=mx.init.Normal(0.02), ctx=ctx)
Expand Down
2 changes: 1 addition & 1 deletion scripts/text_generation/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def _get_gpt2_model(model_name=None, dataset_name=None, vocab=None, pretrained=T
-------
GPT2Model, gluonnlp.vocab.Vocab
"""
predefined_args = gpt2_hparams[model_name]
predefined_args = gpt2_hparams[model_name].copy()
mutable_args = ['dropout']
mutable_args = frozenset(mutable_args)
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
Expand Down
93 changes: 66 additions & 27 deletions src/gluonnlp/model/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,8 @@ def hybrid_forward(self, F, inputs, valid_length=None):

def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
root=os.path.join(get_home_dir(), 'models'), use_pooler=True, use_decoder=True,
use_classifier=True, pretrained_allow_missing=False, **kwargs):
use_classifier=True, pretrained_allow_missing=False,
hparam_allow_override=False, **kwargs):
"""Generic BERT BASE model.

The number of layers (L) is 12, number of units (H) is 768, and the
Expand Down Expand Up @@ -1062,6 +1063,9 @@ def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
If pretrained_allow_missing=True, this will be ignored and the
parameters will be left uninitialized. Otherwise AssertionError is
raised.
hparam_allow_override : bool, default False
If set to True, pre-defined hyper-parameters of the model
(e.g. the number of layers, hidden units) can be overriden.

The pretrained parameters for dataset_name
'openwebtext_book_corpus_wiki_en_uncased' were obtained by running the
Expand Down Expand Up @@ -1094,13 +1098,15 @@ def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
return get_bert_model(model_name='bert_12_768_12', vocab=vocab, dataset_name=dataset_name,
pretrained=pretrained, ctx=ctx, use_pooler=use_pooler,
use_decoder=use_decoder, use_classifier=use_classifier, root=root,
pretrained_allow_missing=pretrained_allow_missing, **kwargs)
pretrained_allow_missing=pretrained_allow_missing,
hparam_allow_override=hparam_allow_override, **kwargs)


def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), use_pooler=True,
use_decoder=True, use_classifier=True,
root=os.path.join(get_home_dir(), 'models'),
pretrained_allow_missing=False, **kwargs):
pretrained_allow_missing=False,
hparam_allow_override=False, **kwargs):
"""Generic BERT LARGE model.

The number of layers (L) is 24, number of units (H) is 1024, and the
Expand Down Expand Up @@ -1141,6 +1147,9 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
If pretrained_allow_missing=True, this will be ignored and the
parameters will be left uninitialized. Otherwise AssertionError is
raised.
hparam_allow_override : bool, default False
If set to True, pre-defined hyper-parameters of the model
(e.g. the number of layers, hidden units) can be overriden.

Returns
-------
Expand All @@ -1149,12 +1158,14 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
return get_bert_model(model_name='bert_24_1024_16', vocab=vocab, dataset_name=dataset_name,
pretrained=pretrained, ctx=ctx, use_pooler=use_pooler,
use_decoder=use_decoder, use_classifier=use_classifier, root=root,
pretrained_allow_missing=pretrained_allow_missing, **kwargs)
pretrained_allow_missing=pretrained_allow_missing,
hparam_allow_override=hparam_allow_override, **kwargs)


def roberta_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
use_decoder=True,
root=os.path.join(get_home_dir(), 'models'), **kwargs):
root=os.path.join(get_home_dir(), 'models'),
hparam_allow_override=False, **kwargs):
"""Generic RoBERTa BASE model.

The number of layers (L) is 12, number of units (H) is 768, and the
Expand All @@ -1179,19 +1190,24 @@ def roberta_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu
MXNET_HOME defaults to '~/.mxnet'.
use_decoder : bool, default True
Whether to include the decoder for masked language model prediction.
hparam_allow_override : bool, default False
If set to True, pre-defined hyper-parameters of the model
(e.g. the number of layers, hidden units) can be overriden.

Returns
-------
RoBERTaModel, gluonnlp.vocab.Vocab
"""
return get_roberta_model(model_name='roberta_12_768_12', vocab=vocab, dataset_name=dataset_name,
pretrained=pretrained, ctx=ctx,
use_decoder=use_decoder, root=root, **kwargs)
use_decoder=use_decoder, root=root,
hparam_allow_override=hparam_allow_override, **kwargs)


def roberta_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
use_decoder=True,
root=os.path.join(get_home_dir(), 'models'), **kwargs):
root=os.path.join(get_home_dir(), 'models'),
hparam_allow_override=False, **kwargs):
"""Generic RoBERTa LARGE model.

The number of layers (L) is 24, number of units (H) is 1024, and the
Expand All @@ -1216,19 +1232,22 @@ def roberta_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cp
MXNET_HOME defaults to '~/.mxnet'.
use_decoder : bool, default True
Whether to include the decoder for masked language model prediction.
hparam_allow_override : bool, default False
If set to True, pre-defined hyper-parameters of the model
(e.g. the number of layers, hidden units) can be overriden.

Returns
-------
RoBERTaModel, gluonnlp.vocab.Vocab
"""
return get_roberta_model(model_name='roberta_24_1024_16', vocab=vocab,
dataset_name=dataset_name, pretrained=pretrained, ctx=ctx,
use_decoder=use_decoder,
root=root, **kwargs)
use_decoder=use_decoder, root=root,
hparam_allow_override=hparam_allow_override, **kwargs)

def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
root=os.path.join(get_home_dir(), 'models'), use_pooler=True, use_decoder=True,
use_classifier=True, **kwargs):
use_classifier=True, hparam_allow_override=False, **kwargs):
"""Baidu ERNIE model.

Reference:
Expand Down Expand Up @@ -1262,6 +1281,9 @@ def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
Whether to include the decoder for masked language model prediction.
use_classifier : bool, default True
Whether to include the classifier for next sentence classification.
hparam_allow_override : bool, default False
If set to True, pre-defined hyper-parameters of the model
(e.g. the number of layers, hidden units) can be overriden.

Returns
-------
Expand All @@ -1270,13 +1292,14 @@ def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
return get_bert_model(model_name='ernie_12_768_12', vocab=vocab, dataset_name=dataset_name,
pretrained=pretrained, ctx=ctx, use_pooler=use_pooler,
use_decoder=use_decoder, use_classifier=use_classifier, root=root,
pretrained_allow_missing=False, **kwargs)
pretrained_allow_missing=False,
hparam_allow_override=hparam_allow_override, **kwargs)


def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
use_decoder=True, output_attention=False,
output_all_encodings=False, root=os.path.join(get_home_dir(), 'models'),
**kwargs):
use_decoder=True, output_attention=False, output_all_encodings=False,
root=os.path.join(get_home_dir(), 'models'), ignore_extra=False,
hparam_allow_override=False, **kwargs):
"""Any RoBERTa pretrained model.

Parameters
Expand Down Expand Up @@ -1310,17 +1333,25 @@ def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained
Whether to include attention weights of each encoding cell to the output.
output_all_encodings : bool, default False
Whether to output encodings of all encoder cells.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
hparam_allow_override : bool, default False
If set to True, pre-defined hyper-parameters of the model
(e.g. the number of layers, hidden units) can be overriden.

Returns
-------
RoBERTaModel, gluonnlp.vocab.Vocab
"""
predefined_args = bert_hparams[model_name]
mutable_args = ['use_residual', 'dropout', 'word_embed']
mutable_args = frozenset(mutable_args)
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
'Cannot override predefined model settings.'
predefined_args = bert_hparams[model_name].copy()
if not hparam_allow_override:
mutable_args = ['use_residual', 'dropout', 'word_embed']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not changed in this PR, but why is embed_size not part of this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied some from the transformer get model function. Actually, i'm not sure whether we want to have a whitelist of mutable args. It make the fucntion harder to maintain, since we have hparam_allow_override flag.

Copy link
Contributor

@leezu leezu Jan 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to remove the whitelist then? In this PR or a separate one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer a separate one

mutable_args = frozenset(mutable_args)
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
'Cannot override predefined model settings.'
predefined_args.update(kwargs)

# encoder
encoder = BERTEncoder(num_layers=predefined_args['num_layers'],
units=predefined_args['units'],
Expand All @@ -1342,7 +1373,7 @@ def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained
word_embed=predefined_args['word_embed'],
use_decoder=use_decoder)
if pretrained:
ignore_extra = not use_decoder
ignore_extra = ignore_extra or not use_decoder
_load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=ignore_extra,
allow_missing=False)
return net, bert_vocab
Expand All @@ -1351,7 +1382,8 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
use_pooler=True, use_decoder=True, use_classifier=True, output_attention=False,
output_all_encodings=False, use_token_type_embed=True,
root=os.path.join(get_home_dir(), 'models'),
pretrained_allow_missing=False, **kwargs):
pretrained_allow_missing=False, ignore_extra=False,
hparam_allow_override=False, **kwargs):
"""Any BERT pretrained model.

Parameters
Expand Down Expand Up @@ -1414,16 +1446,23 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
If pretrained_allow_missing=True, this will be ignored and the
parameters will be left uninitialized. Otherwise AssertionError is
raised.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
hparam_allow_override : bool, default False
If set to True, pre-defined hyper-parameters of the model
(e.g. the number of layers, hidden units) can be overriden.

Returns
-------
BERTModel, gluonnlp.vocab.BERTVocab
"""
predefined_args = bert_hparams[model_name]
mutable_args = ['use_residual', 'dropout', 'word_embed']
mutable_args = frozenset(mutable_args)
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
'Cannot override predefined model settings.'
predefined_args = bert_hparams[model_name].copy()
if not hparam_allow_override:
mutable_args = ['use_residual', 'dropout', 'word_embed']
mutable_args = frozenset(mutable_args)
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
'Cannot override predefined model settings.'
predefined_args.update(kwargs)
# encoder
encoder = BERTEncoder(num_layers=predefined_args['num_layers'],
Expand All @@ -1450,7 +1489,7 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
use_classifier=use_classifier,
use_token_type_embed=use_token_type_embed)
if pretrained:
ignore_extra = not (use_pooler and use_decoder and use_classifier)
ignore_extra = ignore_extra or not (use_pooler and use_decoder and use_classifier)
_load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=ignore_extra,
allow_missing=pretrained_allow_missing)
return net, bert_vocab
41 changes: 41 additions & 0 deletions tests/unittest/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,47 @@ def test_pretrained_bert_models(disable_missing_parameters):
del model
mx.nd.waitall()

@pytest.mark.serial
@pytest.mark.remote_required
@pytest.mark.parametrize('hparam_allow_override', [False, True])
def test_pretrained_bert_models_override(hparam_allow_override):
models = ['bert_12_768_12', 'bert_24_1024_16',
'roberta_12_768_12', 'roberta_24_1024_16']
pretrained = {
'bert_12_768_12': ['book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased'],
'bert_24_1024_16': ['book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased'],
'roberta_12_768_12': ['openwebtext_ccnews_stories_books_cased'],
'roberta_24_1024_16': ['openwebtext_ccnews_stories_books_cased']
}
ones = mx.nd.ones((2, 10))
valid_length = mx.nd.ones((2,))
positions = mx.nd.zeros((2, 3))
for model_name in models:
pretrained_datasets = pretrained.get(model_name)
for dataset in pretrained_datasets:
eprint('testing forward for %s on %s' % (model_name, dataset))

if hparam_allow_override:
model, vocab = nlp.model.get_model(model_name, dataset_name=dataset,
pretrained=True,
root='tests/data/model/',
hparam_allow_override=hparam_allow_override,
ignore_extra=True,
num_layers=6)
else:
with pytest.raises(AssertionError):
model, vocab = nlp.model.get_model(model_name, dataset_name=dataset,
pretrained=True,
root='tests/data/model/',
num_layers=6)
continue
if 'roberta' in model_name:
output = model(ones, valid_length, positions)
else:
output = model(ones, ones, valid_length, positions)
output[0].wait_to_read()
del model
mx.nd.waitall()

@pytest.mark.serial
@pytest.mark.remote_required
Expand Down