Skip to content

Commit 1f30143

Browse files
authored
Merge branch 'develop' into move_infershapes
2 parents 2d9d4b7 + f5ec031 commit 1f30143

File tree

18 files changed

+435
-352
lines changed

18 files changed

+435
-352
lines changed

paddle/fluid/operators/optimizers/adadelta_op.cc

Lines changed: 12 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/optimizers/adadelta_op.h"
15+
#include "paddle/fluid/framework/infershape_utils.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/multiary.h"
1619

1720
namespace paddle {
1821
namespace operators {
@@ -23,77 +26,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
2326
public:
2427
using framework::OperatorWithKernel::OperatorWithKernel;
2528

26-
void InferShape(framework::InferShapeContext *ctx) const override {
27-
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
28-
platform::errors::InvalidArgument(
29-
"Input(Param) of AdadeltaOp should not be null."));
30-
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
31-
platform::errors::InvalidArgument(
32-
"Input(Grad) of AdadeltaOp should not be null."));
33-
PADDLE_ENFORCE_EQ(
34-
ctx->HasInput("AvgSquaredGrad"), true,
35-
platform::errors::InvalidArgument(
36-
"Input(AvgSquaredGrad) of AdadeltaOp should not be null."));
37-
PADDLE_ENFORCE_EQ(
38-
ctx->HasInput("AvgSquaredUpdate"), true,
39-
platform::errors::InvalidArgument(
40-
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null."));
41-
PADDLE_ENFORCE_EQ(
42-
ctx->GetInputsVarType("Param").front() ==
43-
framework::proto::VarType::LOD_TENSOR,
44-
true,
45-
platform::errors::InvalidArgument(
46-
"The input var's type should be LoDTensor, but the received is %s",
47-
ctx->Inputs("Param").front(),
48-
ctx->GetInputsVarType("Param").front()));
49-
PADDLE_ENFORCE_EQ(
50-
ctx->GetInputsVarType("Grad").front() ==
51-
framework::proto::VarType::LOD_TENSOR,
52-
true,
53-
platform::errors::InvalidArgument(
54-
"The input var's type should be LoDTensor, but the received is %s",
55-
ctx->Inputs("Grad").front(),
56-
ctx->GetInputsVarType("Grad").front()));
57-
58-
PADDLE_ENFORCE_EQ(
59-
ctx->HasOutput("ParamOut"), true,
60-
platform::errors::InvalidArgument(
61-
"Output(ParamOut) of AdadeltaOp should not be null."));
62-
PADDLE_ENFORCE_EQ(
63-
ctx->HasOutput("AvgSquaredGradOut"), true,
64-
platform::errors::InvalidArgument(
65-
"Output(AvgSquaredGradOut) of AdadeltaOp should not be null."));
66-
PADDLE_ENFORCE_EQ(
67-
ctx->HasOutput("AvgSquaredUpdateOut"), true,
68-
platform::errors::InvalidArgument(
69-
"Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null."));
70-
71-
auto param_dim = ctx->GetInputDim("Param");
72-
PADDLE_ENFORCE_EQ(
73-
param_dim, ctx->GetInputDim("Grad"),
74-
platform::errors::InvalidArgument(
75-
"Param and grad input of AdadeltaOp should have same dimension."));
76-
PADDLE_ENFORCE_NE(
77-
phi::product(ctx->GetInputDim("AvgSquaredGrad")), 0,
78-
platform::errors::InvalidArgument(
79-
"Maybe the Input variable AvgSquaredGrad has not "
80-
"been initialized. You may need to confirm if you put "
81-
"exe.run(startup_program) after optimizer.minimize "
82-
"function."));
83-
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"),
84-
platform::errors::InvalidArgument(
85-
"Param and AvgSquaredGrad input of AdadeltaOp "
86-
"should have same dimension"));
87-
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"),
88-
platform::errors::InvalidArgument(
89-
"Param and AvgSquaredUpdate input of AdadeltaOp "
90-
"should have same dimension"));
91-
92-
ctx->SetOutputDim("ParamOut", param_dim);
93-
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
94-
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
95-
}
96-
9729
framework::OpKernelType GetExpectedKernelType(
9830
const framework::ExecutionContext &ctx) const override {
9931
return framework::OpKernelType(
@@ -149,7 +81,11 @@ param\_out = param + param\_update
14981
} // namespace paddle
15082

15183
namespace ops = paddle::operators;
152-
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
153-
REGISTER_OP_CPU_KERNEL(
154-
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, float>,
155-
ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, double>);
84+
namespace ops = paddle::operators;
85+
DELCARE_INFER_SHAPE_FUNCTOR(adadelta, AdadeltaInferMetaFunctor,
86+
PT_INFER_META(phi::AdadeltaInferMeta));
87+
REGISTER_OPERATOR(
88+
adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker,
89+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
90+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
91+
AdadeltaInferMetaFunctor);

paddle/fluid/operators/optimizers/adadelta_op.cu

Lines changed: 0 additions & 19 deletions
This file was deleted.

paddle/fluid/operators/optimizers/adadelta_op.h

Lines changed: 0 additions & 84 deletions
This file was deleted.

paddle/fluid/operators/optimizers/adamax_op.cc

Lines changed: 12 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/optimizers/adamax_op.h"
15+
#include "paddle/fluid/framework/infershape_utils.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/multiary.h"
1619

1720
namespace paddle {
1821
namespace operators {
@@ -22,67 +25,6 @@ class AdamaxOp : public framework::OperatorWithKernel {
2225
public:
2326
using framework::OperatorWithKernel::OperatorWithKernel;
2427

25-
void InferShape(framework::InferShapeContext *ctx) const override {
26-
OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "Adamax");
27-
OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "Adamax");
28-
OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment", "Adamax");
29-
OP_INOUT_CHECK(ctx->HasInput("InfNorm"), "Input", "InfNorm", "Adamax");
30-
OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
31-
"Adamax");
32-
OP_INOUT_CHECK(ctx->HasInput("Beta1Pow"), "Input", "Beta1Pow", "Adamax");
33-
PADDLE_ENFORCE_EQ(
34-
ctx->GetInputsVarType("Param").front(),
35-
framework::proto::VarType::LOD_TENSOR,
36-
platform::errors::InvalidArgument(
37-
"The input var's type should be LoDTensor, but the received is %s",
38-
ctx->Inputs("Param").front(),
39-
ctx->GetInputsVarType("Param").front()));
40-
PADDLE_ENFORCE_EQ(
41-
ctx->GetInputsVarType("Grad").front(),
42-
framework::proto::VarType::LOD_TENSOR,
43-
platform::errors::InvalidArgument(
44-
"The input var's type should be LoDTensor, but the received is %s",
45-
ctx->Inputs("Grad").front(),
46-
ctx->GetInputsVarType("Grad").front()));
47-
48-
OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut", "Adamax");
49-
OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut",
50-
"Adamax");
51-
OP_INOUT_CHECK(ctx->HasOutput("InfNormOut"), "Output", "InfNormOut",
52-
"Adamax");
53-
54-
auto lr_dims = ctx->GetInputDim("LearningRate");
55-
PADDLE_ENFORCE_NE(phi::product(lr_dims), 0,
56-
platform::errors::InvalidArgument(
57-
"Maybe the Input variable LearningRate has not "
58-
"been initialized. You may need to confirm "
59-
"if you put exe.run(startup_program) "
60-
"after optimizer.minimize function."));
61-
PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1,
62-
platform::errors::InvalidArgument(
63-
"Learning rate should have 1 dimension"));
64-
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
65-
PADDLE_ENFORCE_EQ(phi::product(beta1_pow_dims), 1,
66-
platform::errors::InvalidArgument(
67-
"Beta1 power accumulator should have 1 dimension"));
68-
auto param_dims = ctx->GetInputDim("Param");
69-
PADDLE_ENFORCE_EQ(
70-
param_dims, ctx->GetInputDim("Grad"),
71-
platform::errors::InvalidArgument(
72-
"Param and Grad input of AdamaxOp should have same dimension"));
73-
PADDLE_ENFORCE_EQ(
74-
param_dims, ctx->GetInputDim("Moment"),
75-
platform::errors::InvalidArgument(
76-
"Param and Moment input of AdamaxOp should have same dimension"));
77-
PADDLE_ENFORCE_EQ(
78-
param_dims, ctx->GetInputDim("InfNorm"),
79-
platform::errors::InvalidArgument(
80-
"Param and InfNorm input of AdamaxOp should have same dimension"));
81-
82-
ctx->SetOutputDim("ParamOut", param_dims);
83-
ctx->SetOutputDim("MomentOut", param_dims);
84-
ctx->SetOutputDim("InfNormOut", param_dims);
85-
}
8628
framework::OpKernelType GetExpectedKernelType(
8729
const framework::ExecutionContext &ctx) const override {
8830
return framework::OpKernelType(
@@ -150,7 +92,11 @@ division by 0 error.
15092
} // namespace paddle
15193

15294
namespace ops = paddle::operators;
153-
REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
154-
REGISTER_OP_CPU_KERNEL(
155-
adamax, ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, float>,
156-
ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, double>);
95+
DELCARE_INFER_SHAPE_FUNCTOR(adamax, AdamaxInferMetaFunctor,
96+
PT_INFER_META(phi::AdamaxInferMeta));
97+
98+
REGISTER_OPERATOR(
99+
adamax, ops::AdamaxOp, ops::AdamaxOpMaker,
100+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
101+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
102+
AdamaxInferMetaFunctor);

paddle/fluid/operators/optimizers/adamax_op.cu

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)