1
+ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License. */
14
+
15
+ #include <cl_common.h>
16
+
17
+ // opt version of conv5x5
18
+ __kernel void conv2d_5x5_opt (__private const int item_ch ,
19
+ __private const int item_w ,
20
+ __private const int item_h ,
21
+ __read_only image2d_t input_image ,
22
+ __read_only image2d_t filter_image ,
23
+ #if defined(BIASE_CH ) || defined(BIASE_ELE )
24
+ __read_only image2d_t bias ,
25
+ #endif
26
+ __write_only image2d_t output_image ,
27
+ __private const int stride ,
28
+ __private const int pad ,
29
+ __private const int dilation ,
30
+ __private const int batch ,
31
+ __private const int in_ch ,
32
+ __private const int in_w ,
33
+ __private const int in_h ,
34
+ __private const int out_w ,
35
+ __private const int out_h ) {
36
+
37
+ const sampler_t sampler =
38
+ CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST ;
39
+ // filter
40
+ const int filter_w = 5 ;
41
+ const int filter_h = 5 ;
42
+
43
+ // item_id
44
+ const int item_ch_id = get_global_id (0 );
45
+ const int item_w_id = get_global_id (1 );
46
+ const int item_h_id = get_global_id (2 );
47
+
48
+ // out_width_id_per_blk and out_batch_id
49
+ int out_w_base_id = item_ch_id * out_w ;
50
+ int out_w_id0 = item_w_id ;
51
+ int out_w_id1 = out_w_id0 + item_w ;
52
+ int out_w_id2 = out_w_id1 + item_w ;
53
+ int out_w_id3 = out_w_id2 + item_w ;
54
+ int out_w_id4 = out_w_id3 + item_w ;
55
+
56
+ // in_width_id_per_blk and in_height_id_per_batch
57
+ int in_h_id = (item_h_id % out_h ) * stride - pad ;
58
+ int in_w_id0 = item_w_id * stride - pad ;
59
+ int in_w_id1 = in_w_id0 + item_w * stride ;
60
+ int in_w_id2 = in_w_id1 + item_w * stride ;
61
+ int in_w_id3 = in_w_id2 + item_w * stride ;
62
+ int in_w_id4 = in_w_id3 + item_w * stride ;
63
+
64
+ #ifdef BIASE_CH
65
+
66
+ CL_DTYPE4 output [5 ];
67
+ output [0 ] =
68
+ READ_IMG_TYPE (CL_DTYPE_CHAR , bias , sampler , (int2 )(item_ch_id , 0 ));
69
+ output [1 ] = output [0 ];
70
+ output [2 ] = output [0 ];
71
+ output [3 ] = output [0 ];
72
+ output [4 ] = output [0 ];
73
+
74
+ #elif defined(BIASE_ELE )
75
+
76
+ CL_DTYPE4 output [5 ];
77
+ output [0 ] = READ_IMG_TYPE (CL_DTYPE_CHAR ,
78
+ bias ,
79
+ sampler ,
80
+ (int2 )(out_w_base_id + out_w_id0 , item_h_id ));
81
+ if (out_w_id1 < out_w ) {
82
+ output [1 ] = READ_IMG_TYPE (CL_DTYPE_CHAR ,
83
+ bias ,
84
+ sampler ,
85
+ (int2 )(out_w_base_id + out_w_id1 , item_h_id ));
86
+ }
87
+ if (out_w_id2 < out_w ) {
88
+ output [2 ] = READ_IMG_TYPE (CL_DTYPE_CHAR ,
89
+ bias ,
90
+ sampler ,
91
+ (int2 )(out_w_base_id + out_w_id2 , item_h_id ));
92
+ }
93
+ if (out_w_id3 < out_w ) {
94
+ output [3 ] = READ_IMG_TYPE (CL_DTYPE_CHAR ,
95
+ bias ,
96
+ sampler ,
97
+ (int2 )(out_w_base_id + out_w_id3 , item_h_id ));
98
+ }
99
+ if (out_w_id4 < out_w ) {
100
+ output [4 ] = READ_IMG_TYPE (CL_DTYPE_CHAR ,
101
+ bias ,
102
+ sampler ,
103
+ (int2 )(out_w_base_id + out_w_id4 , item_h_id ));
104
+ }
105
+ #else
106
+ CL_DTYPE4 output [5 ] = {0.0f };
107
+ #endif
108
+
109
+ CL_DTYPE4 filter [4 ] = {0.0f };
110
+ CL_DTYPE4 filter_trans [4 ] = {0.0f };
111
+ CL_DTYPE4 input [5 ] = {0.0f };
112
+
113
+ int filter_h_val0 = item_ch_id * 4 * filter_h ;
114
+ int filter_h_val1 = filter_h_val0 + filter_h ;
115
+ int filter_h_val2 = filter_h_val1 + filter_h ;
116
+ int filter_h_val3 = filter_h_val2 + filter_h ;
117
+
118
+ for (int ch = 0 ; ch < (in_ch + 3 ) / 4 ; ch ++ ) {
119
+ int ch_surplus = (ch + 1 ) * 4 - in_ch > 0 ? (ch + 1 ) * 4 - in_ch : 0 ;
120
+
121
+ const int in_w_base_id = mul24 (ch , in_w );
122
+
123
+ int filter_w_val = ch * filter_w ;
124
+
125
+ for (int h = 0 ; h < filter_h ; h ++ ) {
126
+ int in_h_val =
127
+ select (in_h_id + h , -1 , (in_h_id + h < 0 || in_h_id + h >= in_h ));
128
+
129
+ for (int w = 0 ; w < filter_w ; w ++ ) {
130
+ int in_w_val0 = select (in_w_base_id + in_w_id0 + w ,
131
+ -1 ,
132
+ (in_w_id0 + w < 0 || in_w_id0 + w >= in_w ));
133
+ int in_w_val1 = select (in_w_base_id + in_w_id1 + w ,
134
+ -1 ,
135
+ (in_w_id1 + w < 0 || in_w_id1 + w >= in_w ));
136
+ int in_w_val2 = select (in_w_base_id + in_w_id2 + w ,
137
+ -1 ,
138
+ (in_w_id2 + w < 0 || in_w_id2 + w >= in_w ));
139
+ int in_w_val3 = select (in_w_base_id + in_w_id3 + w ,
140
+ -1 ,
141
+ (in_w_id3 + w < 0 || in_w_id3 + w >= in_w ));
142
+ int in_w_val4 = select (in_w_base_id + in_w_id4 + w ,
143
+ -1 ,
144
+ (in_w_id4 + w < 0 || in_w_id4 + w >= in_w ));
145
+
146
+ filter [0 ] =
147
+ READ_IMG_TYPE (CL_DTYPE_CHAR ,
148
+ filter_image ,
149
+ sampler ,
150
+ (int2 )(filter_w_val + w ,
151
+ filter_h_val0 + h )); // in_ch:0-3,out_ch:0
152
+ filter [1 ] =
153
+ READ_IMG_TYPE (CL_DTYPE_CHAR ,
154
+ filter_image ,
155
+ sampler ,
156
+ (int2 )(filter_w_val + w ,
157
+ filter_h_val1 + h )); // in_ch:0-3,out_ch:1
158
+ filter [2 ] =
159
+ READ_IMG_TYPE (CL_DTYPE_CHAR ,
160
+ filter_image ,
161
+ sampler ,
162
+ (int2 )(filter_w_val + w ,
163
+ filter_h_val2 + h )); // in_ch:0-3,out_ch:2
164
+ filter [3 ] =
165
+ READ_IMG_TYPE (CL_DTYPE_CHAR ,
166
+ filter_image ,
167
+ sampler ,
168
+ (int2 )(filter_w_val + w ,
169
+ filter_h_val3 + h )); // in_ch:0-3,out_ch:3
170
+
171
+ filter_trans [0 ] = (CL_DTYPE4 )(filter [0 ].x ,
172
+ filter [1 ].x ,
173
+ filter [2 ].x ,
174
+ filter [3 ].x ); // in_ch:0,out_ch:0-3
175
+ filter_trans [1 ] = (CL_DTYPE4 )(filter [0 ].y ,
176
+ filter [1 ].y ,
177
+ filter [2 ].y ,
178
+ filter [3 ].y ); // in_ch:1,out_ch:0-3
179
+ filter_trans [2 ] = (CL_DTYPE4 )(filter [0 ].z ,
180
+ filter [1 ].z ,
181
+ filter [2 ].z ,
182
+ filter [3 ].z ); // in_ch:2,out_ch:0-3
183
+ filter_trans [3 ] = (CL_DTYPE4 )(filter [0 ].w ,
184
+ filter [1 ].w ,
185
+ filter [2 ].w ,
186
+ filter [3 ].w ); // in_ch:3,out_ch:0-3
187
+
188
+ input [0 ] = READ_IMG_TYPE (
189
+ CL_DTYPE_CHAR , input_image , sampler , (int2 )(in_w_val0 , in_h_val ));
190
+ input [1 ] = READ_IMG_TYPE (
191
+ CL_DTYPE_CHAR , input_image , sampler , (int2 )(in_w_val1 , in_h_val ));
192
+ input [2 ] = READ_IMG_TYPE (
193
+ CL_DTYPE_CHAR , input_image , sampler , (int2 )(in_w_val2 , in_h_val ));
194
+ input [3 ] = READ_IMG_TYPE (
195
+ CL_DTYPE_CHAR , input_image , sampler , (int2 )(in_w_val3 , in_h_val ));
196
+ input [4 ] = READ_IMG_TYPE (
197
+ CL_DTYPE_CHAR , input_image , sampler , (int2 )(in_w_val4 , in_h_val ));
198
+
199
+ output [0 ] = mad (input [0 ].x , filter_trans [0 ], output [0 ]);
200
+ output [1 ] = mad (input [1 ].x , filter_trans [0 ], output [1 ]);
201
+ output [2 ] = mad (input [2 ].x , filter_trans [0 ], output [2 ]);
202
+ output [3 ] = mad (input [3 ].x , filter_trans [0 ], output [3 ]);
203
+ output [4 ] = mad (input [4 ].x , filter_trans [0 ], output [4 ]);
204
+
205
+ if (ch_surplus < 3 ) {
206
+ output [0 ] = mad (input [0 ].y , filter_trans [1 ], output [0 ]);
207
+ output [1 ] = mad (input [1 ].y , filter_trans [1 ], output [1 ]);
208
+ output [2 ] = mad (input [2 ].y , filter_trans [1 ], output [2 ]);
209
+ output [3 ] = mad (input [3 ].y , filter_trans [1 ], output [3 ]);
210
+ output [4 ] = mad (input [4 ].y , filter_trans [1 ], output [4 ]);
211
+ }
212
+ if (ch_surplus < 2 ) {
213
+ output [0 ] = mad (input [0 ].z , filter_trans [2 ], output [0 ]);
214
+ output [1 ] = mad (input [1 ].z , filter_trans [2 ], output [1 ]);
215
+ output [2 ] = mad (input [2 ].z , filter_trans [2 ], output [2 ]);
216
+ output [3 ] = mad (input [3 ].z , filter_trans [2 ], output [3 ]);
217
+ output [4 ] = mad (input [4 ].z , filter_trans [2 ], output [4 ]);
218
+ }
219
+ if (ch_surplus < 1 ) {
220
+ output [0 ] = mad (input [0 ].w , filter_trans [3 ], output [0 ]);
221
+ output [1 ] = mad (input [1 ].w , filter_trans [3 ], output [1 ]);
222
+ output [2 ] = mad (input [2 ].w , filter_trans [3 ], output [2 ]);
223
+ output [3 ] = mad (input [3 ].w , filter_trans [3 ], output [3 ]);
224
+ output [4 ] = mad (input [4 ].w , filter_trans [3 ], output [4 ]);
225
+ }
226
+ }
227
+ }
228
+ }
229
+
230
+ output [0 ] = activation_type4 (output [0 ]);
231
+ output [1 ] = activation_type4 (output [1 ]);
232
+ output [2 ] = activation_type4 (output [2 ]);
233
+ output [3 ] = activation_type4 (output [3 ]);
234
+ output [4 ] = activation_type4 (output [4 ]);
235
+
236
+ WRITE_IMG_TYPE (CL_DTYPE_CHAR ,
237
+ output_image ,
238
+ (int2 )(out_w_base_id + out_w_id0 , item_h_id ),
239
+ output [0 ]);
240
+ if (out_w_id1 < out_w ) {
241
+ WRITE_IMG_TYPE (CL_DTYPE_CHAR ,
242
+ output_image ,
243
+ (int2 )(out_w_base_id + out_w_id1 , item_h_id ),
244
+ output [1 ]);
245
+ }
246
+ if (out_w_id2 < out_w ) {
247
+ WRITE_IMG_TYPE (CL_DTYPE_CHAR ,
248
+ output_image ,
249
+ (int2 )(out_w_base_id + out_w_id2 , item_h_id ),
250
+ output [2 ]);
251
+ }
252
+ if (out_w_id3 < out_w ) {
253
+ WRITE_IMG_TYPE (CL_DTYPE_CHAR ,
254
+ output_image ,
255
+ (int2 )(out_w_base_id + out_w_id3 , item_h_id ),
256
+ output [3 ]);
257
+ }
258
+ if (out_w_id4 < out_w ) {
259
+ WRITE_IMG_TYPE (CL_DTYPE_CHAR ,
260
+ output_image ,
261
+ (int2 )(out_w_base_id + out_w_id4 , item_h_id ),
262
+ output [4 ]);
263
+ }
264
+ }
0 commit comments