@@ -187,7 +187,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
187187 std::vector<phi::DenseTensor>& inputs,
188188 std::vector<phi::DenseTensor>& outputs,
189189 Fn fn,
190- CommType op_type) {
190+ CommType op_type,
191+ bool sync_op,
192+ bool use_calc_stream) {
191193 const auto places = GetPlaceList (inputs);
192194 const auto key = GetKeyFromPlaces (places);
193195
@@ -199,20 +201,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
199201 }
200202
201203 auto & ccl_comms = places_to_customcomm_[key];
202- SyncDefaultStream (places, places_to_events_[key], places_to_ctx_[key]);
204+ if (!use_calc_stream) {
205+ SyncDefaultStream (places, places_to_events_[key], places_to_ctx_[key]);
206+ }
203207 auto task = CreateTask (places, rank_, op_type, inputs);
204208 task->SetOutputs (outputs);
205209
206210 for (size_t i = 0 ; i < inputs.size (); ++i) {
207211 phi::DeviceGuard guard (places[i]);
208- const auto & ccl_stream = places_to_ctx_[key][i]->stream ();
212+ const auto & ccl_stream =
213+ use_calc_stream ? reinterpret_cast <phi::CustomContext*>(
214+ phi::DeviceContextPool::Instance ().Get (places[i]))
215+ ->stream ()
216+ : places_to_ctx_[key][i]->stream ();
209217 phi::stream::Stream stream (places[i], ccl_stream);
210218 fn (inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm (), stream);
211219 }
212220
213- for (size_t i = 0 ; i < inputs.size (); ++i) {
214- phi::DeviceGuard guard (places[i]);
215- task->control_events_ [i].Record (*places_to_ctx_[key][i]);
221+ if (!use_calc_stream) {
222+ for (size_t i = 0 ; i < inputs.size (); ++i) {
223+ phi::DeviceGuard guard (places[i]);
224+ task->control_events_ [i].Record (*places_to_ctx_[key][i]);
225+ }
216226 }
217227 return task;
218228}
@@ -280,7 +290,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
280290 comm,
281291 stream);
282292 },
283- CommType::ALLGATHER);
293+ CommType::ALLGATHER,
294+ sync_op,
295+ use_calc_stream);
284296}
285297
286298std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather (
@@ -322,7 +334,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
322334 comm,
323335 stream);
324336 },
325- CommType::ALLGATHER);
337+ CommType::ALLGATHER,
338+ false ,
339+ false );
326340}
327341
328342std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
@@ -333,7 +347,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
333347 bool use_calc_stream) {
334348 std::vector<phi::DenseTensor> in_wrapper{in_tensor};
335349 std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
336- return AllReduce (in_wrapper, out_wrapper, opts);
350+ PADDLE_ENFORCE_EQ (
351+ CheckTensorsInCustomPlace (in_wrapper, device_type_),
352+ true ,
353+ platform::errors::InvalidArgument (
354+ " All inputs should be in CustomPlace(%s)." , device_type_));
355+ PADDLE_ENFORCE_EQ (
356+ CheckTensorsInCustomPlace (out_wrapper, device_type_),
357+ true ,
358+ platform::errors::InvalidArgument (
359+ " All outputs should be in CustomPlace(%s)." , device_type_));
360+ return Collective (
361+ in_wrapper,
362+ out_wrapper,
363+ [&](phi::DenseTensor& input,
364+ phi::DenseTensor& output,
365+ phi::ccl::CCLComm comm,
366+ const phi::stream::Stream& stream) {
367+ return phi::DeviceManager::CCLAllReduce (
368+ device_type_,
369+ input.data (),
370+ output.data (),
371+ input.numel (),
372+ phi::ccl::ToCCLDataType (input.dtype ()),
373+ ToCustomCCLRedType (opts.reduce_op ),
374+ comm,
375+ stream);
376+ },
377+ CommType::ALLREDUCE,
378+ sync_op,
379+ use_calc_stream);
337380}
338381
339382std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
@@ -342,9 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
342385 const AllreduceOptions& opts,
343386 bool sync_op // for compatibility, no use now
344387) {
345- std::vector<phi::DenseTensor> in_wrapper{in_tensor};
346- std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
347- return AllReduce (in_wrapper, out_wrapper, opts);
388+ return AllReduce (out_tensor, in_tensor, opts, sync_op, false );
348389}
349390
350391std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
@@ -378,7 +419,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
378419 comm,
379420 stream);
380421 },
381- CommType::ALLREDUCE);
422+ CommType::ALLREDUCE,
423+ false ,
424+ false );
382425}
383426
384427std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast (
@@ -389,17 +432,55 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
389432 bool use_calc_stream) {
390433 std::vector<phi::DenseTensor> in_wrapper{in_tensor};
391434 std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
392- return Broadcast (in_wrapper, out_wrapper, opts);
435+ PADDLE_ENFORCE_EQ (
436+ CheckTensorsInCustomPlace (in_wrapper, device_type_),
437+ true ,
438+ platform::errors::InvalidArgument (
439+ " All inputs should be in CustomPlace(%s)." , device_type_));
440+ PADDLE_ENFORCE_EQ (
441+ CheckTensorsInCustomPlace (out_wrapper, device_type_),
442+ true ,
443+ platform::errors::InvalidArgument (
444+ " All outputs should be in CustomPlace(%s)." , device_type_));
445+ return Collective (
446+ in_wrapper,
447+ out_wrapper,
448+ [&](phi::DenseTensor& input,
449+ phi::DenseTensor& output,
450+ phi::ccl::CCLComm comm,
451+ const phi::stream::Stream& stream) {
452+ int root = opts.source_rank * in_wrapper.size () + opts.source_root ;
453+ if (rank_ == root) {
454+ return phi::DeviceManager::CCLBroadcast (
455+ device_type_,
456+ input.data (),
457+ input.numel (),
458+ phi::ccl::ToCCLDataType (input.dtype ()),
459+ root,
460+ comm,
461+ stream);
462+ } else {
463+ return phi::DeviceManager::CCLBroadcast (
464+ device_type_,
465+ output.data (),
466+ output.numel (),
467+ phi::ccl::ToCCLDataType (output.dtype ()),
468+ root,
469+ comm,
470+ stream);
471+ }
472+ },
473+ CommType::BROADCAST,
474+ sync_op,
475+ use_calc_stream);
393476}
394477
395478std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast (
396479 phi::DenseTensor* out_tensor,
397480 const phi::DenseTensor& in_tensor,
398481 const BroadcastOptions& opts,
399482 bool sync_op) {
400- std::vector<phi::DenseTensor> in_wrapper{in_tensor};
401- std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
402- return Broadcast (in_wrapper, out_wrapper, opts);
483+ return Broadcast (out_tensor, in_tensor, opts, sync_op, false );
403484}
404485
405486std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier (
@@ -489,7 +570,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
489570 stream);
490571 }
491572 },
492- CommType::BROADCAST);
573+ CommType::BROADCAST,
574+ false ,
575+ false );
493576}
494577
495578std::shared_ptr<ProcessGroupCustom>
0 commit comments