Skip to content

Commit f039d09

Browse files
authored
fused rms spmd (#7830)
1 parent 4069f22 commit f039d09

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
#include "layer_norm_cuda.h" // NOLINT
2525
#include "paddle/extension.h"
2626

27+
#ifdef CUSTOM_OP_WITH_SPMD
28+
#include "paddle/phi/api/ext/spmd_infer.h"
29+
#include "paddle/phi/infermeta/spmd_rules/rules.h"
30+
#endif
31+
2732
#define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor")
2833

2934
static void GetRowsCols(const std::vector<int64_t> &shape,
@@ -214,14 +219,22 @@ PD_BUILD_OP(fused_rms_norm)
214219
.Attrs({"epsilon: float"})
215220
.SetKernelFn(PD_KERNEL(RMSLnFwd))
216221
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnFwdInferShape))
217-
.SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype));
222+
.SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype))
223+
#ifdef CUSTOM_OP_WITH_SPMD
224+
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormInferSpmd))
225+
#endif
226+
;
218227

219228
PD_BUILD_GRAD_OP(fused_rms_norm)
220229
.Inputs({"x", "scale", "invvar", paddle::Grad("y")})
221230
.Outputs({paddle::Grad("x"), paddle::Grad("scale")})
222231
.Attrs({"epsilon: float"})
223232
.SetKernelFn(PD_KERNEL(RMSLnBwd))
224-
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape));
233+
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape))
234+
#ifdef CUSTOM_OP_WITH_SPMD
235+
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormGradInferSpmd))
236+
#endif
237+
;
225238

226239

227240
// https://github.com/NVIDIA/apex/blob/85e9eddece9d4ac72b48c2407f8162f2173e1bf4/csrc/layer_norm_cuda_kernel.cu#L679

0 commit comments

Comments
 (0)