Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
dedd1a9
revise se_resnext for imagenet classification
BigFishMaster Apr 10, 2018
7103698
Add the method to train a SE-ResNeXt model
BigFishMaster Apr 10, 2018
c51499f
Update ImageNet2012 URL
BigFishMaster Apr 12, 2018
7401df1
update
BigFishMaster Apr 16, 2018
5dde5bf
update readme
BigFishMaster Apr 16, 2018
58e8075
Merge branch 'develop' of https://github.com/PaddlePaddle/models into…
BigFishMaster Apr 17, 2018
e833a60
Update readme with download URL
BigFishMaster Apr 17, 2018
2338db7
Merge branch 'model_update' of https://github.com/BigFishMaster/model…
BigFishMaster Apr 17, 2018
5504ee0
add unzip
BigFishMaster Apr 17, 2018
52bac00
train.py update
BigFishMaster Apr 26, 2018
369bdb1
Merge branch 'develop' into model_update
BigFishMaster Apr 26, 2018
cc2fa70
update se_resnext.py with parallel_exe
BigFishMaster Apr 26, 2018
dd9271a
Merge branch 'model_update' of https://github.com/BigFishMaster/model…
BigFishMaster Apr 26, 2018
a8703bf
delete train function in se_resnext.py
BigFishMaster Apr 26, 2018
aa08d91
Merge branch 'develop' into model_update
BigFishMaster Apr 26, 2018
a67370c
add eval.py and infer.py
BigFishMaster Apr 26, 2018
3b0ecb5
add eval.py and infer.py
BigFishMaster Apr 26, 2018
bf17277
add eval.py and infer.py
BigFishMaster Apr 26, 2018
52258be
move cosine_decay from se_resnext.py to train.py
BigFishMaster Apr 28, 2018
e270bb9
Merge branch 'develop' of https://github.com/PaddlePaddle/models into…
BigFishMaster Apr 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions fluid/image_classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,82 @@ The minimum PaddlePaddle version needed for the code sample in this directory is

This model built with paddle fluid is still under active development and is not
the final version. We welcome feedbacks.

## Introduction

The current code support the training of [SE-ResNeXt](https://arxiv.org/abs/1709.01507) (50/152 layers).

## Data Preparation

1. Download ImageNet-2012 dataset
```
cd data/
mkdir -p ILSVRC2012/
cd ILSVRC2012/
# get training set
wget http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar
# get validation set
wget http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar
# prepare directory
tar xf ILSVRC2012_img_train.tar
tar xf ILSVRC2012_img_val.tar

# unzip all classes data using unzip.sh
sh unzip.sh
```

2. Download training and validation label files from [ImageNet2012 url](https://pan.baidu.com/s/1Y6BCo0nmxsm_FsEqmx2hKQ)(password:```wx99```). Untar it into workspace ```ILSVRC2012/```. The files include

**train_list.txt**: training list of imagenet 2012 classification task, with each line seperated by SPACE.
```
train/n02483708/n02483708_2436.jpeg 369
train/n03998194/n03998194_7015.jpeg 741
train/n04523525/n04523525_38118.jpeg 884
train/n04596742/n04596742_3032.jpeg 909
train/n03208938/n03208938_7065.jpeg 535
...
```
**val_list.txt**: validation list of imagenet 2012 classification task, with each line seperated by SPACE.
```
val/ILSVRC2012_val_00000001.jpeg 65
val/ILSVRC2012_val_00000002.jpeg 970
val/ILSVRC2012_val_00000003.jpeg 230
val/ILSVRC2012_val_00000004.jpeg 809
val/ILSVRC2012_val_00000005.jpeg 516
...
```
**synset_words.txt**: the semantic label of each class.

## Training a model

To start a training task, one can use command line as:

```
python train.py --num_layers=50 --batch_size=8 --with_mem_opt=True --parallel_exe=False
```
## Finetune a model
```
python train.py --num_layers=50 --batch_size=8 --with_mem_opt=True --parallel_exe=False --pretrained_model="pretrain/96/"
```
TBD
## Inference
```
python infer.py --num_layers=50 --batch_size=8 --model='model/90' --test_list=''
```
TBD

## Results

The SE-ResNeXt-50 model is trained by starting with learning rate ```0.1``` and decaying it by ```0.1``` after each ```10``` epoches. Top-1/Top-5 Validation Accuracy on ImageNet 2012 is listed in table.

|model | [original paper(Fig.5)](https://arxiv.org/abs/1709.01507) | Pytorch | Paddle fluid
|- | :-: |:-: | -:
|SE-ResNeXt-50 | 77.6%/- | 77.71%/93.63% | 77.42%/93.50%



## Released models
|model | Baidu Cloud
|- | -:
|SE-ResNeXt-50 | [url]()
TBD
9 changes: 9 additions & 0 deletions fluid/image_classification/data/ILSVRC2012/unzip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cd train

dir=./
for x in `ls *.tar`
do
filename=`basename $x .tar`
mkdir $filename
tar -xvf $x -C ./$filename
done
83 changes: 83 additions & 0 deletions fluid/image_classification/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import sys
import numpy as np
import argparse
import functools

import paddle
import paddle.fluid as fluid
from utility import add_arguments, print_arguments
from se_resnext import SE_ResNeXt
import reader

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('test_list', str, '', "The testing data lists.")
add_arg('num_layers', int, 50, "How many layers for SE-ResNeXt model.")
add_arg('model_dir', str, '', "The model path.")
# yapf: enable


def eval(args):
class_dim = 1000
image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = SE_ResNeXt(input=image, class_dim=class_dim, layers=args.num_layers)
cost = fluid.layers.cross_entropy(input=out, label=label)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
avg_cost = fluid.layers.mean(x=cost)

inference_program = fluid.default_main_program().clone(for_test=True)

place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)

if not os.path.exists(args.model_dir):
raise ValueError("The model path [%s] does not exist." %
(args.model_dir))
if not os.path.exists(args.test_list):
raise ValueError("The test lists [%s] does not exist." %
(args.test_list))

def if_exist(var):
return os.path.exists(os.path.join(args.model_dir, var.name))

fluid.io.load_vars(exe, args.model_dir, predicate=if_exist)

test_reader = paddle.batch(
reader.test(args.test_list), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])

fetch_list = [avg_cost, acc_top1, acc_top5]

test_info = [[], [], []]
for batch_id, data in enumerate(test_reader()):
loss, acc1, acc5 = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=fetch_list)
test_info[0].append(loss[0])
test_info[1].append(acc1[0])
test_info[2].append(acc5[0])
if batch_id % 1 == 0:
print("Test {0}, loss {1}, acc1 {2}, acc5 {3}"
.format(batch_id, loss[0], acc1[0], acc5[0]))
sys.stdout.flush()

test_loss = np.array(test_info[0]).mean()
test_acc1 = np.array(test_info[1]).mean()
test_acc5 = np.array(test_info[2]).mean()

print("Test loss {0}, acc1 {1}, acc5 {2}".format(test_loss, test_acc1,
test_acc5))
sys.stdout.flush()


if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
eval(args)
69 changes: 69 additions & 0 deletions fluid/image_classification/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import sys
import numpy as np
import argparse
import functools

import paddle
import paddle.fluid as fluid
from utility import add_arguments, print_arguments
from se_resnext import SE_ResNeXt
import reader

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('test_list', str, '', "The testing data lists.")
add_arg('num_layers', int, 50, "How many layers for SE-ResNeXt model.")
add_arg('model_dir', str, '', "The model path.")
# yapf: enable


def infer(args):
class_dim = 1000
image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
out = SE_ResNeXt(input=image, class_dim=class_dim, layers=args.num_layers)
out = fluid.layers.softmax(input=out)

inference_program = fluid.default_main_program().clone(for_test=True)

place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)

if not os.path.exists(args.model_dir):
raise ValueError("The model path [%s] does not exist." %
(args.model_dir))
if not os.path.exists(args.test_list):
raise ValueError("The test lists [%s] does not exist." %
(args.test_list))

def if_exist(var):
return os.path.exists(os.path.join(args.model_dir, var.name))

fluid.io.load_vars(exe, args.model_dir, predicate=if_exist)

test_reader = paddle.batch(
reader.infer(args.test_list), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image])

fetch_list = [out]

TOPK = 1
for batch_id, data in enumerate(test_reader()):
result = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=fetch_list)
result = result[0]
pred_label = np.argsort(result)[::-1][0][0]
print("Test {0}-score {1}, class {2}: "
.format(batch_id, result[0][pred_label], pred_label))
sys.stdout.flush()


if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
infer(args)
18 changes: 9 additions & 9 deletions fluid/image_classification/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import functools
import numpy as np
import paddle.v2 as paddle
import paddle
from PIL import Image, ImageEnhance

random.seed(0)
Expand All @@ -13,9 +13,9 @@
THREAD = 8
BUF_SIZE = 1024

DATA_DIR = 'ILSVRC2012'
TRAIN_LIST = 'ILSVRC2012/train_list.txt'
TEST_LIST = 'ILSVRC2012/test_list.txt'
DATA_DIR = 'data/ILSVRC2012'
TRAIN_LIST = 'data/ILSVRC2012/train_list.txt'
TEST_LIST = 'data/ILSVRC2012/val_list.txt'

img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
Expand Down Expand Up @@ -123,7 +123,7 @@ def process_image(sample, mode, color_jitter, rotate):
if mode == 'train' or mode == 'test':
return img, sample[1]
elif mode == 'infer':
return img
return [img]


def _reader_creator(file_list,
Expand Down Expand Up @@ -151,13 +151,13 @@ def reader():
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)


def train():
def train(file_list=TRAIN_LIST):
return _reader_creator(
TRAIN_LIST, 'train', shuffle=True, color_jitter=True, rotate=True)
file_list, 'train', shuffle=True, color_jitter=False, rotate=False)


def test():
return _reader_creator(TEST_LIST, 'test', shuffle=False)
def test(file_list=TEST_LIST):
return _reader_creator(file_list, 'test', shuffle=False)


def infer(file_list):
Expand Down
51 changes: 34 additions & 17 deletions fluid/image_classification/se_resnext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import paddle.v2 as paddle
import os
import numpy as np
import time
import sys
import paddle
import paddle.fluid as fluid
import reader
import paddle.fluid.layers.control_flow as control_flow
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.tensor as tensor
import math


def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1,
Expand All @@ -19,23 +28,28 @@ def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1,
def squeeze_excitation(input, num_channels, reduction_ratio):
pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(input=pool,
size=num_channels / reduction_ratio,
act='relu')
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(input=squeeze,
size=num_channels,
act='sigmoid')
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(
-stdv, stdv)))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale


def shortcut(input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out:
if stride == 1:
filter_size = 1
else:
filter_size = 3
if ch_in != ch_out or stride != 1:
filter_size = 1
return conv_bn_layer(input, ch_out, filter_size, stride)
else:
return input
Expand Down Expand Up @@ -66,8 +80,8 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio):
def SE_ResNeXt(input, class_dim, infer=False, layers=50):
supported_layers = [50, 152]
if layers not in supported_layers:
print("supported layers are", supported_layers, "but input layer is",
layers)
print("supported layers are", supported_layers, \
"but input layer is ", layers)
exit()
if layers == 50:
cardinality = 32
Expand Down Expand Up @@ -96,10 +110,7 @@ def SE_ResNeXt(input, class_dim, infer=False, layers=50):
conv = conv_bn_layer(
input=conv, num_filters=128, filter_size=3, stride=1, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
input=conv, pool_size=3, pool_stride=2, pool_padding=1, \
pool_type='max')

for block in range(len(depth)):
Expand All @@ -112,10 +123,16 @@ def SE_ResNeXt(input, class_dim, infer=False, layers=50):
reduction_ratio=reduction_ratio)

pool = fluid.layers.pool2d(
input=conv, pool_size=0, pool_type='avg', global_pooling=True)
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
if not infer:
drop = fluid.layers.dropout(x=pool, dropout_prob=0.2)
drop = fluid.layers.dropout(x=pool, dropout_prob=0.5)
else:
drop = pool
out = fluid.layers.fc(input=drop, size=class_dim, act='softmax')
stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0)
out = fluid.layers.fc(input=drop,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return out
Loading