Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 14 additions & 32 deletions lite/kernels/opencl/concat_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
concat_param_ = param_.get_mutable<param_t>();
axis_ = concat_param_->axis;
if (-1 == axis_) {
axis_ = concat_param_->x[0]->dims().size() - 1;
}

auto inputs = concat_param_->x;
auto axis_ = concat_param_->axis;
auto output_tensor_dims = concat_param_->output->dims();
auto* axis_tensor = concat_param_->axis_tensor;
if (axis_tensor != nullptr) {
// auto* axis_tensor_data = axis_tensor->data<int>(TARGET(kARM));
// axis = axis_tensor_data[0];
}

if (inputs.size() == 2) {
kernel_func_name_ = "concat2";
Expand Down Expand Up @@ -100,8 +98,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
width_ = output_tensor_dims[0]; // n
flag_ = 2;
break;
case 3:
case -1: // width
case 3: // width
width_ = output_tensor_dims[1]; // c
flag_ = 3;
break;
Expand All @@ -113,17 +110,12 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
auto input0_tensor_dims = inputs[0]->dims();
for (int i = 1; i < inputs.size(); i++) {
auto dims = inputs[i]->dims();
// auto flag = CHECK_EQ_OR_FALSE(input0_tensor_dims.size(), dims.size());
if (input0_tensor_dims.size() != dims.size()) {
printf("input shape must be same \n");
return;
}
CHECK(input0_tensor_dims.size() == dims.size())
<< "All inputs must have the same axes!";
for (int i = 0; i < dims.size(); i++) {
if (i != axis_) {
if (input0_tensor_dims[i] != dims[i]) {
printf("input shape must be same \n");
return;
}
CHECK(input0_tensor_dims[i] == dims[i])
<< "All inputs must have the same shape, except at concat axis!";
}
}
}
Expand Down Expand Up @@ -151,29 +143,18 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
VLOG(4) << "concat input shape: ";
for (size_t i = 0; i < inputs.size(); i++) {
VLOG(4) << "inputs [" << i << "]"
<< "[" << inputs[i]->dims().size() << "D]:"
<< " dims:" << inputs[i]->dims()[0] << " "
<< inputs[i]->dims()[1] << " " << inputs[i]->dims()[2] << " "
<< inputs[i]->dims()[3];
<< " dims:" << inputs[i]->dims();
}

VLOG(4) << "concat output shape: ";
VLOG(4) << " out dims: "
<< "[" << output_tensor_dims.size()
<< "D]:" << output_tensor_dims[0] << " " << output_tensor_dims[1]
<< " " << output_tensor_dims[2] << " " << output_tensor_dims[3];
VLOG(4) << " out dims: " << output_tensor_dims;
VLOG(4) << "axis_: " << axis_;
VLOG(4) << "flag_: " << flag_;

VLOG(4) << TargetToStr(concat_param_->output->target());
VLOG(4) << "output_image_shape(w,h):" << output_image_shape["width"] << " "
VLOG(4) << "output_image_shape(w,h): " << output_image_shape["width"] << " "
<< output_image_shape["height"];
VLOG(4) << "output_tensor_dims[" << output_tensor_dims.size()
<< "D]:" << output_tensor_dims[0] << " " << output_tensor_dims[1]
<< " " << output_tensor_dims[2] << " " << output_tensor_dims[3]
<< "output_tensor_dims[output_tensor_dims.size() - 1]"
<< output_tensor_dims[output_tensor_dims.size() - 1];
VLOG(4) << "output_tensor_w: " << output_tensor_w << ", flag_: " << flag_;
VLOG(4) << "output_tensor_w: " << output_tensor_w;
VLOG(4) << "width_:" << width_;
VLOG(4) << "global_work_size: "
<< output_tensor_dims[output_tensor_dims.size() - 1] << " "
Expand Down Expand Up @@ -433,6 +414,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
}
#endif

private:
int axis_ = 1;
int flag_ = 1;
int width_ = 1;
Expand Down