Skip to content

Commit 2408c47

Browse files
authored
【Hackathon 6th Fundable Projects 3 No.101】 [fluid_ops] distributed_fused_lamb (#66573)
* Fix * Fix * Fix
1 parent 5d81c93 commit 2408c47

File tree

5 files changed

+134
-163
lines changed

5 files changed

+134
-163
lines changed

paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -169,64 +169,3 @@ namespace ops = paddle::operators;
169169
REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb,
170170
ops::DistributedFusedLambOp,
171171
ops::DistributedFusedLambOpMaker);
172-
173-
namespace phi {
174-
namespace fusion {
175-
176-
template <typename T, typename Context>
177-
void DistributedFusedLambKernel(const Context &dev_ctx,
178-
const std::vector<const DenseTensor *> &param,
179-
const std::vector<const DenseTensor *> &grad,
180-
const paddle::optional<DenseTensor> &fp32_param,
181-
const paddle::optional<DenseTensor> &fp32_grad,
182-
const paddle::optional<DenseTensor> &fp16_param,
183-
const paddle::optional<DenseTensor> &fp16_grad,
184-
const DenseTensor &moment1,
185-
const DenseTensor &moment2,
186-
const DenseTensor &beta1_pow,
187-
const DenseTensor &beta2_pow,
188-
const DenseTensor &param_offsets,
189-
const DenseTensor &fp32_partial_offsets,
190-
const DenseTensor &fp16_partial_offsets,
191-
const DenseTensor &param_info,
192-
const DenseTensor &param_order,
193-
const DenseTensor &learning_rate,
194-
const DenseTensor &global_scale,
195-
int acc_steps,
196-
float beta1,
197-
float beta2,
198-
float epsilon,
199-
float max_global_grad_norm,
200-
float weight_decay,
201-
bool clip_after_allreduce,
202-
bool use_master_param_norm,
203-
bool use_master_acc_grad,
204-
bool is_grad_scaled_by_nranks,
205-
bool use_hierarchical_allreduce,
206-
int64_t nranks,
207-
const std::vector<int> &ring_ids,
208-
DenseTensor *fp32_param_out,
209-
DenseTensor *fp16_param_out,
210-
DenseTensor *fp32_acc_grad,
211-
DenseTensor *fp16_acc_grad,
212-
DenseTensor *moment1_out,
213-
DenseTensor *moment2_out,
214-
DenseTensor *beta1_pow_out,
215-
DenseTensor *beta2_pow_out,
216-
DenseTensor *param_out,
217-
DenseTensor *found_inf,
218-
DenseTensor *acc_step,
219-
DenseTensor *stop_update,
220-
DenseTensor *step) {
221-
PADDLE_THROW(phi::errors::Unimplemented(
222-
"The distributed_fused_lamb operator does not support CPU yet."));
223-
}
224-
225-
} // namespace fusion
226-
} // namespace phi
227-
228-
PD_REGISTER_KERNEL(distributed_fused_lamb,
229-
CPU,
230-
ALL_LAYOUT,
231-
phi::fusion::DistributedFusedLambKernel,
232-
float) {}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/core/kernel_registry.h"
16+
17+
namespace phi {
18+
namespace fusion {
19+
20+
template <typename T, typename Context>
21+
void DistributedFusedLambKernel(const Context &dev_ctx,
22+
const std::vector<const DenseTensor *> &param,
23+
const std::vector<const DenseTensor *> &grad,
24+
const paddle::optional<DenseTensor> &fp32_param,
25+
const paddle::optional<DenseTensor> &fp32_grad,
26+
const paddle::optional<DenseTensor> &fp16_param,
27+
const paddle::optional<DenseTensor> &fp16_grad,
28+
const DenseTensor &moment1,
29+
const DenseTensor &moment2,
30+
const DenseTensor &beta1_pow,
31+
const DenseTensor &beta2_pow,
32+
const DenseTensor &param_offsets,
33+
const DenseTensor &fp32_partial_offsets,
34+
const DenseTensor &fp16_partial_offsets,
35+
const DenseTensor &param_info,
36+
const DenseTensor &param_order,
37+
const DenseTensor &learning_rate,
38+
const DenseTensor &global_scale,
39+
int acc_steps,
40+
float beta1,
41+
float beta2,
42+
float epsilon,
43+
float max_global_grad_norm,
44+
float weight_decay,
45+
bool clip_after_allreduce,
46+
bool use_master_param_norm,
47+
bool use_master_acc_grad,
48+
bool is_grad_scaled_by_nranks,
49+
bool use_hierarchical_allreduce,
50+
int64_t nranks,
51+
const std::vector<int> &ring_ids,
52+
DenseTensor *fp32_param_out,
53+
DenseTensor *fp16_param_out,
54+
DenseTensor *fp32_acc_grad,
55+
DenseTensor *fp16_acc_grad,
56+
DenseTensor *moment1_out,
57+
DenseTensor *moment2_out,
58+
DenseTensor *beta1_pow_out,
59+
DenseTensor *beta2_pow_out,
60+
DenseTensor *param_out,
61+
DenseTensor *found_inf,
62+
DenseTensor *acc_step,
63+
DenseTensor *stop_update,
64+
DenseTensor *step) {
65+
PADDLE_THROW(phi::errors::Unimplemented(
66+
"The distributed_fused_lamb operator does not support CPU yet."));
67+
}
68+
69+
} // namespace fusion
70+
} // namespace phi
71+
72+
PD_REGISTER_KERNEL(distributed_fused_lamb,
73+
CPU,
74+
ALL_LAYOUT,
75+
phi::fusion::DistributedFusedLambKernel,
76+
float) {}

paddle/phi/kernels/funcs/math.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "math.h" // NOLINT
1818
#include "paddle/common/hostdevice.h"
19+
#include "paddle/phi/common/bfloat16.h"
1920
#include "paddle/phi/common/float16.h"
2021

2122
namespace phi {

paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu renamed to paddle/phi/kernels/gpu/distributed_fused_lamb_kernel.cu

Lines changed: 56 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/platform/collective_helper.h"
1615
#include "paddle/phi/kernels/funcs/multi_tensor_apply_util.h"
1716

1817
#include "paddle/phi/backends/context_pool.h"
@@ -33,7 +32,6 @@
3332
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
3433
#include "paddle/common/flags.h"
3534
#include "paddle/phi/core/distributed/nccl_comm_context.h"
36-
COMMON_DECLARE_bool(dynamic_static_unified_comm);
3735
#endif
3836

3937
#ifdef __NVCC__
@@ -903,30 +901,19 @@ static bool CreatePreMulScaleOpIfSupported(
903901
ncclRedOp_t *op,
904902
distributed::NCCLCommContext *comm_ctx = nullptr) {
905903
#if NCCL_VERSION_CODE >= 21100
906-
if (FLAGS_dynamic_static_unified_comm) {
907-
PADDLE_ENFORCE_NOT_NULL(
908-
comm_ctx,
909-
phi::errors::InvalidArgument(
910-
"You choose to use new communication library by "
911-
"setting environment "
912-
"variable FLAGS_dynamic_static_unified_comm True. "
913-
"But parameter of comm_ctx should not be nullptr."));
914-
int ver = comm_ctx->GetNcclVersion();
915-
if (ver >= 21100) {
916-
VLOG(10) << "ncclRedOpCreatePreMulSum is supported.";
917-
comm_ctx->RedOpCreatePreMulSum(
918-
op, const_cast<void *>(scale), dtype, ncclScalarDevice);
919-
return true;
920-
}
921-
} else {
922-
int ver;
923-
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&ver));
924-
if (ver >= 21100) {
925-
VLOG(10) << "ncclRedOpCreatePreMulSum is supported.";
926-
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpCreatePreMulSum(
927-
op, const_cast<void *>(scale), dtype, ncclScalarDevice, comm));
928-
return true;
929-
}
904+
PADDLE_ENFORCE_NOT_NULL(
905+
comm_ctx,
906+
phi::errors::InvalidArgument(
907+
"You choose to use new communication library by "
908+
"setting environment "
909+
"variable FLAGS_dynamic_static_unified_comm True. "
910+
"But parameter of comm_ctx should not be nullptr."));
911+
int ver = comm_ctx->GetNcclVersion();
912+
if (ver >= 21100) {
913+
VLOG(10) << "ncclRedOpCreatePreMulSum is supported.";
914+
comm_ctx->RedOpCreatePreMulSum(
915+
op, const_cast<void *>(scale), dtype, ncclScalarDevice);
916+
return true;
930917
}
931918
#endif
932919
VLOG(10) << "ncclRedOpCreatePreMulSum is not supported.";
@@ -940,18 +927,14 @@ static void DestoryOpIfSupported(
940927
#if NCCL_VERSION_CODE >= 21100
941928
VLOG(10) << "ncclRedOpDestroy starts";
942929

943-
if (FLAGS_dynamic_static_unified_comm) {
944-
PADDLE_ENFORCE_NOT_NULL(
945-
comm_ctx,
946-
phi::errors::InvalidArgument(
947-
"You choose to use new communication library by "
948-
"setting environment "
949-
"variable FLAGS_dynamic_static_unified_comm True. "
950-
"But parameter of comm_ctx should not be nullptr."));
951-
comm_ctx->RedOpDestroy(op);
952-
} else {
953-
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, comm));
954-
}
930+
PADDLE_ENFORCE_NOT_NULL(
931+
comm_ctx,
932+
phi::errors::InvalidArgument(
933+
"You choose to use new communication library by "
934+
"setting environment "
935+
"variable FLAGS_dynamic_static_unified_comm True. "
936+
"But parameter of comm_ctx should not be nullptr."));
937+
comm_ctx->RedOpDestroy(op);
955938
VLOG(10) << "ncclRedOpDestroy ends";
956939

957940
#endif
@@ -989,15 +972,13 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
989972
const phi::GPUContext &dev_ctx,
990973
distributed::NCCLCommContext *comm_ctx,
991974
const T *scale = nullptr) {
992-
if (FLAGS_dynamic_static_unified_comm) {
993-
PADDLE_ENFORCE_NOT_NULL(
994-
comm_ctx,
995-
phi::errors::InvalidArgument(
996-
"You choose to use new communication library by "
997-
"setting environment "
998-
"variable FLAGS_dynamic_static_unified_comm True. "
999-
"But parameter of comm_ctx should not be nullptr."));
1000-
}
975+
PADDLE_ENFORCE_NOT_NULL(
976+
comm_ctx,
977+
phi::errors::InvalidArgument(
978+
"You choose to use new communication library by "
979+
"setting environment "
980+
"variable FLAGS_dynamic_static_unified_comm True. "
981+
"But parameter of comm_ctx should not be nullptr."));
1001982

1002983
static_assert(
1003984
std::is_same<T, float>::value || std::is_same<T, dtype::float16>::value,
@@ -1758,71 +1739,45 @@ void DistributedFusedLambKernel(
17581739
int64_t global_rank = 0, local_rank = 0;
17591740
ncclComm_t global_comm = nullptr, local_comm = nullptr,
17601741
external_comm = nullptr;
1761-
paddle::platform::NCCLComm *nccl_comm_handle = nullptr,
1762-
*local_nccl_comm_handle = nullptr;
17631742
distributed::NCCLCommContext *comm_ctx = nullptr, *local_comm_ctx = nullptr,
17641743
*external_comm_ctx = nullptr;
17651744

17661745
const auto &comm_context_manager =
17671746
phi::distributed::CommContextManager::GetInstance();
17681747

1769-
if (FLAGS_dynamic_static_unified_comm) {
1770-
CheckCommContextHasRingId(comm_context_manager, ring_ids[0]);
1771-
1772-
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
1773-
comm_context_manager.Get(std::to_string(ring_ids[0])));
1774-
PADDLE_ENFORCE_NE(comm_ctx,
1775-
nullptr,
1776-
phi::errors::Unavailable(
1777-
"NCCLCommContext is nullptr, collective op should "
1778-
"has ring_id attr."));
1779-
1780-
global_comm = comm_ctx->GetNcclComm();
1781-
global_rank = comm_ctx->GetRank();
1782-
if (local_shard) {
1783-
CheckCommContextHasRingId(comm_context_manager, ring_ids[1]);
1748+
CheckCommContextHasRingId(comm_context_manager, ring_ids[0]);
17841749

1785-
local_comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
1786-
comm_context_manager.Get(std::to_string(ring_ids[1])));
1787-
local_comm = local_comm_ctx->GetNcclComm();
1788-
local_rank = local_comm_ctx->GetRank();
1789-
if (use_hierarchical_allreduce) {
1790-
CheckCommContextHasRingId(comm_context_manager, ring_ids[2]);
1750+
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
1751+
comm_context_manager.Get(std::to_string(ring_ids[0])));
1752+
PADDLE_ENFORCE_NE(comm_ctx,
1753+
nullptr,
1754+
phi::errors::Unavailable(
1755+
"NCCLCommContext is nullptr, collective op should "
1756+
"has ring_id attr."));
17911757

1792-
external_comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
1793-
comm_context_manager.Get(std::to_string(ring_ids[2])));
1794-
external_comm = external_comm_ctx->GetNcclComm();
1795-
}
1796-
} else {
1797-
local_comm = global_comm;
1798-
local_rank = global_rank;
1758+
global_comm = comm_ctx->GetNcclComm();
1759+
global_rank = comm_ctx->GetRank();
1760+
if (local_shard) {
1761+
CheckCommContextHasRingId(comm_context_manager, ring_ids[1]);
1762+
1763+
local_comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
1764+
comm_context_manager.Get(std::to_string(ring_ids[1])));
1765+
local_comm = local_comm_ctx->GetNcclComm();
1766+
local_rank = local_comm_ctx->GetRank();
1767+
if (use_hierarchical_allreduce) {
1768+
CheckCommContextHasRingId(comm_context_manager, ring_ids[2]);
1769+
1770+
external_comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
1771+
comm_context_manager.Get(std::to_string(ring_ids[2])));
1772+
external_comm = external_comm_ctx->GetNcclComm();
17991773
}
1800-
1801-
VLOG(3) << "new comm_context_manager has ring_id " << ring_ids[0];
18021774
} else {
1803-
if (nranks > 1) {
1804-
nccl_comm_handle =
1805-
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
1806-
global_comm = nccl_comm_handle->comm();
1807-
global_rank = nccl_comm_handle->rank();
1808-
if (local_shard) {
1809-
local_nccl_comm_handle =
1810-
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1],
1811-
place);
1812-
local_comm = local_nccl_comm_handle->comm();
1813-
local_rank = local_nccl_comm_handle->rank();
1814-
if (use_hierarchical_allreduce) {
1815-
external_comm = paddle::platform::NCCLCommContext::Instance()
1816-
.Get(ring_ids[2], place)
1817-
->comm();
1818-
}
1819-
} else {
1820-
local_comm = global_comm;
1821-
local_rank = global_rank;
1822-
}
1823-
}
1775+
local_comm = global_comm;
1776+
local_rank = global_rank;
18241777
}
18251778

1779+
VLOG(3) << "new comm_context_manager has ring_id " << ring_ids[0];
1780+
18261781
memory_utils::Buffer grad_norm_square_buffer(place);
18271782
auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
18281783
memory_utils::Buffer cub_tmp_buffer(place);

test/legacy_test/test_distributed_fused_lamb_op_with_clip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def run_test(
6868
os.environ['MAX_GLOBAL_NORM'] = str(max_global_norm)
6969
os.environ['GRADIENT_MERGE_STEPS'] = str(gradient_merge_steps)
7070
os.environ['USE_MASTER_ACC_GRAD'] = str(1 if use_master_acc_grad else 0)
71-
os.environ["FLAGS_dynamic_static_unified_comm"] = "0"
71+
os.environ["FLAGS_dynamic_static_unified_comm"] = "1"
7272
os.environ.update(need_env)
7373

7474
touch_file_env = 'SUCCESS_TOUCH_FILE'

0 commit comments

Comments
 (0)