Skip to content

Commit 319af27

Browse files
ronny1996jjyaoao
authored andcommitted
[CustomDevice] add c_identity op (PaddlePaddle#52982)
* [CustomDevice] add c_identity op * fix use calc stream
1 parent 9ce80f1 commit 319af27

File tree

3 files changed

+120
-19
lines changed

3 files changed

+120
-19
lines changed

paddle/fluid/distributed/collective/process_group_custom.cc

Lines changed: 101 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
187187
std::vector<phi::DenseTensor>& inputs,
188188
std::vector<phi::DenseTensor>& outputs,
189189
Fn fn,
190-
CommType op_type) {
190+
CommType op_type,
191+
bool sync_op,
192+
bool use_calc_stream) {
191193
const auto places = GetPlaceList(inputs);
192194
const auto key = GetKeyFromPlaces(places);
193195

@@ -199,20 +201,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
199201
}
200202

201203
auto& ccl_comms = places_to_customcomm_[key];
202-
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
204+
if (!use_calc_stream) {
205+
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
206+
}
203207
auto task = CreateTask(places, rank_, op_type, inputs);
204208
task->SetOutputs(outputs);
205209

206210
for (size_t i = 0; i < inputs.size(); ++i) {
207211
phi::DeviceGuard guard(places[i]);
208-
const auto& ccl_stream = places_to_ctx_[key][i]->stream();
212+
const auto& ccl_stream =
213+
use_calc_stream ? reinterpret_cast<phi::CustomContext*>(
214+
phi::DeviceContextPool::Instance().Get(places[i]))
215+
->stream()
216+
: places_to_ctx_[key][i]->stream();
209217
phi::stream::Stream stream(places[i], ccl_stream);
210218
fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream);
211219
}
212220

213-
for (size_t i = 0; i < inputs.size(); ++i) {
214-
phi::DeviceGuard guard(places[i]);
215-
task->control_events_[i].Record(*places_to_ctx_[key][i]);
221+
if (!use_calc_stream) {
222+
for (size_t i = 0; i < inputs.size(); ++i) {
223+
phi::DeviceGuard guard(places[i]);
224+
task->control_events_[i].Record(*places_to_ctx_[key][i]);
225+
}
216226
}
217227
return task;
218228
}
@@ -280,7 +290,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
280290
comm,
281291
stream);
282292
},
283-
CommType::ALLGATHER);
293+
CommType::ALLGATHER,
294+
sync_op,
295+
use_calc_stream);
284296
}
285297

286298
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
@@ -322,7 +334,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
322334
comm,
323335
stream);
324336
},
325-
CommType::ALLGATHER);
337+
CommType::ALLGATHER,
338+
false,
339+
false);
326340
}
327341

328342
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
@@ -333,7 +347,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
333347
bool use_calc_stream) {
334348
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
335349
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
336-
return AllReduce(in_wrapper, out_wrapper, opts);
350+
PADDLE_ENFORCE_EQ(
351+
CheckTensorsInCustomPlace(in_wrapper, device_type_),
352+
true,
353+
platform::errors::InvalidArgument(
354+
"All inputs should be in CustomPlace(%s).", device_type_));
355+
PADDLE_ENFORCE_EQ(
356+
CheckTensorsInCustomPlace(out_wrapper, device_type_),
357+
true,
358+
platform::errors::InvalidArgument(
359+
"All outputs should be in CustomPlace(%s).", device_type_));
360+
return Collective(
361+
in_wrapper,
362+
out_wrapper,
363+
[&](phi::DenseTensor& input,
364+
phi::DenseTensor& output,
365+
phi::ccl::CCLComm comm,
366+
const phi::stream::Stream& stream) {
367+
return phi::DeviceManager::CCLAllReduce(
368+
device_type_,
369+
input.data(),
370+
output.data(),
371+
input.numel(),
372+
phi::ccl::ToCCLDataType(input.dtype()),
373+
ToCustomCCLRedType(opts.reduce_op),
374+
comm,
375+
stream);
376+
},
377+
CommType::ALLREDUCE,
378+
sync_op,
379+
use_calc_stream);
337380
}
338381

339382
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
@@ -342,9 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
342385
const AllreduceOptions& opts,
343386
bool sync_op // for compatibility, no use now
344387
) {
345-
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
346-
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
347-
return AllReduce(in_wrapper, out_wrapper, opts);
388+
return AllReduce(out_tensor, in_tensor, opts, sync_op, false);
348389
}
349390

350391
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
@@ -378,7 +419,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
378419
comm,
379420
stream);
380421
},
381-
CommType::ALLREDUCE);
422+
CommType::ALLREDUCE,
423+
false,
424+
false);
382425
}
383426

384427
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
@@ -389,17 +432,55 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
389432
bool use_calc_stream) {
390433
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
391434
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
392-
return Broadcast(in_wrapper, out_wrapper, opts);
435+
PADDLE_ENFORCE_EQ(
436+
CheckTensorsInCustomPlace(in_wrapper, device_type_),
437+
true,
438+
platform::errors::InvalidArgument(
439+
"All inputs should be in CustomPlace(%s).", device_type_));
440+
PADDLE_ENFORCE_EQ(
441+
CheckTensorsInCustomPlace(out_wrapper, device_type_),
442+
true,
443+
platform::errors::InvalidArgument(
444+
"All outputs should be in CustomPlace(%s).", device_type_));
445+
return Collective(
446+
in_wrapper,
447+
out_wrapper,
448+
[&](phi::DenseTensor& input,
449+
phi::DenseTensor& output,
450+
phi::ccl::CCLComm comm,
451+
const phi::stream::Stream& stream) {
452+
int root = opts.source_rank * in_wrapper.size() + opts.source_root;
453+
if (rank_ == root) {
454+
return phi::DeviceManager::CCLBroadcast(
455+
device_type_,
456+
input.data(),
457+
input.numel(),
458+
phi::ccl::ToCCLDataType(input.dtype()),
459+
root,
460+
comm,
461+
stream);
462+
} else {
463+
return phi::DeviceManager::CCLBroadcast(
464+
device_type_,
465+
output.data(),
466+
output.numel(),
467+
phi::ccl::ToCCLDataType(output.dtype()),
468+
root,
469+
comm,
470+
stream);
471+
}
472+
},
473+
CommType::BROADCAST,
474+
sync_op,
475+
use_calc_stream);
393476
}
394477

395478
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
396479
phi::DenseTensor* out_tensor,
397480
const phi::DenseTensor& in_tensor,
398481
const BroadcastOptions& opts,
399482
bool sync_op) {
400-
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
401-
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
402-
return Broadcast(in_wrapper, out_wrapper, opts);
483+
return Broadcast(out_tensor, in_tensor, opts, sync_op, false);
403484
}
404485

405486
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
@@ -489,7 +570,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
489570
stream);
490571
}
491572
},
492-
CommType::BROADCAST);
573+
CommType::BROADCAST,
574+
false,
575+
false);
493576
}
494577

495578
std::shared_ptr<ProcessGroupCustom>

paddle/fluid/distributed/collective/process_group_custom.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
176176
std::vector<phi::DenseTensor>& inputs, // NOLINT
177177
std::vector<phi::DenseTensor>& outputs, // NOLINT
178178
Fn fn,
179-
CommType op_type);
179+
CommType op_type,
180+
bool sync_op,
181+
bool use_calc_stream);
180182

181183
void CreateCustomManagerCache(const std::string& places_key,
182184
const std::vector<Place>& places);

paddle/fluid/operators/custom_device_common_op_registry.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/custom_device_common_op_registry.h"
1616
#include "paddle/fluid/distributed/collective/process_group.h"
1717
#include "paddle/fluid/operators/collective/c_concat_op.h"
18+
#include "paddle/fluid/operators/collective/c_identity_op.h"
1819
#include "paddle/fluid/operators/load_combine_op.h"
1920
#include "paddle/fluid/operators/run_program_op.h"
2021
#include "paddle/fluid/operators/save_combine_op.h"
@@ -589,6 +590,21 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
589590
paddle::platform::CustomDeviceContext,
590591
paddle::platform::float16>) {}
591592

593+
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
594+
c_identity,
595+
device_type,
596+
paddle::operators::
597+
CIdentityOpKernel<float, paddle::platform::CustomDeviceContext>,
598+
paddle::operators::
599+
CIdentityOpKernel<double, paddle::platform::CustomDeviceContext>,
600+
paddle::operators::
601+
CIdentityOpKernel<int, paddle::platform::CustomDeviceContext>,
602+
paddle::operators::
603+
CIdentityOpKernel<int64_t, paddle::platform::CustomDeviceContext>,
604+
paddle::operators::CIdentityOpKernel<
605+
paddle::platform::float16,
606+
paddle::platform::CustomDeviceContext>) {}
607+
592608
#endif
593609
}
594610

0 commit comments

Comments
 (0)