Skip to content

Commit 5fba2a9

Browse files
HermitSunLiYuRio
andauthored
[Cherry-pick] Collective communication APIs (#46922)
* Support both use_calc_stream and sync_op in send recv APIs (#46023) * Support both use_calc_stream and sync_op in allgather API (#46295) * Support both use_calc_stream and sync_op in collective communication API (#46761) * Move group and all reduce from collective to communication (#45848) * Completes bfloat16 dtype for collective api in eager mode (#45844) * Fix collective APIs cannot be recognized when building docs (#46962) Co-authored-by: LiYuRio <[email protected]>
1 parent 10225d2 commit 5fba2a9

File tree

71 files changed

+5209
-629
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+5209
-629
lines changed

paddle/fluid/distributed/collective/ProcessGroup.h

Lines changed: 127 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ class ProcessGroup {
122122
"ProcessGroup%s does not support broadcast", GetBackendName()));
123123
}
124124

125+
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
126+
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
127+
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
128+
const BroadcastOptions&,
129+
bool) {
130+
PADDLE_THROW(platform::errors::InvalidArgument(
131+
"ProcessGroup%s does not support broadcast with sync_op flag",
132+
GetBackendName()));
133+
}
134+
125135
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
126136
const BarrierOptions& = BarrierOptions()) {
127137
PADDLE_THROW(platform::errors::InvalidArgument(
@@ -134,38 +144,89 @@ class ProcessGroup {
134144
"ProcessGroup%s does not support send", GetBackendName()));
135145
}
136146

147+
virtual std::shared_ptr<ProcessGroup::Task> Send(
148+
std::vector<phi::DenseTensor>&, int, bool) { // NOLINT
149+
PADDLE_THROW(platform::errors::InvalidArgument(
150+
"ProcessGroup%s does not support send with sync_op flag",
151+
GetBackendName()));
152+
}
153+
137154
virtual std::shared_ptr<ProcessGroup::Task> Recv(
138-
std::vector<phi::DenseTensor>& tensors, int) { // NOLINT
155+
std::vector<phi::DenseTensor>&, int) { // NOLINT
139156
PADDLE_THROW(platform::errors::InvalidArgument(
140-
"ProcessGroup%s does not support receive", GetBackendName()));
157+
"ProcessGroup%s does not support recv", GetBackendName()));
141158
}
142159

143-
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&,
144-
int,
145-
int,
146-
int) { // NOLINT
160+
virtual std::shared_ptr<ProcessGroup::Task> Recv(
161+
std::vector<phi::DenseTensor>&, int, bool) { // NOLINT
147162
PADDLE_THROW(platform::errors::InvalidArgument(
148-
"ProcessGroup%s does not support send", GetBackendName()));
163+
"ProcessGroup%s does not support recv with sync_op flag",
164+
GetBackendName()));
165+
}
166+
167+
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
168+
phi::DenseTensor&, // NOLINT
169+
int,
170+
int64_t,
171+
int64_t) {
172+
PADDLE_THROW(platform::errors::InvalidArgument(
173+
"ProcessGroup%s does not support send_partial", GetBackendName()));
174+
}
175+
176+
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
177+
phi::DenseTensor&, int, int64_t, int64_t, bool) { // NOLINT
178+
PADDLE_THROW(platform::errors::InvalidArgument(
179+
"ProcessGroup%s does not support send_partial with sync_op flag",
180+
GetBackendName()));
149181
}
150182

151183
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
152-
phi::DenseTensor& tensors, int, int, int) { // NOLINT
184+
phi::DenseTensor&, // NOLINT
185+
int,
186+
int64_t,
187+
int64_t) {
153188
PADDLE_THROW(platform::errors::InvalidArgument(
154-
"ProcessGroup%s does not support receive", GetBackendName()));
189+
"ProcessGroup%s does not support recv_partial", GetBackendName()));
190+
}
191+
192+
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
193+
phi::DenseTensor&, int, int64_t, int64_t, bool) { // NOLINT
194+
PADDLE_THROW(platform::errors::InvalidArgument(
195+
"ProcessGroup%s does not support recv_partial with sync_op flag",
196+
GetBackendName()));
155197
}
156198

157199
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
158200
std::vector<phi::DenseTensor>&, // NOLINT
159201
std::vector<phi::DenseTensor>&) { // NOLINT
160202
PADDLE_THROW(platform::errors::InvalidArgument(
161-
"ProcessGroup%s does not support AllGather", GetBackendName()));
203+
"ProcessGroup%s does not support all_gather", GetBackendName()));
204+
}
205+
206+
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
207+
std::vector<phi::DenseTensor>&, // NOLINT
208+
std::vector<phi::DenseTensor>&, // NOLINT
209+
bool) {
210+
PADDLE_THROW(platform::errors::InvalidArgument(
211+
"ProcessGroup%s does not support all_gather with sync_op flag",
212+
GetBackendName()));
162213
}
163214

164215
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
165216
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
166217
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
167-
int offset,
168-
int length) { // NOLINT
218+
int64_t offset,
219+
int64_t length) {
220+
PADDLE_THROW(platform::errors::InvalidArgument(
221+
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
222+
}
223+
224+
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
225+
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
226+
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
227+
int64_t offset,
228+
int64_t length,
229+
bool) {
169230
PADDLE_THROW(platform::errors::InvalidArgument(
170231
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
171232
}
@@ -177,6 +238,14 @@ class ProcessGroup {
177238
"ProcessGroup%s does not support AllToAll", GetBackendName()));
178239
}
179240

241+
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
242+
std::vector<phi::DenseTensor>&, // NOLINT
243+
std::vector<phi::DenseTensor>&, // NOLINT
244+
bool) {
245+
PADDLE_THROW(platform::errors::InvalidArgument(
246+
"ProcessGroup%s does not support alltoall", GetBackendName()));
247+
}
248+
180249
virtual std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
181250
std::vector<phi::DenseTensor>&, // NOLINT
182251
std::vector<phi::DenseTensor>&, // NOLINT
@@ -186,26 +255,66 @@ class ProcessGroup {
186255
"ProcessGroup%s does not support AllToAll_Single", GetBackendName()));
187256
}
188257

258+
virtual std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
259+
std::vector<phi::DenseTensor>&, // NOLINT
260+
std::vector<phi::DenseTensor>&, // NOLINT
261+
std::vector<int64_t>&,
262+
std::vector<int64_t>&,
263+
bool) {
264+
PADDLE_THROW(platform::errors::InvalidArgument(
265+
"ProcessGroup%s does not support alltoall_single", GetBackendName()));
266+
}
267+
189268
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
190269
std::vector<phi::DenseTensor>&, // NOLINT
191270
std::vector<phi::DenseTensor>&, // NOLINT
192271
const ReduceOptions& opts) {
193272
PADDLE_THROW(platform::errors::InvalidArgument(
194-
"ProcessGroup%s does not support Reduce", GetBackendName()));
273+
"ProcessGroup%s does not support reduce", GetBackendName()));
274+
}
275+
276+
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
277+
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
278+
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
279+
const ReduceOptions&,
280+
bool) {
281+
PADDLE_THROW(platform::errors::InvalidArgument(
282+
"ProcessGroup%s does not support reduce with sync_op flag",
283+
GetBackendName()));
284+
}
285+
286+
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
287+
std::vector<phi::DenseTensor>&, // NOLINT
288+
std::vector<phi::DenseTensor>&, // NOLINT
289+
const ScatterOptions&) {
290+
PADDLE_THROW(platform::errors::InvalidArgument(
291+
"ProcessGroup%s does not support scatter", GetBackendName()));
195292
}
196293

197294
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
198295
std::vector<phi::DenseTensor>&, // NOLINT
199296
std::vector<phi::DenseTensor>&, // NOLINT
200-
const ScatterOptions&) { // NOLINT
297+
const ScatterOptions&,
298+
bool) {
299+
PADDLE_THROW(platform::errors::InvalidArgument(
300+
"ProcessGroup%s does not support scatter with sync_op flag",
301+
GetBackendName()));
302+
}
303+
304+
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
305+
std::vector<phi::DenseTensor>&, // NOLINT
306+
std::vector<phi::DenseTensor>&, // NOLINT
307+
const ReduceScatterOptions&,
308+
bool) {
201309
PADDLE_THROW(platform::errors::InvalidArgument(
202-
"ProcessGroup%s does not support Scatter", GetBackendName()));
310+
"ProcessGroup%s does not support reduce_scatter with sync_op flag",
311+
GetBackendName()));
203312
}
204313

205314
virtual std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase(
206-
phi::DenseTensor&, // NOLINT
207-
phi::DenseTensor&, // NOLINT
208-
const ReduceScatterOptions&) { // NOLINT
315+
phi::DenseTensor&, // NOLINT
316+
phi::DenseTensor&, // NOLINT
317+
const ReduceScatterOptions&) {
209318
PADDLE_THROW(platform::errors::InvalidArgument(
210319
"ProcessGroup%s does not support ReduceScatter", GetBackendName()));
211320
}

paddle/fluid/distributed/collective/ProcessGroupCustom.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ void* XcclGetPointerByOffset(void* raw_pointer,
267267
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather_Partial(
268268
std::vector<phi::DenseTensor>& in_tensors,
269269
std::vector<phi::DenseTensor>& out_tensors,
270-
int offset,
271-
int length) {
270+
int64_t offset,
271+
int64_t length) {
272272
PADDLE_ENFORCE_EQ(
273273
CheckTensorsInCustomPlace(in_tensors, device_type_),
274274
true,

paddle/fluid/distributed/collective/ProcessGroupCustom.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ class ProcessGroupCustom : public ProcessGroup {
8080
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
8181
std::vector<phi::DenseTensor>& in_tensors,
8282
std::vector<phi::DenseTensor>& out_tensors,
83-
int offset,
84-
int length) override;
83+
int64_t offset,
84+
int64_t length) override;
8585

8686
std::shared_ptr<ProcessGroup::Task> AllReduce(
8787
std::vector<phi::DenseTensor>& in_tensors,
@@ -117,8 +117,8 @@ class ProcessGroupCustom : public ProcessGroup {
117117
std::set<int> used_place_ids_;
118118

119119
private:
120-
void BcastCustomId(std::vector<phi::ccl::CCLRootId>& ccl_ids,
121-
int root, // NOLINT
120+
void BcastCustomId(std::vector<phi::ccl::CCLRootId>& ccl_ids, // NOLINT
121+
int root,
122122
int server_fd);
123123

124124
void BroadcastUniqueCustomID(

paddle/fluid/distributed/collective/ProcessGroupGloo.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ namespace distributed {
8888
case experimental::DataType::BOOL: \
8989
func<bool>(args); \
9090
break; \
91+
case experimental::DataType::BFLOAT16: \
92+
func<bfloat16>(args); \
93+
break; \
9194
default: \
9295
VLOG(0) << "Error: Unknown DataType."; \
9396
exit(-1); \
@@ -293,6 +296,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
293296
std::vector<phi::DenseTensor>& inputs,
294297
std::vector<phi::DenseTensor>& outputs,
295298
const AllreduceOptions& opts) {
299+
return AllReduce(inputs, outputs, opts, true);
300+
}
301+
302+
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
303+
std::vector<phi::DenseTensor>& inputs,
304+
std::vector<phi::DenseTensor>& outputs,
305+
const AllreduceOptions& opts,
306+
bool sync_op) {
296307
auto tag = next_tag();
297308
std::shared_ptr<GlooTask> task;
298309
auto context = get_context();

paddle/fluid/distributed/collective/ProcessGroupGloo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ class ProcessGroupGloo : public ProcessGroup {
120120
std::vector<phi::DenseTensor>& outputs,
121121
const AllreduceOptions& opts = AllreduceOptions()) override;
122122

123+
std::shared_ptr<ProcessGroup::Task> AllReduce(
124+
std::vector<phi::DenseTensor>& inputs,
125+
std::vector<phi::DenseTensor>& outputs,
126+
const AllreduceOptions& opts,
127+
bool sync_op) override;
128+
123129
std::shared_ptr<ProcessGroup::Task> Barrier(
124130
const BarrierOptions& = BarrierOptions()) override;
125131

0 commit comments

Comments
 (0)