Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions fluid/fcn/README.cn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
**FCN语义分割**
---

**概述**
---
FCN全称:Fully Convolutional Networks for Semantic Segmentation, 是基于深度学习算法完成图像语义分割任务的开创性工作[1]。本示例旨在介绍如何使用PaddlePaddle中的FCN模型进行语义分割。下面首先简要介绍FCN原理,然后介绍示例包含文件及如何使用,接着介绍如何在PASCAL VOC数据集上训练和测试模型。

**FCN原理**
---
FCN基于卷积神经网络实现“端到端”的分割:输入是测试图像,输出为分割结果。论文基于VGG16[2]作为基础网络进行特征提取,不过对基础网络进行了改写以适应图像语义分割任务,具体包含:
1. 将网络中全连接层转化为卷积操作,以接受任意大小的输入图像。
2. 使用转置卷积的方式对特征图进行上采样,以输出和输入图像相同分辨率的特征图。
3. 引入Skip-Connection的连接方式,在网络深层引入浅层信息,以得到更精细的分割结果。

下图为FCN框架:
![FCN框架](https://github.com/chengyuz/models/blob/yucheng/fluid/fcn/images/fcn_network.png?raw=true)

深度网络浅层具有丰富的空间细节信息,而语义信息主要集中于网络深层,由此论文在网络深层引入浅层信息作为补充。具体来说,论文中提出了三个分割模型:FCN-32s,FCN-16s和FCN-8s,FCN-32s直接使用转置卷积的方式对pool5层的输出进行上采样;FCN-16s首先对pool5层的输出进行上采样,然后和pool4层的输出使用sum操作进行特征融合,再进行上采样;FCN-8s引入了更浅层的pool3层信息进行特征融合。

**示例总览**
---
本示例共包含以下文件:

表1. 示例文件

文件 | 用途 |
------------------------- | ------------------------------------- |
train.py | 训练脚本 |
infer.py | 测试脚本,给定图片及模型,完成测试 |
vgg_fcn.py | FCN模型框架定义脚本 |
data_provider.py | 数据处理脚本,生成训练和测试数据 |
utils.py | 常用函数脚本 |
data/prepare_voc_data.py | 准备PASCAL VOC训练和测试文件 |

**PASCAL VOC数据集**
---
**数据准备**

1. 请首先下载数据集:[VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html)[3]训练集和[VOC2007](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html)[4]测试集。将下载好的数据解压,目录结构为:`data/VOCdevkit/VOC2012`和`data/VOCdevkit/VOC2007`。
2. 进入`data`目录,运行`python prepare_voc_data.py`,即可生成`voc2012_trainval.txt`和`voc2007_test.txt`。

下面是`voc2012_trainval.txt`前几行输入示例:
```
VOCdevkit/VOC2012/JPEGImages/2007_000032.jpg voc_processed/2007_000032.png
VOCdevkit/VOC2012/JPEGImages/2007_000033.jpg voc_processed/2007_000033.png
VOCdevkit/VOC2012/JPEGImages/2007_000039.jpg voc_processed/2007_000039.png
```
下面是`voc2007_test.txt`前几行输入示例:
```
VOCdevkit/VOC2007/JPEGImages/000068.jpg
VOCdevkit/VOC2007/JPEGImages/000175.jpg
VOCdevkit/VOC2007/JPEGImages/000243.jpg
```

**预训练模型准备**

下载预训练的VGG16模型,我们提供了一个转化好的模型,下载地址:[VGG16](https://pan.baidu.com/s/1ekZ5O-lp3lGvAOZ4KSXKDQ),将其放置到:`models/vgg16_weights.tar`, 然后解压用于初始化。

**模型训练**

直接执行`python train.py --fcn_arch fcn-32s`即可训练FCN-32s模型,现在同时支持FCN-16s和FCN-8s分割模型。`train.py`中关键逻辑:
```python
weights_dict = resolve_caffe_model(args.pretrain_model)
for k, v in weights_dict.items():
_tensor = fluid.global_scope().find_var(k).get_tensor()
_shape = np.array(_tensor).shape
_tensor.set(v, place)

data_args = data_provider.Settings(
data_dir=args.data_dir,
resize_h=args.img_height,
resize_w=args.img_width,
mean_value=mean_value)
```
主要包括:
1. 调用`resolve_caffe_model`得到预训练模型参数,然后基于fluid中tensor的`set`函数为模型赋初值。
2. 调用`data_provider.Settings`配置数据预处理参数,运行时可通过命令行对相应参数进行配置。
3. 训练中每隔一定epoch会调用`fluid.io.save_inference_model`存储模型。

下面给出了FCN-32s,FCN-16s和FCN-8s在VOC数据集上训练的Loss曲线:

![FCN训练损失曲线](https://github.com/chengyuz/models/blob/yucheng/fluid/fcn/images/train_loss.png?raw=true)

**模型测试**

执行`python infer.py --fcn_arch fcn-32s`即可使用训练好的FCN-32s模型对输入图片进行分割,预测结果保存在`demo`文件夹,具体可通过`--vis_dir`进行配置。`infer.py`中关键逻辑:
```python
model_dir = os.path.join(args.model_dir, '%s-model' % args.fcn_arch)
assert(os.path.exists(model_dir))
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_dir, exe)

predict = exe.run(inference_program, feed={feed_target_names[0]:img_data}, fetch_list=fetch_targets)
res = np.argmax(np.squeeze(predict[0]), axis=0)
res = convert_to_color_label(res)
```
主要包括:
1. 调用`fluid.io.load_inference_model`加载训练好的模型。
2. 调用`convert_to_color_label`将模型预测结果可视化为VOC对应格式。

下图是FCN-32s,FCN-16s和FCN-8s的部分测试结果:

![FCN-32s分割结果](https://github.com/chengyuz/models/blob/yucheng/fluid/fcn/images/seg_res.png?raw=true)

我们提供了训练好的模型用于测试:
[FCN-32s](https://pan.baidu.com/s/1j8pltdzgssmxbXFgHWmCNQ)(密码:dk0i)
[FCN-16s](https://pan.baidu.com/s/1idapCRSxWsJKSqqswUGDSw)(密码:q8gu)
[FCN-8s](https://pan.baidu.com/s/1GcO-mcOWo_VF65X3xwPnpA)(密码:du9x)

**引用**
---
1. Jonathan Long, Evan Shelhamer, Trevor Darrell. [Fully convolutional networks for semantic segmentation](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf). IEEE conference on computer vision and pattern recognition, 2015.
2. Simonyan, Karen, and Andrew Zisserman. [Very deep convolutional networks for large-scale image recognition](https://arxiv.org/abs/1409.1556). arXiv preprint arXiv:1409.1556 (2014).
3. [Visual Object Classes Challenge 2012 (VOC2012)](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html)
4. [The PASCAL Visual Object Classes Challenge 2007](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html)
112 changes: 112 additions & 0 deletions fluid/fcn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
**Fully Convolutional Networks for Semantic Segmentation**
---

**Introduction**
---
FCN[1](Fully Convolutional Networks) is one of the pioneering work in semantic segmentation. This example demonstrates how to use the FCN model in PaddlePaddle for image segmentation. We first provide a brief introduction to the FCN principle, and then describe how to train and evaluate the model in PASCAL VOC dataset.

**FCN Architecture**
---
FCN is an end-to-end network for semantic segmentation, it takes the input image and with a froward propagation, the output is the predicted result. FCN is based on VGG16[2], but differs as following:
1. Convert the fully connected layers into fully convolutional layers, so as to take input of arbitrary size.
2. The deconvolutional layer is used to upsample the feature map to the input dimensions.
3. The skip-connection architecture is defined to combine deep, coarse, semantic information and shallow, fine, apperance information.

The overall structure of FCN is shown below:
![FCN_ARCH](https://github.com/chengyuz/models/blob/yucheng/fluid/fcn/images/fcn_network.png?raw=true)

FCN learns to combine coarse, high layer information with fine, low layer information. Layers are shown as grids that reveal relative spatial coarseness. Only pooling and prediction layers are shown, intermediate convolutional layers are omitted. FCN-32s upsamples stride 32 predictions back to pixels in a single step. FCN-16s combines predictions from both the final layer and the pool4 layer, at stride 16, so the net predict finer details, while retaining high-level semantic information. FCN-8s adds predictions from pool3, at stride 8, provide further precision.

**Example Overview**
---
This example contains the following files:

Table 1. Directory structure

File | Description |
------------------------- | ------------------------------------- |
train.py | Training script |
infer.py | Prediction using the trained model |
vgg_fcn.py | Defining FCN structure |
data_provider.py | Data processing scripts, generating train and test data |
utils.py | Contains common functions |
data/prepare_voc_data.py | Prepare PASCAL VOC data list for training and test |

**PASCAL VOC Data set**
---
**Data Preparation**

First download the data set: [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html)[3] train dataset and [VOC2007](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html)[4] test dataset, and then unzip the data as `data/VOCdevkit/VOC2012` and `data/VOCdevkit/VOC2007`.

Next, run `python prepare_voc_data.py` to generate `voc2012_trainval.txt` and `voc2007_test.txt`.

The data in `voc2012_trainval.txt` will look like:
```
VOCdevkit/VOC2012/JPEGImages/2007_000032.jpg voc_processed/2007_000032.png
VOCdevkit/VOC2012/JPEGImages/2007_000033.jpg voc_processed/2007_000033.png
VOCdevkit/VOC2012/JPEGImages/2007_000039.jpg voc_processed/2007_000039.png
```
The data in `voc2007_test.txt` will look like:
```
VOCdevkit/VOC2007/JPEGImages/000068.jpg
VOCdevkit/VOC2007/JPEGImages/000175.jpg
VOCdevkit/VOC2007/JPEGImages/000243.jpg
```

**To Use Pre-trained Model**

We also provide a pre-trained model of VGG16. To use the model, download the file: [VGG16](https://pan.baidu.com/s/1ekZ5O-lp3lGvAOZ4KSXKDQ) and place it in: `models/vgg16_weights.tar`, and then unzip.

**Training**

Next, run `python train.py --fcn_arch fcn-32s` to train the FCN-32s model, we also provide model structure of FCN-16s and FCN-8s. The relevant function is as following:
```python
weights_dict = resolve_caffe_model(args.pretrain_model)
for k, v in weights_dict.items():
_tensor = fluid.global_scope().find_var(k).get_tensor()
_shape = np.array(_tensor).shape
_tensor.set(v, place)

data_args = data_provider.Settings(
data_dir=args.data_dir,
resize_h=args.img_height,
resize_w=args.img_width,
mean_value=mean_value)
```
Below is the description about this script:
1. Call `resolve_caffe_model` to get the pre-trained model parameters, and then use the `set` function in fluid to initialize the model.
2. Call `data_provider.Settings` to pass configuration parameters, which can be set by command line.
3. Call `fluid.io.save_inference_model` to save the model per epoch.

Below is the training loss of FCN-32s, FCN-16s and FCN-8s in VOC dataset.
![FCN_LOSS](https://github.com/chengyuz/models/blob/yucheng/fluid/fcn/images/train_loss.png?raw=true)

**Model Assessment**

Run `python infer.py` to evaluate the trained model, the predicted result is save in `demo` directory, which can be set by `--vis_dir` in command line. The relevant function is as following:
```python
model_dir = os.path.join(args.model_dir, '%s-model' % args.fcn_arch)
assert(os.path.exists(model_dir))
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_dir, exe)

predict = exe.run(inference_program, feed={feed_target_names[0]:img_data}, fetch_list=fetch_targets)
res = np.argmax(np.squeeze(predict[0]), axis=0)
res = convert_to_color_label(res)
```
Description:
the `fluid.io.load_inference_model` is called to load the trained model, the `convert_to_color_label` function is used to visualize the predicted as VOC format.

Below is the segmentation result of FCN-32s, FCN-16s and FCN-8s:
![FCN-32s-seg](https://github.com/chengyuz/models/blob/yucheng/fluid/fcn/images/seg_res.png?raw=true)

We provide the trained FCN model:
[FCN-32s](https://pan.baidu.com/s/1j8pltdzgssmxbXFgHWmCNQ)[Password: dk0i]
[FCN-16s](https://pan.baidu.com/s/1idapCRSxWsJKSqqswUGDSw)(Password: q8gu)
[FCN-8s](https://pan.baidu.com/s/1GcO-mcOWo_VF65X3xwPnpA)(Password: du9x)

**References**
---
1. Jonathan Long, Evan Shelhamer, Trevor Darrell. [Fully convolutional networks for semantic segmentation](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf). IEEE conference on computer vision and pattern recognition, 2015.
2. Simonyan, Karen, and Andrew Zisserman. [Very deep convolutional networks for large-scale image recognition](https://arxiv.org/abs/1409.1556). arXiv preprint arXiv:1409.1556 (2014).
3. [Visual Object Classes Challenge 2012 (VOC2012)](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html)
4. [The PASCAL Visual Object Classes Challenge 2007](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html)
123 changes: 123 additions & 0 deletions fluid/fcn/data/prepare_voc_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import absolute_import
import os
import sys
import cv2
import numpy as np
import shutil
import pdb


def pascal_classes():
classes = {
'aeroplane': 1,
'bicycle': 2,
'bird': 3,
'boat': 4,
'bottle': 5,
'bus': 6,
'car': 7,
'cat': 8,
'chair': 9,
'cow': 10,
'diningtable': 11,
'dog': 12,
'horse': 13,
'motorbike': 14,
'person': 15,
'pottedplant': 16,
'sheep': 17,
'sofa': 18,
'train': 19,
'tvmonitor': 20
}
return classes


def process_dir(_dir):
if os.path.exists(_dir):
shutil.rmtree(_dir)
os.makedirs(_dir)


def pascal_palette():
palette = {
(0, 0, 0): 0,
(128, 0, 0): 1,
(0, 128, 0): 2,
(128, 128, 0): 3,
(0, 0, 128): 4,
(128, 0, 128): 5,
(0, 128, 128): 6,
(128, 128, 128): 7,
(64, 0, 0): 8,
(192, 0, 0): 9,
(64, 128, 0): 10,
(192, 128, 0): 11,
(64, 0, 128): 12,
(192, 0, 128): 13,
(64, 128, 128): 14,
(192, 128, 128): 15,
(0, 64, 0): 16,
(128, 64, 0): 17,
(0, 192, 0): 18,
(128, 192, 0): 19,
(0, 64, 128): 20,
(224, 224, 192): 0
}
return palette


def convert_from_color_label(img):
'''Convert the Pascal VOC label format to train.

Args:
img: The label result of Pascal VOC.
'''
palette = pascal_palette()
for c, i in palette.items():
_c = (c[2], c[1], c[0]) # the image channel read by opencv is (b, g, r)
m = np.all(img == np.array(_c).reshape(1, 1, 3), axis=2)
img[m] = i
return img


def main():
out_dir = 'voc_processed'
process_dir(out_dir)
out_train_f = open('voc2012_trainval.txt', 'w')
out_test_f = open('voc2007_test.txt', 'w')

devkit_dir = 'VOCdevkit'
trainval_file = os.path.join(devkit_dir, 'VOC2012', 'ImageSets',
'Segmentation', 'trainval.txt')
segclass_dir = os.path.join(devkit_dir, 'VOC2012', 'SegmentationClass')
train_image_dir = os.path.join(devkit_dir, 'VOC2012', 'JPEGImages')
test_image_dir = os.path.join(devkit_dir, 'VOC2007', 'JPEGImages')
test_file = os.path.join(devkit_dir, 'VOC2007', 'ImageSets', 'Segmentation',
'test.txt')

with open(trainval_file, 'r') as input_f:
for line in input_f:
img = cv2.imread(
os.path.join(segclass_dir, '%s.png' % line.strip()))
img = convert_from_color_label(img)

out_label_path = os.path.join(out_dir, '%s.png' % line.strip())
cv2.imwrite(out_label_path, img)

img_path = os.path.join(train_image_dir, '%s.jpg' % line.strip())
assert (os.path.exists(img_path))

out_train_f.write('%s %s \n' % (img_path, out_label_path))
out_train_f.flush()
out_train_f.close()

with open(test_file, 'r') as input_f:
for line in input_f:
img_path = os.path.join(test_image_dir, '%s.jpg \n' % line.strip())
out_test_f.write(img_path)
out_test_f.close()


if __name__ == '__main__':
main()
Loading