Skip to content

Commit 03fd5f6

Browse files
authored
Merge pull request #2686 from qingqing01/row_conv_fix
Fix bug for flowers dataset and row_conv.
2 parents ea641da + 0925681 commit 03fd5f6

File tree

4 files changed

+37
-13
lines changed

4 files changed

+37
-13
lines changed

python/paddle/trainer/config_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,10 +2082,10 @@ def __init__(self, name, inputs, **xargs):
20822082
class RowConvLayer(LayerBase):
20832083
def __init__(self, name, inputs, context_length, **xargs):
20842084
super(RowConvLayer, self).__init__(
2085-
name, 'maxout', 0, inputs=inputs, **xargs)
2085+
name, 'row_conv', 0, inputs=inputs, **xargs)
20862086
config_assert(
20872087
len(self.inputs) == 1,
2088-
'TransLayer must have one and only one input')
2088+
'row convolution layer must have one and only one input.')
20892089
input_layer = self.get_input_layer(0)
20902090
row_conv_conf = self.config.inputs[0].row_conv_conf
20912091
row_conv_conf.context_length = context_length

python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_conv.protostr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ layers {
77
}
88
layers {
99
name: "__row_conv_layer_0__"
10-
type: "maxout"
10+
type: "row_conv"
1111
size: 2560
1212
active_type: "relu"
1313
inputs {

python/paddle/v2/dataset/flowers.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131
import cPickle
3232
import itertools
33+
import functools
3334
from common import download
3435
import tarfile
3536
import scipy.io as scio
@@ -54,21 +55,26 @@
5455
VALID_FLAG = 'valid'
5556

5657

57-
def default_mapper(sample):
58+
def default_mapper(is_train, sample):
5859
'''
5960
map image bytes data to type needed by model input layer
6061
'''
6162
img, label = sample
6263
img = load_image_bytes(img)
63-
img = simple_transform(img, 256, 224, True)
64+
img = simple_transform(
65+
img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
6466
return img.flatten().astype('float32'), label
6567

6668

69+
train_mapper = functools.partial(default_mapper, True)
70+
test_mapper = functools.partial(default_mapper, False)
71+
72+
6773
def reader_creator(data_file,
6874
label_file,
6975
setid_file,
7076
dataset_name,
71-
mapper=default_mapper,
77+
mapper,
7278
buffered_size=1024,
7379
use_xmap=True):
7480
'''
@@ -118,7 +124,7 @@ def reader():
118124
return map_readers(mapper, reader)
119125

120126

121-
def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
127+
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
122128
'''
123129
Create flowers training set reader.
124130
It returns a reader, each sample in the reader is
@@ -141,7 +147,7 @@ def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
141147
buffered_size, use_xmap)
142148

143149

144-
def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
150+
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
145151
'''
146152
Create flowers test set reader.
147153
It returns a reader, each sample in the reader is
@@ -164,7 +170,7 @@ def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
164170
buffered_size, use_xmap)
165171

166172

167-
def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
173+
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
168174
'''
169175
Create flowers validation set reader.
170176
It returns a reader, each sample in the reader is

python/paddle/v2/image.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,12 @@ def left_right_flip(im):
262262
return im[:, ::-1, :]
263263

264264

265-
def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
265+
def simple_transform(im,
266+
resize_size,
267+
crop_size,
268+
is_train,
269+
is_color=True,
270+
mean=None):
266271
"""
267272
Simply data argumentation for training. These operations include
268273
resizing, croping and flipping.
@@ -288,7 +293,19 @@ def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
288293
im = left_right_flip(im)
289294
else:
290295
im = center_crop(im, crop_size)
291-
im = to_chw(im)
296+
if len(im.shape) == 3:
297+
im = to_chw(im)
298+
299+
im = im.astype('float32')
300+
if mean is not None:
301+
mean = np.array(mean, dtype=np.float32)
302+
# mean value, may be one value per channel
303+
if mean.ndim == 1:
304+
mean = mean[:, np.newaxis, np.newaxis]
305+
else:
306+
# elementwise mean
307+
assert len(mean.shape) == len(im)
308+
im -= mean
292309

293310
return im
294311

@@ -297,7 +314,8 @@ def load_and_transform(filename,
297314
resize_size,
298315
crop_size,
299316
is_train,
300-
is_color=True):
317+
is_color=True,
318+
mean=None):
301319
"""
302320
Load image from the input file `filename` and transform image for
303321
data argumentation. Please refer to the `simple_transform` interface
@@ -318,5 +336,5 @@ def load_and_transform(filename,
318336
:type is_train: bool
319337
"""
320338
im = load_image(filename)
321-
im = simple_transform(im, resize_size, crop_size, is_train, is_color)
339+
im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean)
322340
return im

0 commit comments

Comments
 (0)