@@ -38,15 +38,13 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
38
38
void PrepareForRun () override {
39
39
auto & context = ctx_->As <OpenCLContext>();
40
40
concat_param_ = param_.get_mutable <param_t >();
41
+ axis_ = concat_param_->axis ;
42
+ if (-1 == axis_) {
43
+ axis_ = concat_param_->x [0 ]->dims ().size () - 1 ;
44
+ }
41
45
42
46
auto inputs = concat_param_->x ;
43
- auto axis_ = concat_param_->axis ;
44
47
auto output_tensor_dims = concat_param_->output ->dims ();
45
- auto * axis_tensor = concat_param_->axis_tensor ;
46
- if (axis_tensor != nullptr ) {
47
- // auto* axis_tensor_data = axis_tensor->data<int>(TARGET(kARM));
48
- // axis = axis_tensor_data[0];
49
- }
50
48
51
49
if (inputs.size () == 2 ) {
52
50
kernel_func_name_ = " concat2" ;
@@ -100,8 +98,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
100
98
width_ = output_tensor_dims[0 ]; // n
101
99
flag_ = 2 ;
102
100
break ;
103
- case 3 :
104
- case -1 : // width
101
+ case 3 : // width
105
102
width_ = output_tensor_dims[1 ]; // c
106
103
flag_ = 3 ;
107
104
break ;
@@ -113,17 +110,12 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
113
110
auto input0_tensor_dims = inputs[0 ]->dims ();
114
111
for (int i = 1 ; i < inputs.size (); i++) {
115
112
auto dims = inputs[i]->dims ();
116
- // auto flag = CHECK_EQ_OR_FALSE(input0_tensor_dims.size(), dims.size());
117
- if (input0_tensor_dims.size () != dims.size ()) {
118
- printf (" input shape must be same \n " );
119
- return ;
120
- }
113
+ CHECK (input0_tensor_dims.size () == dims.size ())
114
+ << " All inputs must have the same axes!" ;
121
115
for (int i = 0 ; i < dims.size (); i++) {
122
116
if (i != axis_) {
123
- if (input0_tensor_dims[i] != dims[i]) {
124
- printf (" input shape must be same \n " );
125
- return ;
126
- }
117
+ CHECK (input0_tensor_dims[i] == dims[i])
118
+ << " All inputs must have the same shape, except at concat axis!" ;
127
119
}
128
120
}
129
121
}
@@ -151,29 +143,18 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
151
143
VLOG (4 ) << " concat input shape: " ;
152
144
for (size_t i = 0 ; i < inputs.size (); i++) {
153
145
VLOG (4 ) << " inputs [" << i << " ]"
154
- << " [" << inputs[i]->dims ().size () << " D]:"
155
- << " dims:" << inputs[i]->dims ()[0 ] << " "
156
- << inputs[i]->dims ()[1 ] << " " << inputs[i]->dims ()[2 ] << " "
157
- << inputs[i]->dims ()[3 ];
146
+ << " dims:" << inputs[i]->dims ();
158
147
}
159
148
160
149
VLOG (4 ) << " concat output shape: " ;
161
- VLOG (4 ) << " out dims: "
162
- << " [" << output_tensor_dims.size ()
163
- << " D]:" << output_tensor_dims[0 ] << " " << output_tensor_dims[1 ]
164
- << " " << output_tensor_dims[2 ] << " " << output_tensor_dims[3 ];
150
+ VLOG (4 ) << " out dims: " << output_tensor_dims;
165
151
VLOG (4 ) << " axis_: " << axis_;
166
152
VLOG (4 ) << " flag_: " << flag_;
167
153
168
154
VLOG (4 ) << TargetToStr (concat_param_->output ->target ());
169
- VLOG (4 ) << " output_image_shape(w,h):" << output_image_shape[" width" ] << " "
155
+ VLOG (4 ) << " output_image_shape(w,h): " << output_image_shape[" width" ] << " "
170
156
<< output_image_shape[" height" ];
171
- VLOG (4 ) << " output_tensor_dims[" << output_tensor_dims.size ()
172
- << " D]:" << output_tensor_dims[0 ] << " " << output_tensor_dims[1 ]
173
- << " " << output_tensor_dims[2 ] << " " << output_tensor_dims[3 ]
174
- << " output_tensor_dims[output_tensor_dims.size() - 1]"
175
- << output_tensor_dims[output_tensor_dims.size () - 1 ];
176
- VLOG (4 ) << " output_tensor_w: " << output_tensor_w << " , flag_: " << flag_;
157
+ VLOG (4 ) << " output_tensor_w: " << output_tensor_w;
177
158
VLOG (4 ) << " width_:" << width_;
178
159
VLOG (4 ) << " global_work_size: "
179
160
<< output_tensor_dims[output_tensor_dims.size () - 1 ] << " "
@@ -433,6 +414,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
433
414
}
434
415
#endif
435
416
417
+ private:
436
418
int axis_ = 1 ;
437
419
int flag_ = 1 ;
438
420
int width_ = 1 ;
0 commit comments