Skip to content

Commit 3976459

Browse files
authored
【Hackathon 5th No.112】move fused_gemm_epilogue to phi and add the yaml of identity_loss (#59363)
1 parent d155697 commit 3976459

File tree

16 files changed

+555
-776
lines changed

16 files changed

+555
-776
lines changed

paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc

Lines changed: 17 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
See the License for the specific language governing permissions and
1414
limitations under the License. */
1515

16+
#include "paddle/fluid/framework/infershape_utils.h"
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#include "paddle/fluid/framework/op_version_registry.h"
19+
#include "paddle/phi/core/infermeta_utils.h"
20+
#include "paddle/phi/infermeta/fusion.h"
1821
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
1922

2023
namespace paddle {
@@ -25,107 +28,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
2528
using framework::OperatorWithKernel::OperatorWithKernel;
2629

2730
protected:
28-
void InferShape(framework::InferShapeContext* ctx) const override {
29-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedGemmEpilogueOp");
30-
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusedGemmEpilogueOp");
31-
OP_INOUT_CHECK(
32-
ctx->HasInput("Bias"), "Output", "Bias", "FusedGemmEpilogueOp");
33-
OP_INOUT_CHECK(
34-
ctx->HasOutput("Out"), "Output", "Out", "FusedGemmEpilogueOp");
35-
36-
auto x_dims = ctx->GetInputDim("X");
37-
auto y_dims = ctx->GetInputDim("Y");
38-
auto bias_dims = ctx->GetInputDim("Bias");
39-
auto trans_x = ctx->Attrs().Get<bool>("trans_x");
40-
auto trans_y = ctx->Attrs().Get<bool>("trans_y");
41-
42-
PADDLE_ENFORCE_EQ(
43-
y_dims.size(),
44-
2,
45-
platform::errors::InvalidArgument(
46-
"The Input tensor Y's dimension of FusedGemmEpilogueOp "
47-
" should be 2, but got %d.",
48-
y_dims.size()));
49-
50-
PADDLE_ENFORCE_GE(
51-
x_dims.size(),
52-
2,
53-
platform::errors::InvalidArgument(
54-
"The Input tensor X's dimension of FusedGemmEpilogueOp "
55-
" should be >= 2, but got %d.",
56-
x_dims.size()));
57-
58-
PADDLE_ENFORCE_EQ(
59-
bias_dims.size(),
60-
1,
61-
platform::errors::InvalidArgument(
62-
"The Input tensor bias's dimension of FusedGemmEpilogueOp "
63-
" should be == 1, but got %d.",
64-
bias_dims.size()));
65-
66-
PADDLE_ENFORCE_EQ(bias_dims[0],
67-
trans_y ? y_dims[0] : y_dims[1],
68-
platform::errors::InvalidArgument(
69-
"The Input tensor bias's dimension 0"
70-
" should be == Y[-1], but got bias's shape = [%s] "
71-
"and Y's shape = [%s]",
72-
bias_dims,
73-
y_dims));
74-
75-
auto x_mat_dims =
76-
common::flatten_to_2d(x_dims, trans_x ? 1 : x_dims.size() - 1);
77-
78-
int K_from_x = trans_x ? x_mat_dims[0] : x_mat_dims[1];
79-
int K_from_y = trans_y ? y_dims[1] : y_dims[0];
80-
81-
PADDLE_ENFORCE_EQ(
82-
K_from_x,
83-
K_from_y,
84-
platform::errors::InvalidArgument(
85-
"The last dimension of X should be equal with Y's first dimension."
86-
"But received X[-1] = [%d], Y[0] = [%d].",
87-
K_from_x,
88-
K_from_y));
89-
90-
std::vector<int64_t> out_dims;
91-
out_dims.reserve(static_cast<size_t>(x_dims.size()));
92-
if (trans_x) {
93-
for (int i = 1; i < x_dims.size(); ++i) out_dims.push_back(x_dims[i]);
94-
} else {
95-
for (int i = 0; i < x_dims.size() - 1; ++i) out_dims.push_back(x_dims[i]);
96-
}
97-
98-
if (trans_y) {
99-
out_dims.push_back(y_dims[0]);
100-
} else {
101-
out_dims.push_back(y_dims[1]);
102-
}
103-
ctx->SetOutputDim("Out", common::make_ddim(out_dims));
104-
105-
auto activation = ctx->Attrs().Get<std::string>("activation");
106-
if (ctx->HasOutput("ReserveSpace")) {
107-
ctx->SetOutputDim("ReserveSpace", common::make_ddim(out_dims));
108-
109-
if (activation == "none") {
110-
PADDLE_THROW(platform::errors::InvalidArgument(
111-
"The ReserveSpace would not be used when activation = \"none\""));
112-
} else {
113-
int min_size_of_n = activation == "relu" ? 128 : 8;
114-
int N_size = trans_y ? y_dims[0] : y_dims[1];
115-
PADDLE_ENFORCE_EQ(
116-
N_size % min_size_of_n,
117-
0,
118-
platform::errors::InvalidArgument(
119-
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
120-
"should be multiple of %d when auxiliary_key given "
121-
"and activation=%s, but got N = %d.",
122-
min_size_of_n,
123-
activation,
124-
N_size));
125-
}
126-
}
127-
}
128-
12931
phi::KernelKey GetExpectedKernelType(
13032
const framework::ExecutionContext& ctx) const override {
13133
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
@@ -188,94 +90,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
18890
using framework::OperatorWithKernel::OperatorWithKernel;
18991

19092
protected:
191-
void InferShape(framework::InferShapeContext* ctx) const override {
192-
OP_INOUT_CHECK(
193-
ctx->HasInput("DOut"), "Input", "DOut", "FusedGemmEpilogueGradOp");
194-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedGemmEpilogueGradOp");
195-
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusedGemmEpilogueGradOp");
196-
OP_INOUT_CHECK(ctx->HasOutput("DY"), "Output", "DY", "FusedGemmEpilogueOp");
197-
198-
auto dout_dims = ctx->GetInputDim("DOut");
199-
auto x_dims = ctx->GetInputDim("X");
200-
auto y_dims = ctx->GetInputDim("Y");
201-
auto trans_x = ctx->Attrs().Get<bool>("trans_x");
202-
auto trans_y = ctx->Attrs().Get<bool>("trans_y");
203-
204-
PADDLE_ENFORCE_GE(
205-
dout_dims.size(),
206-
2,
207-
platform::errors::InvalidArgument(
208-
"The Input tensor DOut's dimension of FusedGemmEpilogueGradOp "
209-
" should be >= 2, but got %d.",
210-
dout_dims.size()));
211-
212-
PADDLE_ENFORCE_EQ(
213-
y_dims.size(),
214-
2,
215-
platform::errors::InvalidArgument(
216-
"The Input tensor Y's dimension of FusedGemmEpilogueGradOp "
217-
" should be 2, but got %d.",
218-
y_dims.size()));
219-
220-
PADDLE_ENFORCE_GE(
221-
x_dims.size(),
222-
2,
223-
platform::errors::InvalidArgument(
224-
"The Input tensor X's dimension of FusedGemmEpilogueGradOp "
225-
" should be >= 2, but got %d.",
226-
x_dims.size()));
227-
228-
PADDLE_ENFORCE_EQ(
229-
dout_dims.size(),
230-
x_dims.size(),
231-
platform::errors::InvalidArgument(
232-
"The Input tensor DOut's and X's dimension of "
233-
"FusedGemmEpilogueGradOp "
234-
" should be the same, but got DOut's dim = %d and X's = %d.",
235-
dout_dims.size(),
236-
x_dims.size()));
237-
238-
auto dout_mat_dims = common::flatten_to_2d(dout_dims, dout_dims.size() - 1);
239-
auto x_mat_dims = common::flatten_to_2d(x_dims, x_dims.size() - 1);
240-
241-
PADDLE_ENFORCE_EQ(
242-
dout_mat_dims[1],
243-
trans_y ? y_dims[0] : y_dims[1],
244-
platform::errors::InvalidArgument(
245-
"The last dimension of DOut should be equal with Y's last"
246-
"dimension. But received DOut[-1] = [%d], Y[1] = [%d].",
247-
dout_mat_dims[1],
248-
y_dims[1]));
249-
250-
PADDLE_ENFORCE_EQ(
251-
dout_mat_dims[0],
252-
trans_x ? x_mat_dims[1] : x_mat_dims[0],
253-
platform::errors::InvalidArgument(
254-
"The first dimension of DOut should be equal with X's first"
255-
"dimension. But received DOut[0] = [%d], Y[0] = [%d].",
256-
dout_mat_dims[0],
257-
x_mat_dims[0]));
258-
259-
auto activation_grad = ctx->Attrs().Get<std::string>("activation_grad");
260-
if (activation_grad != "none" && !ctx->HasInput("ReserveSpace")) {
261-
PADDLE_ENFORCE_EQ(true,
262-
false,
263-
platform::errors::InvalidArgument(
264-
"The ReserveSpace should not be empty. "
265-
"when activation == {relu_grad, gelu_grad}."));
266-
}
267-
268-
if (ctx->HasOutput("DX")) {
269-
ctx->SetOutputDim("DX", x_dims);
270-
}
271-
ctx->SetOutputDim("DY", y_dims);
272-
273-
if (ctx->HasOutput("DBias")) {
274-
int64_t dbias_dim = trans_y ? y_dims[0] : y_dims[1];
275-
ctx->SetOutputDim("DBias", common::make_ddim({dbias_dim}));
276-
}
277-
}
278-
27993
phi::KernelKey GetExpectedKernelType(
28094
const framework::ExecutionContext& ctx) const override {
28195
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
@@ -367,12 +181,19 @@ class FusedGemmEpilogueOpGradMaker : public framework::SingleGradOpMaker<T> {
367181
} // namespace paddle
368182

369183
namespace ops = paddle::operators;
370-
REGISTER_OPERATOR(
371-
fused_gemm_epilogue,
372-
ops::FusedGemmEpilogueOp,
373-
ops::FusedGemmEpilogueOpMaker,
374-
ops::FusedGemmEpilogueOpGradMaker<paddle::framework::OpDesc>,
375-
ops::FusedGemmEpilogueOpGradMaker<paddle::imperative::OpBase>);
184+
DECLARE_INFER_SHAPE_FUNCTOR(fused_gemm_epilogue,
185+
FusedGemmEpilogueInferShapeFunctor,
186+
PD_INFER_META(phi::FusedGemmEpilogueInferMeta));
187+
DECLARE_INFER_SHAPE_FUNCTOR(fused_gemm_epilogue_grad,
188+
FusedGemmEpilogueGradInferShapeFunctor,
189+
PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta));
190+
REGISTER_OPERATOR(fused_gemm_epilogue,
191+
ops::FusedGemmEpilogueOp,
192+
ops::FusedGemmEpilogueOpMaker,
193+
ops::FusedGemmEpilogueOpGradMaker<paddle::framework::OpDesc>,
194+
ops::FusedGemmEpilogueOpGradMaker<paddle::imperative::OpBase>,
195+
FusedGemmEpilogueInferShapeFunctor);
376196
REGISTER_OPERATOR(fused_gemm_epilogue_grad,
377197
ops::FusedGemmEpilogueGradOp,
378-
ops::FusedGemmEpilogueGradOpMaker);
198+
ops::FusedGemmEpilogueGradOpMaker,
199+
FusedGemmEpilogueGradInferShapeFunctor);

0 commit comments

Comments
 (0)