Skip to content

Commit d32a010

Browse files
authored
Add MultiTensorApply to calculate L2-Norm in DistributedFusedLamb optimizer (#39900)
* add multi tensor apply l2 norm * add multi_tensor_apply code * make sizeof(TensorMeta) smalller * move code to distributed_fused_lamb_op.cu * remove useless FLAGS
1 parent 639675d commit d32a010

File tree

5 files changed

+355
-116
lines changed

5 files changed

+355
-116
lines changed

paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,16 @@ static void CopyVectorToTensor(const std::vector<T> &src,
284284
memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream);
285285
}
286286

287+
template <typename T>
288+
static void CopyVectorToCPUTensor(const std::vector<T> &src,
289+
framework::Tensor *dst) {
290+
dst->Resize({static_cast<int64_t>(src.size())});
291+
T *dst_ptr = dst->mutable_data<T>(platform::CPUPlace());
292+
const T *src_ptr = src.data();
293+
auto nbytes = src.size() * sizeof(T);
294+
std::memcpy(dst_ptr, src_ptr, nbytes);
295+
}
296+
287297
template <typename T>
288298
class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
289299
: public framework::OpKernel<T> {
@@ -677,14 +687,14 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
677687
lengths.back());
678688
}
679689

680-
CopyVectorToTensor(
690+
CopyVectorToCPUTensor(numel_offsets,
691+
ctx.Output<framework::Tensor>("FusedParamOffsets"));
692+
CopyVectorToCPUTensor(
681693
fp32_partial_numel_offsets,
682-
ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"), place,
683-
stream);
684-
CopyVectorToTensor(
694+
ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"));
695+
CopyVectorToCPUTensor(
685696
fp16_partial_numel_offsets,
686-
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"), place,
687-
stream);
697+
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"));
688698

689699
// Fill the weight decay tensor
690700
PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(),

paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,7 @@ class DistributedFusedLambOp : public framework::OperatorWithKernel {
3333
framework::OpKernelType GetKernelTypeForVar(
3434
const std::string &var_name, const framework::Tensor &tensor,
3535
const framework::OpKernelType &expected_kernel_type) const override {
36-
if (var_name == "ParamInfo") {
37-
return expected_kernel_type;
38-
} else {
39-
return framework::OperatorWithKernel::GetKernelTypeForVar(
40-
var_name, tensor, expected_kernel_type);
41-
}
36+
return expected_kernel_type;
4237
}
4338
};
4439

paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu

Lines changed: 179 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
#include <cmath>
1616
#include "paddle/fluid/memory/buffer.h"
17+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
1718
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
1819
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
20+
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
1921
#include "paddle/fluid/operators/tensor_to_string.h"
2022
#include "paddle/fluid/platform/aligned_vector.h"
2123
#include "paddle/fluid/platform/collective_helper.h"
@@ -40,6 +42,163 @@ namespace operators {
4042
template <typename T>
4143
using MasterT = typename details::MPTypeTrait<T>::Type;
4244

45+
template <typename T>
46+
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
47+
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
48+
#ifdef PADDLE_WITH_HIP
49+
PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
50+
#else
51+
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
52+
#endif
53+
}
54+
55+
template <typename T, int BlockDim, int VecSize>
56+
struct L2NormFunctor {
57+
DEVICE void operator()(int tensor_id, int chunk_id, int offset, int size,
58+
const T *x, MasterT<T> *y, int max_chunk_num) const {
59+
using MT = MasterT<T>;
60+
const T *ptr = x + offset;
61+
62+
using BlockReduce = cub::BlockReduce<MT, BlockDim>;
63+
__shared__ typename BlockReduce::TempStorage storage;
64+
65+
MT square_sum = static_cast<MT>(0);
66+
int i;
67+
for (i = threadIdx.x * VecSize; i + VecSize <= size;
68+
i += (BlockDim * VecSize)) {
69+
platform::AlignedVector<T, VecSize> tmp_vec;
70+
platform::Load(ptr + i, &tmp_vec);
71+
#pragma unroll
72+
for (int j = 0; j < VecSize; ++j) {
73+
auto tmp = static_cast<MT>(tmp_vec[j]);
74+
square_sum += (tmp * tmp);
75+
}
76+
}
77+
78+
for (; i < size; ++i) {
79+
auto tmp = static_cast<MT>(ptr[i]);
80+
square_sum += (tmp * tmp);
81+
}
82+
83+
square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum());
84+
if (threadIdx.x == 0) {
85+
y[tensor_id * max_chunk_num + chunk_id] = square_sum;
86+
}
87+
}
88+
};
89+
90+
template <typename InT, typename OutT, int BlockDim, bool NeedSqrt>
91+
static __global__ void MultiTensorL2NormReduceAgainCUDAKernel(
92+
const InT *x, OutT *y, int max_chunk_num) {
93+
int tensor_id = blockIdx.x;
94+
x += (tensor_id * max_chunk_num);
95+
using BlockReduce = cub::BlockReduce<InT, BlockDim>;
96+
__shared__ typename BlockReduce::TempStorage storage;
97+
InT sum = static_cast<InT>(0);
98+
for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) {
99+
sum += x[i];
100+
}
101+
sum = BlockReduce(storage).Reduce(sum, cub::Sum());
102+
if (threadIdx.x == 0) {
103+
if (NeedSqrt) {
104+
y[blockIdx.x] = static_cast<OutT>(sqrtf(sum));
105+
} else {
106+
y[blockIdx.x] = static_cast<OutT>(sum);
107+
}
108+
}
109+
}
110+
111+
template <typename T>
112+
static int GetChunkedVecSize(const T *ptr, int chunk_size) {
113+
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
114+
115+
constexpr int max_load_bits = 128;
116+
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
117+
auto address = reinterpret_cast<uintptr_t>(ptr);
118+
constexpr int vec8 = alignof(platform::AlignedVector<T, 8>);
119+
constexpr int vec4 = alignof(platform::AlignedVector<T, 4>);
120+
constexpr int vec2 = alignof(platform::AlignedVector<T, 2>);
121+
if (address % vec8 == 0 && chunk_size % vec8 == 0) {
122+
return std::min(8, valid_vec_size);
123+
} else if (address % vec4 == 0 && chunk_size % vec4 == 0) {
124+
return std::min(4, valid_vec_size);
125+
} else if (address % vec2 == 0 && chunk_size % vec2 == 0) {
126+
return std::min(2, valid_vec_size);
127+
} else {
128+
return 1;
129+
}
130+
}
131+
132+
#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \
133+
case __vec_size: { \
134+
constexpr int kVecSize = __vec_size; \
135+
__VA_ARGS__; \
136+
break; \
137+
}
138+
139+
#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \
140+
do { \
141+
switch (__vec_size) { \
142+
PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \
143+
PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \
144+
PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \
145+
PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \
146+
} \
147+
} while (0)
148+
149+
// TODO(zengjinle): which chunk_size is better?
150+
template <typename InT, typename OutT, bool NeedSqrt = false,
151+
int MaxTensorNumPerLaunch = 50, int MaxChunkNumPerLaunch = 680,
152+
int BlockDim = 512>
153+
static void MultiTensorL2Norm(const platform::CUDAPlace &place,
154+
gpuStream_t stream, const InT *x,
155+
const int *offsets, int n, OutT *y,
156+
int chunk_size = 65536) {
157+
if (n <= 0) return;
158+
159+
constexpr int kNumTensor = MaxTensorNumPerLaunch;
160+
constexpr int kNumChunk = MaxChunkNumPerLaunch;
161+
constexpr int kBlockDim = BlockDim;
162+
163+
int max_chunk_num = -1;
164+
int vec_size = 8;
165+
int total_chunk_num = 0;
166+
for (int i = 0; i < n; ++i) {
167+
vec_size = std::min(
168+
vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size));
169+
int length = offsets[i + 1] - offsets[i];
170+
auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size;
171+
max_chunk_num = std::max(max_chunk_num, tmp_chunk_num);
172+
total_chunk_num += tmp_chunk_num;
173+
}
174+
175+
VLOG(1) << "MultiTensorL2Norm max_chunk_num = " << max_chunk_num
176+
<< " , total_chunk_num = " << total_chunk_num
177+
<< " , tensor_num = " << n;
178+
179+
using MT = MasterT<InT>;
180+
memory::Buffer tmp_out(place);
181+
auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
182+
FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);
183+
184+
#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \
185+
do { \
186+
using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \
187+
VLOG(10) << __func__ << " " << typeid(InT).name() \
188+
<< " VecSize = " << kVecSize; \
189+
MultiTensorApply<FunctorT, kBlockDim, kNumTensor, kNumChunk>( \
190+
FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \
191+
max_chunk_num); \
192+
} while (0)
193+
194+
PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL);
195+
#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL
196+
197+
MultiTensorL2NormReduceAgainCUDAKernel<MT, OutT, kBlockDim,
198+
NeedSqrt><<<n, kBlockDim, 0, stream>>>(
199+
tmp_out_ptr, y, max_chunk_num);
200+
}
201+
43202
template <int LogLevel>
44203
static void LogParamAndTrustRatioDivSquareNorm(
45204
const framework::ExecutionContext &ctx, const float *param_square_norm,
@@ -620,76 +779,6 @@ static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out,
620779
num_items, reduction_op, init, stream));
621780
}
622781

623-
template <typename InputIteratorT, typename OutputIteratorT,
624-
typename OffsetIteratorT, typename ReductionOp, typename T>
625-
static void CubDeviceSegmentedReduce(InputIteratorT d_in, OutputIteratorT d_out,
626-
int num_segments,
627-
OffsetIteratorT d_begin_offsets,
628-
OffsetIteratorT d_end_offsets,
629-
ReductionOp reduction_op, T initial_value,
630-
gpuStream_t stream,
631-
memory::Buffer *buffer) {
632-
void *d_temp_storage = nullptr;
633-
size_t temp_storage_bytes = 0;
634-
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce(
635-
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
636-
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
637-
d_temp_storage = buffer->Alloc<void>(temp_storage_bytes);
638-
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce(
639-
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
640-
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
641-
}
642-
643-
template <typename T>
644-
struct AddConstantFunctor {
645-
explicit AddConstantFunctor(T bias) : bias_(bias) {}
646-
647-
T operator()(T x) const { return x + bias_; }
648-
649-
private:
650-
T bias_;
651-
};
652-
653-
template <typename T>
654-
struct OffsetWithBiasFunctor {
655-
OffsetWithBiasFunctor(const T *offset, T bias)
656-
: offset_(offset), bias_(bias) {}
657-
658-
HOSTDEVICE T operator()(T idx) const { return offset_[idx] - bias_; }
659-
660-
HOSTDEVICE constexpr bool operator==(const OffsetWithBiasFunctor<T> &) const {
661-
return true;
662-
}
663-
664-
private:
665-
const T *offset_;
666-
const T bias_;
667-
};
668-
669-
template <typename T, typename OffsetT>
670-
static void CubDeviceSegmentedSquareNorm(const T *x, MasterT<T> *y, int n,
671-
const OffsetT *offset,
672-
OffsetT init_offset,
673-
gpuStream_t stream,
674-
memory::Buffer *buffer) {
675-
if (n <= 0) return;
676-
cub::TransformInputIterator<MasterT<T>, SquareFunctor<T>, const T *> iter(
677-
x, SquareFunctor<T>());
678-
if (init_offset == static_cast<OffsetT>(0)) {
679-
CubDeviceSegmentedReduce(iter, y, n, offset, offset + 1, cub::Sum(),
680-
static_cast<MasterT<T>>(0), stream, buffer);
681-
} else {
682-
cub::CountingInputIterator<OffsetT> cnt_iter(0);
683-
OffsetWithBiasFunctor<OffsetT> functor(offset, init_offset);
684-
cub::TransformInputIterator<OffsetT, OffsetWithBiasFunctor<OffsetT>,
685-
cub::CountingInputIterator<OffsetT>>
686-
offset_iter(cnt_iter, functor);
687-
CubDeviceSegmentedReduce(iter, y, n, offset_iter, offset_iter + 1,
688-
cub::Sum(), static_cast<MasterT<T>>(0), stream,
689-
buffer);
690-
}
691-
}
692-
693782
template <typename T>
694783
static void GetSquareGradNormImpl(const T *grad, int n, float *square_norm,
695784
gpuStream_t stream,
@@ -862,16 +951,6 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel,
862951
}
863952
}
864953

865-
template <typename T>
866-
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
867-
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
868-
#ifdef PADDLE_WITH_HIP
869-
PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
870-
#else
871-
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
872-
#endif
873-
}
874-
875954
template <typename T>
876955
class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
877956
: public framework::OpKernel<T> {
@@ -1191,13 +1270,16 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
11911270
fp16_partial_fused_offsets_t->data<int>();
11921271

11931272
VLOG(1) << "FusedParamOffsets: "
1194-
<< FlattenToString(fused_offsets, fused_offsets_t->numel(), place);
1273+
<< FlattenToString(fused_offsets, fused_offsets_t->numel(),
1274+
fused_offsets_t->place());
11951275
VLOG(1) << "FP32ShardFusedParamOffsets: "
11961276
<< FlattenToString(fp32_partial_fused_offsets,
1197-
fp32_partial_fused_offsets_t->numel(), place);
1277+
fp32_partial_fused_offsets_t->numel(),
1278+
fp32_partial_fused_offsets_t->place());
11981279
VLOG(1) << "FP16ShardFusedParamOffsets: "
11991280
<< FlattenToString(fp16_partial_fused_offsets,
1200-
fp16_partial_fused_offsets_t->numel(), place);
1281+
fp16_partial_fused_offsets_t->numel(),
1282+
fp16_partial_fused_offsets_t->place());
12011283

12021284
if (num_devices > 1) {
12031285
if (use_master_param_norm) {
@@ -1207,32 +1289,26 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
12071289
FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
12081290
}
12091291
}
1210-
CubDeviceSegmentedSquareNorm(fp32_param, param_square_norm,
1211-
fp32_global_param_num, fused_offsets, 0,
1212-
stream, &cub_tmp_buffer);
1292+
MultiTensorL2Norm(place, stream, fp32_param, fused_offsets,
1293+
fp32_global_param_num, param_square_norm);
12131294
if (use_master_param_norm) {
1214-
CubDeviceSegmentedSquareNorm(
1215-
master_param + fp16_offset, param_square_norm + fp16_local_start_idx,
1216-
fp16_local_param_num, fp16_partial_fused_offsets, 0, stream,
1217-
&cub_tmp_buffer);
1295+
MultiTensorL2Norm(place, stream, master_param + fp16_offset,
1296+
fp16_partial_fused_offsets, fp16_local_param_num,
1297+
param_square_norm + fp16_local_start_idx);
12181298
} else {
12191299
// NOTE: extra computation is performed. We can improve this performance
12201300
// if needed in the future.
1221-
CubDeviceSegmentedSquareNorm(
1222-
fp16_param, param_square_norm + fp32_global_param_num,
1223-
fp16_global_param_num, fused_offsets + fp32_global_param_num,
1224-
static_cast<int>(fp32_numel), stream, &cub_tmp_buffer);
1301+
MultiTensorL2Norm(
1302+
place, stream, fp16_param, fused_offsets + fp32_global_param_num,
1303+
fp16_global_param_num, param_square_norm + fp32_global_param_num);
12251304
}
12261305

1227-
CubDeviceSegmentedSquareNorm(
1228-
trust_ratio_div, trust_ratio_div_square_norm + fp32_local_start_idx,
1229-
fp32_local_param_num, fp32_partial_fused_offsets, 0, stream,
1230-
&cub_tmp_buffer);
1231-
CubDeviceSegmentedSquareNorm(
1232-
trust_ratio_div + fp32_numel_each_device,
1233-
trust_ratio_div_square_norm + fp16_local_start_idx,
1234-
fp16_local_param_num, fp16_partial_fused_offsets, 0, stream,
1235-
&cub_tmp_buffer);
1306+
MultiTensorL2Norm(place, stream, trust_ratio_div,
1307+
fp32_partial_fused_offsets, fp32_local_param_num,
1308+
trust_ratio_div_square_norm + fp32_local_start_idx);
1309+
MultiTensorL2Norm(place, stream, trust_ratio_div + fp32_numel_each_device,
1310+
fp16_partial_fused_offsets, fp16_local_param_num,
1311+
trust_ratio_div_square_norm + fp16_local_start_idx);
12361312

12371313
VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
12381314
<< FlattenToString(trust_ratio_div_square_norm, param_num, place);

0 commit comments

Comments
 (0)