Skip to content

Commit bc572bd

Browse files
authored
[LITE][XPU] Add xpu softsign kernel (#4860)
test=develop, test=xpu
1 parent 0c4423b commit bc572bd

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

lite/kernels/xpu/activation_compute.cc

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,17 @@ void LeakyReluCompute::Run() {
186186
CHECK_EQ(r, 0);
187187
}
188188

189+
void SoftsignCompute::Run() {
190+
auto& param = this->Param<param_t>();
191+
auto& ctx = this->ctx_->As<XPUContext>();
192+
193+
int r = xdnn::softsign(ctx.GetRawContext(),
194+
param.X->data<float>(),
195+
param.Out->mutable_data<float>(TARGET(kXPU)),
196+
param.X->numel());
197+
CHECK_EQ(r, 0);
198+
}
199+
189200
} // namespace xpu
190201
} // namespace kernels
191202
} // namespace lite
@@ -197,6 +208,12 @@ REGISTER_LITE_KERNEL(
197208
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
198209
.Finalize();
199210

211+
REGISTER_LITE_KERNEL(
212+
relu6, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::Relu6Compute, def)
213+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
214+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
215+
.Finalize();
216+
200217
REGISTER_LITE_KERNEL(
201218
tanh, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::TanhCompute, def)
202219
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
@@ -289,8 +306,12 @@ REGISTER_LITE_KERNEL(leaky_relu,
289306
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
290307
.Finalize();
291308

292-
REGISTER_LITE_KERNEL(
293-
relu6, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::Relu6Compute, def)
309+
REGISTER_LITE_KERNEL(softsign,
310+
kXPU,
311+
kFloat,
312+
kNCHW,
313+
paddle::lite::kernels::xpu::SoftsignCompute,
314+
def)
294315
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
295316
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
296317
.Finalize();

lite/kernels/xpu/activation_compute.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ class LeakyReluCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
146146
virtual ~LeakyReluCompute() = default;
147147
};
148148

149+
class SoftsignCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
150+
public:
151+
using param_t = operators::ActivationParam;
152+
153+
virtual void Run();
154+
155+
virtual ~SoftsignCompute() = default;
156+
};
157+
149158
} // namespace xpu
150159
} // namespace kernels
151160
} // namespace lite

0 commit comments

Comments
 (0)