Skip to content

Commit fb429b6

Browse files
committed
Add MobileNetv2 with pretrained weights
1 parent aaa6101 commit fb429b6

File tree

8 files changed

+315
-10
lines changed

8 files changed

+315
-10
lines changed

README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,16 @@ with tf.Session() as sess:
277277
| [DenseNet121](tensornets/densenets.py#L63) | 25.480 | 8.022 | 6.842 | 8.1M | 7.0M | 202.9 | [[paper]](https://arxiv.org/abs/1608.06993) [[torch]](https://github.com/liuzhuang13/DenseNet/blob/master/models/densenet.lua) |
278278
| [DenseNet169](tensornets/densenets.py#L71) | 23.926 | 6.892 | 6.140 | 14.3M | 12.6M | 219.1 | [[paper]](https://arxiv.org/abs/1608.06993) [[torch]](https://github.com/liuzhuang13/DenseNet/blob/master/models/densenet.lua) |
279279
| [DenseNet201](tensornets/densenets.py#L79) | 22.936 | 6.542 | 5.724 | 20.2M | 18.3M | 272.0 | [[paper]](https://arxiv.org/abs/1608.06993) [[torch]](https://github.com/liuzhuang13/DenseNet/blob/master/models/densenet.lua) |
280-
| [MobileNet25](tensornets/mobilenets.py#L84) | 48.418 | 24.208 | 21.196 | 0.5M | 0.2M | 29.27 | [[paper]](https://arxiv.org/abs/1704.04861) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py) |
281-
| [MobileNet50](tensornets/mobilenets.py#L91) | 35.708 | 14.376 | 12.180 | 1.3M | 0.8M | 42.32 | [[paper]](https://arxiv.org/abs/1704.04861) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py) |
282-
| [MobileNet75](tensornets/mobilenets.py#L98) | 31.588 | 11.758 | 9.878 | 2.6M | 1.8M | 57.23 | [[paper]](https://arxiv.org/abs/1704.04861) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py) |
283-
| [MobileNet100](tensornets/mobilenets.py#L105) | 29.576 | 10.496 | 8.774 | 4.3M | 3.2M | 70.69 | [[paper]](https://arxiv.org/abs/1704.04861) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py) |
280+
| [MobileNet25](tensornets/mobilenets.py#L156) | 48.418 | 24.208 | 21.196 | 0.5M | 0.2M | 34.46 | [[paper]](https://arxiv.org/abs/1704.04861) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py) |
281+
| [MobileNet50](tensornets/mobilenets.py#L163) | 35.708 | 14.376 | 12.180 | 1.3M | 0.8M | 52.46 | [[paper]](https://arxiv.org/abs/1704.04861) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py) |
282+
| [MobileNet75](tensornets/mobilenets.py#L170) | 31.588 | 11.758 | 9.878 | 2.6M | 1.8M | 70.11 | [[paper]](https://arxiv.org/abs/1704.04861) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py) |
283+
| [MobileNet100](tensornets/mobilenets.py#L177) | 29.576 | 10.496 | 8.774 | 4.3M | 3.2M | 83.41 | [[paper]](https://arxiv.org/abs/1704.04861) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py) |
284+
| [MobileNet35v2](tensornets/mobilenets.py#L184) | 39.914 | 17.568 | 15.422 | 1.7M | 0.4M | 57.04 | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) |
285+
| [MobileNet50v2](tensornets/mobilenets.py#L191) | 34.806 | 13.938 | 11.976 | 2.0M | 0.7M | 64.35 | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) |
286+
| [MobileNet75v2](tensornets/mobilenets.py#L198) | 30.468 | 10.824 | 9.188 | 2.7M | 1.4M | 88.68 | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) |
287+
| [MobileNet100v2](tensornets/mobilenets.py#L205) | 28.664 | 9.858 | 8.322 | 3.5M | 2.3M | 93.82 | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) |
288+
| [MobileNet130v2](tensornets/mobilenets.py#L212) | 25.320 | 7.878 | 6.728 | 5.4M | 3.8M | 130.4 | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) |
289+
| [MobileNet140v2](tensornets/mobilenets.py#L219) | 24.770 | 7.578 | 6.518 | 6.2M | 4.4M | 132.9 | [[paper]](https://arxiv.org/abs/1801.04381) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py) |
284290
| [SqueezeNet](tensornets/squeezenets.py#L45) | 45.566 | 21.960 | 18.578 | 1.2M | 0.7M | 71.43 | [[paper]](https://arxiv.org/abs/1602.07360) [[caffe]](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/train_val.prototxt) |
285291

286292
### Object detection
@@ -310,13 +316,14 @@ with tf.Session() as sess:
310316

311317
## News 📰
312318

319+
- The six variants of MobileNetv2 are released, [5 May 2018]().
313320
- YOLOv3 for COCO and VOC are released, [4 April 2018](https://github.com/taehoonlee/tensornets/commit/d8b2d8a54dc4b775a174035da63561028deb6624).
314321
- Generic object detection models for YOLOv2 and FasterRCNN are released, [26 March 2018](https://github.com/taehoonlee/tensornets/commit/67915e659d2097a96c82ba7740b9e43a8c69858d).
315322

316323
## Future work 🔥
317324

318325
- Add training codes.
319-
- Add image classification models (MobileNetv2, PNASNet).
326+
- Add image classification models (PolyNet, PNASNet).
320327
- Add object detection models (MaskRCNN, SSD).
321328
- Add image segmentation models (FCN, UNet).
322329
- Add image datasets (COCO, OpenImages).

tensornets/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@
3636
from .mobilenets import MobileNet75
3737
from .mobilenets import MobileNet100
3838

39+
from .mobilenets import MobileNet35v2
40+
from .mobilenets import MobileNet50v2
41+
from .mobilenets import MobileNet75v2
42+
from .mobilenets import MobileNet100v2
43+
from .mobilenets import MobileNet130v2
44+
from .mobilenets import MobileNet140v2
45+
3946
from .squeezenets import SqueezeNet
4047

4148
from .capsulenets import CapsuleNet

tensornets/middles.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Collection of representative endpoints for each model."""
22
from __future__ import absolute_import
33

4+
from .utils import tf_later_than
5+
46

57
def names_inceptions(k, first_block, omit_first=False,
68
pool_last=False, resnet=False):
@@ -59,6 +61,37 @@ def names_darknets(k):
5961
return names
6062

6163

64+
def tuple_mobilenetv2():
65+
def baseidx(b):
66+
return [b, b + 3, b + 5]
67+
indices = baseidx(2)
68+
if tf_later_than('1.3.0'):
69+
bn_name = 'FusedBatchNorm:0'
70+
else:
71+
bn_name = 'batchnorm/add_1:0'
72+
names = ['conv1/Relu6:0', 'sconv1/Relu6:0', 'pconv1/bn/' + bn_name]
73+
k = 10
74+
l = 2
75+
for (i, j) in enumerate([2, 3, 4, 3, 3, 1]):
76+
indices += baseidx(k)
77+
names += ["conv%d/conv/Relu6:0" % l,
78+
"conv%d/sconv/Relu6:0" % l,
79+
"conv%d/pconv/bn/%s" % (l, bn_name)]
80+
k += 8
81+
l += 1
82+
for _ in range(j - 1):
83+
indices += (baseidx(k) + [k + 6])
84+
names += ["conv%d/conv/Relu6:0" % l,
85+
"conv%d/sconv/Relu6:0" % l,
86+
"conv%d/pconv/bn/%s" % (l, bn_name),
87+
"conv%d/out:0" % l]
88+
k += 9
89+
l += 1
90+
indices += [k]
91+
names += ["conv%d/Relu6:0" % l]
92+
return (indices, names, -16)
93+
94+
6295
def direct(model_name):
6396
try:
6497
return __middles_dict__[model_name]
@@ -236,6 +269,12 @@ def direct(model_name):
236269
['conv%d/conv/Relu6:0' % (i + 4) for i in range(11)],
237270
-3
238271
),
272+
'mobilenet35v2': tuple_mobilenetv2(),
273+
'mobilenet50v2': tuple_mobilenetv2(),
274+
'mobilenet75v2': tuple_mobilenetv2(),
275+
'mobilenet100v2': tuple_mobilenetv2(),
276+
'mobilenet130v2': tuple_mobilenetv2(),
277+
'mobilenet140v2': tuple_mobilenetv2(),
239278
'squeezenet': (
240279
[9, 16, 17, 24, 31, 32] + list(range(39, 61, 7)),
241280
names_squeezenet(),

tensornets/mobilenets.py

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
"""Collection of MobileNet variants
22
3-
The reference paper:
3+
The reference papers:
44
5+
1. V1
56
- MobileNets: Efficient Convolutional Neural Networks for Mobile Vision
67
Applications, arXiv 2017
78
- Andrew G. Howard et al.
89
- https://arxiv.org/abs/1704.04861
10+
2. V2
11+
- MobileNetV2: Inverted Residuals and Linear Bottlenecks, arXiv 2018
12+
- Mark Sandler et al.
13+
- https://arxiv.org/abs/1801.04381
914
10-
The reference implementation:
15+
The reference implementations:
1116
12-
1. TF Slim
17+
1. (for v1) TF Slim
1318
- https://github.com/tensorflow/models/blob/master/research/slim/nets/
1419
mobilenet_v1.py
20+
2. (for v2) TF Slim
21+
- https://github.com/tensorflow/models/blob/master/research/slim/nets/
22+
mobilenet/mobilenet_v2.py
1523
"""
1624
from __future__ import absolute_import
1725

@@ -22,6 +30,7 @@
2230
from .layers import dropout
2331
from .layers import fc
2432
from .layers import separable_conv2d
33+
from .layers import convbn
2534
from .layers import convbnrelu6 as conv
2635
from .layers import sconvbnrelu6 as sconv
2736

@@ -30,8 +39,8 @@
3039
from .utils import var_scope
3140

3241

33-
def __args__(is_training):
34-
return [([batch_norm], {'decay': 0.9997, 'scale': True, 'epsilon': 0.001,
42+
def __base_args__(is_training, decay):
43+
return [([batch_norm], {'decay': decay, 'scale': True, 'epsilon': 0.001,
3544
'is_training': is_training, 'scope': 'bn'}),
3645
([conv2d], {'padding': 'SAME', 'activation_fn': None,
3746
'biases_initializer': None, 'scope': 'conv'}),
@@ -42,13 +51,33 @@ def __args__(is_training):
4251
'scope': 'sconv'})]
4352

4453

54+
def __args__(is_training):
55+
return __base_args__(is_training, 0.9997)
56+
57+
58+
def __args_v2__(is_training):
59+
return __base_args__(is_training, 0.999)
60+
61+
4562
@var_scope('block')
4663
def block(x, filters, stride=1, scope=None):
4764
x = sconv(x, None, 3, 1, stride=stride, scope='sconv')
4865
x = conv(x, filters, 1, stride=1, scope='conv')
4966
return x
5067

5168

69+
@var_scope('blockv2')
70+
def block2(x, filters, first=False, stride=1, scope=None):
71+
shortcut = x
72+
x = conv(x, 6 * x.shape[-1].value, 1, scope='conv')
73+
x = sconv(x, None, 3, 1, stride=stride, scope='sconv')
74+
x = convbn(x, filters, 1, stride=1, scope='pconv')
75+
if stride == 1 and shortcut.shape[-1].value == filters:
76+
return add(shortcut, x, name='out')
77+
else:
78+
return x
79+
80+
5281
def mobilenet(x, depth_multiplier, is_training, classes, stem,
5382
scope=None, reuse=None):
5483
def depth(d):
@@ -81,6 +110,49 @@ def depth(d):
81110
return x
82111

83112

113+
def mobilenetv2(x, depth_multiplier, is_training, classes, stem,
114+
scope=None, reuse=None):
115+
def depth(d):
116+
d *= depth_multiplier
117+
filters = max(8, int(d + 4) // 8 * 8)
118+
if filters < 0.9 * d:
119+
filters += 8
120+
return filters
121+
x = conv(x, depth(32), 3, stride=2, scope='conv1')
122+
x = sconv(x, None, 3, 1, scope='sconv1')
123+
x = convbn(x, depth(16), 1, scope='pconv1')
124+
125+
x = block2(x, depth(24), stride=2, scope='conv2')
126+
x = block2(x, depth(24), scope='conv3')
127+
128+
x = block2(x, depth(32), stride=2, scope='conv4')
129+
x = block2(x, depth(32), scope='conv5')
130+
x = block2(x, depth(32), scope='conv6')
131+
132+
x = block2(x, depth(64), stride=2, scope='conv7')
133+
x = block2(x, depth(64), scope='conv8')
134+
x = block2(x, depth(64), scope='conv9')
135+
x = block2(x, depth(64), scope='conv10')
136+
137+
x = block2(x, depth(96), scope='conv11')
138+
x = block2(x, depth(96), scope='conv12')
139+
x = block2(x, depth(96), scope='conv13')
140+
141+
x = block2(x, depth(160), stride=2, scope='conv14')
142+
x = block2(x, depth(160), scope='conv15')
143+
x = block2(x, depth(160), scope='conv16')
144+
145+
x = block2(x, depth(320), scope='conv17')
146+
x = conv(x, 1280 * depth_multiplier if depth_multiplier > 1. else 1280, 1,
147+
scope='conv18')
148+
if stem: return x
149+
150+
x = reduce_mean(x, [1, 2], name='avgpool')
151+
x = fc(x, classes, scope='logits')
152+
x = softmax(x, name='probs')
153+
return x
154+
155+
84156
@var_scope('mobilenet25')
85157
@set_args(__args__)
86158
def mobilenet25(x, is_training=False, classes=1000,
@@ -109,8 +181,56 @@ def mobilenet100(x, is_training=False, classes=1000,
109181
return mobilenet(x, 1.0, is_training, classes, stem, scope, reuse)
110182

111183

184+
@var_scope('mobilenet35v2')
185+
@set_args(__args_v2__)
186+
def mobilenet35v2(x, is_training=False, classes=1000,
187+
stem=False, scope=None, reuse=None):
188+
return mobilenetv2(x, 0.35, is_training, classes, stem, scope, reuse)
189+
190+
191+
@var_scope('mobilenet50v2')
192+
@set_args(__args_v2__)
193+
def mobilenet50v2(x, is_training=False, classes=1000,
194+
stem=False, scope=None, reuse=None):
195+
return mobilenetv2(x, 0.50, is_training, classes, stem, scope, reuse)
196+
197+
198+
@var_scope('mobilenet75v2')
199+
@set_args(__args_v2__)
200+
def mobilenet75v2(x, is_training=False, classes=1000,
201+
stem=False, scope=None, reuse=None):
202+
return mobilenetv2(x, 0.75, is_training, classes, stem, scope, reuse)
203+
204+
205+
@var_scope('mobilenet100v2')
206+
@set_args(__args_v2__)
207+
def mobilenet100v2(x, is_training=False, classes=1000,
208+
stem=False, scope=None, reuse=None):
209+
return mobilenetv2(x, 1.0, is_training, classes, stem, scope, reuse)
210+
211+
212+
@var_scope('mobilenet130v2')
213+
@set_args(__args_v2__)
214+
def mobilenet130v2(x, is_training=False, classes=1000,
215+
stem=False, scope=None, reuse=None):
216+
return mobilenetv2(x, 1.3, is_training, classes, stem, scope, reuse)
217+
218+
219+
@var_scope('mobilenet140v2')
220+
@set_args(__args_v2__)
221+
def mobilenet140v2(x, is_training=False, classes=1000,
222+
stem=False, scope=None, reuse=None):
223+
return mobilenetv2(x, 1.4, is_training, classes, stem, scope, reuse)
224+
225+
112226
# Simple alias.
113227
MobileNet25 = mobilenet25
114228
MobileNet50 = mobilenet50
115229
MobileNet75 = mobilenet75
116230
MobileNet100 = mobilenet100
231+
MobileNet35v2 = mobilenet35v2
232+
MobileNet50v2 = mobilenet50v2
233+
MobileNet75v2 = mobilenet75v2
234+
MobileNet100v2 = mobilenet100v2
235+
MobileNet130v2 = mobilenet130v2
236+
MobileNet140v2 = mobilenet140v2

tensornets/preprocess.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ def faster_rcnn_preprocess(x):
165165
'mobilenet50': tfslim_preprocess,
166166
'mobilenet75': tfslim_preprocess,
167167
'mobilenet100': tfslim_preprocess,
168+
'mobilenetv2': tfslim_preprocess,
169+
'mobilenet35v2': tfslim_preprocess,
170+
'mobilenet50v2': tfslim_preprocess,
171+
'mobilenet75v2': tfslim_preprocess,
172+
'mobilenet100v2': tfslim_preprocess,
173+
'mobilenet130v2': tfslim_preprocess,
174+
'mobilenet140v2': tfslim_preprocess,
168175
'squeezenet': bair_preprocess,
169176
'zf': faster_rcnn_preprocess,
170177
'darknet19': darknet_preprocess,

tensornets/pretrained.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,72 @@ def load_mobilenet100(scopes, return_fn=_assign):
530530
return return_fn(scopes, values)
531531

532532

533+
def load_mobilenet35v2(scopes, return_fn=_assign):
534+
"""Converted from the [TF Slim][2]."""
535+
filename = 'mobilenet_v2_035_224.npz'
536+
weights_path = get_file(
537+
filename, __model_url__ + 'mobilenet/' + filename,
538+
cache_subdir='models',
539+
file_hash='cf758f7f8024d39365e553ec924bb395')
540+
values = parse_weights(weights_path)
541+
return return_fn(scopes, values)
542+
543+
544+
def load_mobilenet50v2(scopes, return_fn=_assign):
545+
"""Converted from the [TF Slim][2]."""
546+
filename = 'mobilenet_v2_050_224.npz'
547+
weights_path = get_file(
548+
filename, __model_url__ + 'mobilenet/' + filename,
549+
cache_subdir='models',
550+
file_hash='218d51cd1b12b03ece24054029e7005b')
551+
values = parse_weights(weights_path)
552+
return return_fn(scopes, values)
553+
554+
555+
def load_mobilenet75v2(scopes, return_fn=_assign):
556+
"""Converted from the [TF Slim][2]."""
557+
filename = 'mobilenet_v2_075_224.npz'
558+
weights_path = get_file(
559+
filename, __model_url__ + 'mobilenet/' + filename,
560+
cache_subdir='models',
561+
file_hash='25b5f6c93ebec7558a757e7a70b16b1c')
562+
values = parse_weights(weights_path)
563+
return return_fn(scopes, values)
564+
565+
566+
def load_mobilenet100v2(scopes, return_fn=_assign):
567+
"""Converted from the [TF Slim][2]."""
568+
filename = 'mobilenet_v2_100_224.npz'
569+
weights_path = get_file(
570+
filename, __model_url__ + 'mobilenet/' + filename,
571+
cache_subdir='models',
572+
file_hash='ea55ba8d51df1df59b196d2508a3f262')
573+
values = parse_weights(weights_path)
574+
return return_fn(scopes, values)
575+
576+
577+
def load_mobilenet130v2(scopes, return_fn=_assign):
578+
"""Converted from the [TF Slim][2]."""
579+
filename = 'mobilenet_v2_130_224.npz'
580+
weights_path = get_file(
581+
filename, __model_url__ + 'mobilenet/' + filename,
582+
cache_subdir='models',
583+
file_hash='a2470b36675853bbe107a406b50cd648')
584+
values = parse_weights(weights_path)
585+
return return_fn(scopes, values)
586+
587+
588+
def load_mobilenet140v2(scopes, return_fn=_assign):
589+
"""Converted from the [TF Slim][2]."""
590+
filename = 'mobilenet_v2_140_224.npz'
591+
weights_path = get_file(
592+
filename, __model_url__ + 'mobilenet/' + filename,
593+
cache_subdir='models',
594+
file_hash='6988a66bb89a088ce20e2ae97adca88b')
595+
values = parse_weights(weights_path)
596+
return return_fn(scopes, values)
597+
598+
533599
def load_squeezenet(scopes, return_fn=_assign):
534600
"""Converted from the [Caffe SqueezeNets][8]."""
535601
filename = 'squeezenet.npz'
@@ -704,6 +770,12 @@ def load_ref_faster_rcnn_vgg16_voc(scopes, return_fn=_assign):
704770
'mobilenet50': load_mobilenet50,
705771
'mobilenet75': load_mobilenet75,
706772
'mobilenet100': load_mobilenet100,
773+
'mobilenet35v2': load_mobilenet35v2,
774+
'mobilenet50v2': load_mobilenet50v2,
775+
'mobilenet75v2': load_mobilenet75v2,
776+
'mobilenet100v2': load_mobilenet100v2,
777+
'mobilenet130v2': load_mobilenet130v2,
778+
'mobilenet140v2': load_mobilenet140v2,
707779
'squeezenet': load_squeezenet,
708780
'zf': load_nothing,
709781
'darknet19': load_darknet19,

0 commit comments

Comments
 (0)