Skip to content

Commit a4815f9

Browse files
megeminico63oc
authored andcommitted
[Typing] Update _ParameterConfig in Optimizer (PaddlePaddle#65277)
1 parent 987e933 commit a4815f9

File tree

6 files changed

+58
-34
lines changed

6 files changed

+58
-34
lines changed

python/paddle/optimizer/adam.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from collections import defaultdict
1919
from typing import TYPE_CHECKING, Sequence
2020

21-
from typing_extensions import NotRequired
22-
2321
import paddle
2422
from paddle import _C_ops, pir
2523
from paddle.base.libpaddle import DataType
@@ -33,20 +31,23 @@
3331
in_dynamic_or_pir_mode,
3432
in_pir_mode,
3533
)
36-
from .optimizer import Optimizer, _ParameterConfig
37-
38-
39-
class _AdamParameterConfig(_ParameterConfig):
40-
beta1: NotRequired[float | Tensor]
41-
beta2: NotRequired[float | Tensor]
42-
34+
from .optimizer import Optimizer
4335

4436
if TYPE_CHECKING:
37+
from typing_extensions import NotRequired
38+
4539
from paddle import Tensor
4640
from paddle.nn.clip import GradientClipBase
4741
from paddle.regularizer import WeightDecayRegularizer
4842

4943
from .lr import LRScheduler
44+
from .optimizer import _ParameterConfig
45+
46+
class _AdamParameterConfig(_ParameterConfig):
47+
beta1: NotRequired[float | Tensor]
48+
beta2: NotRequired[float | Tensor]
49+
epsilon: NotRequired[float | Tensor]
50+
lazy_mode: NotRequired[bool]
5051

5152

5253
__all__ = []

python/paddle/optimizer/adamax.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,25 @@
2121
from ..base import core, framework
2222
from ..base.dygraph import no_grad
2323
from ..base.framework import name_scope
24-
from .adam import _AdamParameterConfig
2524
from .optimizer import Optimizer
2625

27-
__all__ = []
28-
2926
if TYPE_CHECKING:
27+
from typing_extensions import NotRequired
28+
3029
from paddle import Tensor
3130
from paddle.nn.clip import GradientClipBase
3231
from paddle.regularizer import WeightDecayRegularizer
3332

3433
from .lr import LRScheduler
34+
from .optimizer import _ParameterConfig
35+
36+
class _AdamaxParameterConfig(_ParameterConfig):
37+
beta1: NotRequired[float | Tensor]
38+
beta2: NotRequired[float | Tensor]
39+
epsilon: NotRequired[float | Tensor]
40+
41+
42+
__all__ = []
3543

3644

3745
class Adamax(Optimizer):
@@ -69,7 +77,7 @@ class Adamax(Optimizer):
6977
beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
7078
It should be a float number or a 0-D Tensor with shape [] and data type as float32.
7179
The default value is 0.999.
72-
epsilon (float, optional): A small float value for numerical stability.
80+
epsilon (float|Tensor, optional): A small float value for numerical stability.
7381
The default value is 1e-08.
7482
parameters (list|tuple|None, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``.
7583
This parameter is required in dygraph mode. And you can specify different options for
@@ -154,9 +162,9 @@ def __init__(
154162
learning_rate: float | LRScheduler = 0.001,
155163
beta1: float | Tensor = 0.9,
156164
beta2: float | Tensor = 0.999,
157-
epsilon: float = 1e-8,
165+
epsilon: float | Tensor = 1e-8,
158166
parameters: Sequence[Tensor]
159-
| Sequence[_AdamParameterConfig]
167+
| Sequence[_AdamaxParameterConfig]
160168
| None = None,
161169
weight_decay: float | WeightDecayRegularizer | None = None,
162170
grad_clip: GradientClipBase | None = None,

python/paddle/optimizer/adamw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class AdamW(Optimizer):
7373
beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
7474
It should be a float number or a 0-D Tensor with shape [] and data type as float32.
7575
The default value is 0.999.
76-
epsilon (float, optional): A small float value for numerical stability.
76+
epsilon (float|Tensor, optional): A small float value for numerical stability.
7777
The default value is 1e-08.
7878
parameters (list|tuple|None, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``.
7979
This parameter is required in dygraph mode. And you can specify different options for
@@ -170,7 +170,7 @@ def __init__(
170170
learning_rate: float | LRScheduler = 0.001,
171171
beta1: float | Tensor = 0.9,
172172
beta2: float | Tensor = 0.999,
173-
epsilon: float = 1e-8,
173+
epsilon: float | Tensor = 1e-8,
174174
parameters: Sequence[Tensor]
175175
| Sequence[_AdamParameterConfig]
176176
| None = None,

python/paddle/optimizer/lamb.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,23 @@
2424
from .optimizer import Optimizer
2525

2626
if TYPE_CHECKING:
27+
from typing_extensions import NotRequired
28+
2729
from paddle import Tensor
2830
from paddle.nn.clip import GradientClipBase
2931

3032
from .optimizer import _ParameterConfig
3133

34+
class _LambParameterConfig(_ParameterConfig):
35+
beta1: NotRequired[float | Tensor]
36+
beta2: NotRequired[float | Tensor]
37+
epsilon: NotRequired[float | Tensor]
38+
lamb_weight_decay: NotRequired[float]
39+
exclude_from_weight_decay_fn: NotRequired[
40+
Callable[[Tensor], bool] | None
41+
]
42+
43+
3244
__all__ = []
3345

3446

@@ -62,14 +74,14 @@ class Lamb(Optimizer):
6274
learning rate, :math:`\\lambda` the LAMB weight decay rate.
6375
6476
Args:
65-
learning_rate (float|Variable, optional): the learning rate used to update parameters. \
77+
learning_rate (float|Tensor, optional): the learning rate used to update parameters. \
6678
Can be a float value or a Variable with data type float32. Default 0.001.
6779
lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01. Remind that weight_decay should be None.
68-
beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
80+
beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.
6981
Default 0.9.
70-
beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
82+
beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
7183
Default 0.999.
72-
epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
84+
epsilon (float|Tensor, optional): A small float value for numerical stability. Default 1e-6.
7385
parameters (list|tuple|None, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \
7486
This parameter is required in dygraph mode. And you can specify different options for \
7587
different parameter groups such as the learning rate, weight decay, etc, \
@@ -98,7 +110,8 @@ class Lamb(Optimizer):
98110
>>> loss = paddle.mean(out)
99111
>>> beta1 = paddle.to_tensor([0.9], dtype="float32")
100112
>>> beta2 = paddle.to_tensor([0.85], dtype="float32")
101-
>>> lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
113+
>>> lamb = paddle.optimizer.Lamb(
114+
... learning_rate=0.002, beta1=beta1, beta2=beta2, parameters=linear.parameters(), lamb_weight_decay=0.01)
102115
>>> back = out.backward()
103116
>>> lamb.step()
104117
>>> lamb.clear_grad()
@@ -113,10 +126,12 @@ def __init__(
113126
self,
114127
learning_rate: float | Tensor = 0.001,
115128
lamb_weight_decay: float = 0.01,
116-
beta1: float = 0.9,
117-
beta2: float = 0.999,
118-
epsilon: float = 1e-6,
119-
parameters: Sequence[Tensor] | Sequence[_ParameterConfig] | None = None,
129+
beta1: float | Tensor = 0.9,
130+
beta2: float | Tensor = 0.999,
131+
epsilon: float | Tensor = 1e-6,
132+
parameters: Sequence[Tensor]
133+
| Sequence[_LambParameterConfig]
134+
| None = None,
120135
grad_clip: GradientClipBase | None = None,
121136
exclude_from_weight_decay_fn: Callable[[Tensor], bool] | None = None,
122137
multi_precision: bool = False,

python/paddle/optimizer/optimizer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import TYPE_CHECKING, Callable, Sequence
2121

2222
import numpy as np
23-
from typing_extensions import NotRequired, TypedDict
2423

2524
import paddle
2625
import paddle.autograd as imperative_base
@@ -50,19 +49,19 @@
5049
from ..base.layer_helper import LayerHelper, LayerHelperBase
5150
from .lr import LRScheduler
5251

53-
54-
class _ParameterConfig(TypedDict):
55-
params: Sequence[Tensor]
56-
weight_decay: NotRequired[float | WeightDecayRegularizer | None]
57-
learning_rate: NotRequired[float | Tensor | LRScheduler | None]
58-
59-
6052
if TYPE_CHECKING:
53+
from typing_extensions import NotRequired, TypedDict
54+
6155
from paddle import Tensor
6256
from paddle.nn.clip import GradientClipBase
6357

6458
from ..base.framework import Operator, Program
6559

60+
class _ParameterConfig(TypedDict):
61+
params: Sequence[Tensor]
62+
weight_decay: NotRequired[float | WeightDecayRegularizer | None]
63+
learning_rate: NotRequired[float | Tensor | LRScheduler | None]
64+
6665

6766
__all__ = []
6867

python/paddle/optimizer/sgd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from .lr import LRScheduler
3333
from .optimizer import _ParameterConfig
34+
3435
__all__ = []
3536

3637

0 commit comments

Comments
 (0)