Skip to content

Commit d5ae09c

Browse files
committed
implement wide resnet
1 parent 17adcf6 commit d5ae09c

File tree

5 files changed

+151
-3
lines changed

5 files changed

+151
-3
lines changed

python/paddle/tests/test_vision_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ def test_resnet101(self):
7171
def test_resnet152(self):
7272
self.models_infer('resnet152')
7373

74+
def test_wide_resnet50(self):
75+
self.models_infer('wide_resnet50')
76+
77+
def test_wide_resnet101(self):
78+
self.models_infer('wide_resnet101')
79+
80+
def test_wide_resnet101_pretrained(self):
81+
self.models_infer('wide_resnet101', pretrained=False)
82+
7483
def test_vgg16_num_classes(self):
7584
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)
7685

python/paddle/vision/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from .models import resnet50 # noqa: F401
3535
from .models import resnet101 # noqa: F401
3636
from .models import resnet152 # noqa: F401
37+
from .models import WideResNet # noqa: F401
38+
from .models import wide_resnet50 # noqa: F401
39+
from .models import wide_resnet101 # noqa: F401
3740
from .models import MobileNetV1 # noqa: F401
3841
from .models import mobilenet_v1 # noqa: F401
3942
from .models import MobileNetV2 # noqa: F401

python/paddle/vision/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from .resnet import resnet50 # noqa: F401
1919
from .resnet import resnet101 # noqa: F401
2020
from .resnet import resnet152 # noqa: F401
21+
from .wideresnet import WideResNet # noqa: F401
22+
from .wideresnet import wide_resnet50 # noqa: F401
23+
from .wideresnet import wide_resnet101 # noqa: F401
2124
from .mobilenetv1 import MobileNetV1 # noqa: F401
2225
from .mobilenetv1 import mobilenet_v1 # noqa: F401
2326
from .mobilenetv2 import MobileNetV2 # noqa: F401

python/paddle/vision/models/resnet.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class ResNet(nn.Layer):
155155
depth (int): layers of resnet, default: 50.
156156
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
157157
will not be defined. Default: 1000.
158+
width_per_group (int):
158159
with_pool (bool): use pool before the last fc layer or not. Default: True.
159160
160161
Examples:
@@ -169,7 +170,12 @@ class ResNet(nn.Layer):
169170
170171
"""
171172

172-
def __init__(self, block, depth, num_classes=1000, with_pool=True):
173+
def __init__(self,
174+
block,
175+
depth,
176+
num_classes=1000,
177+
width_per_group=64,
178+
with_pool=True):
173179
super(ResNet, self).__init__()
174180
layer_cfg = {
175181
18: [2, 2, 2, 2],
@@ -180,6 +186,7 @@ def __init__(self, block, depth, num_classes=1000, with_pool=True):
180186
}
181187
layers = layer_cfg[depth]
182188
self.num_classes = num_classes
189+
self.base_width = width_per_group
183190
self.with_pool = with_pool
184191
self._norm_layer = nn.BatchNorm2D
185192

@@ -225,11 +232,16 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
225232

226233
layers = []
227234
layers.append(
228-
block(self.inplanes, planes, stride, downsample, 1, 64,
235+
block(self.inplanes, planes, stride, downsample, 1, self.base_width,
229236
previous_dilation, norm_layer))
230237
self.inplanes = planes * block.expansion
231238
for _ in range(1, blocks):
232-
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
239+
layers.append(
240+
block(
241+
self.inplanes,
242+
planes,
243+
base_width=self.base_width,
244+
norm_layer=norm_layer))
233245

234246
return nn.Sequential(*layers)
235247

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import division
16+
from __future__ import print_function
17+
18+
import paddle
19+
import paddle.nn as nn
20+
21+
from paddle.utils.download import get_weights_path_from_url
22+
from paddle.vision.models.resnet import BottleneckBlock, ResNet
23+
24+
__all__ = []
25+
26+
model_urls = {'wide_resnet50': ('', ''), 'wide_resnet101': ('', '')}
27+
28+
29+
class WideResNet(nn.Layer):
30+
"""Wide ResNet model from
31+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
32+
33+
The model is the same as ResNet except for the bottleneck number of channels
34+
which is twice larger in every block. The number of channels in outer 1x1
35+
convolutions is the same.
36+
37+
Args:
38+
Block (BasicBlock|BottleneckBlock): block module of model.
39+
depth (int): layers of resnet, default: 50.
40+
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
41+
will not be defined. Default: 1000.
42+
width_per_group (int): channel nums of each group
43+
with_pool (bool): use pool before the last fc layer or not. Default: True.
44+
45+
Examples:
46+
.. code-block:: python
47+
48+
from paddle.vision.models import WideResNet
49+
50+
wide_resnet50 = WideResNet(50)
51+
52+
wide_resnet101 = WideResNet(101)
53+
54+
"""
55+
56+
def __init__(self,
57+
depth,
58+
num_classes=1000,
59+
width_per_group=64,
60+
with_pool=True):
61+
super(WideResNet, self).__init__()
62+
self.layers = ResNet(BottleneckBlock, depth, num_classes,
63+
width_per_group * 2, with_pool)
64+
65+
def forward(self, x):
66+
return self.layers.forward(x)
67+
68+
69+
def _wide_resnet(arch, depth, pretrained, **kwargs):
70+
model = WideResNet(depth, **kwargs)
71+
if pretrained:
72+
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
73+
arch)
74+
weight_path = get_weights_path_from_url(model_urls[arch][0],
75+
model_urls[arch][1])
76+
77+
param = paddle.load(weight_path)
78+
model.set_dict(param)
79+
80+
return model
81+
82+
83+
def wide_resnet50(pretrained=False, **kwargs):
84+
"""Wide ResNet 50-layer model
85+
86+
Args:
87+
pretrained (bool): If True, returns a model pre-trained on ImageNet
88+
89+
Examples:
90+
.. code-block:: python
91+
92+
from paddle.vision.models import wide_resnet50
93+
94+
# build model
95+
model = wide_resnet50()
96+
97+
# build model and load imagenet pretrained weight
98+
# model = wide_resnet50(pretrained=True)
99+
"""
100+
kwargs['width_per_group'] = 64 * 2
101+
return _wide_resnet('wide_resnet50', 50, pretrained, **kwargs)
102+
103+
104+
def wide_resnet101(pretrained=False, **kwargs):
105+
"""Wide ResNet 101-layer model
106+
107+
Args:
108+
pretrained (bool): If True, returns a model pre-trained on ImageNet
109+
110+
Examples:
111+
.. code-block:: python
112+
113+
from paddle.vision.models import wide_resnet101
114+
115+
# build model
116+
model = wide_resnet101()
117+
118+
# build model and load imagenet pretrained weight
119+
# model = wide_resnet101(pretrained=True)
120+
"""
121+
return _wide_resnet('wide_resnet101', 101, pretrained, **kwargs)

0 commit comments

Comments
 (0)