2424from .optimizer import Optimizer
2525
2626if TYPE_CHECKING :
27+ from typing_extensions import NotRequired
28+
2729 from paddle import Tensor
2830 from paddle .nn .clip import GradientClipBase
2931
3032 from .optimizer import _ParameterConfig
3133
34+ class _LambParameterConfig (_ParameterConfig ):
35+ beta1 : NotRequired [float | Tensor ]
36+ beta2 : NotRequired [float | Tensor ]
37+ epsilon : NotRequired [float | Tensor ]
38+ lamb_weight_decay : NotRequired [float ]
39+ exclude_from_weight_decay_fn : NotRequired [
40+ Callable [[Tensor ], bool ] | None
41+ ]
42+
43+
3244__all__ = []
3345
3446
@@ -62,14 +74,14 @@ class Lamb(Optimizer):
6274 learning rate, :math:`\\lambda` the LAMB weight decay rate.
6375
6476 Args:
65- learning_rate (float|Variable , optional): the learning rate used to update parameters. \
77+ learning_rate (float|Tensor , optional): the learning rate used to update parameters. \
6678 Can be a float value or a Variable with data type float32. Default 0.001.
6779 lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01. Remind that weight_decay should be None.
68- beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
80+ beta1 (float|Tensor , optional): The exponential decay rate for the 1st moment estimates.
6981 Default 0.9.
70- beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
82+ beta2 (float|Tensor , optional): The exponential decay rate for the 2nd moment estimates.
7183 Default 0.999.
72- epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
84+ epsilon (float|Tensor , optional): A small float value for numerical stability. Default 1e-6.
7385 parameters (list|tuple|None, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \
7486 This parameter is required in dygraph mode. And you can specify different options for \
7587 different parameter groups such as the learning rate, weight decay, etc, \
@@ -98,7 +110,8 @@ class Lamb(Optimizer):
98110 >>> loss = paddle.mean(out)
99111 >>> beta1 = paddle.to_tensor([0.9], dtype="float32")
100112 >>> beta2 = paddle.to_tensor([0.85], dtype="float32")
101- >>> lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
113+ >>> lamb = paddle.optimizer.Lamb(
114+ ... learning_rate=0.002, beta1=beta1, beta2=beta2, parameters=linear.parameters(), lamb_weight_decay=0.01)
102115 >>> back = out.backward()
103116 >>> lamb.step()
104117 >>> lamb.clear_grad()
@@ -113,10 +126,12 @@ def __init__(
113126 self ,
114127 learning_rate : float | Tensor = 0.001 ,
115128 lamb_weight_decay : float = 0.01 ,
116- beta1 : float = 0.9 ,
117- beta2 : float = 0.999 ,
118- epsilon : float = 1e-6 ,
119- parameters : Sequence [Tensor ] | Sequence [_ParameterConfig ] | None = None ,
129+ beta1 : float | Tensor = 0.9 ,
130+ beta2 : float | Tensor = 0.999 ,
131+ epsilon : float | Tensor = 1e-6 ,
132+ parameters : Sequence [Tensor ]
133+ | Sequence [_LambParameterConfig ]
134+ | None = None ,
120135 grad_clip : GradientClipBase | None = None ,
121136 exclude_from_weight_decay_fn : Callable [[Tensor ], bool ] | None = None ,
122137 multi_precision : bool = False ,
0 commit comments