|
27 | 27 | #include "paddle/phi/core/tensor_base.h" |
28 | 28 | #include "paddle/phi/core/visit_type.h" |
29 | 29 |
|
| 30 | +#if defined(PADDLE_WITH_XPU) |
| 31 | +#include <xpu/runtime.h> |
| 32 | +#endif |
| 33 | + |
30 | 34 | namespace phi { |
31 | 35 | class DeviceContext; |
32 | 36 |
|
@@ -86,6 +90,18 @@ phi::DDim InferShapeForReshardFromReplicate( |
86 | 90 | #define DEVICE_CONTEXT CustomContext |
87 | 91 | #endif |
88 | 92 |
|
| 93 | +#if defined(PADDLE_WITH_XPU) |
| 94 | +#define DEVICE_WAIT(dev_ctx) \ |
| 95 | + do { \ |
| 96 | + xpu_wait(); \ |
| 97 | + (dev_ctx)->Wait(); \ |
| 98 | + } while (0) |
| 99 | +#else |
| 100 | +#define DEVICE_WAIT(dev_ctx) \ |
| 101 | + do { \ |
| 102 | + } while (0) // no need to wait on other devices. |
| 103 | +#endif |
| 104 | + |
89 | 105 | // Some reshard function supports fewer data types on xpu than on gpu. For |
90 | 106 | // example, `Transpose`, `Split`, and `Divide` do not support double type. |
91 | 107 | #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) |
@@ -125,12 +141,14 @@ phi::DDim InferShapeForReshardFromReplicate( |
125 | 141 | __VA_ARGS__); \ |
126 | 142 | })); \ |
127 | 143 | } else if (DEVICE_CONTEXT::classof(dev_ctx)) { \ |
| 144 | + DEVICE_WAIT(dev_ctx); \ |
128 | 145 | VLOG(4) << "Call `" << #fn_name << "` in Resharding on device."; \ |
129 | 146 | PD_VISIT_RESHARD_TYPES( \ |
130 | 147 | dtype, #fn_name, ([&] { \ |
131 | 148 | fn_name<data_t>(static_cast<const DEVICE_CONTEXT&>(*dev_ctx), \ |
132 | 149 | __VA_ARGS__); \ |
133 | 150 | })); \ |
| 151 | + DEVICE_WAIT(dev_ctx); \ |
134 | 152 | } else { \ |
135 | 153 | PADDLE_THROW(common::errors::Unimplemented( \ |
136 | 154 | "The %s in reshard only supported on CPU, GPU, and XPU for now.", \ |
|
0 commit comments