Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 0 additions & 55 deletions paddle/fluid/operators/collective/c_comm_init_all_op.cc

This file was deleted.

20 changes: 0 additions & 20 deletions paddle/fluid/operators/collective/c_comm_init_all_op.cu.cc

This file was deleted.

95 changes: 0 additions & 95 deletions paddle/fluid/operators/collective/c_comm_init_all_op.h

This file was deleted.

20 changes: 0 additions & 20 deletions paddle/fluid/operators/collective/c_comm_init_all_op_xpu.cc

This file was deleted.

1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
'c_reduce_sum',
'c_softmax_with_cross_entropy',
'c_split',
'comm_init_all',
'decayed_adagrad',
'distributed_fused_lamb',
'distributed_fused_lamb_',
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"full",
"partial_send",
"push_dense",
"comm_init_all",
]

# prim op with one input and one output, with no attribute
Expand Down
177 changes: 176 additions & 1 deletion paddle/phi/core/platform/device/xpu/bkcl_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,183 @@
#include "xpu/bkcl.h"
#include "xpu/runtime.h"

#define BKCL_ID_VARNAME "BKCLID"

namespace paddle {
namespace platform {} // namespace platform
namespace platform {

inline int GetBKCLRankID(BKCLContext_t comm) {
return reinterpret_cast<int *>(comm)[0];
}

inline int GetBKCLDevID(BKCLContext_t comm) {
return reinterpret_cast<int *>(comm)[1];
}

inline int GetBKCLNRanks(BKCLContext_t comm) {
return reinterpret_cast<int *>(comm)[2];
}

class BKCLGroupGuard {
public:
static std::mutex &BKCLMutex() {
static std::mutex mtx;
return mtx;
}

inline BKCLGroupGuard() {
BKCLMutex().lock();
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
}

inline ~BKCLGroupGuard() PADDLE_MAY_THROW {
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end());
BKCLMutex().unlock();
}
};

struct BKCLContext {
std::unique_ptr<phi::XPUContext> ctx_;
BKCLContext_t comm_;

explicit BKCLContext(int dev_id)
: ctx_(new phi::XPUContext(phi::XPUPlace(dev_id))), comm_{nullptr} {}

XPUStream stream() const { return ctx_->stream(); }
BKCLContext_t comm() const { return comm_; }

int device_id() const { return ctx_->GetPlace().device; }
};

struct InitBKCLPara {
BKCLUniqueId *bkcl_id;
int rank;
int nranks;
int dev_id;
BKCLContext_t *ctx;
};

static void *init_bkcl_context_func(void *args) {
struct InitBKCLPara *para = (struct InitBKCLPara *)args;
platform::SetXPUDeviceId(para->dev_id);
PADDLE_ENFORCE_XPU_SUCCESS(
bkcl_init_rank(para->ctx, para->rank, para->nranks, para->bkcl_id));
return nullptr;
}

struct BKCLContextMap {
std::unordered_map<int, BKCLContext> contexts_;
std::vector<int> order_;
std::vector<phi::Place> places_;
size_t num_trainers_;
size_t trainer_id_;
BKCLUniqueId *bkcl_id_;

explicit BKCLContextMap(const std::vector<phi::Place> &places,
BKCLUniqueId *bkcl_id = nullptr,
size_t num_trainers = 1,
size_t trainer_id = 0) {
places_ = places;
bkcl_id_ = bkcl_id;
num_trainers_ = num_trainers;
trainer_id_ = trainer_id;
}

// Synchronization is required and can only be initialized with
// multithreading.
int init() {
PADDLE_ENFORCE_EQ(
!places_.empty(),
true,
common::errors::InvalidArgument("The BKCL place should not be empty."));
order_.reserve(places_.size());
for (auto &p : places_) {
int dev_id = p.device;
order_.emplace_back(dev_id);
contexts_.emplace(dev_id, BKCLContext(dev_id));
}
PADDLE_ENFORCE_EQ(
order_.size(),
contexts_.size(),
common::errors::Unavailable("BKCL Context Map does not support "
"contain two or more same device"));

std::unique_ptr<BKCLContext_t[]> comms(new BKCLContext_t[order_.size()]);
std::unique_ptr<InitBKCLPara[]> paras(new InitBKCLPara[order_.size()]);
std::unique_ptr<pthread_t[]> pids(new pthread_t[order_.size()]);
BKCLResult_t ret;
BKCLUniqueId id;
// if num_trainers == 1, should create a new bkcl id for local comms.
if (num_trainers_ == 1 && bkcl_id_ == nullptr) {
ret = bkcl_get_unique_id(&id);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS,
ret,
common::errors::PreconditionNotMet(
"bkcl get unique id failed [%d]", ret));
bkcl_id_ = &id;
}
PADDLE_ENFORCE_NOT_NULL(
bkcl_id_,
common::errors::InvalidArgument("The BKCL id should not be null."));
{
int nranks = num_trainers_ * order_.size();
for (size_t i = 0; i < order_.size(); ++i) {
int rank;
if (order_.size() > 1) {
rank = trainer_id_ * order_.size() + i;
} else {
rank = trainer_id_;
}

paras[i].rank = rank;
paras[i].nranks = nranks;
paras[i].dev_id = order_[i];
paras[i].bkcl_id = bkcl_id_;
paras[i].ctx = &comms[i];
PADDLE_ENFORCE_EQ(pthread_create(&pids[i],
nullptr,
init_bkcl_context_func,
reinterpret_cast<void *>(&paras[i])),
0,
common::errors::External("pthread_create failed"));
}
for (size_t i = 0; i < order_.size(); i++) {
pthread_join(pids[i], nullptr);
}
}
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
}
return 0;
}

BKCLContextMap(const BKCLContextMap &other) = delete;
BKCLContextMap &operator=(const BKCLContextMap &other) = delete;

phi::XPUContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }

phi::XPUContext *DevCtx(phi::Place p) const { return DevCtx(p.device); }

const BKCLContext &at(phi::Place p) const { return this->at(p.device); }

const BKCLContext &at(int dev_id) const { return contexts_.at(dev_id); }

void WaitAll() {
for (auto &p : contexts_) {
p.second.ctx_->Wait();
}
}
};

inline std::string GetFlatBKCLVarName(size_t pos) {
if (pos == 0) {
return BKCL_ID_VARNAME;
}
return string::Sprintf("%s_%d", BKCL_ID_VARNAME, static_cast<int>(pos));
}

} // namespace platform
} // namespace paddle

#endif // PADDLE_WITH_XPU_BKCL
Expand Down
Loading