Skip to content

Commit 0c4423b

Browse files
authored
[arm][kernel]fix: broadcast dim fix logic updated test=develop (#4859)
* [arm][kernel]fix: broadcast dim fix logic updated test=develop * fix: use VLOG(4) test=develop * fix: update broadcast test when x.dim.size smaller test=develop
1 parent fc85b90 commit 0c4423b

File tree

3 files changed

+38
-22
lines changed

3 files changed

+38
-22
lines changed

lite/kernels/arm/elementwise_compute.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ inline bool is_fast_broadcast(const DDim& x_dims,
5454
axis = x_dims.size() - y_dims.size();
5555
}
5656
if (axis < 0) {
57-
LOG(INFO) << "Fast broadcast chk fail, for x_dims smaller.";
57+
VLOG(4) << "Fast broadcast chk fail, for x_dims smaller.";
5858
return false;
5959
}
6060
DDim y_dim_trim = trim_trailing_singular_dims(y_dims);
6161
axis = (y_dim_trim.size() == 0) ? x_dims.size() : axis;
6262
if (x_dims.size() == y_dim_trim.size()) {
63-
LOG(INFO)
63+
VLOG(4)
6464
<< "Fast broadcast chk fail, for y's shape not really contained in x";
6565
return false;
6666
}
@@ -72,7 +72,7 @@ inline bool is_fast_broadcast(const DDim& x_dims,
7272
}
7373
for (int i = 0; i < y_dim_trim.size(); ++i) {
7474
if (x_dims[i + axis] != y_dim_trim[i]) {
75-
LOG(WARNING) << "Fast broadcast chk fail, for dimension mismatch.";
75+
VLOG(4) << "Fast broadcast chk fail, for dimension mismatch.";
7676
return false;
7777
}
7878
(*n) *= y_dim_trim[i];
@@ -151,10 +151,6 @@ void elementwise_compute_template(paddle::lite::KernelBase* kernel,
151151
auto& param = kernel->template Param<OpParamType>();
152152
auto x = param.X;
153153
auto y = param.Y;
154-
if (opd_swap_able == OprandSwapable::YES &&
155-
x->dims().size() < y->dims().size()) {
156-
std::swap(x, y);
157-
}
158154

159155
auto* x_data = x->template data<T>();
160156
auto* y_data = y->template data<T>();

lite/kernels/host/elementwise_op_func.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ void BatchElementWiseArg<Elem_t, DimValue_t>::Update(
271271
BroadcastType broadcast_type) {
272272
// arg checking
273273
if (broadcast_type == BroadcastType::UNKNOWN) {
274-
LOG(INFO) << "No broadcast type input";
274+
VLOG(4) << "No broadcast type input";
275275
broadcast_type = get_broadcast_type(x_dims, y_dims, dim_size);
276276
}
277277
if (broadcast_type == BroadcastType::UNKNOWN ||
@@ -281,7 +281,7 @@ void BatchElementWiseArg<Elem_t, DimValue_t>::Update(
281281
}
282282
if (broadcast_type == BroadcastType::SAME_DIM) {
283283
broadcast_type = BroadcastType::BOTH_CONTINUOUS;
284-
LOG(INFO) << "Same dim detected";
284+
VLOG(4) << "Same dim detected";
285285
// SAME_DIM should not be treated as broadcast. For SAME_DIM is a special
286286
// case of BOTH_CONTINUOUS, we could still process it.
287287
}
@@ -492,13 +492,25 @@ void fix_x_y_dims(const Tensor *X,
492492
}
493493
} else {
494494
if (X->dims().size() != Out->dims().size()) {
495-
LOG(FATAL) << "X and OUT dim size mismatch";
496-
}
497-
for (int i = 0; i < out_dim_size; ++i) {
498-
x_dims[i] = X->dims()[i];
499-
}
500-
for (int i = axis; i < out_dim_size; ++i) {
501-
y_dims[i + axis] = Y->dims()[i];
495+
if (Y->dims().size() != Out->dims().size()) {
496+
LOG(FATAL) << "X/Y and OUT dim size mismatch";
497+
} else {
498+
VLOG(4) << "Arguments broke API reference, for X.dims().size() is "
499+
"smaller and axis is set";
500+
for (int i = 0; i < out_dim_size; ++i) {
501+
y_dims[i] = Y->dims()[i];
502+
}
503+
for (int i = 0; i < X->dims().size(); ++i) {
504+
x_dims[i + axis] = X->dims()[i];
505+
}
506+
}
507+
} else {
508+
for (int i = 0; i < out_dim_size; ++i) {
509+
x_dims[i] = X->dims()[i];
510+
}
511+
for (int i = 0; i < Y->dims().size(); ++i) {
512+
y_dims[i + axis] = Y->dims()[i];
513+
}
502514
}
503515
}
504516
}

lite/tests/kernels/elementwise_common_broadcast_test.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,13 @@ class ElementwiseComputeTester : public arena::TestCase {
168168
void PrepareData() override {
169169
std::vector<T> dx(x_dims_.production());
170170
for (size_t i = 0; i < dx.size(); i++) {
171-
dx[i] = i;
171+
dx[i] = 1;
172172
}
173173
SetCommonTensor(x_, x_dims_, dx.data());
174174

175175
std::vector<T> dy(y_dims_.production());
176176
for (size_t i = 0; i < dy.size(); i++) {
177-
dy[i] = 2 * i + 1;
177+
dy[i] = 1;
178178
}
179179
SetCommonTensor(y_, y_dims_, dy.data());
180180
}
@@ -204,10 +204,12 @@ bool RunOnRandomArgs(const Place& place,
204204
std::vector<int> y_dim_cut;
205205

206206
int axis = -1;
207-
bool cut_dimension = randbool();
207+
static bool cut_dimension = true;
208+
cut_dimension = !cut_dimension;
208209
if (cut_dimension) {
209210
// generate x_dim_cut and y_dim_cut by remove dimension
210-
bool use_axis = randbool();
211+
static bool use_axis = true;
212+
use_axis = !use_axis;
211213
if (use_axis) {
212214
x_dim_cut = x_dim_full;
213215
// we will cut y only, and set tail of y to be 1
@@ -225,6 +227,12 @@ bool RunOnRandomArgs(const Place& place,
225227
for (int i = dim_size - tail1_num; i < dim_size; ++i) {
226228
y_dim_full[i] = 1;
227229
}
230+
static bool swap_x_and_y = true;
231+
swap_x_and_y = !swap_x_and_y;
232+
if (swap_x_and_y) {
233+
std::swap(x_dim_cut, y_dim_cut);
234+
std::swap(x_dim_full, y_dim_full);
235+
}
228236
} else {
229237
// we will cut x or y
230238
if (randbool()) {
@@ -301,7 +309,7 @@ bool RunOnRandomArgs(const Place& place,
301309
#ifdef LITE_WITH_ARM
302310

303311
TEST(elementwise_broadcast, compute_fp32) {
304-
const int TEST_RETEAT_NUM = 5;
312+
const int TEST_RETEAT_NUM = 10;
305313
for (int repeat_count = 0; repeat_count < TEST_RETEAT_NUM; ++repeat_count) {
306314
EXPECT_TRUE(paddle::lite::RunOnRandomArgs<float>(
307315
TARGET(kARM), "def", "add", "", [](float l, float r) {
@@ -323,7 +331,7 @@ TEST(elementwise_broadcast, compute_fp32) {
323331
}
324332

325333
TEST(elementwise_broadcast, compute_i32) {
326-
const int TEST_RETEAT_NUM = 5;
334+
const int TEST_RETEAT_NUM = 10;
327335
for (int repeat_count = 0; repeat_count < TEST_RETEAT_NUM; ++repeat_count) {
328336
EXPECT_TRUE(paddle::lite::RunOnRandomArgs<int32_t>(
329337
paddle::lite::Place(TARGET(kARM), PRECISION(kInt32)),

0 commit comments

Comments
 (0)