Skip to content

Commit 95a5636

Browse files
authored
[fluid_ops] partial_allgather process_group (#67917)
1 parent 7bedc47 commit 95a5636

File tree

1 file changed

+9
-37
lines changed

1 file changed

+9
-37
lines changed

paddle/phi/kernels/gpu/partial_allgather_kernel.cu

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,59 +13,45 @@
1313
// limitations under the License.
1414

1515
#include "glog/logging.h"
16+
#include "paddle/phi/core/distributed/utils.h"
1617
#include "paddle/phi/core/kernel_registry.h"
1718

1819
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
19-
#include "paddle/phi/core/distributed/collective/process_group.h"
2020
#include "paddle/phi/core/distributed/nccl_comm_context.h"
2121
#endif
2222

23-
#include "paddle/phi/core/distributed/comm_context_manager.h"
24-
2523
namespace phi {
2624

2725
template <typename T, typename Context>
2826
void PartialAllGatherOpCUDAKernel(const Context& dev_ctx,
2927
const DenseTensor& x_in,
3028
int nranks,
3129
int rank,
32-
int ring_id,
33-
bool use_calc_stream,
30+
int ring_id UNUSED,
31+
bool use_calc_stream UNUSED,
3432
DenseTensor* out) {
3533
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
3634
auto in = &x_in;
3735
int64_t numel = in->numel();
3836
ncclDataType_t dtype = phi::ToNCCLDataType(in->dtype());
39-
int rid = ring_id;
37+
4038
gpuStream_t stream = nullptr;
4139
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
42-
const auto& comm_context_manager =
43-
phi::distributed::CommContextManager::GetInstance();
4440

4541
int real_nranks = 0;
4642
int real_rank = 0;
4743

48-
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
49-
true,
50-
common::errors::InvalidArgument(
51-
"You choose to use new communication library by "
52-
"setting environment "
53-
"variable FLAGS_dynamic_static_unified_comm True. "
54-
"But ring_id(%d) is "
55-
"not found in comm_context_manager.",
56-
std::to_string(rid)));
57-
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
58-
comm_context_manager.Get(std::to_string(rid)));
44+
comm_ctx =
45+
static_cast<phi::distributed::NCCLCommContext*>(dev_ctx.GetCommContext());
5946
PADDLE_ENFORCE_NE(comm_ctx,
6047
nullptr,
6148
common::errors::Unavailable(
6249
"NCCLCommContext is nullptr, collective op should "
6350
"has ring_id attr."));
6451

65-
stream = comm_ctx->GetStream();
52+
stream = dev_ctx.stream();
6653
real_nranks = comm_ctx->GetSize();
6754
real_rank = comm_ctx->GetRank();
68-
VLOG(3) << "new comm_context_manager has ring_id " << rid;
6955

7056
PADDLE_ENFORCE_EQ(nranks,
7157
real_nranks,
@@ -90,22 +76,8 @@ void PartialAllGatherOpCUDAKernel(const Context& dev_ctx,
9076
int64_t send_numel = numel / nranks;
9177
int offset = send_numel * rank;
9278

93-
auto map = distributed::ProcessGroupMapFromGid::getInstance();
94-
if (map->has(rid)) {
95-
// Use ProcessGroup
96-
distributed::ProcessGroup* pg = map->get(rid);
97-
auto task = pg->AllGather(out, *in, offset, send_numel, /*sync_op*/ true);
98-
task->Wait();
99-
} else {
100-
if (use_calc_stream) {
101-
// should ExecutionContext for calc stream.
102-
stream = dev_ctx.stream();
103-
}
104-
105-
auto send_buf = distributed::GetPartialTensor(*in, offset, send_numel);
106-
107-
comm_ctx->AllGather(out, send_buf, stream);
108-
}
79+
auto send_buf = distributed::GetPartialTensor(*in, offset, send_numel);
80+
comm_ctx->AllGather(out, send_buf, stream);
10981
#else
11082
PADDLE_THROW(common::errors::PreconditionNotMet(
11183
"PaddlePaddle should compile with GPU."));

0 commit comments

Comments
 (0)