Skip to content

Commit 5a9adcc

Browse files
committed
replace index_select_grad_init with SetConstant
1 parent bd32fac commit 5a9adcc

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

paddle/phi/kernels/gpu/index_select_grad_kernel.cu

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "paddle/phi/backends/gpu/gpu_info.h"
2020
#include "paddle/phi/core/kernel_registry.h"
2121
#include "paddle/phi/core/utils/data_type.h"
22+
#include "paddle/phi/kernels/funcs/math_function.h"
2223

2324
DECLARE_bool(cudnn_deterministic);
2425

@@ -45,11 +46,6 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
4546
}
4647
}
4748

48-
template <typename T>
49-
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
50-
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { input_grad[idx] = 0.0; }
51-
}
52-
5349
template <typename T, typename Context>
5450
void IndexSelectGradKernel(const Context& ctx,
5551
const DenseTensor& x,
@@ -93,8 +89,8 @@ void IndexSelectGradKernel(const Context& ctx,
9389
dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
9490
paddle::platform::LimitGridDim(ctx, &grid_dim);
9591

96-
index_select_grad_init<T><<<grid_dim, block_dim, 0, stream>>>(in_grad_data,
97-
numel);
92+
phi::funcs::SetConstant<phi::GPUContext, T> index_select_grad_init;
93+
index_select_grad_init(ctx, x_grad, static_cast<T>(0));
9894

9995
if (FLAGS_cudnn_deterministic) {
10096
VLOG(2) << "Run grad kernel of index_select with single thread.";

0 commit comments

Comments
 (0)