Skip to content

Commit c3f8dc3

Browse files
authored
fix xpu comm stream (#71289)
1 parent 7c4c514 commit c3f8dc3

File tree

6 files changed

+24
-5
lines changed

6 files changed

+24
-5
lines changed

paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
#include "paddle/phi/core/tensor_base.h"
2828
#include "paddle/phi/core/visit_type.h"
2929

30+
#if defined(PADDLE_WITH_XPU)
31+
#include <xpu/runtime.h>
32+
#endif
33+
3034
namespace phi {
3135
class DeviceContext;
3236

@@ -86,6 +90,18 @@ phi::DDim InferShapeForReshardFromReplicate(
8690
#define DEVICE_CONTEXT CustomContext
8791
#endif
8892

93+
#if defined(PADDLE_WITH_XPU)
94+
#define DEVICE_WAIT(dev_ctx) \
95+
do { \
96+
xpu_wait(); \
97+
(dev_ctx)->Wait(); \
98+
} while (0)
99+
#else
100+
#define DEVICE_WAIT(dev_ctx) \
101+
do { \
102+
} while (0) // no need to wait on other devices.
103+
#endif
104+
89105
// Some reshard function supports fewer data types on xpu than on gpu. For
90106
// example, `Transpose`, `Split`, and `Divide` do not support double type.
91107
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
@@ -125,12 +141,14 @@ phi::DDim InferShapeForReshardFromReplicate(
125141
__VA_ARGS__); \
126142
})); \
127143
} else if (DEVICE_CONTEXT::classof(dev_ctx)) { \
144+
DEVICE_WAIT(dev_ctx); \
128145
VLOG(4) << "Call `" << #fn_name << "` in Resharding on device."; \
129146
PD_VISIT_RESHARD_TYPES( \
130147
dtype, #fn_name, ([&] { \
131148
fn_name<data_t>(static_cast<const DEVICE_CONTEXT&>(*dev_ctx), \
132149
__VA_ARGS__); \
133150
})); \
151+
DEVICE_WAIT(dev_ctx); \
134152
} else { \
135153
PADDLE_THROW(common::errors::Unimplemented( \
136154
"The %s in reshard only supported on CPU, GPU, and XPU for now.", \

paddle/phi/kernels/xpu/all_gather_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ void AllGatherKernel(const Context& dev_ctx,
4646
errors::InvalidArgument(
4747
"nranks: %s should equal to %s", nranks, comm_ctx->GetSize()));
4848

49-
XPUStream stream = comm_ctx->GetStream();
49+
XPUStream stream = dev_ctx.stream();
5050
comm_ctx->AllGather(out, x, stream);
5151
#else
5252
PADDLE_THROW(common::errors::PreconditionNotMet(

paddle/phi/kernels/xpu/all_reduce_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ void AllReduceKernel(const Context& dev_ctx,
3737
common::errors::Unavailable(
3838
"BKCLCommContext is nullptr, collective op should "
3939
"has ring_id attr."));
40-
XPUStream stream = comm_ctx->GetStream();
40+
41+
XPUStream stream = dev_ctx.stream();
4142

4243
BKCLOp bkcl_reduce_type = BKCL_ADD;
4344
switch (static_cast<ReduceType>(reduce_type)) {

paddle/phi/kernels/xpu/all_to_all_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void AllToAllKernel(const Context& dev_ctx,
3838
"BKCLCommContext is nullptr, collective op should "
3939
"has ring_id attr."));
4040

41-
XPUStream stream = comm_ctx->GetStream();
41+
XPUStream stream = dev_ctx.stream();
4242
int nranks = comm_ctx->GetSize();
4343
PADDLE_ENFORCE_EQ(
4444
x_dims[0] % nranks,

paddle/phi/kernels/xpu/barrier_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void BarrierKernel(const Context &dev_ctx,
4242
common::errors::Unavailable(
4343
"BKCLCommContext is nullptr, collective op should "
4444
"has ring_id attr."));
45-
XPUStream stream = comm_ctx->GetStream();
45+
XPUStream stream = dev_ctx.stream();
4646
BKCLOp bkcl_reduce_type = BKCL_ADD;
4747
comm_ctx->AllReduce(out, *in, bkcl_reduce_type, stream);
4848
XPUStreamSync(stream);

paddle/phi/kernels/xpu/reduce_scatter_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void ReduceScatterKernel(const Context& dev_ctx,
4747
"BKCLCommContext is nullptr, collective op should "
4848
"has ring_id attr."));
4949

50-
XPUStream stream = comm_ctx->GetStream();
50+
XPUStream stream = dev_ctx.stream();
5151
comm_ctx->ReduceScatter(out, x, BKCL_ADD, stream);
5252
#else
5353
PADDLE_THROW(common::errors::PreconditionNotMet(

0 commit comments

Comments
 (0)