Skip to content

Commit db4e2ea

Browse files
committed
[XPU] remove xpu_wait in PowKernel
1 parent 4a99813 commit db4e2ea

File tree

1 file changed

+5
-22
lines changed

1 file changed

+5
-22
lines changed

paddle/phi/kernels/xpu/activation_kernel.cc

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -218,33 +218,16 @@ void PowKernel(const Context& dev_ctx,
218218
DenseTensor* out) {
219219
using XPUType = typename XPUTypeTrait<T>::Type;
220220
dev_ctx.template Alloc<T>(out);
221-
T pow_factor = factor.to<T>();
221+
222222
const XPUType* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
223223
XPUType* y_data = reinterpret_cast<XPUType*>(out->data<T>());
224+
XPUType pow_factor = static_cast<XPUType>(pad_value.to<T>());
224225

225226
auto xpu_context = dev_ctx.x_context();
226-
// allocate temp memory for factor on xpu
227-
xpu::ctx_guard RAII_GUARD(xpu_context);
228-
XPUType* factor_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(1);
229-
PADDLE_ENFORCE_NOT_NULL(
230-
factor_data, errors::External("XPU alloc_l3_or_gm returns nullptr"));
231-
memory_utils::Copy(dev_ctx.GetPlace(),
232-
static_cast<void*>(factor_data),
233-
phi::CPUPlace(),
234-
static_cast<void*>(&pow_factor),
235-
sizeof(T));
236-
237-
auto x_dims = common::vectorize<int>(x.dims());
238-
// use [1] to replace [], because xpu not support []
239-
if (x_dims.size() == 0) {
240-
x_dims = std::vector<int>({1});
241-
}
242-
243-
// broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const
244-
// std::vector<int>& xshape, const std::vector<int>& yshape);
227+
245228
int r =
246-
xpu::broadcast_pow(xpu_context, x_data, factor_data, y_data, x_dims, {1});
247-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_pow");
229+
xpu::pow_tensor_scalar(xpu_context, x_data, pow_factor, y_data, x.numel())
230+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow_tensor_scalar");
248231
}
249232

250233
template <typename T>

0 commit comments

Comments
 (0)