1919#include " paddle/phi/backends/gpu/gpu_info.h"
2020#include " paddle/phi/core/kernel_registry.h"
2121#include " paddle/phi/core/utils/data_type.h"
22+ #include " paddle/phi/kernels/funcs/math_function.h"
2223
2324DECLARE_bool (cudnn_deterministic);
2425
@@ -45,11 +46,6 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
4546 }
4647}
4748
48- template <typename T>
49- __global__ void index_select_grad_init (T* input_grad, int64_t N) {
50- CUDA_KERNEL_LOOP_TYPE (idx, N, int64_t ) { input_grad[idx] = 0.0 ; }
51- }
52-
5349template <typename T, typename Context>
5450void IndexSelectGradKernel (const Context& ctx,
5551 const DenseTensor& x,
@@ -93,8 +89,8 @@ void IndexSelectGradKernel(const Context& ctx,
9389 dim3 grid_dim = dim3 ((numel + block_dim - 1 ) / block_dim);
9490 paddle::platform::LimitGridDim (ctx, &grid_dim);
9591
96- index_select_grad_init<T> <<<grid_dim, block_dim, 0 , stream>>> (in_grad_data,
97- numel );
92+ phi::funcs::SetConstant<phi::GPUContext, T> index_select_grad_init;
93+ index_select_grad_init (ctx, x_grad, static_cast <T>( 0 ) );
9894
9995 if (FLAGS_cudnn_deterministic) {
10096 VLOG (2 ) << " Run grad kernel of index_select with single thread." ;
0 commit comments