Skip to content

Commit 3bbbe0f

Browse files
zrr1999SigureMo
andauthored
[Typing][A-96, B-43] Add type annotations for python/paddle/hapi/callbacks.py, python/paddle/framework/framework.py (#65777)
--------- Co-authored-by: SigureMo <[email protected]>
1 parent 3ccc952 commit 3bbbe0f

File tree

3 files changed

+249
-130
lines changed

3 files changed

+249
-130
lines changed

python/paddle/base/layer_helper_base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import copy
17+
from typing import TYPE_CHECKING
1718

1819
import numpy as np
1920

@@ -33,12 +34,15 @@
3334
from .initializer import _global_bias_initializer, _global_weight_initializer
3435
from .param_attr import ParamAttr, WeightNormParamAttr
3536

37+
if TYPE_CHECKING:
38+
from paddle._typing.dtype_like import _DTypeLiteral
39+
3640
__all__ = []
3741

3842

3943
class LayerHelperBase:
4044
# global dtype
41-
__dtype = "float32"
45+
__dtype: _DTypeLiteral = "float32"
4246

4347
def __init__(self, name, layer_type):
4448
self._layer_type = layer_type
@@ -265,9 +269,11 @@ def __weight_normalize(g, v, dim):
265269
# to achieve the subset.
266270
w = paddle.tensor.math._multiply_with_axis(
267271
x=v,
268-
y=scale
269-
if dim is None
270-
else paddle.reshape(x=scale, shape=[v.shape[dim]]),
272+
y=(
273+
scale
274+
if dim is None
275+
else paddle.reshape(x=scale, shape=[v.shape[dim]])
276+
),
271277
axis=-1 if dim is None else dim,
272278
)
273279
# To serialize the original parameter for inference, maybe a

python/paddle/framework/framework.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
1518

1619
import numpy as np
1720

1821
import paddle
1922
from paddle.base.data_feeder import convert_dtype
20-
21-
# TODO: define framework api
2223
from paddle.base.layer_helper_base import LayerHelperBase
2324

25+
if TYPE_CHECKING:
26+
from paddle._typing.dtype_like import DTypeLike, _DTypeLiteral
27+
2428
__all__ = []
2529

2630

27-
def set_default_dtype(d):
31+
def set_default_dtype(d: DTypeLike) -> None:
2832
"""
2933
Set default dtype. The default dtype is initially float32.
3034
@@ -73,14 +77,14 @@ def set_default_dtype(d):
7377
LayerHelperBase.set_default_dtype(d)
7478

7579

76-
def get_default_dtype():
80+
def get_default_dtype() -> _DTypeLiteral:
7781
"""
7882
Get the current default dtype. The default dtype is initially float32.
7983
8084
Args:
8185
None.
8286
Returns:
83-
String, this global dtype only supports float16, float32, float64.
87+
str, this global dtype only supports float16, float32, float64.
8488
8589
Examples:
8690
.. code-block:: python

0 commit comments

Comments
 (0)