@@ -815,7 +815,8 @@ def hybrid_forward(self, F, inputs, valid_length=None):
815
815
816
816
def bert_12_768_12 (dataset_name = None , vocab = None , pretrained = True , ctx = mx .cpu (),
817
817
root = os .path .join (get_home_dir (), 'models' ), use_pooler = True , use_decoder = True ,
818
- use_classifier = True , pretrained_allow_missing = False , ** kwargs ):
818
+ use_classifier = True , pretrained_allow_missing = False ,
819
+ hparam_allow_override = False , ** kwargs ):
819
820
"""Generic BERT BASE model.
820
821
821
822
The number of layers (L) is 12, number of units (H) is 768, and the
@@ -873,6 +874,9 @@ def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
873
874
If pretrained_allow_missing=True, this will be ignored and the
874
875
parameters will be left uninitialized. Otherwise AssertionError is
875
876
raised.
877
+ hparam_allow_override : bool, default False
878
+ If set to True, pre-defined hyper-parameters of the model
879
+ (e.g. the number of layers, hidden units) can be overriden.
876
880
877
881
The pretrained parameters for dataset_name
878
882
'openwebtext_book_corpus_wiki_en_uncased' were obtained by running the
@@ -905,13 +909,15 @@ def bert_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(),
905
909
return get_bert_model (model_name = 'bert_12_768_12' , vocab = vocab , dataset_name = dataset_name ,
906
910
pretrained = pretrained , ctx = ctx , use_pooler = use_pooler ,
907
911
use_decoder = use_decoder , use_classifier = use_classifier , root = root ,
908
- pretrained_allow_missing = pretrained_allow_missing , ** kwargs )
912
+ pretrained_allow_missing = pretrained_allow_missing ,
913
+ hparam_allow_override = hparam_allow_override , ** kwargs )
909
914
910
915
911
916
def bert_24_1024_16 (dataset_name = None , vocab = None , pretrained = True , ctx = mx .cpu (), use_pooler = True ,
912
917
use_decoder = True , use_classifier = True ,
913
918
root = os .path .join (get_home_dir (), 'models' ),
914
- pretrained_allow_missing = False , ** kwargs ):
919
+ pretrained_allow_missing = False ,
920
+ hparam_allow_override = False , ** kwargs ):
915
921
"""Generic BERT LARGE model.
916
922
917
923
The number of layers (L) is 24, number of units (H) is 1024, and the
@@ -952,6 +958,9 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
952
958
If pretrained_allow_missing=True, this will be ignored and the
953
959
parameters will be left uninitialized. Otherwise AssertionError is
954
960
raised.
961
+ hparam_allow_override : bool, default False
962
+ If set to True, pre-defined hyper-parameters of the model
963
+ (e.g. the number of layers, hidden units) can be overriden.
955
964
956
965
Returns
957
966
-------
@@ -960,12 +969,14 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
960
969
return get_bert_model (model_name = 'bert_24_1024_16' , vocab = vocab , dataset_name = dataset_name ,
961
970
pretrained = pretrained , ctx = ctx , use_pooler = use_pooler ,
962
971
use_decoder = use_decoder , use_classifier = use_classifier , root = root ,
963
- pretrained_allow_missing = pretrained_allow_missing , ** kwargs )
972
+ pretrained_allow_missing = pretrained_allow_missing ,
973
+ hparam_allow_override = hparam_allow_override , ** kwargs )
964
974
965
975
966
976
def roberta_12_768_12 (dataset_name = None , vocab = None , pretrained = True , ctx = mx .cpu (),
967
977
use_decoder = True ,
968
- root = os .path .join (get_home_dir (), 'models' ), ** kwargs ):
978
+ root = os .path .join (get_home_dir (), 'models' ),
979
+ hparam_allow_override = False , ** kwargs ):
969
980
"""Generic RoBERTa BASE model.
970
981
971
982
The number of layers (L) is 12, number of units (H) is 768, and the
@@ -990,19 +1001,24 @@ def roberta_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu
990
1001
MXNET_HOME defaults to '~/.mxnet'.
991
1002
use_decoder : bool, default True
992
1003
Whether to include the decoder for masked language model prediction.
1004
+ hparam_allow_override : bool, default False
1005
+ If set to True, pre-defined hyper-parameters of the model
1006
+ (e.g. the number of layers, hidden units) can be overriden.
993
1007
994
1008
Returns
995
1009
-------
996
1010
RoBERTaModel, gluonnlp.vocab.Vocab
997
1011
"""
998
1012
return get_roberta_model (model_name = 'roberta_12_768_12' , vocab = vocab , dataset_name = dataset_name ,
999
1013
pretrained = pretrained , ctx = ctx ,
1000
- use_decoder = use_decoder , root = root , ** kwargs )
1014
+ use_decoder = use_decoder , root = root ,
1015
+ hparam_allow_override = hparam_allow_override , ** kwargs )
1001
1016
1002
1017
1003
1018
def roberta_24_1024_16 (dataset_name = None , vocab = None , pretrained = True , ctx = mx .cpu (),
1004
1019
use_decoder = True ,
1005
- root = os .path .join (get_home_dir (), 'models' ), ** kwargs ):
1020
+ root = os .path .join (get_home_dir (), 'models' ),
1021
+ hparam_allow_override = False , ** kwargs ):
1006
1022
"""Generic RoBERTa LARGE model.
1007
1023
1008
1024
The number of layers (L) is 24, number of units (H) is 1024, and the
@@ -1027,19 +1043,22 @@ def roberta_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cp
1027
1043
MXNET_HOME defaults to '~/.mxnet'.
1028
1044
use_decoder : bool, default True
1029
1045
Whether to include the decoder for masked language model prediction.
1046
+ hparam_allow_override : bool, default False
1047
+ If set to True, pre-defined hyper-parameters of the model
1048
+ (e.g. the number of layers, hidden units) can be overriden.
1030
1049
1031
1050
Returns
1032
1051
-------
1033
1052
RoBERTaModel, gluonnlp.vocab.Vocab
1034
1053
"""
1035
1054
return get_roberta_model (model_name = 'roberta_24_1024_16' , vocab = vocab ,
1036
1055
dataset_name = dataset_name , pretrained = pretrained , ctx = ctx ,
1037
- use_decoder = use_decoder ,
1038
- root = root , ** kwargs )
1056
+ use_decoder = use_decoder , root = root ,
1057
+ hparam_allow_override = hparam_allow_override , ** kwargs )
1039
1058
1040
1059
def ernie_12_768_12 (dataset_name = None , vocab = None , pretrained = True , ctx = mx .cpu (),
1041
1060
root = os .path .join (get_home_dir (), 'models' ), use_pooler = True , use_decoder = True ,
1042
- use_classifier = True , ** kwargs ):
1061
+ use_classifier = True , hparam_allow_override = False , ** kwargs ):
1043
1062
"""Baidu ERNIE model.
1044
1063
1045
1064
Reference:
@@ -1073,6 +1092,9 @@ def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
1073
1092
Whether to include the decoder for masked language model prediction.
1074
1093
use_classifier : bool, default True
1075
1094
Whether to include the classifier for next sentence classification.
1095
+ hparam_allow_override : bool, default False
1096
+ If set to True, pre-defined hyper-parameters of the model
1097
+ (e.g. the number of layers, hidden units) can be overriden.
1076
1098
1077
1099
Returns
1078
1100
-------
@@ -1081,13 +1103,14 @@ def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu()
1081
1103
return get_bert_model (model_name = 'ernie_12_768_12' , vocab = vocab , dataset_name = dataset_name ,
1082
1104
pretrained = pretrained , ctx = ctx , use_pooler = use_pooler ,
1083
1105
use_decoder = use_decoder , use_classifier = use_classifier , root = root ,
1084
- pretrained_allow_missing = False , ** kwargs )
1106
+ pretrained_allow_missing = False ,
1107
+ hparam_allow_override = hparam_allow_override , ** kwargs )
1085
1108
1086
1109
1087
1110
def get_roberta_model (model_name = None , dataset_name = None , vocab = None , pretrained = True , ctx = mx .cpu (),
1088
- use_decoder = True , output_attention = False ,
1089
- output_all_encodings = False , root = os .path .join (get_home_dir (), 'models' ),
1090
- ** kwargs ):
1111
+ use_decoder = True , output_attention = False , output_all_encodings = False ,
1112
+ root = os .path .join (get_home_dir (), 'models' ), ignore_extra = False ,
1113
+ hparam_allow_override = False , ** kwargs ):
1091
1114
"""Any RoBERTa pretrained model.
1092
1115
1093
1116
Parameters
@@ -1121,17 +1144,25 @@ def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained
1121
1144
Whether to include attention weights of each encoding cell to the output.
1122
1145
output_all_encodings : bool, default False
1123
1146
Whether to output encodings of all encoder cells.
1147
+ ignore_extra : bool, default False
1148
+ Whether to silently ignore parameters from the file that are not
1149
+ present in this Block.
1150
+ hparam_allow_override : bool, default False
1151
+ If set to True, pre-defined hyper-parameters of the model
1152
+ (e.g. the number of layers, hidden units) can be overriden.
1124
1153
1125
1154
Returns
1126
1155
-------
1127
1156
RoBERTaModel, gluonnlp.vocab.Vocab
1128
1157
"""
1129
- predefined_args = bert_hparams [model_name ]
1130
- mutable_args = ['use_residual' , 'dropout' , 'word_embed' ]
1131
- mutable_args = frozenset (mutable_args )
1132
- assert all ((k not in kwargs or k in mutable_args ) for k in predefined_args ), \
1133
- 'Cannot override predefined model settings.'
1158
+ predefined_args = bert_hparams [model_name ].copy ()
1159
+ if not hparam_allow_override :
1160
+ mutable_args = ['use_residual' , 'dropout' , 'word_embed' ]
1161
+ mutable_args = frozenset (mutable_args )
1162
+ assert all ((k not in kwargs or k in mutable_args ) for k in predefined_args ), \
1163
+ 'Cannot override predefined model settings.'
1134
1164
predefined_args .update (kwargs )
1165
+
1135
1166
# encoder
1136
1167
encoder = BERTEncoder (attention_cell = predefined_args ['attention_cell' ],
1137
1168
num_layers = predefined_args ['num_layers' ],
@@ -1156,7 +1187,7 @@ def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained
1156
1187
word_embed = predefined_args ['word_embed' ],
1157
1188
use_decoder = use_decoder )
1158
1189
if pretrained :
1159
- ignore_extra = not use_decoder
1190
+ ignore_extra = ignore_extra or not use_decoder
1160
1191
_load_pretrained_params (net , model_name , dataset_name , root , ctx , ignore_extra = ignore_extra ,
1161
1192
allow_missing = False )
1162
1193
return net , bert_vocab
@@ -1165,7 +1196,8 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
1165
1196
use_pooler = True , use_decoder = True , use_classifier = True , output_attention = False ,
1166
1197
output_all_encodings = False , use_token_type_embed = True ,
1167
1198
root = os .path .join (get_home_dir (), 'models' ),
1168
- pretrained_allow_missing = False , ** kwargs ):
1199
+ pretrained_allow_missing = False , ignore_extra = False ,
1200
+ hparam_allow_override = False , ** kwargs ):
1169
1201
"""Any BERT pretrained model.
1170
1202
1171
1203
Parameters
@@ -1228,16 +1260,23 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
1228
1260
If pretrained_allow_missing=True, this will be ignored and the
1229
1261
parameters will be left uninitialized. Otherwise AssertionError is
1230
1262
raised.
1263
+ ignore_extra : bool, default False
1264
+ Whether to silently ignore parameters from the file that are not
1265
+ present in this Block.
1266
+ hparam_allow_override : bool, default False
1267
+ If set to True, pre-defined hyper-parameters of the model
1268
+ (e.g. the number of layers, hidden units) can be overriden.
1231
1269
1232
1270
Returns
1233
1271
-------
1234
1272
BERTModel, gluonnlp.vocab.BERTVocab
1235
1273
"""
1236
- predefined_args = bert_hparams [model_name ]
1237
- mutable_args = ['use_residual' , 'dropout' , 'word_embed' ]
1238
- mutable_args = frozenset (mutable_args )
1239
- assert all ((k not in kwargs or k in mutable_args ) for k in predefined_args ), \
1240
- 'Cannot override predefined model settings.'
1274
+ predefined_args = bert_hparams [model_name ].copy ()
1275
+ if not hparam_allow_override :
1276
+ mutable_args = ['use_residual' , 'dropout' , 'word_embed' ]
1277
+ mutable_args = frozenset (mutable_args )
1278
+ assert all ((k not in kwargs or k in mutable_args ) for k in predefined_args ), \
1279
+ 'Cannot override predefined model settings.'
1241
1280
predefined_args .update (kwargs )
1242
1281
# encoder
1243
1282
encoder = BERTEncoder (attention_cell = predefined_args ['attention_cell' ],
@@ -1267,7 +1306,7 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr
1267
1306
use_classifier = use_classifier ,
1268
1307
use_token_type_embed = use_token_type_embed )
1269
1308
if pretrained :
1270
- ignore_extra = not (use_pooler and use_decoder and use_classifier )
1309
+ ignore_extra = ignore_extra or not (use_pooler and use_decoder and use_classifier )
1271
1310
_load_pretrained_params (net , model_name , dataset_name , root , ctx , ignore_extra = ignore_extra ,
1272
1311
allow_missing = pretrained_allow_missing )
1273
1312
return net , bert_vocab
0 commit comments