1313// limitations under the License.
1414
1515#include " glog/logging.h"
16+ #include " paddle/phi/core/distributed/utils.h"
1617#include " paddle/phi/core/kernel_registry.h"
1718
1819#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
19- #include " paddle/phi/core/distributed/collective/process_group.h"
2020#include " paddle/phi/core/distributed/nccl_comm_context.h"
2121#endif
2222
23- #include " paddle/phi/core/distributed/comm_context_manager.h"
24-
2523namespace phi {
2624
2725template <typename T, typename Context>
2826void PartialAllGatherOpCUDAKernel (const Context& dev_ctx,
2927 const DenseTensor& x_in,
3028 int nranks,
3129 int rank,
32- int ring_id,
33- bool use_calc_stream,
30+ int ring_id UNUSED ,
31+ bool use_calc_stream UNUSED ,
3432 DenseTensor* out) {
3533#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
3634 auto in = &x_in;
3735 int64_t numel = in->numel ();
3836 ncclDataType_t dtype = phi::ToNCCLDataType (in->dtype ());
39- int rid = ring_id;
37+
4038 gpuStream_t stream = nullptr ;
4139 phi::distributed::NCCLCommContext* comm_ctx = nullptr ;
42- const auto & comm_context_manager =
43- phi::distributed::CommContextManager::GetInstance ();
4440
4541 int real_nranks = 0 ;
4642 int real_rank = 0 ;
4743
48- PADDLE_ENFORCE_EQ (comm_context_manager.Has (std::to_string (rid)),
49- true ,
50- common::errors::InvalidArgument (
51- " You choose to use new communication library by "
52- " setting environment "
53- " variable FLAGS_dynamic_static_unified_comm True. "
54- " But ring_id(%d) is "
55- " not found in comm_context_manager." ,
56- std::to_string (rid)));
57- comm_ctx = static_cast <phi::distributed::NCCLCommContext*>(
58- comm_context_manager.Get (std::to_string (rid)));
44+ comm_ctx =
45+ static_cast <phi::distributed::NCCLCommContext*>(dev_ctx.GetCommContext ());
5946 PADDLE_ENFORCE_NE (comm_ctx,
6047 nullptr ,
6148 common::errors::Unavailable (
6249 " NCCLCommContext is nullptr, collective op should "
6350 " has ring_id attr." ));
6451
65- stream = comm_ctx-> GetStream ();
52+ stream = dev_ctx. stream ();
6653 real_nranks = comm_ctx->GetSize ();
6754 real_rank = comm_ctx->GetRank ();
68- VLOG (3 ) << " new comm_context_manager has ring_id " << rid;
6955
7056 PADDLE_ENFORCE_EQ (nranks,
7157 real_nranks,
@@ -90,22 +76,8 @@ void PartialAllGatherOpCUDAKernel(const Context& dev_ctx,
9076 int64_t send_numel = numel / nranks;
9177 int offset = send_numel * rank;
9278
93- auto map = distributed::ProcessGroupMapFromGid::getInstance ();
94- if (map->has (rid)) {
95- // Use ProcessGroup
96- distributed::ProcessGroup* pg = map->get (rid);
97- auto task = pg->AllGather (out, *in, offset, send_numel, /* sync_op*/ true );
98- task->Wait ();
99- } else {
100- if (use_calc_stream) {
101- // should ExecutionContext for calc stream.
102- stream = dev_ctx.stream ();
103- }
104-
105- auto send_buf = distributed::GetPartialTensor (*in, offset, send_numel);
106-
107- comm_ctx->AllGather (out, send_buf, stream);
108- }
79+ auto send_buf = distributed::GetPartialTensor (*in, offset, send_numel);
80+ comm_ctx->AllGather (out, send_buf, stream);
10981#else
11082 PADDLE_THROW (common::errors::PreconditionNotMet (
11183 " PaddlePaddle should compile with GPU." ));
0 commit comments