Skip to content

Commit ff0caf9

Browse files
authored
Fix (#66943)
1 parent f1f5f11 commit ff0caf9

File tree

14 files changed

+322
-191
lines changed

14 files changed

+322
-191
lines changed

paddle/fluid/operators/collective/c_comm_init_all_op.cc

Lines changed: 0 additions & 55 deletions
This file was deleted.

paddle/fluid/operators/collective/c_comm_init_all_op.cu.cc

Lines changed: 0 additions & 20 deletions
This file was deleted.

paddle/fluid/operators/collective/c_comm_init_all_op.h

Lines changed: 0 additions & 95 deletions
This file was deleted.

paddle/fluid/operators/collective/c_comm_init_all_op_xpu.cc

Lines changed: 0 additions & 20 deletions
This file was deleted.

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
'c_reduce_sum',
170170
'c_softmax_with_cross_entropy',
171171
'c_split',
172+
'comm_init_all',
172173
'decayed_adagrad',
173174
'distributed_fused_lamb',
174175
'distributed_fused_lamb_',

paddle/fluid/primitive/codegen/gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
"full",
6363
"partial_send",
6464
"push_dense",
65+
"comm_init_all",
6566
]
6667

6768
# prim op with one input and one output, with no attribute

paddle/phi/core/platform/device/xpu/bkcl_helper.h

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,183 @@
3333
#include "xpu/bkcl.h"
3434
#include "xpu/runtime.h"
3535

36+
#define BKCL_ID_VARNAME "BKCLID"
37+
3638
namespace paddle {
37-
namespace platform {} // namespace platform
39+
namespace platform {
40+
41+
inline int GetBKCLRankID(BKCLContext_t comm) {
42+
return reinterpret_cast<int *>(comm)[0];
43+
}
44+
45+
inline int GetBKCLDevID(BKCLContext_t comm) {
46+
return reinterpret_cast<int *>(comm)[1];
47+
}
48+
49+
inline int GetBKCLNRanks(BKCLContext_t comm) {
50+
return reinterpret_cast<int *>(comm)[2];
51+
}
52+
53+
class BKCLGroupGuard {
54+
public:
55+
static std::mutex &BKCLMutex() {
56+
static std::mutex mtx;
57+
return mtx;
58+
}
59+
60+
inline BKCLGroupGuard() {
61+
BKCLMutex().lock();
62+
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
63+
}
64+
65+
inline ~BKCLGroupGuard() PADDLE_MAY_THROW {
66+
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end());
67+
BKCLMutex().unlock();
68+
}
69+
};
70+
71+
struct BKCLContext {
72+
std::unique_ptr<phi::XPUContext> ctx_;
73+
BKCLContext_t comm_;
74+
75+
explicit BKCLContext(int dev_id)
76+
: ctx_(new phi::XPUContext(phi::XPUPlace(dev_id))), comm_{nullptr} {}
77+
78+
XPUStream stream() const { return ctx_->stream(); }
79+
BKCLContext_t comm() const { return comm_; }
80+
81+
int device_id() const { return ctx_->GetPlace().device; }
82+
};
83+
84+
struct InitBKCLPara {
85+
BKCLUniqueId *bkcl_id;
86+
int rank;
87+
int nranks;
88+
int dev_id;
89+
BKCLContext_t *ctx;
90+
};
91+
92+
static void *init_bkcl_context_func(void *args) {
93+
struct InitBKCLPara *para = (struct InitBKCLPara *)args;
94+
platform::SetXPUDeviceId(para->dev_id);
95+
PADDLE_ENFORCE_XPU_SUCCESS(
96+
bkcl_init_rank(para->ctx, para->rank, para->nranks, para->bkcl_id));
97+
return nullptr;
98+
}
99+
100+
struct BKCLContextMap {
101+
std::unordered_map<int, BKCLContext> contexts_;
102+
std::vector<int> order_;
103+
std::vector<phi::Place> places_;
104+
size_t num_trainers_;
105+
size_t trainer_id_;
106+
BKCLUniqueId *bkcl_id_;
107+
108+
explicit BKCLContextMap(const std::vector<phi::Place> &places,
109+
BKCLUniqueId *bkcl_id = nullptr,
110+
size_t num_trainers = 1,
111+
size_t trainer_id = 0) {
112+
places_ = places;
113+
bkcl_id_ = bkcl_id;
114+
num_trainers_ = num_trainers;
115+
trainer_id_ = trainer_id;
116+
}
117+
118+
// Synchronization is required and can only be initialized with
119+
// multithreading.
120+
int init() {
121+
PADDLE_ENFORCE_EQ(
122+
!places_.empty(),
123+
true,
124+
common::errors::InvalidArgument("The BKCL place should not be empty."));
125+
order_.reserve(places_.size());
126+
for (auto &p : places_) {
127+
int dev_id = p.device;
128+
order_.emplace_back(dev_id);
129+
contexts_.emplace(dev_id, BKCLContext(dev_id));
130+
}
131+
PADDLE_ENFORCE_EQ(
132+
order_.size(),
133+
contexts_.size(),
134+
common::errors::Unavailable("BKCL Context Map does not support "
135+
"contain two or more same device"));
136+
137+
std::unique_ptr<BKCLContext_t[]> comms(new BKCLContext_t[order_.size()]);
138+
std::unique_ptr<InitBKCLPara[]> paras(new InitBKCLPara[order_.size()]);
139+
std::unique_ptr<pthread_t[]> pids(new pthread_t[order_.size()]);
140+
BKCLResult_t ret;
141+
BKCLUniqueId id;
142+
// if num_trainers == 1, should create a new bkcl id for local comms.
143+
if (num_trainers_ == 1 && bkcl_id_ == nullptr) {
144+
ret = bkcl_get_unique_id(&id);
145+
PADDLE_ENFORCE_EQ(BKCL_SUCCESS,
146+
ret,
147+
common::errors::PreconditionNotMet(
148+
"bkcl get unique id failed [%d]", ret));
149+
bkcl_id_ = &id;
150+
}
151+
PADDLE_ENFORCE_NOT_NULL(
152+
bkcl_id_,
153+
common::errors::InvalidArgument("The BKCL id should not be null."));
154+
{
155+
int nranks = num_trainers_ * order_.size();
156+
for (size_t i = 0; i < order_.size(); ++i) {
157+
int rank;
158+
if (order_.size() > 1) {
159+
rank = trainer_id_ * order_.size() + i;
160+
} else {
161+
rank = trainer_id_;
162+
}
163+
164+
paras[i].rank = rank;
165+
paras[i].nranks = nranks;
166+
paras[i].dev_id = order_[i];
167+
paras[i].bkcl_id = bkcl_id_;
168+
paras[i].ctx = &comms[i];
169+
PADDLE_ENFORCE_EQ(pthread_create(&pids[i],
170+
nullptr,
171+
init_bkcl_context_func,
172+
reinterpret_cast<void *>(&paras[i])),
173+
0,
174+
common::errors::External("pthread_create failed"));
175+
}
176+
for (size_t i = 0; i < order_.size(); i++) {
177+
pthread_join(pids[i], nullptr);
178+
}
179+
}
180+
int i = 0;
181+
for (auto &dev_id : order_) {
182+
contexts_.at(dev_id).comm_ = comms[i++];
183+
}
184+
return 0;
185+
}
186+
187+
BKCLContextMap(const BKCLContextMap &other) = delete;
188+
BKCLContextMap &operator=(const BKCLContextMap &other) = delete;
189+
190+
phi::XPUContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
191+
192+
phi::XPUContext *DevCtx(phi::Place p) const { return DevCtx(p.device); }
193+
194+
const BKCLContext &at(phi::Place p) const { return this->at(p.device); }
195+
196+
const BKCLContext &at(int dev_id) const { return contexts_.at(dev_id); }
197+
198+
void WaitAll() {
199+
for (auto &p : contexts_) {
200+
p.second.ctx_->Wait();
201+
}
202+
}
203+
};
204+
205+
inline std::string GetFlatBKCLVarName(size_t pos) {
206+
if (pos == 0) {
207+
return BKCL_ID_VARNAME;
208+
}
209+
return string::Sprintf("%s_%d", BKCL_ID_VARNAME, static_cast<int>(pos));
210+
}
211+
212+
} // namespace platform
38213
} // namespace paddle
39214

40215
#endif // PADDLE_WITH_XPU_BKCL

0 commit comments

Comments
 (0)