Skip to content

Commit e4ebf38

Browse files
authored
Update ProcessGroupCustom for sync_op compatibility (#47976)
* refactor: update pg custom * fix: use new api in ut * fix: typo * revert: recover legacy apis * fix: add GetDeviceContext
1 parent 39c8506 commit e4ebf38

File tree

3 files changed

+131
-84
lines changed

3 files changed

+131
-84
lines changed

paddle/fluid/distributed/collective/ProcessGroupCustom.cc

Lines changed: 101 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -202,38 +202,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
202202
return task;
203203
}
204204

205-
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
206-
std::vector<phi::DenseTensor>& in_tensors,
207-
std::vector<phi::DenseTensor>& out_tensors) {
208-
PADDLE_ENFORCE_EQ(
209-
CheckTensorsInCustomPlace(in_tensors, device_type_),
210-
true,
211-
platform::errors::InvalidArgument(
212-
"All inputs should be in CustomPlace(%s).", device_type_));
213-
PADDLE_ENFORCE_EQ(
214-
CheckTensorsInCustomPlace(out_tensors, device_type_),
215-
true,
216-
platform::errors::InvalidArgument(
217-
"All outputs should be in CustomPlace(%s).", device_type_));
218-
return Collective(
219-
in_tensors,
220-
out_tensors,
221-
[&](phi::DenseTensor& input,
222-
phi::DenseTensor& output,
223-
phi::ccl::CCLComm comm,
224-
const phi::stream::Stream& stream) {
225-
return phi::DeviceManager::CCLAllGather(
226-
device_type_,
227-
input.data(),
228-
output.data(),
229-
input.numel(),
230-
phi::ccl::ToCCLDataType(input.dtype()),
231-
comm,
232-
stream);
233-
},
234-
CommType::ALLGATHER);
235-
}
236-
237205
void* XcclGetPointerByOffset(void* raw_pointer,
238206
size_t offset,
239207
experimental::DataType type) {
@@ -259,13 +227,13 @@ void* XcclGetPointerByOffset(void* raw_pointer,
259227
return nullptr;
260228
}
261229

262-
// NOTE: this is ONLY for compatibility
263230
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
264231
phi::DenseTensor* out_tensor,
265232
const phi::DenseTensor& in_tensor,
266233
int64_t offset,
267234
int64_t numel,
268-
bool sync_op) {
235+
bool sync_op // for compatibility, no use now
236+
) {
269237
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
270238
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
271239
return Collective(
@@ -287,6 +255,105 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
287255
CommType::ALLGATHER);
288256
}
289257

258+
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
259+
phi::DenseTensor* out_tensor,
260+
const phi::DenseTensor& in_tensor,
261+
const AllreduceOptions& opts,
262+
bool sync_op // for compatibility, no use now
263+
) {
264+
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
265+
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
266+
return AllReduce(in_wrapper, out_wrapper, opts);
267+
}
268+
269+
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
270+
phi::DenseTensor* out_tensor,
271+
const phi::DenseTensor& in_tensor,
272+
const BroadcastOptions& opts,
273+
bool sync_op // for compatibility, no use now
274+
) {
275+
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
276+
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
277+
return Broadcast(in_wrapper, out_wrapper, opts);
278+
}
279+
280+
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
281+
const BarrierOptions& opts) {
282+
// Only support single card single process
283+
PADDLE_ENFORCE_GE(opts.device_id,
284+
0,
285+
platform::errors::PreconditionNotMet(
286+
"The barrier device id must greater or equal than 0."));
287+
platform::CustomPlace place(device_type_, opts.device_id);
288+
auto allocator = std::unique_ptr<phi::Allocator>(
289+
new paddle::experimental::DefaultAllocator(place));
290+
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
291+
phi::DenseTensor barrier_tensor{allocator.get(), meta};
292+
293+
auto task = ProcessGroupCustom::AllReduce(&barrier_tensor,
294+
barrier_tensor,
295+
{},
296+
/*sync_op*/ true);
297+
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
298+
xccl_task->barrierTensors_ = {barrier_tensor};
299+
return task;
300+
}
301+
302+
const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext(
303+
const Place& place) const {
304+
const std::string key = GetKeyFromPlace(place);
305+
const auto& iter = places_to_ctx_.find(key);
306+
PADDLE_ENFORCE_NE(
307+
iter,
308+
places_to_ctx_.end(),
309+
platform::errors::NotFound(
310+
"Cannot find the device context in this process group."));
311+
return *iter->second[0];
312+
}
313+
314+
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
315+
std::vector<Place> places = {place};
316+
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
317+
PADDLE_ENFORCE_NE(iter,
318+
places_to_customcomm_.end(),
319+
platform::errors::InvalidArgument(
320+
"Cannot find nccl comm in process group."));
321+
return iter->second[0]->GetCustomCCLComm();
322+
}
323+
324+
// TODO(sunyilun): methods below will be removed later
325+
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
326+
std::vector<phi::DenseTensor>& in_tensors,
327+
std::vector<phi::DenseTensor>& out_tensors) {
328+
PADDLE_ENFORCE_EQ(
329+
CheckTensorsInCustomPlace(in_tensors, device_type_),
330+
true,
331+
platform::errors::InvalidArgument(
332+
"All inputs should be in CustomPlace(%s).", device_type_));
333+
PADDLE_ENFORCE_EQ(
334+
CheckTensorsInCustomPlace(out_tensors, device_type_),
335+
true,
336+
platform::errors::InvalidArgument(
337+
"All outputs should be in CustomPlace(%s).", device_type_));
338+
return Collective(
339+
in_tensors,
340+
out_tensors,
341+
[&](phi::DenseTensor& input,
342+
phi::DenseTensor& output,
343+
phi::ccl::CCLComm comm,
344+
const phi::stream::Stream& stream) {
345+
return phi::DeviceManager::CCLAllGather(
346+
device_type_,
347+
input.data(),
348+
output.data(),
349+
input.numel(),
350+
phi::ccl::ToCCLDataType(input.dtype()),
351+
comm,
352+
stream);
353+
},
354+
CommType::ALLGATHER);
355+
}
356+
290357
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
291358
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
292359
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
@@ -366,40 +433,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
366433
CommType::BROADCAST);
367434
}
368435

369-
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
370-
const BarrierOptions& opts) {
371-
// Only support single card single process
372-
PADDLE_ENFORCE_GE(opts.device_id,
373-
0,
374-
platform::errors::PreconditionNotMet(
375-
"The barrier device id must greater or equal than 0."));
376-
platform::CustomPlace place(device_type_, opts.device_id);
377-
std::vector<phi::CustomPlace> places = {place};
378-
std::vector<phi::DenseTensor> barrierTensors;
379-
barrierTensors.reserve(places.size());
380-
381-
for (auto& place : places) {
382-
phi::DeviceGuard guard(place);
383-
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
384-
auto allocator = std::unique_ptr<phi::Allocator>(
385-
new paddle::experimental::DefaultAllocator(place));
386-
barrierTensors.emplace_back(allocator.get(), meta);
387-
}
388-
auto task = ProcessGroupCustom::AllReduce(barrierTensors, barrierTensors);
389-
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
390-
xccl_task->barrierTensors_ = std::move(barrierTensors);
391-
return task;
392-
}
393-
394-
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
395-
std::vector<Place> places = {place};
396-
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
397-
PADDLE_ENFORCE_NE(iter,
398-
places_to_customcomm_.end(),
399-
platform::errors::InvalidArgument(
400-
"Cannot find nccl comm in process group."));
401-
return iter->second[0]->GetCustomCCLComm();
402-
}
403-
404436
} // namespace distributed
405437
} // namespace paddle

paddle/fluid/distributed/collective/ProcessGroupCustom.h

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ class ProcessGroupCustom : public ProcessGroup {
7878
int64_t numel,
7979
bool sync_op) override;
8080

81+
std::shared_ptr<ProcessGroup::Task> AllReduce(
82+
phi::DenseTensor* out_tensor,
83+
const phi::DenseTensor& in_tensor,
84+
const AllreduceOptions& opts,
85+
bool sync_op) override;
86+
87+
std::shared_ptr<ProcessGroup::Task> Broadcast(
88+
phi::DenseTensor* out_tensor,
89+
const phi::DenseTensor& in_tensor,
90+
const BroadcastOptions& opts,
91+
bool sync_op) override;
92+
93+
std::shared_ptr<ProcessGroup::Task> Barrier(
94+
const BarrierOptions& = BarrierOptions()) override;
95+
96+
const phi::DeviceContext& GetDeviceContext(const Place& place) const override;
97+
98+
phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
99+
100+
// TODO(sunyilun): methods below will be removed later
81101
std::shared_ptr<ProcessGroup::Task> AllGather(
82102
std::vector<phi::DenseTensor>& in_tensors,
83103
std::vector<phi::DenseTensor>& out_tensors) override;
@@ -92,11 +112,6 @@ class ProcessGroupCustom : public ProcessGroup {
92112
std::vector<phi::DenseTensor>& out_tensors,
93113
const BroadcastOptions& = BroadcastOptions()) override;
94114

95-
std::shared_ptr<ProcessGroup::Task> Barrier(
96-
const BarrierOptions& = BarrierOptions()) override;
97-
98-
phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
99-
100115
protected:
101116
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
102117
std::vector<Place> places,

python/paddle/fluid/tests/custom_runtime/process_group_xccl.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ def test_create_process_group_xccl(self):
6363

6464
sum_result = tensor_x + tensor_y
6565
if pg.rank() == 0:
66-
task = pg.allreduce(tensor_x)
66+
task = pg.all_reduce(tensor_x, core.ReduceOp.SUM, sync_op=True)
6767
task.wait()
6868
# assert np.array_equal(tensor_x, sum_result)
6969
else:
70-
task = pg.allreduce(tensor_y)
70+
task = pg.all_reduce(tensor_y, core.ReduceOp.SUM, sync_op=True)
7171
task.wait()
7272
# assert np.array_equal(tensor_y, sum_result)
7373

@@ -81,11 +81,11 @@ def test_create_process_group_xccl(self):
8181
max_result = paddle.maximum(tensor_x, tensor_y)
8282

8383
if pg.rank() == 0:
84-
task = pg.allreduce(tensor_x, core.ReduceOp.MAX)
84+
task = pg.all_reduce(tensor_x, core.ReduceOp.MAX, sync_op=True)
8585
task.wait()
8686
# assert np.array_equal(tensor_x, max_result)
8787
else:
88-
task = pg.allreduce(tensor_y, core.ReduceOp.MAX)
88+
task = pg.all_reduce(tensor_y, core.ReduceOp.MAX, sync_op=True)
8989
task.wait()
9090
# assert np.array_equal(tensor_y, max_result)
9191

@@ -101,14 +101,14 @@ def test_create_process_group_xccl(self):
101101

102102
broadcast_result = paddle.assign(tensor_x)
103103
if pg.rank() == 0:
104-
task = pg.broadcast(tensor_x, 0)
105-
task.synchronize()
104+
task = pg.broadcast(tensor_x, 0, sync_op=True)
105+
task.wait()
106106
# paddle.fluid.core._custom_device_synchronize("custom_cpu", -1)
107107
assert task.is_completed()
108108
# assert np.array_equal(broadcast_result, tensor_x)
109109
else:
110-
task = pg.broadcast(tensor_y, 0)
111-
task.synchronize()
110+
task = pg.broadcast(tensor_y, 0, sync_op=True)
111+
task.wait()
112112
# paddle.fluid.core._custom_device_synchronize("custom_cpu", -1)
113113
assert task.is_completed()
114114
# assert np.array_equal(broadcast_result, tensor_y)
@@ -139,12 +139,12 @@ def test_create_process_group_xccl(self):
139139
out = np.random.random(out_shape).astype(self.dtype)
140140
tensor_out = paddle.to_tensor(out)
141141
if pg.rank() == 0:
142-
task = pg.all_gather(tensor_x, tensor_out)
142+
task = pg.all_gather(tensor_out, tensor_x, sync_op=True)
143143
task.wait()
144144
# paddle.fluid.core._custom_device_synchronize("custom_cpu", -1)
145145
# rank 1
146146
else:
147-
task = pg.all_gather(tensor_y, tensor_out)
147+
task = pg.all_gather(tensor_out, tensor_y, sync_op=True)
148148
task.wait()
149149
# paddle.fluid.core._custom_device_synchronize("custom_cpu", -1)
150150
out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])

0 commit comments

Comments
 (0)