Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 64 additions & 13 deletions src/gluonnlp/optimizer/bert_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
# under the License.

"""Weight updating functions."""
import os
import warnings
import numpy
from mxnet.optimizer import Optimizer, register
from mxnet.ndarray import zeros, NDArray, full
from mxnet.ndarray.contrib import mp_adamw_update, adamw_update
from mxnet.ndarray.contrib import mp_adamw_update, adamw_update, \
multi_mp_adamw_update, multi_adamw_update

__all__ = ['BERTAdam']

Expand Down Expand Up @@ -63,6 +65,8 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.aggregate_num = max(1, min(50, int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE',
'4'))))

def create_state_multi_precision(self, index, weight):
"""multi-precision state creation function."""
Expand All @@ -88,31 +92,78 @@ def update(self, index, weight, grad, state):

def update_multi_precision(self, index, weight, grad, state):
"""multi-precision update function"""
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)

def _update_impl(self, indices, weight, grad, state, multi_precision=False):
"""update function"""
aggregate = self.aggregate_num > 1
if not isinstance(indices, (tuple, list)):
indices = [indices]
weight = [weight]
grad = [grad]
state = [state]
for w_i, g_i in zip(weight, grad):
assert(isinstance(w_i, NDArray))
assert(isinstance(g_i, NDArray))
aggregate = (aggregate and
w_i.stype == 'default' and
g_i.stype == 'default')
self._update_count(indices)
lr = self._get_lr(indices)
wd = self._get_wd(indices)
lrs = self._get_lrs(indices)
wds = self._get_wds(indices)

# pylint: disable=access-member-before-definition
if not isinstance(self.rescale_grad, NDArray):
self.rescale_grad = full(shape=(1,), val=self.rescale_grad, ctx=weight.context)
self.rescale_grad = full(shape=(1,), val=self.rescale_grad, ctx=weight[0].context)
else:
self.rescale_grad = self.rescale_grad.as_in_context(weight.context)
self.rescale_grad = self.rescale_grad.as_in_context(weight[0].context)

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'rescale_grad': self.rescale_grad}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if not multi_precision:
mean, var = state
adamw_update(weight, grad, mean, var, out=weight,
lr=1, wd=wd, eta=lr, **kwargs)

if aggregate:
current_index = 0
while current_index < len(indices):
sidx = current_index
eidx = min(current_index + self.aggregate_num, len(indices))
if not multi_precision:
mean, var = list(zip(*state[sidx:eidx]))
multi_adamw_update(weight[sidx:eidx],
grad[sidx:eidx],
mean, var,
out=weight[sidx:eidx],
size=len(weight[sidx:eidx]),
lrs=list(numpy.ones(len(weight[sidx:eidx]))),
wds=wds[sidx:eidx],
etas=lrs[sidx:eidx],
**kwargs)
else:
mean_var = list(zip(*state[sidx:eidx]))[0]
tmean_var = list(zip(*mean_var))
mean = tmean_var[0]
var = tmean_var[1]
multi_mp_adamw_update(weight[sidx:eidx],
grad[sidx:eidx],
mean, var,
list(zip(*state[sidx:eidx]))[1],
out=weight[sidx:eidx],
size=len(weight[sidx:eidx]),
lrs=list(numpy.ones(len(weight[sidx:eidx]))),
wds=wds[sidx:eidx],
etas=lrs[sidx:eidx],
**kwargs)
current_index += self.aggregate_num
else:
mean, var = state[0]
mp_adamw_update(weight, grad, mean, var, state[1], out=weight,
lr=1, wd=wd, eta=lr, **kwargs)
for w_i, g_i, s_i, lr, wd in zip(weight, grad, state, lrs, wds):
if not multi_precision:
mean, var = s_i
adamw_update(w_i, g_i, mean, var, out=w_i,
lr=1, wd=wd, eta=lr, **kwargs)
else:
mean, var = s_i[0]
mp_adamw_update(w_i, g_i, mean, var, s_i[1], out=w_i,
lr=1, wd=wd, eta=lr, **kwargs)
103 changes: 78 additions & 25 deletions tests/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,54 @@ def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
rtol=1e-4, atol=1e-5, compare_states=True):
"""Compare opt1 and opt2."""
if w_stype == 'default':
w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
w1 = w2.copyto(default_context())
elif w_stype == 'row_sparse' or w_stype == 'csr':
w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
w1 = w2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")
if g_stype == 'default':
g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
g1 = g2.copyto(default_context())
elif g_stype == 'row_sparse' or g_stype == 'csr':
g2 = rand_ndarray(shape, g_stype, dtype=dtype)
g1 = g2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")
if not isinstance(shape, list):
if w_stype == 'default':
w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
w1 = w2.copyto(default_context())
elif w_stype == 'row_sparse' or w_stype == 'csr':
w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
w1 = w2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")
if g_stype == 'default':
g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
g1 = g2.copyto(default_context())
elif g_stype == 'row_sparse' or g_stype == 'csr':
g2 = rand_ndarray(shape, g_stype, dtype=dtype)
g1 = g2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")

state1 = opt1.create_state_multi_precision(0, w1)
state2 = opt2.create_state_multi_precision(0, w2)
if compare_states:
compare_ndarray_tuple(state1, state2)
state1 = opt1.create_state_multi_precision(0, w1)
state2 = opt2.create_state_multi_precision(0, w2)
if compare_states:
compare_ndarray_tuple(state1, state2)

opt1.update_multi_precision(0, w1, g1, state1)
opt2.update_multi_precision(0, w2, g2, state2)
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
opt1.update_multi_precision(0, w1, g1, state1)
opt2.update_multi_precision(0, w2, g2, state2)
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
else:
# test multi-tensor: Opt1 single-tensor reference, Opt2 multi-tensor
from copy import deepcopy
ntensors = len(shape)
w1, g1 = [], []
for s in shape:
w1.append(mx.random.uniform(shape=s, ctx=default_context(), dtype=dtype))
g1.append(mx.random.uniform(shape=s, ctx=default_context(), dtype=dtype))
w1 = tuple(w1)
w2 = deepcopy(w1)
g1 = tuple(g1)
g2 = deepcopy(g1)
state2 = [opt2.create_state_multi_precision(0, w2[i]) for i in range(ntensors)]
opt2.update_multi_precision(list(range(ntensors)), w2, g2, state2)
for i in range(ntensors):
state1 = opt1.create_state_multi_precision(i, w1[i])
opt1.update_multi_precision(i, w1[i], g1[i], state1)
if compare_states:
compare_ndarray_tuple(state1, state2[i], rtol, atol)
assert_almost_equal(w1[i].asnumpy(), w2[i].asnumpy(), rtol=rtol, atol=atol)

# BERT ADAM
class PyBERTAdam(mx.optimizer.Optimizer):
Expand Down Expand Up @@ -165,3 +186,35 @@ def test_bert_adam():
except ImportError:
print('skipping test_bert_adam() because an old version of MXNet is found')
return

def test_bert_multi_adam():
opt1 = PyBERTAdam
opt2 = optimizer.BERTAdam
# shapes as Bert-large
dims_x = [1024, 4096, 1024, 1024]
dims_y = [1, 1, 1024, 4096]
dims_occurrences = [3, 1, 2, 2]
nlayers = 2
shapes=[]
for l in range(nlayers):
for i, (dx,dy) in enumerate(zip(dims_x, dims_y)):
for j in range(dims_occurrences[i]):
shapes.append((dx,dy))
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}]
for dtype in [np.float16, np.float32]:
for cg_option in cg_options:
for rg_option in rg_options:
for wd_option in wd_options:
kwarg = {}
kwarg.update(cg_option)
kwarg.update(rg_option)
kwarg.update(wd_option)
if np.float16 == dtype:
kwarg['multi_precision'] = True
rtol = 1e-3
else:
rtol = 1e-4
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shapes, dtype,
rtol=rtol, atol=2e-5)