Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit d228fdd

Browse files
committed
add the latest version of the calculation for LAMB
1 parent d5f03e4 commit d228fdd

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

src/gluonnlp/optimizer/lamb.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,21 @@
2727

2828
@register
2929
class LAMB(Optimizer):
30-
"""The LAMB optimizer:
31-
It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`.
32-
https://arxiv.org/abs/1904.00962
30+
"""The LAMB optimizer proposed in
31+
`Reducing BERT Pre-Training Time from 3 Days to 76 Minutes <https://arxiv.org/abs/1904.00962>`_.
3332
3433
Updates are applied by::
34+
if use_latest:
35+
grad = clip(grad * rescale_grad, clip_gradient)
36+
m = beta1 * m + (1 - beta1) * grad
37+
v = beta2 * v + (1 - beta2) * (grad**2)
38+
r1 = minimum(maximum(w.norm(), self.lower_bound), self.upper_bound)
39+
g = m / (sqrt(v_hat) + epsilon) + wd * w
40+
r2 = g.norm()
41+
r = if r1 == 0. or r2 == 0. else r1/r2
42+
lr = r * lr
43+
w = w - lr * g
44+
else:
3545
grad = clip(grad * rescale_grad, clip_gradient)
3646
m = beta1 * m + (1 - beta1) * grad
3747
v = beta2 * v + (1 - beta2) * (grad**2)
@@ -40,8 +50,7 @@ class LAMB(Optimizer):
4050
r1 = w.norm()
4151
g = m_hat / (sqrt(v_hat + epsilon)) + wd * w
4252
r2 = g.norm()
43-
r = if r1 == 0. or r2 == 0. else minimum(
44-
maximum(r1 / r2, self.lower_bound), self.upper_bound)
53+
r = if r1 == 0. or r2 == 0. else r1/r2
4554
lr = r * lr
4655
w = w - lr * g
4756
@@ -54,20 +63,26 @@ class LAMB(Optimizer):
5463
epsilon : float, optional, default is 1e-6
5564
Small value to avoid division by 0.
5665
lower_bound : float, optional, default is 1e-3
57-
Lower limit of lamb_trust_ratio
66+
Lower limit of norm of weight
5867
upper_bound : float, optional, default is 10.0
59-
Upper limit of lamb_trust_ratio
68+
Upper limit of norm of weight
69+
use_latest : bool, optional, default is True
70+
Whether to use the latest version of LAMB. The new version of LAMB
71+
has some differences from the old version,
72+
such as the bias correction was removed in the new version
6073
"""
6174

6275
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
63-
lower_bound=1e-3, upper_bound=10.0, lazy_update=False, **kwargs):
76+
lower_bound=1e-3, upper_bound=10.0, use_latest=True,
77+
lazy_update=False, **kwargs):
6478
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
6579
self.beta1 = beta1
6680
self.beta2 = beta2
6781
self.epsilon = epsilon
6882
self.lower_bound = lower_bound
6983
self.upper_bound = upper_bound
7084
self.lazy_update = lazy_update
85+
self.use_latest = use_latest
7186

7287
def create_state(self, index, weight):
7388
stype = weight.stype if self.lazy_update else 'default'
@@ -93,17 +108,21 @@ def update(self, index, weight, grad, state):
93108
mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
94109
var[:] = self.beta2 * var + (1. - self.beta2) * square(grad)
95110

96-
# execution bias correction
97-
mean_hat = mean / (1. - power(self.beta1, t))
98-
var_hat = var / (1. - power(self.beta2, t))
99-
100111
r1 = weight.norm()
101-
g = mean_hat / sqrt(var_hat + self.epsilon) + wd * weight
112+
if self.use_latest:
113+
r1 = minimum(maximum(r1, self.lower_bound), self.upper_bound)
114+
g = mean / (sqrt(var) + self.epsilon) + wd * weight
115+
116+
else:
117+
# execution bias correction
118+
mean_hat = mean / (1. - power(self.beta1, t))
119+
var_hat = var / (1. - power(self.beta2, t))
120+
g = mean_hat / sqrt(var_hat + self.epsilon) + wd * weight
121+
102122
r2 = g.norm()
103123

104124
# calculate lamb_trust_ratio
105-
r = 1. if r1 == 0. or r2 == 0. else minimum(
106-
maximum(r1 / r2, self.lower_bound), self.upper_bound)
125+
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
107126
lr *= r
108127

109128
# update weight

tests/unittest/test_lamb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_lamb_for_fashion_mnist():
1414
batch_size = 512
1515
transformer = gdata.vision.transforms.ToTensor()
1616
if sys.platform.startswith('win'):
17-
num_workers = 0 # 0表示不用额外的进程来加速读取数据
17+
num_workers = 0 # 0 disables multi-processing.
1818
else:
1919
num_workers = 4
2020

0 commit comments

Comments
 (0)