@@ -122,6 +122,16 @@ class ProcessGroup {
122122 " ProcessGroup%s does not support broadcast" , GetBackendName ()));
123123 }
124124
125+ virtual std::shared_ptr<ProcessGroup::Task> Broadcast (
126+ std::vector<phi::DenseTensor>& /* input tensors */ , // NOLINT
127+ std::vector<phi::DenseTensor>& /* output tensors */ , // NOLINT
128+ const BroadcastOptions&,
129+ bool ) {
130+ PADDLE_THROW (platform::errors::InvalidArgument (
131+ " ProcessGroup%s does not support broadcast with sync_op flag" ,
132+ GetBackendName ()));
133+ }
134+
125135 virtual std::shared_ptr<ProcessGroup::Task> Barrier (
126136 const BarrierOptions& = BarrierOptions()) {
127137 PADDLE_THROW (platform::errors::InvalidArgument (
@@ -134,38 +144,89 @@ class ProcessGroup {
134144 " ProcessGroup%s does not support send" , GetBackendName ()));
135145 }
136146
147+ virtual std::shared_ptr<ProcessGroup::Task> Send (
148+ std::vector<phi::DenseTensor>&, int , bool ) { // NOLINT
149+ PADDLE_THROW (platform::errors::InvalidArgument (
150+ " ProcessGroup%s does not support send with sync_op flag" ,
151+ GetBackendName ()));
152+ }
153+
137154 virtual std::shared_ptr<ProcessGroup::Task> Recv (
138- std::vector<phi::DenseTensor>& tensors , int ) { // NOLINT
155+ std::vector<phi::DenseTensor>&, int ) { // NOLINT
139156 PADDLE_THROW (platform::errors::InvalidArgument (
140- " ProcessGroup%s does not support receive " , GetBackendName ()));
157+ " ProcessGroup%s does not support recv " , GetBackendName ()));
141158 }
142159
143- virtual std::shared_ptr<ProcessGroup::Task> Send_Partial (phi::DenseTensor&,
144- int ,
145- int ,
146- int ) { // NOLINT
160+ virtual std::shared_ptr<ProcessGroup::Task> Recv (
161+ std::vector<phi::DenseTensor>&, int , bool ) { // NOLINT
147162 PADDLE_THROW (platform::errors::InvalidArgument (
148- " ProcessGroup%s does not support send" , GetBackendName ()));
163+ " ProcessGroup%s does not support recv with sync_op flag" ,
164+ GetBackendName ()));
165+ }
166+
167+ virtual std::shared_ptr<ProcessGroup::Task> Send_Partial (
168+ phi::DenseTensor&, // NOLINT
169+ int ,
170+ int64_t ,
171+ int64_t ) {
172+ PADDLE_THROW (platform::errors::InvalidArgument (
173+ " ProcessGroup%s does not support send_partial" , GetBackendName ()));
174+ }
175+
176+ virtual std::shared_ptr<ProcessGroup::Task> Send_Partial (
177+ phi::DenseTensor&, int , int64_t , int64_t , bool ) { // NOLINT
178+ PADDLE_THROW (platform::errors::InvalidArgument (
179+ " ProcessGroup%s does not support send_partial with sync_op flag" ,
180+ GetBackendName ()));
149181 }
150182
151183 virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial (
152- phi::DenseTensor& tensors, int , int , int ) { // NOLINT
184+ phi::DenseTensor&, // NOLINT
185+ int ,
186+ int64_t ,
187+ int64_t ) {
153188 PADDLE_THROW (platform::errors::InvalidArgument (
154- " ProcessGroup%s does not support receive" , GetBackendName ()));
189+ " ProcessGroup%s does not support recv_partial" , GetBackendName ()));
190+ }
191+
192+ virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial (
193+ phi::DenseTensor&, int , int64_t , int64_t , bool ) { // NOLINT
194+ PADDLE_THROW (platform::errors::InvalidArgument (
195+ " ProcessGroup%s does not support recv_partial with sync_op flag" ,
196+ GetBackendName ()));
155197 }
156198
157199 virtual std::shared_ptr<ProcessGroup::Task> AllGather (
158200 std::vector<phi::DenseTensor>&, // NOLINT
159201 std::vector<phi::DenseTensor>&) { // NOLINT
160202 PADDLE_THROW (platform::errors::InvalidArgument (
161- " ProcessGroup%s does not support AllGather" , GetBackendName ()));
203+ " ProcessGroup%s does not support all_gather" , GetBackendName ()));
204+ }
205+
206+ virtual std::shared_ptr<ProcessGroup::Task> AllGather (
207+ std::vector<phi::DenseTensor>&, // NOLINT
208+ std::vector<phi::DenseTensor>&, // NOLINT
209+ bool ) {
210+ PADDLE_THROW (platform::errors::InvalidArgument (
211+ " ProcessGroup%s does not support all_gather with sync_op flag" ,
212+ GetBackendName ()));
162213 }
163214
164215 virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial (
165216 std::vector<phi::DenseTensor>& in_tensors, // NOLINT
166217 std::vector<phi::DenseTensor>& out_tensors, // NOLINT
167- int offset,
168- int length) { // NOLINT
218+ int64_t offset,
219+ int64_t length) {
220+ PADDLE_THROW (platform::errors::InvalidArgument (
221+ " ProcessGroup%s does not support AllGather_Partial" , GetBackendName ()));
222+ }
223+
224+ virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial (
225+ std::vector<phi::DenseTensor>& in_tensors, // NOLINT
226+ std::vector<phi::DenseTensor>& out_tensors, // NOLINT
227+ int64_t offset,
228+ int64_t length,
229+ bool ) {
169230 PADDLE_THROW (platform::errors::InvalidArgument (
170231 " ProcessGroup%s does not support AllGather_Partial" , GetBackendName ()));
171232 }
@@ -177,6 +238,14 @@ class ProcessGroup {
177238 " ProcessGroup%s does not support AllToAll" , GetBackendName ()));
178239 }
179240
241+ virtual std::shared_ptr<ProcessGroup::Task> AllToAll (
242+ std::vector<phi::DenseTensor>&, // NOLINT
243+ std::vector<phi::DenseTensor>&, // NOLINT
244+ bool ) {
245+ PADDLE_THROW (platform::errors::InvalidArgument (
246+ " ProcessGroup%s does not support alltoall" , GetBackendName ()));
247+ }
248+
180249 virtual std::shared_ptr<ProcessGroup::Task> AllToAll_Single (
181250 std::vector<phi::DenseTensor>&, // NOLINT
182251 std::vector<phi::DenseTensor>&, // NOLINT
@@ -186,26 +255,66 @@ class ProcessGroup {
186255 " ProcessGroup%s does not support AllToAll_Single" , GetBackendName ()));
187256 }
188257
258+ virtual std::shared_ptr<ProcessGroup::Task> AllToAllSingle (
259+ std::vector<phi::DenseTensor>&, // NOLINT
260+ std::vector<phi::DenseTensor>&, // NOLINT
261+ std::vector<int64_t >&,
262+ std::vector<int64_t >&,
263+ bool ) {
264+ PADDLE_THROW (platform::errors::InvalidArgument (
265+ " ProcessGroup%s does not support alltoall_single" , GetBackendName ()));
266+ }
267+
189268 virtual std::shared_ptr<ProcessGroup::Task> Reduce (
190269 std::vector<phi::DenseTensor>&, // NOLINT
191270 std::vector<phi::DenseTensor>&, // NOLINT
192271 const ReduceOptions& opts) {
193272 PADDLE_THROW (platform::errors::InvalidArgument (
194- " ProcessGroup%s does not support Reduce" , GetBackendName ()));
273+ " ProcessGroup%s does not support reduce" , GetBackendName ()));
274+ }
275+
276+ virtual std::shared_ptr<ProcessGroup::Task> Reduce (
277+ std::vector<phi::DenseTensor>& /* input tensors */ , // NOLINT
278+ std::vector<phi::DenseTensor>& /* output tensors */ , // NOLINT
279+ const ReduceOptions&,
280+ bool ) {
281+ PADDLE_THROW (platform::errors::InvalidArgument (
282+ " ProcessGroup%s does not support reduce with sync_op flag" ,
283+ GetBackendName ()));
284+ }
285+
286+ virtual std::shared_ptr<ProcessGroup::Task> Scatter (
287+ std::vector<phi::DenseTensor>&, // NOLINT
288+ std::vector<phi::DenseTensor>&, // NOLINT
289+ const ScatterOptions&) {
290+ PADDLE_THROW (platform::errors::InvalidArgument (
291+ " ProcessGroup%s does not support scatter" , GetBackendName ()));
195292 }
196293
197294 virtual std::shared_ptr<ProcessGroup::Task> Scatter (
198295 std::vector<phi::DenseTensor>&, // NOLINT
199296 std::vector<phi::DenseTensor>&, // NOLINT
200- const ScatterOptions&) { // NOLINT
297+ const ScatterOptions&,
298+ bool ) {
299+ PADDLE_THROW (platform::errors::InvalidArgument (
300+ " ProcessGroup%s does not support scatter with sync_op flag" ,
301+ GetBackendName ()));
302+ }
303+
304+ virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter (
305+ std::vector<phi::DenseTensor>&, // NOLINT
306+ std::vector<phi::DenseTensor>&, // NOLINT
307+ const ReduceScatterOptions&,
308+ bool ) {
201309 PADDLE_THROW (platform::errors::InvalidArgument (
202- " ProcessGroup%s does not support Scatter" , GetBackendName ()));
310+ " ProcessGroup%s does not support reduce_scatter with sync_op flag" ,
311+ GetBackendName ()));
203312 }
204313
205314 virtual std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase (
206- phi::DenseTensor&, // NOLINT
207- phi::DenseTensor&, // NOLINT
208- const ReduceScatterOptions&) { // NOLINT
315+ phi::DenseTensor&, // NOLINT
316+ phi::DenseTensor&, // NOLINT
317+ const ReduceScatterOptions&) {
209318 PADDLE_THROW (platform::errors::InvalidArgument (
210319 " ProcessGroup%s does not support ReduceScatter" , GetBackendName ()));
211320 }
0 commit comments