27
27
28
28
@register
29
29
class LAMB (Optimizer ):
30
- """The LAMB optimizer:
31
- It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`.
32
- https://arxiv.org/abs/1904.00962
30
+ """The LAMB optimizer proposed in
31
+ `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes <https://arxiv.org/abs/1904.00962>`_.
33
32
34
33
Updates are applied by::
34
+ if use_latest:
35
+ grad = clip(grad * rescale_grad, clip_gradient)
36
+ m = beta1 * m + (1 - beta1) * grad
37
+ v = beta2 * v + (1 - beta2) * (grad**2)
38
+ r1 = minimum(maximum(w.norm(), self.lower_bound), self.upper_bound)
39
+ g = m / (sqrt(v_hat) + epsilon) + wd * w
40
+ r2 = g.norm()
41
+ r = if r1 == 0. or r2 == 0. else r1/r2
42
+ lr = r * lr
43
+ w = w - lr * g
44
+ else:
35
45
grad = clip(grad * rescale_grad, clip_gradient)
36
46
m = beta1 * m + (1 - beta1) * grad
37
47
v = beta2 * v + (1 - beta2) * (grad**2)
@@ -40,8 +50,7 @@ class LAMB(Optimizer):
40
50
r1 = w.norm()
41
51
g = m_hat / (sqrt(v_hat + epsilon)) + wd * w
42
52
r2 = g.norm()
43
- r = if r1 == 0. or r2 == 0. else minimum(
44
- maximum(r1 / r2, self.lower_bound), self.upper_bound)
53
+ r = if r1 == 0. or r2 == 0. else r1/r2
45
54
lr = r * lr
46
55
w = w - lr * g
47
56
@@ -54,20 +63,26 @@ class LAMB(Optimizer):
54
63
epsilon : float, optional, default is 1e-6
55
64
Small value to avoid division by 0.
56
65
lower_bound : float, optional, default is 1e-3
57
- Lower limit of lamb_trust_ratio
66
+ Lower limit of norm of weight
58
67
upper_bound : float, optional, default is 10.0
59
- Upper limit of lamb_trust_ratio
68
+ Upper limit of norm of weight
69
+ use_latest : bool, optional, default is True
70
+ Whether to use the latest version of LAMB. The new version of LAMB
71
+ has some differences from the old version,
72
+ such as the bias correction was removed in the new version
60
73
"""
61
74
62
75
def __init__ (self , learning_rate = 0.001 , beta1 = 0.9 , beta2 = 0.999 , epsilon = 1e-6 ,
63
- lower_bound = 1e-3 , upper_bound = 10.0 , lazy_update = False , ** kwargs ):
76
+ lower_bound = 1e-3 , upper_bound = 10.0 , use_latest = True ,
77
+ lazy_update = False , ** kwargs ):
64
78
super (LAMB , self ).__init__ (learning_rate = learning_rate , ** kwargs )
65
79
self .beta1 = beta1
66
80
self .beta2 = beta2
67
81
self .epsilon = epsilon
68
82
self .lower_bound = lower_bound
69
83
self .upper_bound = upper_bound
70
84
self .lazy_update = lazy_update
85
+ self .use_latest = use_latest
71
86
72
87
def create_state (self , index , weight ):
73
88
stype = weight .stype if self .lazy_update else 'default'
@@ -93,17 +108,21 @@ def update(self, index, weight, grad, state):
93
108
mean [:] = self .beta1 * mean + (1. - self .beta1 ) * grad
94
109
var [:] = self .beta2 * var + (1. - self .beta2 ) * square (grad )
95
110
96
- # execution bias correction
97
- mean_hat = mean / (1. - power (self .beta1 , t ))
98
- var_hat = var / (1. - power (self .beta2 , t ))
99
-
100
111
r1 = weight .norm ()
101
- g = mean_hat / sqrt (var_hat + self .epsilon ) + wd * weight
112
+ if self .use_latest :
113
+ r1 = minimum (maximum (r1 , self .lower_bound ), self .upper_bound )
114
+ g = mean / (sqrt (var ) + self .epsilon ) + wd * weight
115
+
116
+ else :
117
+ # execution bias correction
118
+ mean_hat = mean / (1. - power (self .beta1 , t ))
119
+ var_hat = var / (1. - power (self .beta2 , t ))
120
+ g = mean_hat / sqrt (var_hat + self .epsilon ) + wd * weight
121
+
102
122
r2 = g .norm ()
103
123
104
124
# calculate lamb_trust_ratio
105
- r = 1. if r1 == 0. or r2 == 0. else minimum (
106
- maximum (r1 / r2 , self .lower_bound ), self .upper_bound )
125
+ r = 1. if r1 == 0. or r2 == 0. else r1 / r2
107
126
lr *= r
108
127
109
128
# update weight
0 commit comments