Skip to content

Commit a5f2327

Browse files
committed
[xpu]: support equal int64 and transpose int64;test=develop
1 parent 8654144 commit a5f2327

File tree

4 files changed

+108
-8
lines changed

4 files changed

+108
-8
lines changed

lite/kernels/x86/transpose_compute.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,24 @@ REGISTER_LITE_KERNEL(transpose2,
3434
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
3535
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
3636
.Finalize();
37+
38+
REGISTER_LITE_KERNEL(transpose,
39+
kX86,
40+
kFloat,
41+
kNCHW,
42+
paddle::lite::kernels::x86::TransposeCompute<int64_t>,
43+
int64)
44+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
45+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
46+
.Finalize();
47+
48+
REGISTER_LITE_KERNEL(transpose2,
49+
kX86,
50+
kFloat,
51+
kNCHW,
52+
paddle::lite::kernels::x86::Transpose2Compute<int64_t>,
53+
int64)
54+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
55+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
56+
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
57+
.Finalize();

lite/kernels/xpu/compare_compute.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ struct LessThanFunctor {
3434
}
3535
};
3636

37+
template <typename T>
38+
struct EqualFunctor {
39+
inline int operator()(xdnn::Context* ctx,
40+
const T* x,
41+
const T* y,
42+
bool* z,
43+
const std::vector<int>& xshape,
44+
const std::vector<int>& yshape) const {
45+
return xdnn::broadcast_equal<T>(ctx, x, y, z, xshape, yshape);
46+
}
47+
};
48+
3749
template <PrecisionType PType, class T, class Functor>
3850
void CompareCompute<PType, T, Functor>::Run() {
3951
auto& param = this->template Param<operators::CompareParam>();
@@ -152,3 +164,65 @@ REGISTER_LITE_KERNEL(less_than, kXPU, kFloat, kAny, less_than_int64, int64)
152164
DATALAYOUT(kAny))})
153165
.BindPaddleOpVersion("less_than", 1)
154166
.Finalize();
167+
168+
169+
using equal_float = paddle::lite::kernels::xpu::CompareCompute<
170+
PRECISION(kFloat),
171+
float,
172+
paddle::lite::kernels::xpu::EqualFunctor<float>>;
173+
REGISTER_LITE_KERNEL(equal, kXPU, kFloat, kAny, equal_float, def)
174+
.BindInput("X",
175+
{LiteType::GetTensorTy(TARGET(kXPU),
176+
PRECISION(kFloat),
177+
DATALAYOUT(kAny))})
178+
.BindInput("Y",
179+
{LiteType::GetTensorTy(TARGET(kXPU),
180+
PRECISION(kFloat),
181+
DATALAYOUT(kAny))})
182+
.BindOutput("Out",
183+
{LiteType::GetTensorTy(TARGET(kXPU),
184+
PRECISION(kBool),
185+
DATALAYOUT(kAny))})
186+
.BindPaddleOpVersion("equal", 1)
187+
.Finalize();
188+
189+
using equal_int32 = paddle::lite::kernels::xpu::CompareCompute<
190+
PRECISION(kFloat),
191+
int,
192+
paddle::lite::kernels::xpu::EqualFunctor<int>>;
193+
REGISTER_LITE_KERNEL(equal, kXPU, kFloat, kAny, equal_int32, int32)
194+
.BindInput("X",
195+
{LiteType::GetTensorTy(TARGET(kXPU),
196+
PRECISION(kInt32),
197+
DATALAYOUT(kAny))})
198+
.BindInput("Y",
199+
{LiteType::GetTensorTy(TARGET(kXPU),
200+
PRECISION(kInt32),
201+
DATALAYOUT(kAny))})
202+
.BindOutput("Out",
203+
{LiteType::GetTensorTy(TARGET(kXPU),
204+
PRECISION(kBool),
205+
DATALAYOUT(kAny))})
206+
.BindPaddleOpVersion("equal", 1)
207+
.Finalize();
208+
209+
using euqal_int64 = paddle::lite::kernels::xpu::CompareCompute<
210+
PRECISION(kFloat),
211+
int64_t,
212+
paddle::lite::kernels::xpu::EqualFunctor<int64_t>>;
213+
REGISTER_LITE_KERNEL(equal, kXPU, kFloat, kAny, euqal_int64, int64)
214+
.BindInput("X",
215+
{LiteType::GetTensorTy(TARGET(kXPU),
216+
PRECISION(kInt64),
217+
DATALAYOUT(kAny))})
218+
.BindInput("Y",
219+
{LiteType::GetTensorTy(TARGET(kXPU),
220+
PRECISION(kInt64),
221+
DATALAYOUT(kAny))})
222+
.BindOutput("Out",
223+
{LiteType::GetTensorTy(TARGET(kXPU),
224+
PRECISION(kBool),
225+
DATALAYOUT(kAny))})
226+
.BindPaddleOpVersion("equal", 1)
227+
.Finalize();
228+

lite/kernels/xpu/transpose_compute.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ namespace lite {
2222
namespace kernels {
2323
namespace xpu {
2424

25-
void TransposeCompute::Run() {
25+
template <class T>
26+
void TransposeCompute <T>::Run() {
2627
auto& param = this->Param<param_t>();
2728
auto& ctx = this->ctx_->As<XPUContext>();
2829
auto x = param.x;
@@ -38,10 +39,11 @@ void TransposeCompute::Run() {
3839
for (int i = 0; i < ndims; ++i) {
3940
x_shape_host[i] = x_dims[i];
4041
}
42+
4143
int r =
42-
xdnn::transpose<float>(ctx.GetRawContext(),
43-
x->data<float>(),
44-
param.output->mutable_data<float>(TARGET(kXPU)),
44+
xdnn::transpose<T>(ctx.GetRawContext(),
45+
x->data<T>(),
46+
param.output->mutable_data<T>(TARGET(kXPU)),
4547
x_shape_host,
4648
axis);
4749
CHECK_EQ(r, 0);
@@ -56,7 +58,7 @@ REGISTER_LITE_KERNEL(transpose,
5658
kXPU,
5759
kFloat,
5860
kNCHW,
59-
paddle::lite::kernels::xpu::TransposeCompute,
61+
paddle::lite::kernels::xpu::TransposeCompute<float>,
6062
def)
6163
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
6264
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
@@ -66,17 +68,18 @@ REGISTER_LITE_KERNEL(transpose2,
6668
kXPU,
6769
kFloat,
6870
kNCHW,
69-
paddle::lite::kernels::xpu::TransposeCompute,
71+
paddle::lite::kernels::xpu::TransposeCompute<float>,
7072
def)
7173
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
7274
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
7375
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))})
7476
.Finalize();
77+
7578
REGISTER_LITE_KERNEL(transpose2,
7679
kXPU,
7780
kFloat,
7881
kNCHW,
79-
paddle::lite::kernels::xpu::TransposeCompute,
82+
paddle::lite::kernels::xpu::TransposeCompute<int64_t>,
8083
def_int64)
8184
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
8285
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})

lite/kernels/xpu/transpose_compute.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ namespace lite {
2121
namespace kernels {
2222
namespace xpu {
2323

24-
class TransposeCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
24+
template <class T>
25+
class TransposeCompute
26+
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
2527
public:
2628
using param_t = operators::TransposeParam;
2729

0 commit comments

Comments
 (0)