Skip to content

Commit 8937205

Browse files
SigureMoAinavopithygit
authored
add googlenet (#36034)
* update AvgPool2D to AdaptiveAvgPool2D * class_num -> num_classes * add en doc * add googlenet to pretrained test * remove weights name * add parameter with_pool * update en doc * fix googlenet out shape * 2020 -> 2021 Co-authored-by: Ainavo <[email protected]> Co-authored-by: pithygit <[email protected]> Co-authored-by: Ainavo <[email protected]> Co-authored-by: pithygit <[email protected]>
1 parent 442688a commit 8937205

File tree

5 files changed

+265
-2
lines changed

5 files changed

+265
-2
lines changed

python/paddle/tests/test_pretrained_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def infer(self, arch):
5454
def test_models(self):
5555
arches = [
5656
'mobilenet_v1', 'mobilenet_v2', 'resnet18', 'vgg16', 'alexnet',
57-
'resnext50_32x4d', 'inception_v3', 'densenet121'
57+
'resnext50_32x4d', 'inception_v3', 'densenet121', 'googlenet'
5858
]
5959
for arch in arches:
6060
self.infer(arch)

python/paddle/tests/test_vision_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def test_resnext152_64x4d(self):
109109
def test_inception_v3(self):
110110
self.models_infer('inception_v3')
111111

112+
def test_googlenet(self):
113+
self.models_infer('googlenet')
114+
112115
def test_vgg16_num_classes(self):
113116
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)
114117

python/paddle/vision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
from .models import resnext152_64x4d # noqa: F401
6262
from .models import InceptionV3 # noqa: F401
6363
from .models import inception_v3 # noqa: F401
64+
from .models import GoogLeNet # noqa: F401
65+
from .models import googlenet # noqa: F401
6466
from .transforms import BaseTransform # noqa: F401
6567
from .transforms import Compose # noqa: F401
6668
from .transforms import Resize # noqa: F401

python/paddle/vision/models/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from .resnext import resnext152_64x4d # noqa: F401
4646
from .inceptionv3 import InceptionV3 # noqa: F401
4747
from .inceptionv3 import inception_v3 # noqa: F401
48+
from .googlenet import GoogLeNet # noqa: F401
49+
from .googlenet import googlenet # noqa: F401
4850

4951
__all__ = [ #noqa
5052
'ResNet',
@@ -79,5 +81,7 @@
7981
'resnext152_32x4d',
8082
'resnext152_64x4d',
8183
'InceptionV3',
82-
'inception_v3'
84+
'inception_v3',
85+
'GoogLeNet',
86+
'googlenet',
8387
]
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import division
16+
from __future__ import print_function
17+
18+
import paddle
19+
import paddle.nn as nn
20+
import paddle.nn.functional as F
21+
22+
from paddle.nn import Conv2D, Linear, Dropout
23+
from paddle.nn import MaxPool2D, AvgPool2D, AdaptiveAvgPool2D
24+
from paddle.nn.initializer import Uniform
25+
from paddle.fluid.param_attr import ParamAttr
26+
from paddle.utils.download import get_weights_path_from_url
27+
28+
__all__ = []
29+
30+
model_urls = {
31+
"googlenet":
32+
("https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GoogLeNet_pretrained.pdparams",
33+
"80c06f038e905c53ab32c40eca6e26ae")
34+
}
35+
36+
37+
def xavier(channels, filter_size):
38+
stdv = (3.0 / (filter_size**2 * channels))**0.5
39+
param_attr = ParamAttr(initializer=Uniform(-stdv, stdv))
40+
return param_attr
41+
42+
43+
class ConvLayer(nn.Layer):
44+
def __init__(self,
45+
num_channels,
46+
num_filters,
47+
filter_size,
48+
stride=1,
49+
groups=1):
50+
super(ConvLayer, self).__init__()
51+
52+
self._conv = Conv2D(
53+
in_channels=num_channels,
54+
out_channels=num_filters,
55+
kernel_size=filter_size,
56+
stride=stride,
57+
padding=(filter_size - 1) // 2,
58+
groups=groups,
59+
bias_attr=False)
60+
61+
def forward(self, inputs):
62+
y = self._conv(inputs)
63+
return y
64+
65+
66+
class Inception(nn.Layer):
67+
def __init__(self, input_channels, output_channels, filter1, filter3R,
68+
filter3, filter5R, filter5, proj):
69+
super(Inception, self).__init__()
70+
71+
self._conv1 = ConvLayer(input_channels, filter1, 1)
72+
self._conv3r = ConvLayer(input_channels, filter3R, 1)
73+
self._conv3 = ConvLayer(filter3R, filter3, 3)
74+
self._conv5r = ConvLayer(input_channels, filter5R, 1)
75+
self._conv5 = ConvLayer(filter5R, filter5, 5)
76+
self._pool = MaxPool2D(kernel_size=3, stride=1, padding=1)
77+
78+
self._convprj = ConvLayer(input_channels, proj, 1)
79+
80+
def forward(self, inputs):
81+
conv1 = self._conv1(inputs)
82+
83+
conv3r = self._conv3r(inputs)
84+
conv3 = self._conv3(conv3r)
85+
86+
conv5r = self._conv5r(inputs)
87+
conv5 = self._conv5(conv5r)
88+
89+
pool = self._pool(inputs)
90+
convprj = self._convprj(pool)
91+
92+
cat = paddle.concat([conv1, conv3, conv5, convprj], axis=1)
93+
cat = F.relu(cat)
94+
return cat
95+
96+
97+
class GoogLeNet(nn.Layer):
98+
"""GoogLeNet (Inception v1) model architecture from
99+
`"Going Deeper with Convolutions" <https://arxiv.org/pdf/1409.4842.pdf>`_
100+
101+
Args:
102+
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
103+
will not be defined. Default: 1000.
104+
with_pool (bool, optional): use pool before the last fc layer or not. Default: True.
105+
106+
Examples:
107+
.. code-block:: python
108+
109+
import paddle
110+
from paddle.vision.models import GoogLeNet
111+
112+
# build model
113+
model = GoogLeNet()
114+
115+
x = paddle.rand([1, 3, 224, 224])
116+
out, out1, out2 = model(x)
117+
118+
print(out.shape)
119+
"""
120+
121+
def __init__(self, num_classes=1000, with_pool=True):
122+
super(GoogLeNet, self).__init__()
123+
self.num_classes = num_classes
124+
self.with_pool = with_pool
125+
126+
self._conv = ConvLayer(3, 64, 7, 2)
127+
self._pool = MaxPool2D(kernel_size=3, stride=2)
128+
self._conv_1 = ConvLayer(64, 64, 1)
129+
self._conv_2 = ConvLayer(64, 192, 3)
130+
131+
self._ince3a = Inception(192, 192, 64, 96, 128, 16, 32, 32)
132+
self._ince3b = Inception(256, 256, 128, 128, 192, 32, 96, 64)
133+
134+
self._ince4a = Inception(480, 480, 192, 96, 208, 16, 48, 64)
135+
self._ince4b = Inception(512, 512, 160, 112, 224, 24, 64, 64)
136+
self._ince4c = Inception(512, 512, 128, 128, 256, 24, 64, 64)
137+
self._ince4d = Inception(512, 512, 112, 144, 288, 32, 64, 64)
138+
self._ince4e = Inception(528, 528, 256, 160, 320, 32, 128, 128)
139+
140+
self._ince5a = Inception(832, 832, 256, 160, 320, 32, 128, 128)
141+
self._ince5b = Inception(832, 832, 384, 192, 384, 48, 128, 128)
142+
143+
if with_pool:
144+
# out
145+
self._pool_5 = AdaptiveAvgPool2D(1)
146+
# out1
147+
self._pool_o1 = AvgPool2D(kernel_size=5, stride=3)
148+
# out2
149+
self._pool_o2 = AvgPool2D(kernel_size=5, stride=3)
150+
151+
if num_classes > 0:
152+
# out
153+
self._drop = Dropout(p=0.4, mode="downscale_in_infer")
154+
self._fc_out = Linear(
155+
1024, num_classes, weight_attr=xavier(1024, 1))
156+
157+
# out1
158+
self._conv_o1 = ConvLayer(512, 128, 1)
159+
self._fc_o1 = Linear(1152, 1024, weight_attr=xavier(2048, 1))
160+
self._drop_o1 = Dropout(p=0.7, mode="downscale_in_infer")
161+
self._out1 = Linear(1024, num_classes, weight_attr=xavier(1024, 1))
162+
163+
# out2
164+
self._conv_o2 = ConvLayer(528, 128, 1)
165+
self._fc_o2 = Linear(1152, 1024, weight_attr=xavier(2048, 1))
166+
self._drop_o2 = Dropout(p=0.7, mode="downscale_in_infer")
167+
self._out2 = Linear(1024, num_classes, weight_attr=xavier(1024, 1))
168+
169+
def forward(self, inputs):
170+
x = self._conv(inputs)
171+
x = self._pool(x)
172+
x = self._conv_1(x)
173+
x = self._conv_2(x)
174+
x = self._pool(x)
175+
176+
x = self._ince3a(x)
177+
x = self._ince3b(x)
178+
x = self._pool(x)
179+
180+
ince4a = self._ince4a(x)
181+
x = self._ince4b(ince4a)
182+
x = self._ince4c(x)
183+
ince4d = self._ince4d(x)
184+
x = self._ince4e(ince4d)
185+
x = self._pool(x)
186+
187+
x = self._ince5a(x)
188+
ince5b = self._ince5b(x)
189+
190+
out, out1, out2 = ince5b, ince4a, ince4d
191+
192+
if self.with_pool:
193+
out = self._pool_5(out)
194+
out1 = self._pool_o1(out1)
195+
out2 = self._pool_o2(out2)
196+
197+
if self.num_classes > 0:
198+
out = self._drop(out)
199+
out = paddle.squeeze(out, axis=[2, 3])
200+
out = self._fc_out(out)
201+
202+
out1 = self._conv_o1(out1)
203+
out1 = paddle.flatten(out1, start_axis=1, stop_axis=-1)
204+
out1 = self._fc_o1(out1)
205+
out1 = F.relu(out1)
206+
out1 = self._drop_o1(out1)
207+
out1 = self._out1(out1)
208+
209+
out2 = self._conv_o2(out2)
210+
out2 = paddle.flatten(out2, start_axis=1, stop_axis=-1)
211+
out2 = self._fc_o2(out2)
212+
out2 = self._drop_o2(out2)
213+
out2 = self._out2(out2)
214+
215+
return [out, out1, out2]
216+
217+
218+
def googlenet(pretrained=False, **kwargs):
219+
"""GoogLeNet (Inception v1) model architecture from
220+
`"Going Deeper with Convolutions" <https://arxiv.org/pdf/1409.4842.pdf>`_
221+
222+
Args:
223+
pretrained (bool): If True, returns a model pre-trained on ImageNet
224+
225+
Examples:
226+
.. code-block:: python
227+
228+
import paddle
229+
from paddle.vision.models import googlenet
230+
231+
# build model
232+
model = googlenet()
233+
234+
# build model and load imagenet pretrained weight
235+
# model = googlenet(pretrained=True)
236+
237+
x = paddle.rand([1, 3, 224, 224])
238+
out, out1, out2 = model(x)
239+
240+
print(out.shape)
241+
"""
242+
model = GoogLeNet(**kwargs)
243+
arch = "googlenet"
244+
if pretrained:
245+
assert (
246+
arch in model_urls
247+
), "{} model do not have a pretrained model now, you should set pretrained=False".format(
248+
arch)
249+
weight_path = get_weights_path_from_url(model_urls[arch][0],
250+
model_urls[arch][1])
251+
252+
param = paddle.load(weight_path)
253+
model.set_dict(param)
254+
return model

0 commit comments

Comments
 (0)