1- // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22//
33// Licensed under the Apache License, Version 2.0 (the "License");
44// you may not use this file except in compliance with the License.
1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- #include " paddle/fluid/platform/collective_helper.h"
1615#include " paddle/phi/kernels/funcs/multi_tensor_apply_util.h"
1716
1817#include " paddle/phi/backends/context_pool.h"
3332#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
3433#include " paddle/common/flags.h"
3534#include " paddle/phi/core/distributed/nccl_comm_context.h"
36- COMMON_DECLARE_bool (dynamic_static_unified_comm);
3735#endif
3836
3937#ifdef __NVCC__
@@ -903,30 +901,19 @@ static bool CreatePreMulScaleOpIfSupported(
903901 ncclRedOp_t *op,
904902 distributed::NCCLCommContext *comm_ctx = nullptr ) {
905903#if NCCL_VERSION_CODE >= 21100
906- if (FLAGS_dynamic_static_unified_comm) {
907- PADDLE_ENFORCE_NOT_NULL (
908- comm_ctx,
909- phi::errors::InvalidArgument (
910- " You choose to use new communication library by "
911- " setting environment "
912- " variable FLAGS_dynamic_static_unified_comm True. "
913- " But parameter of comm_ctx should not be nullptr." ));
914- int ver = comm_ctx->GetNcclVersion ();
915- if (ver >= 21100 ) {
916- VLOG (10 ) << " ncclRedOpCreatePreMulSum is supported." ;
917- comm_ctx->RedOpCreatePreMulSum (
918- op, const_cast <void *>(scale), dtype, ncclScalarDevice);
919- return true ;
920- }
921- } else {
922- int ver;
923- PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::ncclGetVersion (&ver));
924- if (ver >= 21100 ) {
925- VLOG (10 ) << " ncclRedOpCreatePreMulSum is supported." ;
926- PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::ncclRedOpCreatePreMulSum (
927- op, const_cast <void *>(scale), dtype, ncclScalarDevice, comm));
928- return true ;
929- }
904+ PADDLE_ENFORCE_NOT_NULL (
905+ comm_ctx,
906+ phi::errors::InvalidArgument (
907+ " You choose to use new communication library by "
908+ " setting environment "
909+ " variable FLAGS_dynamic_static_unified_comm True. "
910+ " But parameter of comm_ctx should not be nullptr." ));
911+ int ver = comm_ctx->GetNcclVersion ();
912+ if (ver >= 21100 ) {
913+ VLOG (10 ) << " ncclRedOpCreatePreMulSum is supported." ;
914+ comm_ctx->RedOpCreatePreMulSum (
915+ op, const_cast <void *>(scale), dtype, ncclScalarDevice);
916+ return true ;
930917 }
931918#endif
932919 VLOG (10 ) << " ncclRedOpCreatePreMulSum is not supported." ;
@@ -940,18 +927,14 @@ static void DestoryOpIfSupported(
940927#if NCCL_VERSION_CODE >= 21100
941928 VLOG (10 ) << " ncclRedOpDestroy starts" ;
942929
943- if (FLAGS_dynamic_static_unified_comm) {
944- PADDLE_ENFORCE_NOT_NULL (
945- comm_ctx,
946- phi::errors::InvalidArgument (
947- " You choose to use new communication library by "
948- " setting environment "
949- " variable FLAGS_dynamic_static_unified_comm True. "
950- " But parameter of comm_ctx should not be nullptr." ));
951- comm_ctx->RedOpDestroy (op);
952- } else {
953- PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::ncclRedOpDestroy (op, comm));
954- }
930+ PADDLE_ENFORCE_NOT_NULL (
931+ comm_ctx,
932+ phi::errors::InvalidArgument (
933+ " You choose to use new communication library by "
934+ " setting environment "
935+ " variable FLAGS_dynamic_static_unified_comm True. "
936+ " But parameter of comm_ctx should not be nullptr." ));
937+ comm_ctx->RedOpDestroy (op);
955938 VLOG (10 ) << " ncclRedOpDestroy ends" ;
956939
957940#endif
@@ -989,15 +972,13 @@ static void NCCLSumWithScaleBase(const T *sendbuff,
989972 const phi::GPUContext &dev_ctx,
990973 distributed::NCCLCommContext *comm_ctx,
991974 const T *scale = nullptr ) {
992- if (FLAGS_dynamic_static_unified_comm) {
993- PADDLE_ENFORCE_NOT_NULL (
994- comm_ctx,
995- phi::errors::InvalidArgument (
996- " You choose to use new communication library by "
997- " setting environment "
998- " variable FLAGS_dynamic_static_unified_comm True. "
999- " But parameter of comm_ctx should not be nullptr." ));
1000- }
975+ PADDLE_ENFORCE_NOT_NULL (
976+ comm_ctx,
977+ phi::errors::InvalidArgument (
978+ " You choose to use new communication library by "
979+ " setting environment "
980+ " variable FLAGS_dynamic_static_unified_comm True. "
981+ " But parameter of comm_ctx should not be nullptr." ));
1001982
1002983 static_assert (
1003984 std::is_same<T, float >::value || std::is_same<T, dtype::float16>::value,
@@ -1758,71 +1739,45 @@ void DistributedFusedLambKernel(
17581739 int64_t global_rank = 0 , local_rank = 0 ;
17591740 ncclComm_t global_comm = nullptr , local_comm = nullptr ,
17601741 external_comm = nullptr ;
1761- paddle::platform::NCCLComm *nccl_comm_handle = nullptr ,
1762- *local_nccl_comm_handle = nullptr ;
17631742 distributed::NCCLCommContext *comm_ctx = nullptr , *local_comm_ctx = nullptr ,
17641743 *external_comm_ctx = nullptr ;
17651744
17661745 const auto &comm_context_manager =
17671746 phi::distributed::CommContextManager::GetInstance ();
17681747
1769- if (FLAGS_dynamic_static_unified_comm) {
1770- CheckCommContextHasRingId (comm_context_manager, ring_ids[0 ]);
1771-
1772- comm_ctx = static_cast <phi::distributed::NCCLCommContext *>(
1773- comm_context_manager.Get (std::to_string (ring_ids[0 ])));
1774- PADDLE_ENFORCE_NE (comm_ctx,
1775- nullptr ,
1776- phi::errors::Unavailable (
1777- " NCCLCommContext is nullptr, collective op should "
1778- " has ring_id attr." ));
1779-
1780- global_comm = comm_ctx->GetNcclComm ();
1781- global_rank = comm_ctx->GetRank ();
1782- if (local_shard) {
1783- CheckCommContextHasRingId (comm_context_manager, ring_ids[1 ]);
1748+ CheckCommContextHasRingId (comm_context_manager, ring_ids[0 ]);
17841749
1785- local_comm_ctx = static_cast <phi::distributed::NCCLCommContext *>(
1786- comm_context_manager.Get (std::to_string (ring_ids[1 ])));
1787- local_comm = local_comm_ctx->GetNcclComm ();
1788- local_rank = local_comm_ctx->GetRank ();
1789- if (use_hierarchical_allreduce) {
1790- CheckCommContextHasRingId (comm_context_manager, ring_ids[2 ]);
1750+ comm_ctx = static_cast <phi::distributed::NCCLCommContext *>(
1751+ comm_context_manager.Get (std::to_string (ring_ids[0 ])));
1752+ PADDLE_ENFORCE_NE (comm_ctx,
1753+ nullptr ,
1754+ phi::errors::Unavailable (
1755+ " NCCLCommContext is nullptr, collective op should "
1756+ " has ring_id attr." ));
17911757
1792- external_comm_ctx = static_cast <phi::distributed::NCCLCommContext *>(
1793- comm_context_manager.Get (std::to_string (ring_ids[2 ])));
1794- external_comm = external_comm_ctx->GetNcclComm ();
1795- }
1796- } else {
1797- local_comm = global_comm;
1798- local_rank = global_rank;
1758+ global_comm = comm_ctx->GetNcclComm ();
1759+ global_rank = comm_ctx->GetRank ();
1760+ if (local_shard) {
1761+ CheckCommContextHasRingId (comm_context_manager, ring_ids[1 ]);
1762+
1763+ local_comm_ctx = static_cast <phi::distributed::NCCLCommContext *>(
1764+ comm_context_manager.Get (std::to_string (ring_ids[1 ])));
1765+ local_comm = local_comm_ctx->GetNcclComm ();
1766+ local_rank = local_comm_ctx->GetRank ();
1767+ if (use_hierarchical_allreduce) {
1768+ CheckCommContextHasRingId (comm_context_manager, ring_ids[2 ]);
1769+
1770+ external_comm_ctx = static_cast <phi::distributed::NCCLCommContext *>(
1771+ comm_context_manager.Get (std::to_string (ring_ids[2 ])));
1772+ external_comm = external_comm_ctx->GetNcclComm ();
17991773 }
1800-
1801- VLOG (3 ) << " new comm_context_manager has ring_id " << ring_ids[0 ];
18021774 } else {
1803- if (nranks > 1 ) {
1804- nccl_comm_handle =
1805- paddle::platform::NCCLCommContext::Instance ().Get (ring_ids[0 ], place);
1806- global_comm = nccl_comm_handle->comm ();
1807- global_rank = nccl_comm_handle->rank ();
1808- if (local_shard) {
1809- local_nccl_comm_handle =
1810- paddle::platform::NCCLCommContext::Instance ().Get (ring_ids[1 ],
1811- place);
1812- local_comm = local_nccl_comm_handle->comm ();
1813- local_rank = local_nccl_comm_handle->rank ();
1814- if (use_hierarchical_allreduce) {
1815- external_comm = paddle::platform::NCCLCommContext::Instance ()
1816- .Get (ring_ids[2 ], place)
1817- ->comm ();
1818- }
1819- } else {
1820- local_comm = global_comm;
1821- local_rank = global_rank;
1822- }
1823- }
1775+ local_comm = global_comm;
1776+ local_rank = global_rank;
18241777 }
18251778
1779+ VLOG (3 ) << " new comm_context_manager has ring_id " << ring_ids[0 ];
1780+
18261781 memory_utils::Buffer grad_norm_square_buffer (place);
18271782 auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc <float >(2 );
18281783 memory_utils::Buffer cub_tmp_buffer (place);
0 commit comments