1414
1515#include < cmath>
1616#include " paddle/fluid/memory/buffer.h"
17+ #include " paddle/fluid/operators/amp/fp16_type_traits.h"
1718#include " paddle/fluid/operators/optimizers/cast_with_ptr.h"
1819#include " paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
20+ #include " paddle/fluid/operators/optimizers/multi_tensor_apply.h"
1921#include " paddle/fluid/operators/tensor_to_string.h"
2022#include " paddle/fluid/platform/aligned_vector.h"
2123#include " paddle/fluid/platform/collective_helper.h"
@@ -40,6 +42,163 @@ namespace operators {
4042template <typename T>
4143using MasterT = typename details::MPTypeTrait<T>::Type;
4244
45+ template <typename T>
46+ static void FillZeroWithPtr (T *x, size_t n, gpuStream_t stream) {
47+ static_assert (!std::is_same<T, void >::value, " T cannot be void." );
48+ #ifdef PADDLE_WITH_HIP
49+ PADDLE_ENFORCE_GPU_SUCCESS (hipMemsetAsync (x, 0 , n * sizeof (T), stream));
50+ #else
51+ PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (x, 0 , n * sizeof (T), stream));
52+ #endif
53+ }
54+
55+ template <typename T, int BlockDim, int VecSize>
56+ struct L2NormFunctor {
57+ DEVICE void operator ()(int tensor_id, int chunk_id, int offset, int size,
58+ const T *x, MasterT<T> *y, int max_chunk_num) const {
59+ using MT = MasterT<T>;
60+ const T *ptr = x + offset;
61+
62+ using BlockReduce = cub::BlockReduce<MT, BlockDim>;
63+ __shared__ typename BlockReduce::TempStorage storage;
64+
65+ MT square_sum = static_cast <MT>(0 );
66+ int i;
67+ for (i = threadIdx .x * VecSize; i + VecSize <= size;
68+ i += (BlockDim * VecSize)) {
69+ platform::AlignedVector<T, VecSize> tmp_vec;
70+ platform::Load (ptr + i, &tmp_vec);
71+ #pragma unroll
72+ for (int j = 0 ; j < VecSize; ++j) {
73+ auto tmp = static_cast <MT>(tmp_vec[j]);
74+ square_sum += (tmp * tmp);
75+ }
76+ }
77+
78+ for (; i < size; ++i) {
79+ auto tmp = static_cast <MT>(ptr[i]);
80+ square_sum += (tmp * tmp);
81+ }
82+
83+ square_sum = BlockReduce (storage).Reduce (square_sum, cub::Sum ());
84+ if (threadIdx .x == 0 ) {
85+ y[tensor_id * max_chunk_num + chunk_id] = square_sum;
86+ }
87+ }
88+ };
89+
90+ template <typename InT, typename OutT, int BlockDim, bool NeedSqrt>
91+ static __global__ void MultiTensorL2NormReduceAgainCUDAKernel (
92+ const InT *x, OutT *y, int max_chunk_num) {
93+ int tensor_id = blockIdx .x ;
94+ x += (tensor_id * max_chunk_num);
95+ using BlockReduce = cub::BlockReduce<InT, BlockDim>;
96+ __shared__ typename BlockReduce::TempStorage storage;
97+ InT sum = static_cast <InT>(0 );
98+ for (int i = threadIdx .x ; i < max_chunk_num; i += BlockDim) {
99+ sum += x[i];
100+ }
101+ sum = BlockReduce (storage).Reduce (sum, cub::Sum ());
102+ if (threadIdx .x == 0 ) {
103+ if (NeedSqrt) {
104+ y[blockIdx .x ] = static_cast <OutT>(sqrtf (sum));
105+ } else {
106+ y[blockIdx .x ] = static_cast <OutT>(sum);
107+ }
108+ }
109+ }
110+
111+ template <typename T>
112+ static int GetChunkedVecSize (const T *ptr, int chunk_size) {
113+ static_assert (!std::is_same<T, void >::value, " T cannot be void." );
114+
115+ constexpr int max_load_bits = 128 ;
116+ int valid_vec_size = max_load_bits / CHAR_BIT / sizeof (T);
117+ auto address = reinterpret_cast <uintptr_t >(ptr);
118+ constexpr int vec8 = alignof (platform::AlignedVector<T, 8 >);
119+ constexpr int vec4 = alignof (platform::AlignedVector<T, 4 >);
120+ constexpr int vec2 = alignof (platform::AlignedVector<T, 2 >);
121+ if (address % vec8 == 0 && chunk_size % vec8 == 0 ) {
122+ return std::min (8 , valid_vec_size);
123+ } else if (address % vec4 == 0 && chunk_size % vec4 == 0 ) {
124+ return std::min (4 , valid_vec_size);
125+ } else if (address % vec2 == 0 && chunk_size % vec2 == 0 ) {
126+ return std::min (2 , valid_vec_size);
127+ } else {
128+ return 1 ;
129+ }
130+ }
131+
132+ #define PD_VEC_MULTI_TENSOR_APPLY_CASE (__vec_size, ...) \
133+ case __vec_size: { \
134+ constexpr int kVecSize = __vec_size; \
135+ __VA_ARGS__; \
136+ break ; \
137+ }
138+
139+ #define PD_VEC_MULTI_TENSOR_APPLY (__vec_size, ...) \
140+ do { \
141+ switch (__vec_size) { \
142+ PD_VEC_MULTI_TENSOR_APPLY_CASE (8 , __VA_ARGS__); \
143+ PD_VEC_MULTI_TENSOR_APPLY_CASE (4 , __VA_ARGS__); \
144+ PD_VEC_MULTI_TENSOR_APPLY_CASE (2 , __VA_ARGS__); \
145+ PD_VEC_MULTI_TENSOR_APPLY_CASE (1 , __VA_ARGS__); \
146+ } \
147+ } while (0 )
148+
149+ // TODO(zengjinle): which chunk_size is better?
150+ template <typename InT, typename OutT, bool NeedSqrt = false ,
151+ int MaxTensorNumPerLaunch = 50 , int MaxChunkNumPerLaunch = 680 ,
152+ int BlockDim = 512 >
153+ static void MultiTensorL2Norm (const platform::CUDAPlace &place,
154+ gpuStream_t stream, const InT *x,
155+ const int *offsets, int n, OutT *y,
156+ int chunk_size = 65536 ) {
157+ if (n <= 0 ) return ;
158+
159+ constexpr int kNumTensor = MaxTensorNumPerLaunch;
160+ constexpr int kNumChunk = MaxChunkNumPerLaunch;
161+ constexpr int kBlockDim = BlockDim;
162+
163+ int max_chunk_num = -1 ;
164+ int vec_size = 8 ;
165+ int total_chunk_num = 0 ;
166+ for (int i = 0 ; i < n; ++i) {
167+ vec_size = std::min (
168+ vec_size, GetChunkedVecSize (x + offsets[i] - offsets[0 ], chunk_size));
169+ int length = offsets[i + 1 ] - offsets[i];
170+ auto tmp_chunk_num = (length + chunk_size - 1 ) / chunk_size;
171+ max_chunk_num = std::max (max_chunk_num, tmp_chunk_num);
172+ total_chunk_num += tmp_chunk_num;
173+ }
174+
175+ VLOG (1 ) << " MultiTensorL2Norm max_chunk_num = " << max_chunk_num
176+ << " , total_chunk_num = " << total_chunk_num
177+ << " , tensor_num = " << n;
178+
179+ using MT = MasterT<InT>;
180+ memory::Buffer tmp_out (place);
181+ auto *tmp_out_ptr = tmp_out.Alloc <MT>(n * max_chunk_num);
182+ FillZeroWithPtr (tmp_out_ptr, n * max_chunk_num, stream);
183+
184+ #define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \
185+ do { \
186+ using FunctorT = L2NormFunctor<InT, kBlockDim , kVecSize >; \
187+ VLOG (10 ) << __func__ << " " << typeid (InT).name () \
188+ << " VecSize = " << kVecSize ; \
189+ MultiTensorApply<FunctorT, kBlockDim , kNumTensor , kNumChunk >( \
190+ FunctorT (), stream, offsets, n, chunk_size, x, tmp_out_ptr, \
191+ max_chunk_num); \
192+ } while (0 )
193+
194+ PD_VEC_MULTI_TENSOR_APPLY (vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL);
195+ #undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL
196+
197+ MultiTensorL2NormReduceAgainCUDAKernel<MT, OutT, kBlockDim ,
198+ NeedSqrt><<<n, kBlockDim , 0 , stream>>> (
199+ tmp_out_ptr, y, max_chunk_num);
200+ }
201+
43202template <int LogLevel>
44203static void LogParamAndTrustRatioDivSquareNorm (
45204 const framework::ExecutionContext &ctx, const float *param_square_norm,
@@ -620,76 +779,6 @@ static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out,
620779 num_items, reduction_op, init, stream));
621780}
622781
623- template <typename InputIteratorT, typename OutputIteratorT,
624- typename OffsetIteratorT, typename ReductionOp, typename T>
625- static void CubDeviceSegmentedReduce (InputIteratorT d_in, OutputIteratorT d_out,
626- int num_segments,
627- OffsetIteratorT d_begin_offsets,
628- OffsetIteratorT d_end_offsets,
629- ReductionOp reduction_op, T initial_value,
630- gpuStream_t stream,
631- memory::Buffer *buffer) {
632- void *d_temp_storage = nullptr ;
633- size_t temp_storage_bytes = 0 ;
634- PADDLE_ENFORCE_GPU_SUCCESS (cub::DeviceSegmentedReduce::Reduce (
635- d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
636- d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
637- d_temp_storage = buffer->Alloc <void >(temp_storage_bytes);
638- PADDLE_ENFORCE_GPU_SUCCESS (cub::DeviceSegmentedReduce::Reduce (
639- d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
640- d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
641- }
642-
643- template <typename T>
644- struct AddConstantFunctor {
645- explicit AddConstantFunctor (T bias) : bias_(bias) {}
646-
647- T operator ()(T x) const { return x + bias_; }
648-
649- private:
650- T bias_;
651- };
652-
653- template <typename T>
654- struct OffsetWithBiasFunctor {
655- OffsetWithBiasFunctor (const T *offset, T bias)
656- : offset_(offset), bias_(bias) {}
657-
658- HOSTDEVICE T operator ()(T idx) const { return offset_[idx] - bias_; }
659-
660- HOSTDEVICE constexpr bool operator ==(const OffsetWithBiasFunctor<T> &) const {
661- return true ;
662- }
663-
664- private:
665- const T *offset_;
666- const T bias_;
667- };
668-
669- template <typename T, typename OffsetT>
670- static void CubDeviceSegmentedSquareNorm (const T *x, MasterT<T> *y, int n,
671- const OffsetT *offset,
672- OffsetT init_offset,
673- gpuStream_t stream,
674- memory::Buffer *buffer) {
675- if (n <= 0 ) return ;
676- cub::TransformInputIterator<MasterT<T>, SquareFunctor<T>, const T *> iter (
677- x, SquareFunctor<T>());
678- if (init_offset == static_cast <OffsetT>(0 )) {
679- CubDeviceSegmentedReduce (iter, y, n, offset, offset + 1 , cub::Sum (),
680- static_cast <MasterT<T>>(0 ), stream, buffer);
681- } else {
682- cub::CountingInputIterator<OffsetT> cnt_iter (0 );
683- OffsetWithBiasFunctor<OffsetT> functor (offset, init_offset);
684- cub::TransformInputIterator<OffsetT, OffsetWithBiasFunctor<OffsetT>,
685- cub::CountingInputIterator<OffsetT>>
686- offset_iter (cnt_iter, functor);
687- CubDeviceSegmentedReduce (iter, y, n, offset_iter, offset_iter + 1 ,
688- cub::Sum (), static_cast <MasterT<T>>(0 ), stream,
689- buffer);
690- }
691- }
692-
693782template <typename T>
694783static void GetSquareGradNormImpl (const T *grad, int n, float *square_norm,
695784 gpuStream_t stream,
@@ -862,16 +951,6 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel,
862951 }
863952}
864953
865- template <typename T>
866- static void FillZeroWithPtr (T *x, size_t n, gpuStream_t stream) {
867- static_assert (!std::is_same<T, void >::value, " T cannot be void." );
868- #ifdef PADDLE_WITH_HIP
869- PADDLE_ENFORCE_GPU_SUCCESS (hipMemsetAsync (x, 0 , n * sizeof (T), stream));
870- #else
871- PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (x, 0 , n * sizeof (T), stream));
872- #endif
873- }
874-
875954template <typename T>
876955class DistributedFusedLambOpKernel <platform::CUDADeviceContext, T>
877956 : public framework::OpKernel<T> {
@@ -1191,13 +1270,16 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
11911270 fp16_partial_fused_offsets_t ->data <int >();
11921271
11931272 VLOG (1 ) << " FusedParamOffsets: "
1194- << FlattenToString (fused_offsets, fused_offsets_t ->numel (), place);
1273+ << FlattenToString (fused_offsets, fused_offsets_t ->numel (),
1274+ fused_offsets_t ->place ());
11951275 VLOG (1 ) << " FP32ShardFusedParamOffsets: "
11961276 << FlattenToString (fp32_partial_fused_offsets,
1197- fp32_partial_fused_offsets_t ->numel (), place);
1277+ fp32_partial_fused_offsets_t ->numel (),
1278+ fp32_partial_fused_offsets_t ->place ());
11981279 VLOG (1 ) << " FP16ShardFusedParamOffsets: "
11991280 << FlattenToString (fp16_partial_fused_offsets,
1200- fp16_partial_fused_offsets_t ->numel (), place);
1281+ fp16_partial_fused_offsets_t ->numel (),
1282+ fp16_partial_fused_offsets_t ->place ());
12011283
12021284 if (num_devices > 1 ) {
12031285 if (use_master_param_norm) {
@@ -1207,32 +1289,26 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
12071289 FillZeroWithPtr (trust_ratio_div_square_norm, param_num, stream);
12081290 }
12091291 }
1210- CubDeviceSegmentedSquareNorm (fp32_param, param_square_norm,
1211- fp32_global_param_num, fused_offsets, 0 ,
1212- stream, &cub_tmp_buffer);
1292+ MultiTensorL2Norm (place, stream, fp32_param, fused_offsets,
1293+ fp32_global_param_num, param_square_norm);
12131294 if (use_master_param_norm) {
1214- CubDeviceSegmentedSquareNorm (
1215- master_param + fp16_offset, param_square_norm + fp16_local_start_idx,
1216- fp16_local_param_num, fp16_partial_fused_offsets, 0 , stream,
1217- &cub_tmp_buffer);
1295+ MultiTensorL2Norm (place, stream, master_param + fp16_offset,
1296+ fp16_partial_fused_offsets, fp16_local_param_num,
1297+ param_square_norm + fp16_local_start_idx);
12181298 } else {
12191299 // NOTE: extra computation is performed. We can improve this performance
12201300 // if needed in the future.
1221- CubDeviceSegmentedSquareNorm (
1222- fp16_param, param_square_norm + fp32_global_param_num,
1223- fp16_global_param_num, fused_offsets + fp32_global_param_num,
1224- static_cast <int >(fp32_numel), stream, &cub_tmp_buffer);
1301+ MultiTensorL2Norm (
1302+ place, stream, fp16_param, fused_offsets + fp32_global_param_num,
1303+ fp16_global_param_num, param_square_norm + fp32_global_param_num);
12251304 }
12261305
1227- CubDeviceSegmentedSquareNorm (
1228- trust_ratio_div, trust_ratio_div_square_norm + fp32_local_start_idx,
1229- fp32_local_param_num, fp32_partial_fused_offsets, 0 , stream,
1230- &cub_tmp_buffer);
1231- CubDeviceSegmentedSquareNorm (
1232- trust_ratio_div + fp32_numel_each_device,
1233- trust_ratio_div_square_norm + fp16_local_start_idx,
1234- fp16_local_param_num, fp16_partial_fused_offsets, 0 , stream,
1235- &cub_tmp_buffer);
1306+ MultiTensorL2Norm (place, stream, trust_ratio_div,
1307+ fp32_partial_fused_offsets, fp32_local_param_num,
1308+ trust_ratio_div_square_norm + fp32_local_start_idx);
1309+ MultiTensorL2Norm (place, stream, trust_ratio_div + fp32_numel_each_device,
1310+ fp16_partial_fused_offsets, fp16_local_param_num,
1311+ trust_ratio_div_square_norm + fp16_local_start_idx);
12361312
12371313 VLOG (1 ) << " TrustRatioDiv L2-Norm before allreduce: "
12381314 << FlattenToString (trust_ratio_div_square_norm, param_num, place);
0 commit comments