Skip to content

Commit 9bc1f34

Browse files
[BugFix][OpenCL] Fix concat image impl when axis is not 1. test=develop (#4241)
* [BugFix][OpenCL] Fix concat image impl when concat axis is not 1 * fix code when axis == 1. test=develop * fix illegal access when print debug info. test=develop * fix typo
1 parent 4539964 commit 9bc1f34

File tree

1 file changed

+14
-32
lines changed

1 file changed

+14
-32
lines changed

lite/kernels/opencl/concat_image_compute.cc

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,13 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
3838
void PrepareForRun() override {
3939
auto& context = ctx_->As<OpenCLContext>();
4040
concat_param_ = param_.get_mutable<param_t>();
41+
axis_ = concat_param_->axis;
42+
if (-1 == axis_) {
43+
axis_ = concat_param_->x[0]->dims().size() - 1;
44+
}
4145

4246
auto inputs = concat_param_->x;
43-
auto axis_ = concat_param_->axis;
4447
auto output_tensor_dims = concat_param_->output->dims();
45-
auto* axis_tensor = concat_param_->axis_tensor;
46-
if (axis_tensor != nullptr) {
47-
// auto* axis_tensor_data = axis_tensor->data<int>(TARGET(kARM));
48-
// axis = axis_tensor_data[0];
49-
}
5048

5149
if (inputs.size() == 2) {
5250
kernel_func_name_ = "concat2";
@@ -100,8 +98,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
10098
width_ = output_tensor_dims[0]; // n
10199
flag_ = 2;
102100
break;
103-
case 3:
104-
case -1: // width
101+
case 3: // width
105102
width_ = output_tensor_dims[1]; // c
106103
flag_ = 3;
107104
break;
@@ -113,17 +110,12 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
113110
auto input0_tensor_dims = inputs[0]->dims();
114111
for (int i = 1; i < inputs.size(); i++) {
115112
auto dims = inputs[i]->dims();
116-
// auto flag = CHECK_EQ_OR_FALSE(input0_tensor_dims.size(), dims.size());
117-
if (input0_tensor_dims.size() != dims.size()) {
118-
printf("input shape must be same \n");
119-
return;
120-
}
113+
CHECK(input0_tensor_dims.size() == dims.size())
114+
<< "All inputs must have the same axes!";
121115
for (int i = 0; i < dims.size(); i++) {
122116
if (i != axis_) {
123-
if (input0_tensor_dims[i] != dims[i]) {
124-
printf("input shape must be same \n");
125-
return;
126-
}
117+
CHECK(input0_tensor_dims[i] == dims[i])
118+
<< "All inputs must have the same shape, except at concat axis!";
127119
}
128120
}
129121
}
@@ -151,29 +143,18 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
151143
VLOG(4) << "concat input shape: ";
152144
for (size_t i = 0; i < inputs.size(); i++) {
153145
VLOG(4) << "inputs [" << i << "]"
154-
<< "[" << inputs[i]->dims().size() << "D]:"
155-
<< " dims:" << inputs[i]->dims()[0] << " "
156-
<< inputs[i]->dims()[1] << " " << inputs[i]->dims()[2] << " "
157-
<< inputs[i]->dims()[3];
146+
<< " dims:" << inputs[i]->dims();
158147
}
159148

160149
VLOG(4) << "concat output shape: ";
161-
VLOG(4) << " out dims: "
162-
<< "[" << output_tensor_dims.size()
163-
<< "D]:" << output_tensor_dims[0] << " " << output_tensor_dims[1]
164-
<< " " << output_tensor_dims[2] << " " << output_tensor_dims[3];
150+
VLOG(4) << " out dims: " << output_tensor_dims;
165151
VLOG(4) << "axis_: " << axis_;
166152
VLOG(4) << "flag_: " << flag_;
167153

168154
VLOG(4) << TargetToStr(concat_param_->output->target());
169-
VLOG(4) << "output_image_shape(w,h):" << output_image_shape["width"] << " "
155+
VLOG(4) << "output_image_shape(w,h): " << output_image_shape["width"] << " "
170156
<< output_image_shape["height"];
171-
VLOG(4) << "output_tensor_dims[" << output_tensor_dims.size()
172-
<< "D]:" << output_tensor_dims[0] << " " << output_tensor_dims[1]
173-
<< " " << output_tensor_dims[2] << " " << output_tensor_dims[3]
174-
<< "output_tensor_dims[output_tensor_dims.size() - 1]"
175-
<< output_tensor_dims[output_tensor_dims.size() - 1];
176-
VLOG(4) << "output_tensor_w: " << output_tensor_w << ", flag_: " << flag_;
157+
VLOG(4) << "output_tensor_w: " << output_tensor_w;
177158
VLOG(4) << "width_:" << width_;
178159
VLOG(4) << "global_work_size: "
179160
<< output_tensor_dims[output_tensor_dims.size() - 1] << " "
@@ -433,6 +414,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
433414
}
434415
#endif
435416

417+
private:
436418
int axis_ = 1;
437419
int flag_ = 1;
438420
int width_ = 1;

0 commit comments

Comments
 (0)