@@ -24,7 +24,8 @@ namespace lite {
2424namespace kernels {
2525namespace xpu {
2626
27- void BilinearInterpCompute::Run () {
27+ template <typename InType, PrecisionType PType>
28+ void BilinearInterpCompute<InType, PType>::Run() {
2829 auto & param = this ->template Param <param_t >();
2930 auto & ctx = this ->ctx_ ->template As <XPUContext>();
3031 lite::Tensor* X = param.X ;
@@ -47,22 +48,23 @@ void BilinearInterpCompute::Run() {
4748 } else {
4849 trans_mode = 2 ;
4950 }
50- int r = xdnn::interpolate2d<float >(ctx.GetRawContext (),
51- X->data <float >(),
52- Out->mutable_data <float >(TARGET (kXPU )),
53- n,
54- c,
55- in_h,
56- in_w,
57- out_h,
58- out_w,
59- false ,
60- trans_mode,
61- true );
51+ int r = xdnn::interpolate2d<InType >(ctx.GetRawContext (),
52+ X->data <InType >(),
53+ Out->mutable_data <InType >(TARGET (kXPU )),
54+ n,
55+ c,
56+ in_h,
57+ in_w,
58+ out_h,
59+ out_w,
60+ false ,
61+ trans_mode,
62+ true );
6263 CHECK_EQ (r, 0 );
6364}
6465
65- void NearestInterpCompute::Run () {
66+ template <typename InType, PrecisionType PType>
67+ void NearestInterpCompute<InType, PType>::Run() {
6668 auto & param = this ->template Param <param_t >();
6769 auto & ctx = this ->ctx_ ->template As <XPUContext>();
6870 lite::Tensor* X = param.X ;
@@ -77,18 +79,18 @@ void NearestInterpCompute::Run() {
7779 bool align_corners = param.align_corners ;
7880 int trans_mode = (align_corners == true ) ? 0 : 2 ;
7981
80- int r = xdnn::interpolate2d<float >(ctx.GetRawContext (),
81- X->data <float >(),
82- Out->mutable_data <float >(TARGET (kXPU )),
83- n,
84- c,
85- in_h,
86- in_w,
87- out_h,
88- out_w,
89- true ,
90- trans_mode,
91- true );
82+ int r = xdnn::interpolate2d<InType >(ctx.GetRawContext (),
83+ X->data <InType >(),
84+ Out->mutable_data <InType >(TARGET (kXPU )),
85+ n,
86+ c,
87+ in_h,
88+ in_w,
89+ out_h,
90+ out_w,
91+ true ,
92+ trans_mode,
93+ true );
9294
9395 CHECK_EQ (r, 0 );
9496}
@@ -98,12 +100,40 @@ void NearestInterpCompute::Run() {
98100} // namespace lite
99101} // namespace paddle
100102
103+ namespace xpu = paddle::lite::kernels::xpu;
104+
105+ using BiliInterp_FP32 = xpu::BilinearInterpCompute<float , PRECISION(kFloat )>;
106+ using BiliInterp_FP16 = xpu::BilinearInterpCompute<float16, PRECISION(kFP16 )>;
107+ using NearInterp_FP32 = xpu::NearestInterpCompute<float , PRECISION(kFloat )>;
108+ using NearInterp_FP16 = xpu::NearestInterpCompute<float16, PRECISION(kFP16 )>;
109+
110+ REGISTER_LITE_KERNEL (bilinear_interp, kXPU , kFloat , kNCHW , BiliInterp_FP32, def)
111+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ))})
112+ .BindInput(" OutSize" ,
113+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
114+ .BindInput(" SizeTensor" ,
115+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
116+ .BindInput(" Scale" , {LiteType::GetTensorTy (TARGET (kHost ))})
117+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ))})
118+ .Finalize();
119+
101120REGISTER_LITE_KERNEL (bilinear_interp,
102121 kXPU ,
103- kFloat ,
122+ kFP16 ,
104123 kNCHW ,
105- paddle::lite::kernels::xpu::BilinearInterpCompute,
106- def)
124+ BiliInterp_FP16,
125+ DISABLE_XPU1_binterp_FP16)
126+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
127+ .BindInput(" OutSize" ,
128+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
129+ .BindInput(" SizeTensor" ,
130+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
131+ .BindInput(" Scale" , {LiteType::GetTensorTy (TARGET (kHost ))})
132+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
133+ .Finalize();
134+
135+ REGISTER_LITE_KERNEL (
136+ bilinear_interp_v2, kXPU , kFloat , kNCHW , BiliInterp_FP32, def)
107137 .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ))})
108138 .BindInput(" OutSize" ,
109139 {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
@@ -115,10 +145,20 @@ REGISTER_LITE_KERNEL(bilinear_interp,
115145
116146REGISTER_LITE_KERNEL (bilinear_interp_v2,
117147 kXPU ,
118- kFloat ,
148+ kFP16 ,
119149 kNCHW ,
120- paddle::lite::kernels::xpu::BilinearInterpCompute,
121- def)
150+ BiliInterp_FP16,
151+ DISABLE_XPU1_binterp_v2_FP16)
152+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
153+ .BindInput(" OutSize" ,
154+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
155+ .BindInput(" SizeTensor" ,
156+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
157+ .BindInput(" Scale" , {LiteType::GetTensorTy (TARGET (kHost ))})
158+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
159+ .Finalize();
160+
161+ REGISTER_LITE_KERNEL (nearest_interp, kXPU , kFloat , kNCHW , NearInterp_FP32, def)
122162 .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ))})
123163 .BindInput(" OutSize" ,
124164 {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
@@ -130,10 +170,21 @@ REGISTER_LITE_KERNEL(bilinear_interp_v2,
130170
131171REGISTER_LITE_KERNEL (nearest_interp,
132172 kXPU ,
133- kFloat ,
173+ kFP16 ,
134174 kNCHW ,
135- paddle::lite::kernels::xpu::NearestInterpCompute,
136- def)
175+ NearInterp_FP16,
176+ DISABLE_XPU1_ninterp_FP16)
177+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
178+ .BindInput(" OutSize" ,
179+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
180+ .BindInput(" SizeTensor" ,
181+ {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
182+ .BindInput(" Scale" , {LiteType::GetTensorTy (TARGET (kHost ))})
183+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
184+ .Finalize();
185+
186+ REGISTER_LITE_KERNEL (
187+ nearest_interp_v2, kXPU , kFloat , kNCHW , NearInterp_FP32, def)
137188 .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ))})
138189 .BindInput(" OutSize" ,
139190 {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
@@ -145,15 +196,15 @@ REGISTER_LITE_KERNEL(nearest_interp,
145196
146197REGISTER_LITE_KERNEL (nearest_interp_v2,
147198 kXPU ,
148- kFloat ,
199+ kFP16 ,
149200 kNCHW ,
150- paddle::lite::kernels::xpu::NearestInterpCompute ,
151- def )
152- .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ))})
201+ NearInterp_FP16 ,
202+ DISABLE_XPU1_niterp_v2_FP16 )
203+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION ( kFP16 ) )})
153204 .BindInput(" OutSize" ,
154205 {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
155206 .BindInput(" SizeTensor" ,
156207 {LiteType::GetTensorTy (TARGET (kHost ), PRECISION (kInt32 ))})
157208 .BindInput(" Scale" , {LiteType::GetTensorTy (TARGET (kHost ))})
158- .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ))})
209+ .BindOutput(" Out" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION ( kFP16 ) )})
159210 .Finalize();
0 commit comments