1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
17+ import typing
1518from collections import OrderedDict
1619from collections .abc import Iterable , Mapping
20+ from typing import Any , Iterator , Sequence
21+
22+ from typing_extensions import Self
23+
24+ from paddle import Tensor
1725
1826from ...base .dygraph .base import param_guard
1927from ...base .framework import Parameter
@@ -67,30 +75,38 @@ class LayerDict(Layer):
6775
6876 """
6977
70- def __init__ (self , sublayers = None ):
78+ def __init__ (
79+ self ,
80+ sublayers : (
81+ LayerDict
82+ | typing .Mapping [str , Layer ]
83+ | Sequence [tuple [str , Layer ]]
84+ | None
85+ ) = None ,
86+ ) -> None :
7187 super ().__init__ ()
7288 if sublayers is not None :
7389 self .update (sublayers )
7490
75- def __getitem__ (self , key ) :
91+ def __getitem__ (self , key : str ) -> Layer :
7692 return self ._sub_layers [key ]
7793
78- def __setitem__ (self , key , sublayer ) :
94+ def __setitem__ (self , key : str , sublayer : Layer ) -> Layer :
7995 return self .add_sublayer (key , sublayer )
8096
81- def __delitem__ (self , key ) :
97+ def __delitem__ (self , key : str ) -> None :
8298 del self ._sub_layers [key ]
8399
84- def __len__ (self ):
100+ def __len__ (self ) -> int :
85101 return len (self ._sub_layers )
86102
87- def __iter__ (self ):
103+ def __iter__ (self ) -> Iterator [ Layer ] :
88104 return iter (self ._sub_layers )
89105
90- def __contains__ (self , key ) :
106+ def __contains__ (self , key : str ) -> bool :
91107 return key in self ._sub_layers
92108
93- def clear (self ):
109+ def clear (self ) -> None :
94110 """
95111 Clear all the sublayers in the LayerDict.
96112
@@ -120,7 +136,7 @@ def clear(self):
120136 """
121137 self ._sub_layers .clear ()
122138
123- def pop (self , key ) :
139+ def pop (self , key : str ) -> Layer :
124140 """
125141 Remove the key from the LayerDict and return the layer of the key.
126142
@@ -152,7 +168,7 @@ def pop(self, key):
152168 del self [key ]
153169 return v
154170
155- def keys (self ):
171+ def keys (self ) -> Iterable [ str ] :
156172 """
157173 Return the iterable of the keys in LayerDict.
158174
@@ -181,7 +197,7 @@ def keys(self):
181197 """
182198 return self ._sub_layers .keys ()
183199
184- def items (self ):
200+ def items (self ) -> Iterable [ tuple [ str , Layer ]] :
185201 """
186202 Return the iterable of the key/value pairs in LayerDict.
187203
@@ -210,7 +226,7 @@ def items(self):
210226 """
211227 return self ._sub_layers .items ()
212228
213- def values (self ):
229+ def values (self ) -> Iterable [ Layer ] :
214230 """
215231 Return the iterable of the values in LayerDict.
216232
@@ -239,7 +255,12 @@ def values(self):
239255 """
240256 return self ._sub_layers .values ()
241257
242- def update (self , sublayers ):
258+ def update (
259+ self ,
260+ sublayers : (
261+ LayerDict | typing .Mapping [str , Layer ] | Sequence [tuple [str , Layer ]]
262+ ),
263+ ) -> None :
243264 """
244265 Update the key/values pairs in sublayers to the LayerDict, overwriting the existing keys.
245266
@@ -353,29 +374,29 @@ class ParameterList(Layer):
353374 [5, 4]
354375 """
355376
356- def __init__ (self , parameters = None ):
377+ def __init__ (self , parameters : Iterable [ Tensor ] | None = None ) -> None :
357378 super ().__init__ ()
358379 if parameters is not None :
359380 for idx , param in enumerate (parameters ):
360381 assert isinstance (param , Parameter )
361382 self .add_parameter (str (idx ), param )
362383
363- def __getitem__ (self , idx ) :
384+ def __getitem__ (self , idx : int ) -> Tensor :
364385 with param_guard (self ._parameters ):
365386 return self ._parameters [str (idx )]
366387
367- def __setitem__ (self , idx , param ) :
388+ def __setitem__ (self , idx : int , param : Tensor ) -> None :
368389 assert isinstance (param , Parameter )
369390 setattr (self , str (idx ), param )
370391
371- def __len__ (self ):
392+ def __len__ (self ) -> int :
372393 return len (self ._parameters )
373394
374- def __iter__ (self ):
395+ def __iter__ (self ) -> Iterator [ Tensor ] :
375396 with param_guard (self ._parameters ):
376397 return iter (self ._parameters .values ())
377398
378- def append (self , parameter ) :
399+ def append (self , parameter : Tensor ) -> Self :
379400 """Appends a given parameter at the end of the list.
380401
381402 Parameters:
@@ -412,13 +433,13 @@ class LayerList(Layer):
412433 ... return x
413434 """
414435
415- def __init__ (self , sublayers = None ):
436+ def __init__ (self , sublayers : Iterable [ Layer ] | None = None ) -> None :
416437 super ().__init__ ()
417438 if sublayers is not None :
418439 for idx , layer in enumerate (sublayers ):
419440 self .add_sublayer (str (idx ), layer )
420441
421- def _get_abs_idx (self , idx ) :
442+ def _get_abs_idx (self , idx : int ) -> int :
422443 if isinstance (idx , int ):
423444 if not (- len (self ) <= idx < len (self )):
424445 raise IndexError (
@@ -428,18 +449,18 @@ def _get_abs_idx(self, idx):
428449 idx += len (self )
429450 return idx
430451
431- def __getitem__ (self , idx ) :
452+ def __getitem__ (self , idx : int ) -> Layer :
432453 if isinstance (idx , slice ):
433454 return self .__class__ (list (self ._sub_layers .values ())[idx ])
434455 else :
435456 idx = self ._get_abs_idx (idx )
436457 return self ._sub_layers [str (idx )]
437458
438- def __setitem__ (self , idx , sublayer ) :
459+ def __setitem__ (self , idx : int , sublayer : Layer ) -> None :
439460 idx = self ._get_abs_idx (idx )
440461 return setattr (self , str (idx ), sublayer )
441462
442- def __delitem__ (self , idx ) :
463+ def __delitem__ (self , idx : int ) -> None :
443464 if isinstance (idx , slice ):
444465 for k in range (len (self ._sub_layers ))[idx ]:
445466 delattr (self , str (k ))
@@ -451,13 +472,13 @@ def __delitem__(self, idx):
451472 list (zip (str_indices , self ._sub_layers .values ()))
452473 )
453474
454- def __len__ (self ):
475+ def __len__ (self ) -> int :
455476 return len (self ._sub_layers )
456477
457- def __iter__ (self ):
478+ def __iter__ (self ) -> Iterator [ Layer ] :
458479 return iter (self ._sub_layers .values ())
459480
460- def append (self , sublayer ) :
481+ def append (self , sublayer : Layer ) -> Self :
461482 """
462483 Appends a sublayer to the end of the list.
463484
@@ -478,7 +499,7 @@ def append(self, sublayer):
478499 self .add_sublayer (str (len (self )), sublayer )
479500 return self
480501
481- def insert (self , index , sublayer ) :
502+ def insert (self , index : int , sublayer : Layer ) -> None :
482503 """
483504 Insert a sublayer before a given index in the list.
484505
@@ -510,7 +531,7 @@ def insert(self, index, sublayer):
510531 self ._sub_layers [str (i )] = self ._sub_layers [str (i - 1 )]
511532 self ._sub_layers [str (index )] = sublayer
512533
513- def extend (self , sublayers ) :
534+ def extend (self , sublayers : Iterable [ Layer ]) -> Self :
514535 """
515536 Appends sublayers to the end of the list.
516537
@@ -575,7 +596,7 @@ class Sequential(Layer):
575596
576597 """
577598
578- def __init__ (self , * layers ) :
599+ def __init__ (self , * layers : Layer | tuple [ str , Layer ] | list [ Any ]) -> None :
579600 super ().__init__ ()
580601 if len (layers ) > 0 and isinstance (layers [0 ], (list , tuple )):
581602 for name , layer in layers :
@@ -584,7 +605,7 @@ def __init__(self, *layers):
584605 for idx , layer in enumerate (layers ):
585606 self .add_sublayer (str (idx ), layer )
586607
587- def __getitem__ (self , name ) :
608+ def __getitem__ (self , name : str ) -> Layer :
588609 if isinstance (name , slice ):
589610 return self .__class__ (* (list (self ._sub_layers .values ())[name ]))
590611 elif isinstance (name , str ):
@@ -598,19 +619,19 @@ def __getitem__(self, name):
598619 raise IndexError (f'index { name } is out of range' )
599620 return list (self ._sub_layers .values ())[name ]
600621
601- def __setitem__ (self , name , layer ) :
622+ def __setitem__ (self , name : str , layer : Layer ) -> None :
602623 assert isinstance (layer , Layer )
603624 setattr (self , str (name ), layer )
604625
605- def __delitem__ (self , name ) :
626+ def __delitem__ (self , name : str ) -> None :
606627 name = str (name )
607628 assert name in self ._sub_layers
608629 del self ._sub_layers [name ]
609630
610- def __len__ (self ):
631+ def __len__ (self ) -> int :
611632 return len (self ._sub_layers )
612633
613- def forward (self , input ) :
634+ def forward (self , input : Any ) -> Any :
614635 for layer in self ._sub_layers .values ():
615636 input = layer (input )
616637 return input
0 commit comments