Skip to content

Commit c09a01b

Browse files
committed
[X86][XPU] add reduce_max; fix xpu fill_any_like; test=develop
1 parent acd40c8 commit c09a01b

File tree

4 files changed

+66
-9
lines changed

4 files changed

+66
-9
lines changed

lite/kernels/x86/reduce_compute.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,13 @@ REGISTER_LITE_KERNEL(reduce_mean,
3333
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
3434
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
3535
.Finalize();
36+
37+
REGISTER_LITE_KERNEL(reduce_max,
38+
kX86,
39+
kFloat,
40+
kNCHW,
41+
paddle::lite::kernels::x86::ReduceMaxCompute<float>,
42+
def)
43+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
44+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
45+
.Finalize();

lite/kernels/x86/reduce_compute.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ struct MeanFunctor {
3838
}
3939
};
4040

41+
struct MaxFunctor {
42+
template <typename X, typename Y, typename Dim>
43+
void operator()(X* x, Y* y, const Dim& dim) {
44+
y->device(lite::fluid::EigenDeviceType<TARGET(kX86)>()) = x->maximum(dim);
45+
}
46+
};
47+
4148
#define HANDLE_DIM(NDIM, RDIM, FUNCTOR) \
4249
if (ndim == NDIM && rdim == RDIM) { \
4350
paddle::lite::kernels::x86:: \
@@ -120,6 +127,44 @@ class ReduceMeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
120127
virtual ~ReduceMeanCompute() = default;
121128
};
122129

130+
template <typename T>
131+
class ReduceMaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
132+
public:
133+
using param_t = operators::ReduceParam;
134+
135+
void Run() override {
136+
auto& param = *param_.get_mutable<operators::ReduceParam>();
137+
auto* input = param.X;
138+
auto* Out = param.Out;
139+
param.Out->template mutable_data<T>();
140+
141+
const auto& dims = param.dim;
142+
bool keep_dim = param.keep_dim;
143+
144+
if (dims.size() == 0) {
145+
// Flatten and reduce 1-D tensor
146+
auto x = lite::fluid::EigenVector<T>::Flatten(*input);
147+
auto out = lite::fluid::EigenScalar<T>::From(Out);
148+
auto reduce_dim = Eigen::array<int, 1>({{0}});
149+
MaxFunctor functor;
150+
functor(&x, &out, reduce_dim);
151+
} else {
152+
int ndim = input->dims().size();
153+
int rdim = dims.size();
154+
HANDLE_DIM(4, 3, MaxFunctor);
155+
HANDLE_DIM(4, 2, MaxFunctor);
156+
HANDLE_DIM(4, 1, MaxFunctor);
157+
HANDLE_DIM(3, 2, MaxFunctor);
158+
HANDLE_DIM(3, 1, MaxFunctor);
159+
HANDLE_DIM(2, 2, MaxFunctor);
160+
HANDLE_DIM(2, 1, MaxFunctor);
161+
HANDLE_DIM(1, 1, MaxFunctor);
162+
}
163+
}
164+
165+
virtual ~ReduceMaxCompute() = default;
166+
};
167+
123168
} // namespace x86
124169
} // namespace kernels
125170
} // namespace lite

lite/kernels/xpu/fill_any_like_compute.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ void FillAnyLikeCompute::Run() {
5252
static_cast<int64_t>(param.value));
5353
break;
5454
}
55+
case -1:
5556
case 5: {
5657
auto data = param.Out->mutable_data<float>(TARGET(kXPU));
5758
r = xdnn::constant<float>(ctx.GetRawContext(),

lite/tests/kernels/reduce_max_compute_test.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ class ReduceMaxComputeTester : public arena::TestCase {
398398
}
399399
};
400400

401-
void test_reduce_max(Place place) {
401+
void test_reduce_max_4d(Place place) {
402402
std::vector<std::vector<int>> reduce_dim{
403403
{0}, {1}, {2}, {3}, {0, 1}, {1, 2}, {2, 3}, {-2, -1}};
404404
for (auto n : {1, 3}) {
@@ -421,7 +421,7 @@ void test_reduce_max(Place place) {
421421
}
422422
}
423423

424-
void test_reduce_max_for_three(Place place) {
424+
void test_reduce_max_3d(Place place) {
425425
std::vector<std::vector<int>> reduce_dim{{0}, {1}, {2}};
426426
for (bool keep_dim : {false, true}) {
427427
for (auto dim : reduce_dim) {
@@ -435,14 +435,15 @@ void test_reduce_max_for_three(Place place) {
435435
}
436436

437437
TEST(ReduceMax, precision) {
438-
// #ifdef LITE_WITH_X86
439-
// Place place(TARGET(kX86));
440-
// #endif
441-
#ifdef LITE_WITH_ARM
442-
Place place(TARGET(kARM));
443-
test_reduce_max(place);
444-
test_reduce_max_for_three(place);
438+
Place place;
439+
#if defined(LITE_WITH_ARM)
440+
place = TARGET(kARM);
441+
#elif defined(LITE_WITH_X86)
442+
place = TARGET(kX86);
445443
#endif
444+
445+
test_reduce_max_4d(place);
446+
test_reduce_max_3d(place);
446447
}
447448

448449
} // namespace lite

0 commit comments

Comments
 (0)