Skip to content

Commit 4ca9b7b

Browse files
authored
[Typing][A-16,A-19] Add type annotations for base Layer and containers (#65190)
1 parent b29ab37 commit 4ca9b7b

File tree

5 files changed

+266
-138
lines changed

5 files changed

+266
-138
lines changed

python/paddle/base/framework.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,9 +710,11 @@ def __impl__(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
710710
# introducing compatibility issues, add this decorator
711711
# NOTE(chenweihang): not using `wrap_decorator` here is because `wrap_decorator` will
712712
# move kwargs to args, which doesn't work in this decorate case
713-
def deprecate_stat_dict(func):
713+
def deprecate_stat_dict(
714+
func: Callable[_InputT, _RetT]
715+
) -> Callable[_InputT, _RetT]:
714716
@functools.wraps(func)
715-
def wrapper(*args, **kwargs):
717+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
716718
if "stat_dict" in kwargs:
717719
warnings.warn(
718720
"The argument `stat_dict` has deprecated, please change it to `state_dict`.",

python/paddle/base/unique_name.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __call__(self, name):
9595
generator = UniqueNameGenerator()
9696

9797

98-
def generate(key):
98+
def generate(key: str) -> str:
9999
"""
100100
Generate unique name with prefix key. Currently, Paddle distinguishes the
101101
names of the same key by numbering it from zero. For example, when key=fc,

python/paddle/nn/layer/container.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
import typing
1518
from collections import OrderedDict
1619
from collections.abc import Iterable, Mapping
20+
from typing import Any, Iterator, Sequence
21+
22+
from typing_extensions import Self
23+
24+
from paddle import Tensor
1725

1826
from ...base.dygraph.base import param_guard
1927
from ...base.framework import Parameter
@@ -67,30 +75,38 @@ class LayerDict(Layer):
6775
6876
"""
6977

70-
def __init__(self, sublayers=None):
78+
def __init__(
79+
self,
80+
sublayers: (
81+
LayerDict
82+
| typing.Mapping[str, Layer]
83+
| Sequence[tuple[str, Layer]]
84+
| None
85+
) = None,
86+
) -> None:
7187
super().__init__()
7288
if sublayers is not None:
7389
self.update(sublayers)
7490

75-
def __getitem__(self, key):
91+
def __getitem__(self, key: str) -> Layer:
7692
return self._sub_layers[key]
7793

78-
def __setitem__(self, key, sublayer):
94+
def __setitem__(self, key: str, sublayer: Layer) -> Layer:
7995
return self.add_sublayer(key, sublayer)
8096

81-
def __delitem__(self, key):
97+
def __delitem__(self, key: str) -> None:
8298
del self._sub_layers[key]
8399

84-
def __len__(self):
100+
def __len__(self) -> int:
85101
return len(self._sub_layers)
86102

87-
def __iter__(self):
103+
def __iter__(self) -> Iterator[Layer]:
88104
return iter(self._sub_layers)
89105

90-
def __contains__(self, key):
106+
def __contains__(self, key: str) -> bool:
91107
return key in self._sub_layers
92108

93-
def clear(self):
109+
def clear(self) -> None:
94110
"""
95111
Clear all the sublayers in the LayerDict.
96112
@@ -120,7 +136,7 @@ def clear(self):
120136
"""
121137
self._sub_layers.clear()
122138

123-
def pop(self, key):
139+
def pop(self, key: str) -> Layer:
124140
"""
125141
Remove the key from the LayerDict and return the layer of the key.
126142
@@ -152,7 +168,7 @@ def pop(self, key):
152168
del self[key]
153169
return v
154170

155-
def keys(self):
171+
def keys(self) -> Iterable[str]:
156172
"""
157173
Return the iterable of the keys in LayerDict.
158174
@@ -181,7 +197,7 @@ def keys(self):
181197
"""
182198
return self._sub_layers.keys()
183199

184-
def items(self):
200+
def items(self) -> Iterable[tuple[str, Layer]]:
185201
"""
186202
Return the iterable of the key/value pairs in LayerDict.
187203
@@ -210,7 +226,7 @@ def items(self):
210226
"""
211227
return self._sub_layers.items()
212228

213-
def values(self):
229+
def values(self) -> Iterable[Layer]:
214230
"""
215231
Return the iterable of the values in LayerDict.
216232
@@ -239,7 +255,12 @@ def values(self):
239255
"""
240256
return self._sub_layers.values()
241257

242-
def update(self, sublayers):
258+
def update(
259+
self,
260+
sublayers: (
261+
LayerDict | typing.Mapping[str, Layer] | Sequence[tuple[str, Layer]]
262+
),
263+
) -> None:
243264
"""
244265
Update the key/values pairs in sublayers to the LayerDict, overwriting the existing keys.
245266
@@ -353,29 +374,29 @@ class ParameterList(Layer):
353374
[5, 4]
354375
"""
355376

356-
def __init__(self, parameters=None):
377+
def __init__(self, parameters: Iterable[Tensor] | None = None) -> None:
357378
super().__init__()
358379
if parameters is not None:
359380
for idx, param in enumerate(parameters):
360381
assert isinstance(param, Parameter)
361382
self.add_parameter(str(idx), param)
362383

363-
def __getitem__(self, idx):
384+
def __getitem__(self, idx: int) -> Tensor:
364385
with param_guard(self._parameters):
365386
return self._parameters[str(idx)]
366387

367-
def __setitem__(self, idx, param):
388+
def __setitem__(self, idx: int, param: Tensor) -> None:
368389
assert isinstance(param, Parameter)
369390
setattr(self, str(idx), param)
370391

371-
def __len__(self):
392+
def __len__(self) -> int:
372393
return len(self._parameters)
373394

374-
def __iter__(self):
395+
def __iter__(self) -> Iterator[Tensor]:
375396
with param_guard(self._parameters):
376397
return iter(self._parameters.values())
377398

378-
def append(self, parameter):
399+
def append(self, parameter: Tensor) -> Self:
379400
"""Appends a given parameter at the end of the list.
380401
381402
Parameters:
@@ -412,13 +433,13 @@ class LayerList(Layer):
412433
... return x
413434
"""
414435

415-
def __init__(self, sublayers=None):
436+
def __init__(self, sublayers: Iterable[Layer] | None = None) -> None:
416437
super().__init__()
417438
if sublayers is not None:
418439
for idx, layer in enumerate(sublayers):
419440
self.add_sublayer(str(idx), layer)
420441

421-
def _get_abs_idx(self, idx):
442+
def _get_abs_idx(self, idx: int) -> int:
422443
if isinstance(idx, int):
423444
if not (-len(self) <= idx < len(self)):
424445
raise IndexError(
@@ -428,18 +449,18 @@ def _get_abs_idx(self, idx):
428449
idx += len(self)
429450
return idx
430451

431-
def __getitem__(self, idx):
452+
def __getitem__(self, idx: int) -> Layer:
432453
if isinstance(idx, slice):
433454
return self.__class__(list(self._sub_layers.values())[idx])
434455
else:
435456
idx = self._get_abs_idx(idx)
436457
return self._sub_layers[str(idx)]
437458

438-
def __setitem__(self, idx, sublayer):
459+
def __setitem__(self, idx: int, sublayer: Layer) -> None:
439460
idx = self._get_abs_idx(idx)
440461
return setattr(self, str(idx), sublayer)
441462

442-
def __delitem__(self, idx):
463+
def __delitem__(self, idx: int) -> None:
443464
if isinstance(idx, slice):
444465
for k in range(len(self._sub_layers))[idx]:
445466
delattr(self, str(k))
@@ -451,13 +472,13 @@ def __delitem__(self, idx):
451472
list(zip(str_indices, self._sub_layers.values()))
452473
)
453474

454-
def __len__(self):
475+
def __len__(self) -> int:
455476
return len(self._sub_layers)
456477

457-
def __iter__(self):
478+
def __iter__(self) -> Iterator[Layer]:
458479
return iter(self._sub_layers.values())
459480

460-
def append(self, sublayer):
481+
def append(self, sublayer: Layer) -> Self:
461482
"""
462483
Appends a sublayer to the end of the list.
463484
@@ -478,7 +499,7 @@ def append(self, sublayer):
478499
self.add_sublayer(str(len(self)), sublayer)
479500
return self
480501

481-
def insert(self, index, sublayer):
502+
def insert(self, index: int, sublayer: Layer) -> None:
482503
"""
483504
Insert a sublayer before a given index in the list.
484505
@@ -510,7 +531,7 @@ def insert(self, index, sublayer):
510531
self._sub_layers[str(i)] = self._sub_layers[str(i - 1)]
511532
self._sub_layers[str(index)] = sublayer
512533

513-
def extend(self, sublayers):
534+
def extend(self, sublayers: Iterable[Layer]) -> Self:
514535
"""
515536
Appends sublayers to the end of the list.
516537
@@ -575,7 +596,7 @@ class Sequential(Layer):
575596
576597
"""
577598

578-
def __init__(self, *layers):
599+
def __init__(self, *layers: Layer | tuple[str, Layer] | list[Any]) -> None:
579600
super().__init__()
580601
if len(layers) > 0 and isinstance(layers[0], (list, tuple)):
581602
for name, layer in layers:
@@ -584,7 +605,7 @@ def __init__(self, *layers):
584605
for idx, layer in enumerate(layers):
585606
self.add_sublayer(str(idx), layer)
586607

587-
def __getitem__(self, name):
608+
def __getitem__(self, name: str) -> Layer:
588609
if isinstance(name, slice):
589610
return self.__class__(*(list(self._sub_layers.values())[name]))
590611
elif isinstance(name, str):
@@ -598,19 +619,19 @@ def __getitem__(self, name):
598619
raise IndexError(f'index {name} is out of range')
599620
return list(self._sub_layers.values())[name]
600621

601-
def __setitem__(self, name, layer):
622+
def __setitem__(self, name: str, layer: Layer) -> None:
602623
assert isinstance(layer, Layer)
603624
setattr(self, str(name), layer)
604625

605-
def __delitem__(self, name):
626+
def __delitem__(self, name: str) -> None:
606627
name = str(name)
607628
assert name in self._sub_layers
608629
del self._sub_layers[name]
609630

610-
def __len__(self):
631+
def __len__(self) -> int:
611632
return len(self._sub_layers)
612633

613-
def forward(self, input):
634+
def forward(self, input: Any) -> Any:
614635
for layer in self._sub_layers.values():
615636
input = layer(input)
616637
return input

0 commit comments

Comments
 (0)