@@ -31,6 +31,7 @@ namespace cub = hipcub;
3131
3232#include " paddle/fluid/framework/tensor.h"
3333#include " paddle/fluid/framework/tensor_util.h"
34+ #include " paddle/fluid/operators/amp/fp16_type_traits.h"
3435
3536namespace paddle {
3637namespace operators {
@@ -66,39 +67,66 @@ struct Array {
6667 T data_[ElementCount];
6768};
6869
70+ // reduce the 1d array to one element
71+ template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
72+ typename TransformOp, int BlockDim>
73+ __global__ void ReduceKernel1D (const Tx* x, Ty* y, ReduceOp reducer,
74+ TransformOp transformer, MPType init,
75+ int reduce_num) {
76+ int thread_id = blockIdx.x * blockDim.x + threadIdx.x ;
77+
78+ typedef cub::BlockReduce<MPType, BlockDim> BlockReduce;
79+ __shared__ typename BlockReduce::TempStorage temp_storage;
80+
81+ MPType local_data = init;
82+ for (int i = thread_id; i < reduce_num; i += gridDim.x * blockDim.x ) {
83+ local_data = static_cast <MPType>(
84+ reducer (local_data, static_cast <MPType>(transformer (x[i]))));
85+ }
86+ __syncthreads ();
87+
88+ local_data = BlockReduce (temp_storage).Reduce (local_data, reducer);
89+
90+ if (threadIdx.x == 0 ) {
91+ y[blockIdx.x ] = static_cast <Ty>(local_data);
92+ }
93+ }
94+
6995// reduce the last axis of 2d array
70- template <typename Tx, typename Ty , typename ReduceOp , typename TransformOp ,
71- int BlockDim>
96+ template <typename Tx, typename MPType , typename Ty , typename ReduceOp ,
97+ typename TransformOp, int BlockDim>
7298__global__ void ReduceKernel2D (const Tx* x, Ty* y, ReduceOp reducer,
73- TransformOp transformer, Ty init,
99+ TransformOp transformer, MPType init,
74100 int reduce_num) {
75- __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
101+ __shared__
102+ typename cub::BlockReduce<MPType, BlockDim>::TempStorage temp_storage;
76103 int idx_x = blockIdx.x * reduce_num;
77104 int idx_y = threadIdx.x ;
78- Ty reduce_var = init;
105+ MPType reduce_var = init;
79106 for (int idx_y = threadIdx.x ; idx_y < reduce_num; idx_y += BlockDim)
80107 reduce_var =
81- reducer (reduce_var, static_cast <Ty >(transformer (x[idx_x + idx_y])));
108+ reducer (reduce_var, static_cast <MPType >(transformer (x[idx_x + idx_y])));
82109 __syncthreads ();
83110
84- reduce_var =
85- cub::BlockReduce<Ty, BlockDim>(temp_storage) .Reduce (reduce_var, reducer);
111+ reduce_var = cub::BlockReduce<MPType, BlockDim>(temp_storage)
112+ .Reduce (reduce_var, reducer);
86113
87114 if (threadIdx.x == 0 ) {
88- y[blockIdx.x ] = reduce_var;
115+ y[blockIdx.x ] = static_cast <Ty>( reduce_var) ;
89116 }
90117}
91118
92- template <typename Tx, typename Ty , typename ReduceOp , typename TransformOp ,
93- int BlockDim, int Rank, int ReduceRank>
119+ template <typename Tx, typename MPType , typename Ty , typename ReduceOp ,
120+ typename TransformOp, int BlockDim, int Rank, int ReduceRank>
94121__global__ void ReduceKernel (const Tx* x, Ty* y, ReduceOp reducer,
95- TransformOp transformer, Ty init, int reduce_num ,
96- Array<int , Rank> x_strides,
122+ TransformOp transformer, MPType init,
123+ int reduce_num, Array<int , Rank> x_strides,
97124 Array<int , ReduceRank> reduce_dim,
98125 Array<int , ReduceRank> reduce_strides,
99126 Array<int , Rank - ReduceRank> left_dim,
100127 Array<int , Rank - ReduceRank> left_strides) {
101- __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
128+ __shared__
129+ typename cub::BlockReduce<MPType, BlockDim>::TempStorage temp_storage;
102130 Array<int , Rank> sub_index;
103131 int left_idx = blockIdx.x ;
104132 for (int i = 0 ; i < Rank - ReduceRank; ++i) {
@@ -114,7 +142,7 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
114142
115143 int idx_x = 0 ;
116144 for (int k = 0 ; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
117- Ty reduce_var = static_cast <Ty >(transformer (x[idx_x]));
145+ MPType reduce_var = static_cast <MPType >(transformer (x[idx_x]));
118146
119147 for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
120148 int reduce_idx = i;
@@ -125,16 +153,16 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
125153
126154 int idx_x = 0 ;
127155 for (int k = 0 ; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
128- reduce_var = static_cast <Ty >(
129- reducer (reduce_var, static_cast <Ty >(transformer (x[idx_x]))));
156+ reduce_var = static_cast <MPType >(
157+ reducer (reduce_var, static_cast <MPType >(transformer (x[idx_x]))));
130158 }
131159 __syncthreads ();
132160
133- reduce_var =
134- cub::BlockReduce<Ty, BlockDim>(temp_storage) .Reduce (reduce_var, reducer);
161+ reduce_var = cub::BlockReduce<MPType, BlockDim>(temp_storage)
162+ .Reduce (reduce_var, reducer);
135163
136164 if (threadIdx.x == 0 ) {
137- y[blockIdx.x ] = reduce_var;
165+ y[blockIdx.x ] = static_cast <Ty>( reduce_var) ;
138166 }
139167}
140168
@@ -192,6 +220,53 @@ static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
192220 }
193221}
194222
223+ template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
224+ typename TransformOp, int BlockDim>
225+ typename std::enable_if<!std::is_same<Tx, paddle::platform::float16>::value,
226+ void >::type
227+ LaunchCubReduceKernel (const Tx* x_data, Ty* y_data,
228+ const platform::Place& place, const ReduceOp& reducer,
229+ const TransformOp& transformer, const MPType& init,
230+ int reduce_num, gpuStream_t stream) {
231+ cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x (x_data,
232+ transformer);
233+ size_t temp_storage_bytes = 0 ;
234+ cub::DeviceReduce::Reduce (nullptr , temp_storage_bytes, trans_x, y_data,
235+ reduce_num, reducer, init, stream);
236+ framework::Tensor tmp;
237+ auto * temp_storage = tmp.mutable_data <uint8_t >(
238+ framework::make_ddim ({static_cast <int64_t >(temp_storage_bytes)}), place);
239+ cub::DeviceReduce::Reduce (temp_storage, temp_storage_bytes, trans_x, y_data,
240+ reduce_num, reducer, init, stream);
241+ }
242+
243+ template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
244+ typename TransformOp, int BlockDim>
245+ typename std::enable_if<std::is_same<Tx, paddle::platform::float16>::value,
246+ void >::type
247+ LaunchCubReduceKernel (const Tx* x_data, Ty* y_data,
248+ const platform::Place& place, const ReduceOp& reducer,
249+ const TransformOp& transformer, const MPType& init,
250+ int reduce_num, gpuStream_t stream) {
251+ int element_per_block = BlockDim * 10 ;
252+ int block_per_grid = (reduce_num + element_per_block - 1 ) / element_per_block;
253+
254+ framework::Tensor tmp;
255+ auto * temp_storage = tmp.mutable_data <MPType>(
256+ framework::make_ddim (
257+ {static_cast <int64_t >(block_per_grid * sizeof (MPType))}),
258+ place);
259+
260+ // each block reduce number to interim result
261+ ReduceKernel1D<Tx, MPType, MPType, ReduceOp, TransformOp,
262+ BlockDim><<<block_per_grid, BlockDim, 0 , stream>>>(
263+ x_data, temp_storage, reducer, transformer, init, reduce_num);
264+ // reduce all number to final result
265+ ReduceKernel1D<MPType, MPType, Ty, ReduceOp, TransformOp,
266+ BlockDim><<<1 , BlockDim, 0 , stream>>>(
267+ temp_storage, y_data, reducer, transformer, init, block_per_grid);
268+ }
269+
195270template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
196271 typename TransformOp>
197272static void TensorReduceImpl (
@@ -201,45 +276,40 @@ static void TensorReduceImpl(
201276 const std::vector<int >& reduce_dim, const std::vector<int >& reduce_strides,
202277 const std::vector<int >& left_dim, const std::vector<int >& left_strides,
203278 gpuStream_t stream) {
279+ using MPType = typename details::MPTypeTrait<Ty>::Type;
280+ MPType init_mp = static_cast <MPType>(init);
281+
204282#define CUB_RANK_CASE (i, ...) \
205283 case i: { \
206284 constexpr auto kRank = i; \
207285 switch (reduce_rank) { __VA_ARGS__; } \
208286 } break
209287
210- #define CUB_REDUCE_RANK_CASE (i, ...) \
211- case i: { \
212- constexpr auto kReduceRank = i; \
213- ReduceKernel<Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank , \
214- kReduceRank ><<<left_num, BlockDim, 0 , stream>>>( \
215- x_data, y_data, reducer, transformer, init , reduce_num, \
216- Array<int , kRank >::From (x_strides), \
217- Array<int , kReduceRank >::From (reduce_dim), \
218- Array<int , kReduceRank >::From (reduce_strides), \
219- Array<int , kRank - kReduceRank >::From (left_dim), \
220- Array<int , kRank - kReduceRank >::From (left_strides)); \
288+ #define CUB_REDUCE_RANK_CASE (i, ...) \
289+ case i: { \
290+ constexpr auto kReduceRank = i; \
291+ ReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim, kRank , \
292+ kReduceRank ><<<left_num, BlockDim, 0 , stream>>>( \
293+ x_data, y_data, reducer, transformer, init_mp , reduce_num, \
294+ Array<int , kRank >::From (x_strides), \
295+ Array<int , kReduceRank >::From (reduce_dim), \
296+ Array<int , kReduceRank >::From (reduce_strides), \
297+ Array<int , kRank - kReduceRank >::From (left_dim), \
298+ Array<int , kRank - kReduceRank >::From (left_strides)); \
221299 } break
222300
223301 int rank = x_strides.size ();
224302 int reduce_rank = reduce_strides.size ();
225303 if (rank == reduce_rank) {
226- cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x (
227- x_data, transformer);
228- size_t temp_storage_bytes = 0 ;
229- cub::DeviceReduce::Reduce (nullptr , temp_storage_bytes, trans_x, y_data,
230- reduce_num, reducer, init, stream);
231- framework::Tensor tmp;
232- auto * temp_storage = tmp.mutable_data <uint8_t >(
233- framework::make_ddim ({static_cast <int64_t >(temp_storage_bytes)}),
234- place);
235- cub::DeviceReduce::Reduce (temp_storage, temp_storage_bytes, trans_x, y_data,
236- reduce_num, reducer, init, stream);
304+ LaunchCubReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim>(
305+ x_data, y_data, place, reducer, transformer, init_mp, reduce_num,
306+ stream);
237307 return ;
238308 }
239309 if (rank == 2 && reduce_rank == 1 && reduce_dim[0 ] == 1 ) {
240- ReduceKernel2D<Tx, Ty, ReduceOp, TransformOp,
310+ ReduceKernel2D<Tx, MPType, Ty, ReduceOp, TransformOp,
241311 BlockDim><<<left_num, BlockDim, 0 , stream>>>(
242- x_data, y_data, reducer, transformer, init , reduce_num);
312+ x_data, y_data, reducer, transformer, init_mp , reduce_num);
243313 return ;
244314 }
245315 /*
@@ -366,8 +436,7 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y,
366436#undef CUB_BLOCK_DIM_CASE
367437}
368438
369- template <typename Tx, typename ReduceOp,
370- template <typename , typename > class TransformOp >
439+ template <typename Tx, typename ReduceOp, template <typename > class TransformOp >
371440struct TensorReduceFunctor {
372441 const framework::Tensor& x;
373442 framework::Tensor* y;
@@ -389,9 +458,9 @@ struct TensorReduceFunctor {
389458
390459 void apply () const {
391460 const Ty& init_cast = static_cast <Ty>(init);
392- TensorReduce<Tx, Ty, ReduceOp, TransformOp<Tx, Ty>>(
393- x, y, origin_reduce_dims, init_cast, reducer, TransformOp<Tx, Ty>() ,
394- stream);
461+ TensorReduce<Tx, Ty, ReduceOp, TransformOp<Ty>>(x, y, origin_reduce_dims,
462+ init_cast, reducer ,
463+ TransformOp<Ty>(), stream);
395464 }
396465};
397466
0 commit comments