Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,31 @@ class DistributedFusedLambInitOpMaker
"The fp32 beta1 power accumulator tensor. Its shape is [1].");
AddOutput("Beta2Pow",
"The fp32 beta2 power accumulator tensor. Its shape is [1].");
AddOutput("FusedIndices",
"The param index of each element in FP32FusedParam. Its shape is "
"[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...].");
AddOutput(
"FusedParamOffsets",
"The numel offset of each parameter inside the FP32FusedParam. Its "
"shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
"+ n_2, ...].");
AddOutput("FP32ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp32_local_param_num + 1].");
AddOutput("FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1].");
"+ n_2, ...]. It should be in CPUPlace.");
AddOutput(
"WeightDecay",
"The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N].");
"FP32ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace.");
AddOutput(
"FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace.");
AddOutput("ParamInfo",
"The param info. It should be in CPUPlace, and its shape is [6]"
"CPUPlace, and its shape is [6]. It is "
"CPUPlace, and its shape is [8]. It is "
"[fp32_shard_param_start_idx, fp32_local_param_num, "
"fp32_global_param_num, fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num].");

"fp32_global_param_num, fp32_weight_decay_end_idx, "
"fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num, "
"fp16_weight_decay_end_idx].");
AddOutput("ParamOrder",
"The reordered parameter order. Inside this op, "
"the parameter would be reordered by data type and weight decay "
"value.");
AddOutput("ParamOut", "The output parameter list.").AsDuplicable();
AddOutput("MasterParamOut",
"The output master parameter list. It would share the memory of "
Expand All @@ -96,10 +97,8 @@ class DistributedFusedLambInitOpMaker

AddAttr<float>("beta1", "The initial value of Beta1Pow.");
AddAttr<float>("beta2", "The initial value of Beta2Pow.");
AddAttr<std::vector<float>>(
"weight_decay",
"The weight decay for each parameter. Its "
"shape is equal to the global parameter number.");
AddAttr<std::vector<int>>("apply_weight_decay",
"Whether to apply weight decay.");
AddAttr<int>("alignment", "The alignment in bytes for the fused tensors.");
AddAttr<int>("rank", "The global rank of the current process.");
AddAttr<int>("nranks", "The global world size.");
Expand Down
162 changes: 71 additions & 91 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,32 +258,6 @@ static void ShareBufferForNonInitedTensor(framework::Tensor *origin,
<< ") , dtype = " << fused_out->dtype();
}

template <typename OffsetT, typename IndexT>
static __global__ void LambFillFusedIndicesCUDAKernel(const OffsetT *offsets,
IndexT *out,
int offset_num,
int out_num) {
CUDA_KERNEL_LOOP_TYPE(i, out_num, int) {
auto idx = phi::funcs::LowerBound(offsets, offset_num, i);
if (idx == offset_num || offsets[idx] != i) {
--idx;
}
out[i] = idx;
}
}

template <typename T>
static void CopyVectorToTensor(const std::vector<T> &src,
framework::Tensor *dst,
const platform::Place &place,
gpuStream_t stream) {
dst->Resize({static_cast<int64_t>(src.size())});
T *dst_ptr = dst->mutable_data<T>(place);
const T *src_ptr = src.data();
auto nbytes = src.size() * sizeof(T);
memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream);
}

template <typename T>
static void CopyVectorToCPUTensor(const std::vector<T> &src,
framework::Tensor *dst) {
Expand All @@ -294,6 +268,42 @@ static void CopyVectorToCPUTensor(const std::vector<T> &src,
std::memcpy(dst_ptr, src_ptr, nbytes);
}

static size_t ReorderParamGradInfoList(const std::vector<int> &flags,
std::vector<ParamGradInfo> *infos) {
size_t n = infos->size();
std::vector<int> cur_flags;
cur_flags.reserve(n);
for (size_t i = 0; i < n; ++i) {
auto idx = (*infos)[i].idx;
cur_flags.push_back(flags[idx]);
}

auto origin_infos = *infos;
size_t j = 0;
for (size_t i = 0; i < n; ++i) {
if (cur_flags[i]) {
(*infos)[j] = origin_infos[i];
++j;
}
}
size_t ret_idx = j;

for (size_t i = 0; i < n; ++i) {
if (!cur_flags[i]) {
(*infos)[j] = origin_infos[i];
++j;
}
}
return ret_idx;
}

template <typename T>
static T ClipByBound(T x, T low_value, T high_value) {
if (x < low_value) return low_value;
if (x > high_value) return high_value;
return x;
}

template <typename T>
class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
Expand Down Expand Up @@ -404,6 +414,24 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
info->numel_offset = 0; // not determined yet
}
}
const auto &apply_weight_decay =
ctx.Attr<std::vector<int>>("apply_weight_decay");
size_t fp32_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp32_infos);
size_t fp16_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp16_infos);

auto *param_order_t = ctx.Output<framework::Tensor>("ParamOrder");
auto param_num = fp32_infos.size() + fp16_infos.size();
param_order_t->Resize({static_cast<int16_t>(param_num)});
auto *param_order = param_order_t->mutable_data<int>(platform::CPUPlace());
for (size_t i = 0; i < fp32_infos.size(); ++i) {
param_order[i] = static_cast<int>(fp32_infos[i].idx);
}
for (size_t i = 0; i < fp16_infos.size(); ++i) {
param_order[i + fp32_infos.size()] = static_cast<int>(fp16_infos[i].idx);
}

VLOG(10) << "Fill ParamGradInfo ends";

// Step 2: determine the numel_with_padding and numel_offset
Expand Down Expand Up @@ -568,45 +596,29 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
VLOG(10) << "Found the sharding arguments";

auto *param_info_t = ctx.Output<framework::Tensor>("ParamInfo");
param_info_t->Resize({6});
param_info_t->Resize({8});
auto *param_info = param_info_t->mutable_data<int>(platform::CPUPlace());
param_info[0] = static_cast<int>(fp32_start_idx);
param_info[1] = static_cast<int>(fp32_local_param_num);
param_info[2] = static_cast<int>(fp32_infos.size());
param_info[3] = static_cast<int>(fp16_start_idx + fp32_infos.size());
param_info[4] = static_cast<int>(fp16_local_param_num);
param_info[5] = static_cast<int>(fp16_infos.size());
param_info[3] = ClipByBound<int>(fp32_wd_end_idx, fp32_start_idx,
fp32_start_idx + fp32_local_param_num) -
static_cast<int>(fp32_start_idx);
param_info[4] = static_cast<int>(fp16_start_idx + fp32_infos.size());
param_info[5] = static_cast<int>(fp16_local_param_num);
param_info[6] = static_cast<int>(fp16_infos.size());
param_info[7] = ClipByBound<int>(fp16_wd_end_idx, fp16_start_idx,
fp16_start_idx + fp16_local_param_num) -
static_cast<int>(fp16_start_idx);

VLOG(10) << "Start FP32 idx: " << param_info[0];
VLOG(10) << "Local FP32 param num: " << param_info[1];
VLOG(10) << "Global FP32 param num: " << param_info[2];

VLOG(10) << "Start FP16 idx: " << param_info[3];
VLOG(10) << "Local FP16 param num: " << param_info[4];
VLOG(10) << "Global FP16 param num: " << param_info[5];
VLOG(10) << "Start FP16 idx: " << param_info[4];
VLOG(10) << "Local FP16 param num: " << param_info[5];
VLOG(10) << "Global FP16 param num: " << param_info[6];

// For WeightDecay, shard and perform H2D copy
const auto &origin_weight_decay =
ctx.Attr<std::vector<float>>("weight_decay");
PADDLE_ENFORCE_EQ(params.size(), origin_weight_decay.size(),
platform::errors::InvalidArgument(
"The attr(weight_decay) should have the "
"same length with Input(Param)."));
std::vector<float> shard_weight_decay;
shard_weight_decay.reserve(total_local_param_num);
for (size_t i = 0; i < fp32_local_param_num; ++i) {
shard_weight_decay.push_back(
origin_weight_decay[fp32_infos[i + fp32_start_idx].idx]);
}
for (size_t i = 0; i < fp16_local_param_num; ++i) {
shard_weight_decay.push_back(
origin_weight_decay[fp16_infos[i + fp16_start_idx].idx]);
}

// For FusedIndices, launch CUDA kernel to do binary search
auto *fused_indices_t = ctx.Output<framework::Tensor>("FusedIndices");
fused_indices_t->Resize({static_cast<int64_t>(total_numel)});
auto *fused_indices = fused_indices_t->mutable_data<int>(place);
std::vector<int> numel_offsets;
numel_offsets.reserve(params.size() + 1);
for (const auto &info : fp32_infos) {
Expand All @@ -621,21 +633,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
"The numel_offsets number must be one larger than "
"the parameter number."));
VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets);
auto *fused_param_offset_t =
ctx.Output<framework::Tensor>("FusedParamOffsets");
fused_param_offset_t->Resize({static_cast<int64_t>(numel_offsets.size())});
auto *fused_param_offset = fused_param_offset_t->mutable_data<int>(place);
memory::Copy(place, fused_param_offset, platform::CPUPlace(),
numel_offsets.data(),
numel_offsets.size() * sizeof(numel_offsets[0]), stream);
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, total_numel);
LambFillFusedIndicesCUDAKernel<<<config.block_per_grid,
config.thread_per_block, 0, stream>>>(
fused_param_offset, fused_indices, numel_offsets.size() - 1,
total_numel);

std::vector<int> lengths;
lengths.reserve(fp32_local_param_num + fp16_local_param_num);

std::vector<int> fp32_partial_numel_offsets;
fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1);
Expand All @@ -659,9 +656,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
VLOG(10) << "FP32 Partial numel = ["
<< valid_start_n + fp32_infos[i].numel << ","
<< end_n + fp32_infos[i].numel;
lengths.push_back(end_n - valid_start_n);
auto len = end_n - valid_start_n;
fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() +
lengths.back());
len);
}

std::vector<int> fp16_partial_numel_offsets;
Expand All @@ -682,9 +679,9 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE_NE(valid_start_n, end_n,
platform::errors::InvalidArgument(
"Indices sharding error. This may be a bug."));
lengths.push_back(end_n - valid_start_n);
auto len = end_n - valid_start_n;
fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() +
lengths.back());
len);
}

CopyVectorToCPUTensor(numel_offsets,
Expand All @@ -696,23 +693,6 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
fp16_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"));

// Fill the weight decay tensor
PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(),
platform::errors::InvalidArgument(
"Invalid weight decay sharding. This may be a bug."));
std::vector<float> wd_cpu;
for (size_t i = 0; i < shard_weight_decay.size(); ++i) {
int len = lengths[i];
for (int j = 0; j < len; ++j) {
wd_cpu.push_back(shard_weight_decay[i]);
}
}
PADDLE_ENFORCE_EQ(wd_cpu.size() * nranks, fp32_numel + fp16_numel,
platform::errors::InvalidArgument(
"Invalid weight decay sharding. This may be a bug."));
CopyVectorToTensor(wd_cpu, ctx.Output<framework::Tensor>("WeightDecay"),
place, stream);

auto *global_scale = ctx.Output<framework::Tensor>("GlobalScale");
if (!global_scale->IsInitialized()) {
TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f);
Expand Down
34 changes: 19 additions & 15 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,31 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"The fp32 beta1 power accumulator tensor. Its shape is [1].");
AddInput("Beta2Pow",
"The fp32 beta2 power accumulator tensor. Its shape is [1].");
AddInput("FusedIndices",
"The param index of each element in FP32FusedParam. Its shape is "
"[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...].");
AddInput(
"FusedParamOffsets",
"The numel offset of each parameter inside the FP32FusedParam. Its "
"shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
"+ n_2, ...].");
AddInput("FP32ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp32_local_param_num + 1].");
AddInput("FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1].");
AddInput("WeightDecay",
"The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N].");
"+ n_2, ...]. It should be in CPUPlace.");
AddInput(
"FP32ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace.");
AddInput(
"FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace.");
AddInput("ParamInfo",
"The param info. It should be in CPUPlace, and its shape is [6]"
"CPUPlace, and its shape is [6]. It is "
"CPUPlace, and its shape is [8]. It is "
"[fp32_shard_param_start_idx, fp32_local_param_num, "
"fp32_global_param_num, fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num].");
"fp32_global_param_num, fp32_weight_decay_end_idx, "
"fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num, "
"fp16_weight_decay_end_idx].");
AddInput("ParamOrder",
"The reordered parameter order. Inside this op, "
"the parameter would be reordered by data type and weight decay "
"value.");

AddInput("LearningRate",
"The fp32 learning rate tensor. Its shape is [1].");
Expand Down Expand Up @@ -116,6 +119,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"max_global_grad_norm",
"The maximum global gradient l2-norm value for clipping. If "
"max_global_grad_norm <= 0, no clipping would be performed.");
AddAttr<float>("weight_decay", "The weight decay value.");
AddAttr<bool>("clip_after_allreduce",
"Whether to clip before allreduce, only valid when the "
"world size is larger than 1.");
Expand Down
Loading