1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
17+ from typing import (
18+ TYPE_CHECKING ,
19+ TypedDict ,
20+ )
21+
22+ from typing_extensions import NotRequired , Unpack
23+
1524import paddle
1625import paddle .nn .functional as F
1726from paddle import nn
27+ from paddle ._typing import Size2
1828from paddle .base .param_attr import ParamAttr
1929from paddle .nn import (
2030 AdaptiveAvgPool2D ,
2737from paddle .nn .initializer import Uniform
2838from paddle .utils .download import get_weights_path_from_url
2939
40+ if TYPE_CHECKING :
41+ from paddle import Tensor
3042__all__ = []
3143
3244model_urls = {
3749}
3850
3951
40- def xavier (channels , filter_size ) :
52+ def xavier (channels : int , filter_size : int ) -> ParamAttr :
4153 stdv = (3.0 / (filter_size ** 2 * channels )) ** 0.5
4254 param_attr = ParamAttr (initializer = Uniform (- stdv , stdv ))
4355 return param_attr
4456
4557
4658class ConvLayer (nn .Layer ):
4759 def __init__ (
48- self , num_channels , num_filters , filter_size , stride = 1 , groups = 1
60+ self ,
61+ num_channels : int ,
62+ num_filters : int ,
63+ filter_size : int ,
64+ stride : Size2 = 1 ,
65+ groups : int = 1 ,
4966 ):
5067 super ().__init__ ()
5168
@@ -59,22 +76,22 @@ def __init__(
5976 bias_attr = False ,
6077 )
6178
62- def forward (self , inputs ) :
79+ def forward (self , inputs : Tensor ) -> Tensor :
6380 y = self ._conv (inputs )
6481 return y
6582
6683
6784class Inception (nn .Layer ):
6885 def __init__ (
6986 self ,
70- input_channels ,
71- output_channels ,
72- filter1 ,
73- filter3R ,
74- filter3 ,
75- filter5R ,
76- filter5 ,
77- proj ,
87+ input_channels : int ,
88+ output_channels : int ,
89+ filter1 : int ,
90+ filter3R : int ,
91+ filter3 : int ,
92+ filter5R : int ,
93+ filter5 : int ,
94+ proj : int ,
7895 ):
7996 super ().__init__ ()
8097
@@ -87,7 +104,7 @@ def __init__(
87104
88105 self ._convprj = ConvLayer (input_channels , proj , 1 )
89106
90- def forward (self , inputs ) :
107+ def forward (self , inputs : Tensor ) -> Tensor :
91108 conv1 = self ._conv1 (inputs )
92109
93110 conv3r = self ._conv3r (inputs )
@@ -132,7 +149,7 @@ class GoogLeNet(nn.Layer):
132149 [1, 1000] [1, 1000] [1, 1000]
133150 """
134151
135- def __init__ (self , num_classes = 1000 , with_pool = True ):
152+ def __init__ (self , num_classes : int = 1000 , with_pool : bool = True ) -> None :
136153 super ().__init__ ()
137154 self .num_classes = num_classes
138155 self .with_pool = with_pool
@@ -181,7 +198,7 @@ def __init__(self, num_classes=1000, with_pool=True):
181198 self ._drop_o2 = Dropout (p = 0.7 , mode = "downscale_in_infer" )
182199 self ._out2 = Linear (1024 , num_classes , weight_attr = xavier (1024 , 1 ))
183200
184- def forward (self , inputs ) :
201+ def forward (self , inputs : Tensor ) -> tuple [ Tensor , Tensor , Tensor ] :
185202 x = self ._conv (inputs )
186203 x = self ._pool (x )
187204 x = self ._conv_1 (x )
@@ -227,10 +244,17 @@ def forward(self, inputs):
227244 out2 = self ._drop_o2 (out2 )
228245 out2 = self ._out2 (out2 )
229246
230- return [out , out1 , out2 ]
247+ return out , out1 , out2
248+
249+
250+ class _GoogLeNetOptions (TypedDict ):
251+ num_classes : NotRequired [int ]
252+ with_pool : NotRequired [bool ]
231253
232254
233- def googlenet (pretrained = False , ** kwargs ):
255+ def googlenet (
256+ pretrained : bool = False , ** kwargs : Unpack [_GoogLeNetOptions ]
257+ ) -> GoogLeNet :
234258 """GoogLeNet (Inception v1) model architecture from
235259 `"Going Deeper with Convolutions" <https://arxiv.org/pdf/1409.4842.pdf>`_.
236260
0 commit comments