Skip to content

Commit 3dc699c

Browse files
authored
[LITE][OPENCL][Image]optimise conv2d 5x5 7x7,test=develop (#3253)
* [LITE][OPENCL][Image]optimise conv2d 5x5 7x7,test=develop * [LITE][OPENCL][Image]optimise conv2d 5x5 7x7,test=develop * [LITE][OPENCL][Image]optimise conv2d 5x5 7x7,test=develop
1 parent 9f579a7 commit 3dc699c

File tree

7 files changed

+927
-25
lines changed

7 files changed

+927
-25
lines changed

lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl renamed to lite/backends/opencl/cl_kernel/image/conv2d_1x1_opt_kernel.cl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include <cl_common.h>
22

3-
__kernel void conv2d_1x1(__private const int global_size_dim0,
3+
__kernel void conv2d_1x1_opt(__private const int global_size_dim0,
44
__private const int global_size_dim1,
55
__private const int global_size_dim2,
66
__read_only image2d_t input_image,

lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
2626
__private const int stride,
2727
__private const int pad,
2828
__private const int dilation,
29-
__private const int in_ch,
29+
__private const int batch,
30+
__private const int in_ch,
3031
__private const int in_w,
3132
__private const int in_h,
3233
__private const int out_w,
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <cl_common.h>
16+
17+
// opt version of conv5x5
18+
__kernel void conv2d_5x5_opt(__private const int item_ch,
19+
__private const int item_w,
20+
__private const int item_h,
21+
__read_only image2d_t input_image,
22+
__read_only image2d_t filter_image,
23+
#if defined(BIASE_CH) || defined(BIASE_ELE)
24+
__read_only image2d_t bias,
25+
#endif
26+
__write_only image2d_t output_image,
27+
__private const int stride,
28+
__private const int pad,
29+
__private const int dilation,
30+
__private const int batch,
31+
__private const int in_ch,
32+
__private const int in_w,
33+
__private const int in_h,
34+
__private const int out_w,
35+
__private const int out_h) {
36+
37+
const sampler_t sampler =
38+
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
39+
// filter
40+
const int filter_w = 5;
41+
const int filter_h = 5;
42+
43+
// item_id
44+
const int item_ch_id = get_global_id(0);
45+
const int item_w_id = get_global_id(1);
46+
const int item_h_id = get_global_id(2);
47+
48+
// out_width_id_per_blk and out_batch_id
49+
int out_w_base_id = item_ch_id * out_w;
50+
int out_w_id0 = item_w_id;
51+
int out_w_id1 = out_w_id0 + item_w;
52+
int out_w_id2 = out_w_id1 + item_w;
53+
int out_w_id3 = out_w_id2 + item_w;
54+
int out_w_id4 = out_w_id3 + item_w;
55+
56+
// in_width_id_per_blk and in_height_id_per_batch
57+
int in_h_id = (item_h_id % out_h) * stride - pad;
58+
int in_w_id0 = item_w_id * stride - pad;
59+
int in_w_id1 = in_w_id0 + item_w * stride;
60+
int in_w_id2 = in_w_id1 + item_w * stride;
61+
int in_w_id3 = in_w_id2 + item_w * stride;
62+
int in_w_id4 = in_w_id3 + item_w * stride;
63+
64+
#ifdef BIASE_CH
65+
66+
CL_DTYPE4 output[5];
67+
output[0] =
68+
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0));
69+
output[1] = output[0];
70+
output[2] = output[0];
71+
output[3] = output[0];
72+
output[4] = output[0];
73+
74+
#elif defined(BIASE_ELE)
75+
76+
CL_DTYPE4 output[5];
77+
output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
78+
bias,
79+
sampler,
80+
(int2)(out_w_base_id + out_w_id0, item_h_id));
81+
if (out_w_id1 < out_w) {
82+
output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR,
83+
bias,
84+
sampler,
85+
(int2)(out_w_base_id + out_w_id1, item_h_id));
86+
}
87+
if (out_w_id2 < out_w) {
88+
output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
89+
bias,
90+
sampler,
91+
(int2)(out_w_base_id + out_w_id2, item_h_id));
92+
}
93+
if (out_w_id3 < out_w) {
94+
output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
95+
bias,
96+
sampler,
97+
(int2)(out_w_base_id + out_w_id3, item_h_id));
98+
}
99+
if (out_w_id4 < out_w) {
100+
output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR,
101+
bias,
102+
sampler,
103+
(int2)(out_w_base_id + out_w_id4, item_h_id));
104+
}
105+
#else
106+
CL_DTYPE4 output[5] = {0.0f};
107+
#endif
108+
109+
CL_DTYPE4 filter[4] = {0.0f};
110+
CL_DTYPE4 filter_trans[4] = {0.0f};
111+
CL_DTYPE4 input[5] = {0.0f};
112+
113+
int filter_h_val0 = item_ch_id * 4 * filter_h;
114+
int filter_h_val1 = filter_h_val0 + filter_h;
115+
int filter_h_val2 = filter_h_val1 + filter_h;
116+
int filter_h_val3 = filter_h_val2 + filter_h;
117+
118+
for (int ch = 0; ch < (in_ch + 3) / 4; ch++) {
119+
int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0;
120+
121+
const int in_w_base_id = mul24(ch, in_w);
122+
123+
int filter_w_val = ch * filter_w;
124+
125+
for (int h = 0; h < filter_h; h++) {
126+
int in_h_val =
127+
select(in_h_id + h, -1, (in_h_id + h < 0 || in_h_id + h >= in_h));
128+
129+
for (int w = 0; w < filter_w; w++) {
130+
int in_w_val0 = select(in_w_base_id + in_w_id0 + w,
131+
-1,
132+
(in_w_id0 + w < 0 || in_w_id0 + w >= in_w));
133+
int in_w_val1 = select(in_w_base_id + in_w_id1 + w,
134+
-1,
135+
(in_w_id1 + w < 0 || in_w_id1 + w >= in_w));
136+
int in_w_val2 = select(in_w_base_id + in_w_id2 + w,
137+
-1,
138+
(in_w_id2 + w < 0 || in_w_id2 + w >= in_w));
139+
int in_w_val3 = select(in_w_base_id + in_w_id3 + w,
140+
-1,
141+
(in_w_id3 + w < 0 || in_w_id3 + w >= in_w));
142+
int in_w_val4 = select(in_w_base_id + in_w_id4 + w,
143+
-1,
144+
(in_w_id4 + w < 0 || in_w_id4 + w >= in_w));
145+
146+
filter[0] =
147+
READ_IMG_TYPE(CL_DTYPE_CHAR,
148+
filter_image,
149+
sampler,
150+
(int2)(filter_w_val + w,
151+
filter_h_val0 + h)); // in_ch:0-3,out_ch:0
152+
filter[1] =
153+
READ_IMG_TYPE(CL_DTYPE_CHAR,
154+
filter_image,
155+
sampler,
156+
(int2)(filter_w_val + w,
157+
filter_h_val1 + h)); // in_ch:0-3,out_ch:1
158+
filter[2] =
159+
READ_IMG_TYPE(CL_DTYPE_CHAR,
160+
filter_image,
161+
sampler,
162+
(int2)(filter_w_val + w,
163+
filter_h_val2 + h)); // in_ch:0-3,out_ch:2
164+
filter[3] =
165+
READ_IMG_TYPE(CL_DTYPE_CHAR,
166+
filter_image,
167+
sampler,
168+
(int2)(filter_w_val + w,
169+
filter_h_val3 + h)); // in_ch:0-3,out_ch:3
170+
171+
filter_trans[0] = (CL_DTYPE4)(filter[0].x,
172+
filter[1].x,
173+
filter[2].x,
174+
filter[3].x); // in_ch:0,out_ch:0-3
175+
filter_trans[1] = (CL_DTYPE4)(filter[0].y,
176+
filter[1].y,
177+
filter[2].y,
178+
filter[3].y); // in_ch:1,out_ch:0-3
179+
filter_trans[2] = (CL_DTYPE4)(filter[0].z,
180+
filter[1].z,
181+
filter[2].z,
182+
filter[3].z); // in_ch:2,out_ch:0-3
183+
filter_trans[3] = (CL_DTYPE4)(filter[0].w,
184+
filter[1].w,
185+
filter[2].w,
186+
filter[3].w); // in_ch:3,out_ch:0-3
187+
188+
input[0] = READ_IMG_TYPE(
189+
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val0, in_h_val));
190+
input[1] = READ_IMG_TYPE(
191+
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val1, in_h_val));
192+
input[2] = READ_IMG_TYPE(
193+
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val2, in_h_val));
194+
input[3] = READ_IMG_TYPE(
195+
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val3, in_h_val));
196+
input[4] = READ_IMG_TYPE(
197+
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val4, in_h_val));
198+
199+
output[0] = mad(input[0].x, filter_trans[0], output[0]);
200+
output[1] = mad(input[1].x, filter_trans[0], output[1]);
201+
output[2] = mad(input[2].x, filter_trans[0], output[2]);
202+
output[3] = mad(input[3].x, filter_trans[0], output[3]);
203+
output[4] = mad(input[4].x, filter_trans[0], output[4]);
204+
205+
if (ch_surplus < 3) {
206+
output[0] = mad(input[0].y, filter_trans[1], output[0]);
207+
output[1] = mad(input[1].y, filter_trans[1], output[1]);
208+
output[2] = mad(input[2].y, filter_trans[1], output[2]);
209+
output[3] = mad(input[3].y, filter_trans[1], output[3]);
210+
output[4] = mad(input[4].y, filter_trans[1], output[4]);
211+
}
212+
if (ch_surplus < 2) {
213+
output[0] = mad(input[0].z, filter_trans[2], output[0]);
214+
output[1] = mad(input[1].z, filter_trans[2], output[1]);
215+
output[2] = mad(input[2].z, filter_trans[2], output[2]);
216+
output[3] = mad(input[3].z, filter_trans[2], output[3]);
217+
output[4] = mad(input[4].z, filter_trans[2], output[4]);
218+
}
219+
if (ch_surplus < 1) {
220+
output[0] = mad(input[0].w, filter_trans[3], output[0]);
221+
output[1] = mad(input[1].w, filter_trans[3], output[1]);
222+
output[2] = mad(input[2].w, filter_trans[3], output[2]);
223+
output[3] = mad(input[3].w, filter_trans[3], output[3]);
224+
output[4] = mad(input[4].w, filter_trans[3], output[4]);
225+
}
226+
}
227+
}
228+
}
229+
230+
output[0] = activation_type4(output[0]);
231+
output[1] = activation_type4(output[1]);
232+
output[2] = activation_type4(output[2]);
233+
output[3] = activation_type4(output[3]);
234+
output[4] = activation_type4(output[4]);
235+
236+
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
237+
output_image,
238+
(int2)(out_w_base_id + out_w_id0, item_h_id),
239+
output[0]);
240+
if (out_w_id1 < out_w) {
241+
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
242+
output_image,
243+
(int2)(out_w_base_id + out_w_id1, item_h_id),
244+
output[1]);
245+
}
246+
if (out_w_id2 < out_w) {
247+
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
248+
output_image,
249+
(int2)(out_w_base_id + out_w_id2, item_h_id),
250+
output[2]);
251+
}
252+
if (out_w_id3 < out_w) {
253+
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
254+
output_image,
255+
(int2)(out_w_base_id + out_w_id3, item_h_id),
256+
output[3]);
257+
}
258+
if (out_w_id4 < out_w) {
259+
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
260+
output_image,
261+
(int2)(out_w_base_id + out_w_id4, item_h_id),
262+
output[4]);
263+
}
264+
}

0 commit comments

Comments
 (0)