@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15- #include " paddle/fluid/operators/optimizers/adamax_op.h"
15+ #include " paddle/fluid/framework/infershape_utils.h"
16+ #include " paddle/fluid/framework/op_registry.h"
17+ #include " paddle/phi/core/infermeta_utils.h"
18+ #include " paddle/phi/infermeta/multiary.h"
1619
1720namespace paddle {
1821namespace operators {
@@ -22,67 +25,6 @@ class AdamaxOp : public framework::OperatorWithKernel {
2225 public:
2326 using framework::OperatorWithKernel::OperatorWithKernel;
2427
25- void InferShape (framework::InferShapeContext *ctx) const override {
26- OP_INOUT_CHECK (ctx->HasInput (" Param" ), " Input" , " Param" , " Adamax" );
27- OP_INOUT_CHECK (ctx->HasInput (" Grad" ), " Input" , " Grad" , " Adamax" );
28- OP_INOUT_CHECK (ctx->HasInput (" Moment" ), " Input" , " Moment" , " Adamax" );
29- OP_INOUT_CHECK (ctx->HasInput (" InfNorm" ), " Input" , " InfNorm" , " Adamax" );
30- OP_INOUT_CHECK (ctx->HasInput (" LearningRate" ), " Input" , " LearningRate" ,
31- " Adamax" );
32- OP_INOUT_CHECK (ctx->HasInput (" Beta1Pow" ), " Input" , " Beta1Pow" , " Adamax" );
33- PADDLE_ENFORCE_EQ (
34- ctx->GetInputsVarType (" Param" ).front (),
35- framework::proto::VarType::LOD_TENSOR,
36- platform::errors::InvalidArgument (
37- " The input var's type should be LoDTensor, but the received is %s" ,
38- ctx->Inputs (" Param" ).front (),
39- ctx->GetInputsVarType (" Param" ).front ()));
40- PADDLE_ENFORCE_EQ (
41- ctx->GetInputsVarType (" Grad" ).front (),
42- framework::proto::VarType::LOD_TENSOR,
43- platform::errors::InvalidArgument (
44- " The input var's type should be LoDTensor, but the received is %s" ,
45- ctx->Inputs (" Grad" ).front (),
46- ctx->GetInputsVarType (" Grad" ).front ()));
47-
48- OP_INOUT_CHECK (ctx->HasOutput (" ParamOut" ), " Output" , " ParamOut" , " Adamax" );
49- OP_INOUT_CHECK (ctx->HasOutput (" MomentOut" ), " Output" , " MomentOut" ,
50- " Adamax" );
51- OP_INOUT_CHECK (ctx->HasOutput (" InfNormOut" ), " Output" , " InfNormOut" ,
52- " Adamax" );
53-
54- auto lr_dims = ctx->GetInputDim (" LearningRate" );
55- PADDLE_ENFORCE_NE (phi::product (lr_dims), 0 ,
56- platform::errors::InvalidArgument (
57- " Maybe the Input variable LearningRate has not "
58- " been initialized. You may need to confirm "
59- " if you put exe.run(startup_program) "
60- " after optimizer.minimize function." ));
61- PADDLE_ENFORCE_EQ (phi::product (lr_dims), 1 ,
62- platform::errors::InvalidArgument (
63- " Learning rate should have 1 dimension" ));
64- auto beta1_pow_dims = ctx->GetInputDim (" Beta1Pow" );
65- PADDLE_ENFORCE_EQ (phi::product (beta1_pow_dims), 1 ,
66- platform::errors::InvalidArgument (
67- " Beta1 power accumulator should have 1 dimension" ));
68- auto param_dims = ctx->GetInputDim (" Param" );
69- PADDLE_ENFORCE_EQ (
70- param_dims, ctx->GetInputDim (" Grad" ),
71- platform::errors::InvalidArgument (
72- " Param and Grad input of AdamaxOp should have same dimension" ));
73- PADDLE_ENFORCE_EQ (
74- param_dims, ctx->GetInputDim (" Moment" ),
75- platform::errors::InvalidArgument (
76- " Param and Moment input of AdamaxOp should have same dimension" ));
77- PADDLE_ENFORCE_EQ (
78- param_dims, ctx->GetInputDim (" InfNorm" ),
79- platform::errors::InvalidArgument (
80- " Param and InfNorm input of AdamaxOp should have same dimension" ));
81-
82- ctx->SetOutputDim (" ParamOut" , param_dims);
83- ctx->SetOutputDim (" MomentOut" , param_dims);
84- ctx->SetOutputDim (" InfNormOut" , param_dims);
85- }
8628 framework::OpKernelType GetExpectedKernelType (
8729 const framework::ExecutionContext &ctx) const override {
8830 return framework::OpKernelType (
@@ -150,7 +92,11 @@ division by 0 error.
15092} // namespace paddle
15193
15294namespace ops = paddle::operators;
153- REGISTER_OP_WITHOUT_GRADIENT (adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
154- REGISTER_OP_CPU_KERNEL (
155- adamax, ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, float >,
156- ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, double >);
95+ DELCARE_INFER_SHAPE_FUNCTOR (adamax, AdamaxInferMetaFunctor,
96+ PT_INFER_META (phi::AdamaxInferMeta));
97+
98+ REGISTER_OPERATOR (
99+ adamax, ops::AdamaxOp, ops::AdamaxOpMaker,
100+ paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
101+ paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
102+ AdamaxInferMetaFunctor);
0 commit comments