|
24 | 24 | #include "layer_norm_cuda.h" // NOLINT |
25 | 25 | #include "paddle/extension.h" |
26 | 26 |
|
| 27 | +#ifdef CUSTOM_OP_WITH_SPMD |
| 28 | +#include "paddle/phi/api/ext/spmd_infer.h" |
| 29 | +#include "paddle/phi/infermeta/spmd_rules/rules.h" |
| 30 | +#endif |
| 31 | + |
27 | 32 | #define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor") |
28 | 33 |
|
29 | 34 | static void GetRowsCols(const std::vector<int64_t> &shape, |
@@ -214,14 +219,22 @@ PD_BUILD_OP(fused_rms_norm) |
214 | 219 | .Attrs({"epsilon: float"}) |
215 | 220 | .SetKernelFn(PD_KERNEL(RMSLnFwd)) |
216 | 221 | .SetInferShapeFn(PD_INFER_SHAPE(RMSLnFwdInferShape)) |
217 | | - .SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype)); |
| 222 | + .SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype)) |
| 223 | +#ifdef CUSTOM_OP_WITH_SPMD |
| 224 | + .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormInferSpmd)) |
| 225 | +#endif |
| 226 | + ; |
218 | 227 |
|
219 | 228 | PD_BUILD_GRAD_OP(fused_rms_norm) |
220 | 229 | .Inputs({"x", "scale", "invvar", paddle::Grad("y")}) |
221 | 230 | .Outputs({paddle::Grad("x"), paddle::Grad("scale")}) |
222 | 231 | .Attrs({"epsilon: float"}) |
223 | 232 | .SetKernelFn(PD_KERNEL(RMSLnBwd)) |
224 | | - .SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape)); |
| 233 | + .SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape)) |
| 234 | +#ifdef CUSTOM_OP_WITH_SPMD |
| 235 | + .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormGradInferSpmd)) |
| 236 | +#endif |
| 237 | + ; |
225 | 238 |
|
226 | 239 |
|
227 | 240 | // https://github.com/NVIDIA/apex/blob/85e9eddece9d4ac72b48c2407f8162f2173e1bf4/csrc/layer_norm_cuda_kernel.cu#L679 |
0 commit comments