@@ -202,38 +202,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
202202  return  task;
203203}
204204
205- std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather (
206-     std::vector<phi::DenseTensor>& in_tensors,
207-     std::vector<phi::DenseTensor>& out_tensors) {
208-   PADDLE_ENFORCE_EQ (
209-       CheckTensorsInCustomPlace (in_tensors, device_type_),
210-       true ,
211-       platform::errors::InvalidArgument (
212-           " All inputs should be in CustomPlace(%s)." 
213-   PADDLE_ENFORCE_EQ (
214-       CheckTensorsInCustomPlace (out_tensors, device_type_),
215-       true ,
216-       platform::errors::InvalidArgument (
217-           " All outputs should be in CustomPlace(%s)." 
218-   return  Collective (
219-       in_tensors,
220-       out_tensors,
221-       [&](phi::DenseTensor& input,
222-           phi::DenseTensor& output,
223-           phi::ccl::CCLComm comm,
224-           const  phi::stream::Stream& stream) {
225-         return  phi::DeviceManager::CCLAllGather (
226-             device_type_,
227-             input.data (),
228-             output.data (),
229-             input.numel (),
230-             phi::ccl::ToCCLDataType (input.dtype ()),
231-             comm,
232-             stream);
233-       },
234-       CommType::ALLGATHER);
235- }
236- 
237205void * XcclGetPointerByOffset (void * raw_pointer,
238206                             size_t  offset,
239207                             experimental::DataType type) {
@@ -259,13 +227,13 @@ void* XcclGetPointerByOffset(void* raw_pointer,
259227  return  nullptr ;
260228}
261229
262- //  NOTE: this is ONLY for compatibility
263230std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather (
264231    phi::DenseTensor* out_tensor,
265232    const  phi::DenseTensor& in_tensor,
266233    int64_t  offset,
267234    int64_t  numel,
268-     bool  sync_op) {
235+     bool  sync_op  //  for compatibility, no use now
236+ ) {
269237  std::vector<phi::DenseTensor> in_wrapper{in_tensor};
270238  std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
271239  return  Collective (
@@ -287,6 +255,105 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
287255      CommType::ALLGATHER);
288256}
289257
258+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
259+     phi::DenseTensor* out_tensor,
260+     const  phi::DenseTensor& in_tensor,
261+     const  AllreduceOptions& opts,
262+     bool  sync_op  //  for compatibility, no use now
263+ ) {
264+   std::vector<phi::DenseTensor> in_wrapper{in_tensor};
265+   std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
266+   return  AllReduce (in_wrapper, out_wrapper, opts);
267+ }
268+ 
269+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast (
270+     phi::DenseTensor* out_tensor,
271+     const  phi::DenseTensor& in_tensor,
272+     const  BroadcastOptions& opts,
273+     bool  sync_op  //  for compatibility, no use now
274+ ) {
275+   std::vector<phi::DenseTensor> in_wrapper{in_tensor};
276+   std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
277+   return  Broadcast (in_wrapper, out_wrapper, opts);
278+ }
279+ 
280+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier (
281+     const  BarrierOptions& opts) {
282+   //  Only support single card single process
283+   PADDLE_ENFORCE_GE (opts.device_id ,
284+                     0 ,
285+                     platform::errors::PreconditionNotMet (
286+                         " The barrier device id must greater or equal than 0." 
287+   platform::CustomPlace place (device_type_, opts.device_id );
288+   auto  allocator = std::unique_ptr<phi::Allocator>(
289+       new  paddle::experimental::DefaultAllocator (place));
290+   phi::DenseTensorMeta meta (phi::DataType::FLOAT32, phi::DDim{1 });
291+   phi::DenseTensor barrier_tensor{allocator.get (), meta};
292+ 
293+   auto  task = ProcessGroupCustom::AllReduce (&barrier_tensor,
294+                                             barrier_tensor,
295+                                             {},
296+                                             /* sync_op*/ true );
297+   auto  xccl_task = dynamic_cast <ProcessGroupCustom::CustomTask*>(task.get ());
298+   xccl_task->barrierTensors_  = {barrier_tensor};
299+   return  task;
300+ }
301+ 
302+ const  phi::DeviceContext& ProcessGroupCustom::GetDeviceContext (
303+     const  Place& place) const  {
304+   const  std::string key = GetKeyFromPlace (place);
305+   const  auto & iter = places_to_ctx_.find (key);
306+   PADDLE_ENFORCE_NE (
307+       iter,
308+       places_to_ctx_.end (),
309+       platform::errors::NotFound (
310+           " Cannot find the device context in this process group." 
311+   return  *iter->second [0 ];
312+ }
313+ 
314+ phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm (const  Place& place) const  {
315+   std::vector<Place> places = {place};
316+   const  auto & iter = places_to_customcomm_.find (GetKeyFromPlaces (places));
317+   PADDLE_ENFORCE_NE (iter,
318+                     places_to_customcomm_.end (),
319+                     platform::errors::InvalidArgument (
320+                         " Cannot find nccl comm in process group." 
321+   return  iter->second [0 ]->GetCustomCCLComm ();
322+ }
323+ 
324+ //  TODO(sunyilun): methods below will be removed later
325+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather (
326+     std::vector<phi::DenseTensor>& in_tensors,
327+     std::vector<phi::DenseTensor>& out_tensors) {
328+   PADDLE_ENFORCE_EQ (
329+       CheckTensorsInCustomPlace (in_tensors, device_type_),
330+       true ,
331+       platform::errors::InvalidArgument (
332+           " All inputs should be in CustomPlace(%s)." 
333+   PADDLE_ENFORCE_EQ (
334+       CheckTensorsInCustomPlace (out_tensors, device_type_),
335+       true ,
336+       platform::errors::InvalidArgument (
337+           " All outputs should be in CustomPlace(%s)." 
338+   return  Collective (
339+       in_tensors,
340+       out_tensors,
341+       [&](phi::DenseTensor& input,
342+           phi::DenseTensor& output,
343+           phi::ccl::CCLComm comm,
344+           const  phi::stream::Stream& stream) {
345+         return  phi::DeviceManager::CCLAllGather (
346+             device_type_,
347+             input.data (),
348+             output.data (),
349+             input.numel (),
350+             phi::ccl::ToCCLDataType (input.dtype ()),
351+             comm,
352+             stream);
353+       },
354+       CommType::ALLGATHER);
355+ }
356+ 
290357std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
291358    std::vector<phi::DenseTensor>& in_tensors,   //  NOLINT
292359    std::vector<phi::DenseTensor>& out_tensors,  //  NOLINT
@@ -366,40 +433,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
366433      CommType::BROADCAST);
367434}
368435
369- std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier (
370-     const  BarrierOptions& opts) {
371-   //  Only support single card single process
372-   PADDLE_ENFORCE_GE (opts.device_id ,
373-                     0 ,
374-                     platform::errors::PreconditionNotMet (
375-                         " The barrier device id must greater or equal than 0." 
376-   platform::CustomPlace place (device_type_, opts.device_id );
377-   std::vector<phi::CustomPlace> places = {place};
378-   std::vector<phi::DenseTensor> barrierTensors;
379-   barrierTensors.reserve (places.size ());
380- 
381-   for  (auto & place : places) {
382-     phi::DeviceGuard guard (place);
383-     phi::DenseTensorMeta meta (phi::DataType::FLOAT32, phi::DDim ({1 }));
384-     auto  allocator = std::unique_ptr<phi::Allocator>(
385-         new  paddle::experimental::DefaultAllocator (place));
386-     barrierTensors.emplace_back (allocator.get (), meta);
387-   }
388-   auto  task = ProcessGroupCustom::AllReduce (barrierTensors, barrierTensors);
389-   auto  xccl_task = dynamic_cast <ProcessGroupCustom::CustomTask*>(task.get ());
390-   xccl_task->barrierTensors_  = std::move (barrierTensors);
391-   return  task;
392- }
393- 
394- phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm (const  Place& place) const  {
395-   std::vector<Place> places = {place};
396-   const  auto & iter = places_to_customcomm_.find (GetKeyFromPlaces (places));
397-   PADDLE_ENFORCE_NE (iter,
398-                     places_to_customcomm_.end (),
399-                     platform::errors::InvalidArgument (
400-                         " Cannot find nccl comm in process group." 
401-   return  iter->second [0 ]->GetCustomCCLComm ();
402- }
403- 
404436}  //   namespace distributed
405437}  //   namespace paddle
0 commit comments