@@ -13,8 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313See the License for the specific language governing permissions and
1414limitations under the License. */
1515
16+ #include " paddle/fluid/framework/infershape_utils.h"
1617#include " paddle/fluid/framework/op_registry.h"
1718#include " paddle/fluid/framework/op_version_registry.h"
19+ #include " paddle/phi/core/infermeta_utils.h"
20+ #include " paddle/phi/infermeta/fusion.h"
1821#include " paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
1922
2023namespace paddle {
@@ -25,107 +28,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
2528 using framework::OperatorWithKernel::OperatorWithKernel;
2629
2730 protected:
28- void InferShape (framework::InferShapeContext* ctx) const override {
29- OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " FusedGemmEpilogueOp" );
30- OP_INOUT_CHECK (ctx->HasInput (" Y" ), " Input" , " Y" , " FusedGemmEpilogueOp" );
31- OP_INOUT_CHECK (
32- ctx->HasInput (" Bias" ), " Output" , " Bias" , " FusedGemmEpilogueOp" );
33- OP_INOUT_CHECK (
34- ctx->HasOutput (" Out" ), " Output" , " Out" , " FusedGemmEpilogueOp" );
35-
36- auto x_dims = ctx->GetInputDim (" X" );
37- auto y_dims = ctx->GetInputDim (" Y" );
38- auto bias_dims = ctx->GetInputDim (" Bias" );
39- auto trans_x = ctx->Attrs ().Get <bool >(" trans_x" );
40- auto trans_y = ctx->Attrs ().Get <bool >(" trans_y" );
41-
42- PADDLE_ENFORCE_EQ (
43- y_dims.size (),
44- 2 ,
45- platform::errors::InvalidArgument (
46- " The Input tensor Y's dimension of FusedGemmEpilogueOp "
47- " should be 2, but got %d." ,
48- y_dims.size ()));
49-
50- PADDLE_ENFORCE_GE (
51- x_dims.size (),
52- 2 ,
53- platform::errors::InvalidArgument (
54- " The Input tensor X's dimension of FusedGemmEpilogueOp "
55- " should be >= 2, but got %d." ,
56- x_dims.size ()));
57-
58- PADDLE_ENFORCE_EQ (
59- bias_dims.size (),
60- 1 ,
61- platform::errors::InvalidArgument (
62- " The Input tensor bias's dimension of FusedGemmEpilogueOp "
63- " should be == 1, but got %d." ,
64- bias_dims.size ()));
65-
66- PADDLE_ENFORCE_EQ (bias_dims[0 ],
67- trans_y ? y_dims[0 ] : y_dims[1 ],
68- platform::errors::InvalidArgument (
69- " The Input tensor bias's dimension 0"
70- " should be == Y[-1], but got bias's shape = [%s] "
71- " and Y's shape = [%s]" ,
72- bias_dims,
73- y_dims));
74-
75- auto x_mat_dims =
76- common::flatten_to_2d (x_dims, trans_x ? 1 : x_dims.size () - 1 );
77-
78- int K_from_x = trans_x ? x_mat_dims[0 ] : x_mat_dims[1 ];
79- int K_from_y = trans_y ? y_dims[1 ] : y_dims[0 ];
80-
81- PADDLE_ENFORCE_EQ (
82- K_from_x,
83- K_from_y,
84- platform::errors::InvalidArgument (
85- " The last dimension of X should be equal with Y's first dimension."
86- " But received X[-1] = [%d], Y[0] = [%d]." ,
87- K_from_x,
88- K_from_y));
89-
90- std::vector<int64_t > out_dims;
91- out_dims.reserve (static_cast <size_t >(x_dims.size ()));
92- if (trans_x) {
93- for (int i = 1 ; i < x_dims.size (); ++i) out_dims.push_back (x_dims[i]);
94- } else {
95- for (int i = 0 ; i < x_dims.size () - 1 ; ++i) out_dims.push_back (x_dims[i]);
96- }
97-
98- if (trans_y) {
99- out_dims.push_back (y_dims[0 ]);
100- } else {
101- out_dims.push_back (y_dims[1 ]);
102- }
103- ctx->SetOutputDim (" Out" , common::make_ddim (out_dims));
104-
105- auto activation = ctx->Attrs ().Get <std::string>(" activation" );
106- if (ctx->HasOutput (" ReserveSpace" )) {
107- ctx->SetOutputDim (" ReserveSpace" , common::make_ddim (out_dims));
108-
109- if (activation == " none" ) {
110- PADDLE_THROW (platform::errors::InvalidArgument (
111- " The ReserveSpace would not be used when activation = \" none\" " ));
112- } else {
113- int min_size_of_n = activation == " relu" ? 128 : 8 ;
114- int N_size = trans_y ? y_dims[0 ] : y_dims[1 ];
115- PADDLE_ENFORCE_EQ (
116- N_size % min_size_of_n,
117- 0 ,
118- platform::errors::InvalidArgument (
119- " The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
120- " should be multiple of %d when auxiliary_key given "
121- " and activation=%s, but got N = %d." ,
122- min_size_of_n,
123- activation,
124- N_size));
125- }
126- }
127- }
128-
12931 phi::KernelKey GetExpectedKernelType (
13032 const framework::ExecutionContext& ctx) const override {
13133 auto data_type = OperatorWithKernel::IndicateVarDataType (ctx, " X" );
@@ -188,94 +90,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
18890 using framework::OperatorWithKernel::OperatorWithKernel;
18991
19092 protected:
191- void InferShape (framework::InferShapeContext* ctx) const override {
192- OP_INOUT_CHECK (
193- ctx->HasInput (" DOut" ), " Input" , " DOut" , " FusedGemmEpilogueGradOp" );
194- OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " FusedGemmEpilogueGradOp" );
195- OP_INOUT_CHECK (ctx->HasInput (" Y" ), " Input" , " Y" , " FusedGemmEpilogueGradOp" );
196- OP_INOUT_CHECK (ctx->HasOutput (" DY" ), " Output" , " DY" , " FusedGemmEpilogueOp" );
197-
198- auto dout_dims = ctx->GetInputDim (" DOut" );
199- auto x_dims = ctx->GetInputDim (" X" );
200- auto y_dims = ctx->GetInputDim (" Y" );
201- auto trans_x = ctx->Attrs ().Get <bool >(" trans_x" );
202- auto trans_y = ctx->Attrs ().Get <bool >(" trans_y" );
203-
204- PADDLE_ENFORCE_GE (
205- dout_dims.size (),
206- 2 ,
207- platform::errors::InvalidArgument (
208- " The Input tensor DOut's dimension of FusedGemmEpilogueGradOp "
209- " should be >= 2, but got %d." ,
210- dout_dims.size ()));
211-
212- PADDLE_ENFORCE_EQ (
213- y_dims.size (),
214- 2 ,
215- platform::errors::InvalidArgument (
216- " The Input tensor Y's dimension of FusedGemmEpilogueGradOp "
217- " should be 2, but got %d." ,
218- y_dims.size ()));
219-
220- PADDLE_ENFORCE_GE (
221- x_dims.size (),
222- 2 ,
223- platform::errors::InvalidArgument (
224- " The Input tensor X's dimension of FusedGemmEpilogueGradOp "
225- " should be >= 2, but got %d." ,
226- x_dims.size ()));
227-
228- PADDLE_ENFORCE_EQ (
229- dout_dims.size (),
230- x_dims.size (),
231- platform::errors::InvalidArgument (
232- " The Input tensor DOut's and X's dimension of "
233- " FusedGemmEpilogueGradOp "
234- " should be the same, but got DOut's dim = %d and X's = %d." ,
235- dout_dims.size (),
236- x_dims.size ()));
237-
238- auto dout_mat_dims = common::flatten_to_2d (dout_dims, dout_dims.size () - 1 );
239- auto x_mat_dims = common::flatten_to_2d (x_dims, x_dims.size () - 1 );
240-
241- PADDLE_ENFORCE_EQ (
242- dout_mat_dims[1 ],
243- trans_y ? y_dims[0 ] : y_dims[1 ],
244- platform::errors::InvalidArgument (
245- " The last dimension of DOut should be equal with Y's last"
246- " dimension. But received DOut[-1] = [%d], Y[1] = [%d]." ,
247- dout_mat_dims[1 ],
248- y_dims[1 ]));
249-
250- PADDLE_ENFORCE_EQ (
251- dout_mat_dims[0 ],
252- trans_x ? x_mat_dims[1 ] : x_mat_dims[0 ],
253- platform::errors::InvalidArgument (
254- " The first dimension of DOut should be equal with X's first"
255- " dimension. But received DOut[0] = [%d], Y[0] = [%d]." ,
256- dout_mat_dims[0 ],
257- x_mat_dims[0 ]));
258-
259- auto activation_grad = ctx->Attrs ().Get <std::string>(" activation_grad" );
260- if (activation_grad != " none" && !ctx->HasInput (" ReserveSpace" )) {
261- PADDLE_ENFORCE_EQ (true ,
262- false ,
263- platform::errors::InvalidArgument (
264- " The ReserveSpace should not be empty. "
265- " when activation == {relu_grad, gelu_grad}." ));
266- }
267-
268- if (ctx->HasOutput (" DX" )) {
269- ctx->SetOutputDim (" DX" , x_dims);
270- }
271- ctx->SetOutputDim (" DY" , y_dims);
272-
273- if (ctx->HasOutput (" DBias" )) {
274- int64_t dbias_dim = trans_y ? y_dims[0 ] : y_dims[1 ];
275- ctx->SetOutputDim (" DBias" , common::make_ddim ({dbias_dim}));
276- }
277- }
278-
27993 phi::KernelKey GetExpectedKernelType (
28094 const framework::ExecutionContext& ctx) const override {
28195 auto data_type = OperatorWithKernel::IndicateVarDataType (ctx, " DOut" );
@@ -367,12 +181,19 @@ class FusedGemmEpilogueOpGradMaker : public framework::SingleGradOpMaker<T> {
367181} // namespace paddle
368182
369183namespace ops = paddle::operators;
370- REGISTER_OPERATOR (
371- fused_gemm_epilogue,
372- ops::FusedGemmEpilogueOp,
373- ops::FusedGemmEpilogueOpMaker,
374- ops::FusedGemmEpilogueOpGradMaker<paddle::framework::OpDesc>,
375- ops::FusedGemmEpilogueOpGradMaker<paddle::imperative::OpBase>);
184+ DECLARE_INFER_SHAPE_FUNCTOR (fused_gemm_epilogue,
185+ FusedGemmEpilogueInferShapeFunctor,
186+ PD_INFER_META (phi::FusedGemmEpilogueInferMeta));
187+ DECLARE_INFER_SHAPE_FUNCTOR (fused_gemm_epilogue_grad,
188+ FusedGemmEpilogueGradInferShapeFunctor,
189+ PD_INFER_META (phi::FusedGemmEpilogueGradInferMeta));
190+ REGISTER_OPERATOR (fused_gemm_epilogue,
191+ ops::FusedGemmEpilogueOp,
192+ ops::FusedGemmEpilogueOpMaker,
193+ ops::FusedGemmEpilogueOpGradMaker<paddle::framework::OpDesc>,
194+ ops::FusedGemmEpilogueOpGradMaker<paddle::imperative::OpBase>,
195+ FusedGemmEpilogueInferShapeFunctor);
376196REGISTER_OPERATOR (fused_gemm_epilogue_grad,
377197 ops::FusedGemmEpilogueGradOp,
378- ops::FusedGemmEpilogueGradOpMaker);
198+ ops::FusedGemmEpilogueGradOpMaker,
199+ FusedGemmEpilogueGradInferShapeFunctor);
0 commit comments