@@ -419,23 +419,6 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
419419 }
420420}
421421
422- template <typename T, typename AccT, int VecSize>
423- __global__ void VectorizedGetDsDbCUDAKernel (int imsize, const T* x, const T* dy,
424- T* ds, T* db) {
425- int i = blockIdx .x ;
426- AccT ds_sum = static_cast <AccT>(0 );
427- AccT db_sum = static_cast <AccT>(0 );
428- x += i * imsize;
429- const int input_offset = ((uint64_t )x) % ALIGN_BYTES / sizeof (T);
430-
431- phi::Array<const T*, 2 > ins;
432- ins[0 ] = x;
433- ins[1 ] = dy;
434- ThreadReduce<T, AccT, VecSize, 2 >(ins, imsize, input_offset, &db_sum,
435- &ds_sum);
436- ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1 );
437- }
438-
439422template <typename T>
440423__global__ void ScalarGetDsDbCUDAKernel (int imsize, const T* x, const T* dy,
441424 T* ds, T* db) {
@@ -622,25 +605,9 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
622605 int flags =
623606 (scale_data != nullptr ) * kHasScale + (bias_data != nullptr ) * kHasBias ;
624607 if (data_layout == DataLayout::kNCHW ) {
625- using AccT = typename details::MPTypeTrait<T>::Type;
626- constexpr int vec_size = sizeof (float4 ) / sizeof (T);
627- const int max_num_threads = 1024 ;
628- int max_block_size = std::min (imsize / vec_size, max_num_threads);
629- int block_size_nchw = 1 ;
630- while (block_size_nchw < max_block_size) {
631- block_size_nchw *= 2 ;
632- }
633- block_size_nchw = std::max (block_size_nchw, kps::details::kWarpSize );
634- dim3 blocks (block_size_nchw);
635- if (imsize < vec_size * block_size_nchw) {
636- ScalarGetDsDbCUDAKernel<
637- T><<<x_dims[0 ] * C, blocks, 0 , dev_ctx.stream()>>> (
638- imsize, x_data, dy_data, ds_data, db_data);
639- } else {
640- VectorizedGetDsDbCUDAKernel<
641- T, AccT, vec_size><<<x_dims[0 ] * C, blocks, 0 , dev_ctx.stream()>>> (
642- imsize, x_data, dy_data, ds_data, db_data);
643- }
608+ ScalarGetDsDbCUDAKernel<
609+ T><<<x_dims[0 ] * C, block_size, 0 , dev_ctx.stream()>>> (
610+ imsize, x_data, dy_data, ds_data, db_data);
644611
645612 if (d_scale || d_bias) {
646613 const int block = 256 ;
0 commit comments