Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/modules/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ Package Reference
model.train
loss
initializer
optimizer
43 changes: 43 additions & 0 deletions docs/api/modules/optimizer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
gluonnlp.optimizer
======================

Gluonnlp provides some special optimizers for training in natural language processing.

.. currentmodule:: gluonnlp.optimizer

BERTAdam Optimizer
--------------------------

The Adam optimizer with weight decay regularization for BERT.

.. autosummary::
:nosignatures:

BERTAdam

LAMB Optimizer
--------------------------

Implementation of the LAMB optimizer from the paper `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes. <https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/>`_

In paper, the empirical results demonstrate the superior performance of LAMB for BERT and ResNet-50 training.
By increasing the batch size to the memory limit of a TPUv3 pod, BERT training time can be reduced from 3 days to 76 minutes.

.. code-block:: none

@inproceedings{You2019LargeBO,
title={Large Batch Optimization for Deep Learning: Training BERT in 76 minutes},
author={Yang You and Jing Li and Sashank J. Reddi and Jonathan Hseu and Sanjiv Kumar and Srinadh Bhojanapalli and Xiaodan Song and James Demmel and Cho-Jui Hsieh},
year={2019}}

.. autosummary::
:nosignatures:

LAMB

API Reference
-------------

.. automodule:: gluonnlp.optimizer
:members:
:imported-members:
3 changes: 2 additions & 1 deletion src/gluonnlp/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
"""NLP optimizer."""

from .bert_adam import *
from .lamb import *

__all__ = bert_adam.__all__
__all__ = bert_adam.__all__ + lamb.__all__
128 changes: 128 additions & 0 deletions src/gluonnlp/optimizer/lamb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""LAMB optimizer"""

from mxnet.optimizer import Optimizer, register
from mxnet.ndarray import zeros, NDArray
from mxnet.ndarray import square, power, sqrt, maximum, minimum, clip

__all__ = ['LAMB']


@register
class LAMB(Optimizer):
"""The LAMB optimizer proposed in
`Reducing BERT Pre-Training Time from 3 Days to 76 Minutes <https://arxiv.org/abs/1904.00962>`_.

If bias_correction is set to False, updates are applied by::

grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
r1 = min(max(w.norm(), lower_bound), upper_bound)
g = m / (sqrt(v_hat) + epsilon) + wd * w
r2 = g.norm()
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr = r * lr
w = w - lr * g

Otherwise, updates are applied by::

grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
m_hat = m / (1 - power(beta1, t))
v_hat = m / (1 - power(beta2, t))
r1 = w.norm()
g = m_hat / (sqrt(v_hat + epsilon)) + wd * w
r2 = g.norm()
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr = r * lr
w = w - lr * g

Parameters
----------
beta1 : float, optional, default is 0.9
Exponential decay rate for the first moment estimates.
beta2 : float, optional, default is 0.999
Exponential decay rate for the second moment estimates.
epsilon : float, optional, default is 1e-6
Small value to avoid division by 0.
lower_bound : float, optional, default is 1e-3
Lower limit of norm of weight
upper_bound : float, optional, default is 10.0
Upper limit of norm of weight
bias_correction : bool, optional, default is False
Whether to use bias correction, in the latest version of the lamb,
the bias correction was removed and some simple changes were made.
"""

def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.bias_correction = bias_correction

def create_state(self, index, weight):
stype = weight.stype
return (zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=stype), # mean
zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=stype)) # variance

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

# preprocess grad
grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)

mean, var = state
mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
var[:] = self.beta2 * var + (1. - self.beta2) * square(grad)

r1 = weight.norm()
if not self.bias_correction:
r1 = minimum(maximum(r1, self.lower_bound), self.upper_bound)
g = mean / (sqrt(var) + self.epsilon) + wd * weight

else:
# execution bias correction
mean_hat = mean / (1. - power(self.beta1, t))
var_hat = var / (1. - power(self.beta2, t))
g = mean_hat / sqrt(var_hat + self.epsilon) + wd * weight

r2 = g.norm()

# calculate lamb_trust_ratio
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr *= r

# update weight
weight[:] -= lr * g
88 changes: 88 additions & 0 deletions tests/unittest/test_lamb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import sys
import mxnet as mx
from mxnet.gluon import data as gdata
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn

from gluonnlp.optimizer import LAMB


def test_lamb_for_fashion_mnist():
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)

batch_size = 512
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0 # 0 disables multi-processing.
else:
num_workers = 4

train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),
batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),
batch_size, shuffle=False,
num_workers=num_workers)

net = nn.Sequential()
net.add(nn.Conv2D(6, kernel_size=5),
nn.BatchNorm(),
nn.Activation('relu'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(16, kernel_size=5),
nn.BatchNorm(),
nn.Activation('relu'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Dense(120),
nn.BatchNorm(),
nn.Activation('relu'),
nn.Dense(84),
nn.BatchNorm(),
nn.Activation('relu'),
nn.Dense(10))

ctx = mx.cpu()
net.initialize(ctx=ctx)

trainer = gluon.Trainer(net.collect_params(), 'LAMB', {'learning_rate': 0.001})

loss = gluon.loss.SoftmaxCrossEntropyLoss()

num_epochs = 5

def evaluate_accuracy(data_iter, net, ctx):
"""Evaluate accuracy of a model on the given data set."""
acc_sum, n = 0.0, 0.0
for X, y in train_iter:
X = X.as_in_context(ctx)
y = y.as_in_context(ctx)
y_hat = net(X)

y = y.astype('float32')
acc_sum += (y_hat.argmax(axis=1) == y).sum().asscalar()
n += y.size
return acc_sum / n

def train(net, train_iter, test_iter, loss, num_epochs, batch_size,
trainer, ctx):
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
for X, y in train_iter:
X = X.as_in_context(ctx)
y = y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y).sum()
l.backward()

trainer.step(batch_size)
y = y.astype('float32')
train_l_sum += l.asscalar()
train_acc_sum += (y_hat.argmax(axis=1) == y).sum().asscalar()
n += y.size
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

train(net, train_iter, test_iter, loss, num_epochs, batch_size, trainer, ctx)