@@ -218,33 +218,16 @@ void PowKernel(const Context& dev_ctx,
218218 DenseTensor* out) {
219219 using XPUType = typename XPUTypeTrait<T>::Type;
220220 dev_ctx.template Alloc <T>(out);
221- T pow_factor = factor. to <T>();
221+
222222 const XPUType* x_data = reinterpret_cast <const XPUType*>(x.data <T>());
223223 XPUType* y_data = reinterpret_cast <XPUType*>(out->data <T>());
224+ XPUType pow_factor = static_cast <XPUType>(pad_value.to <T>());
224225
225226 auto xpu_context = dev_ctx.x_context ();
226- // allocate temp memory for factor on xpu
227- xpu::ctx_guard RAII_GUARD (xpu_context);
228- XPUType* factor_data = RAII_GUARD.alloc_l3_or_gm <XPUType>(1 );
229- PADDLE_ENFORCE_NOT_NULL (
230- factor_data, errors::External (" XPU alloc_l3_or_gm returns nullptr" ));
231- memory_utils::Copy (dev_ctx.GetPlace (),
232- static_cast <void *>(factor_data),
233- phi::CPUPlace (),
234- static_cast <void *>(&pow_factor),
235- sizeof (T));
236-
237- auto x_dims = common::vectorize<int >(x.dims ());
238- // use [1] to replace [], because xpu not support []
239- if (x_dims.size () == 0 ) {
240- x_dims = std::vector<int >({1 });
241- }
242-
243- // broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const
244- // std::vector<int>& xshape, const std::vector<int>& yshape);
227+
245228 int r =
246- xpu::broadcast_pow (xpu_context, x_data, factor_data , y_data, x_dims, { 1 });
247- PADDLE_ENFORCE_XDNN_SUCCESS (r, " broadcast_pow " );
229+ xpu::pow_tensor_scalar (xpu_context, x_data, pow_factor , y_data, x. numel ())
230+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " pow_tensor_scalar " );
248231}
249232
250233template <typename T>
0 commit comments