Skip to content

Commit a0cb320

Browse files
authored
[Phi]Move size, erfinv, pixel_shuffle infershape to phi (#39949)
* move size, erfinv, pixel_shuffle infershape to phi * fix erfinv infermeta
1 parent 2c66775 commit a0cb320

File tree

5 files changed

+80
-68
lines changed

5 files changed

+80
-68
lines changed

paddle/fluid/operators/erfinv_op.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,17 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/fluid/framework/infershape_utils.h"
1516
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/unary.h"
1619

1720
namespace paddle {
1821
namespace operators {
1922

2023
class ErfinvOp : public framework::OperatorWithKernel {
2124
public:
2225
using framework::OperatorWithKernel::OperatorWithKernel;
23-
24-
void InferShape(framework::InferShapeContext* ctx) const override {
25-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "erfinv");
26-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "erfinv");
27-
28-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
29-
ctx->ShareLoD("X", /*->*/ "Out");
30-
}
3126
};
3227

3328
class ErfinvOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -78,10 +73,13 @@ DECLARE_INPLACE_OP_INFERER(ErfinvInplaceInferer, {"X", "Out"});
7873
} // namespace operators
7974
} // namespace paddle
8075

76+
DELCARE_INFER_SHAPE_FUNCTOR(erfinv, ErfinvInferShapeFunctor,
77+
PT_INFER_META(phi::UnchangedInferMeta));
78+
8179
REGISTER_OPERATOR(
8280
erfinv, paddle::operators::ErfinvOp, paddle::operators::ErfinvOpMaker,
8381
paddle::operators::ErfinvGradMaker<paddle::framework::OpDesc>,
8482
paddle::operators::ErfinvGradMaker<paddle::imperative::OpBase>,
85-
paddle::operators::ErfinvInplaceInferer);
83+
paddle::operators::ErfinvInplaceInferer, ErfinvInferShapeFunctor);
8684

8785
REGISTER_OPERATOR(erfinv_grad, paddle::operators::ErfinvGradOp);

paddle/fluid/operators/pixel_shuffle_op.cc

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,65 +10,18 @@ See the License for the specific language governing permissions and
1010
limitations under the License. */
1111

1212
#include <memory>
13+
#include "paddle/fluid/framework/infershape_utils.h"
1314
#include "paddle/fluid/framework/op_registry.h"
1415
#include "paddle/fluid/framework/op_version_registry.h"
16+
#include "paddle/phi/core/infermeta_utils.h"
17+
#include "paddle/phi/infermeta/unary.h"
1518

1619
namespace paddle {
1720
namespace operators {
1821

1922
class PixelShuffleOp : public framework::OperatorWithKernel {
2023
public:
2124
using framework::OperatorWithKernel::OperatorWithKernel;
22-
23-
void InferShape(framework::InferShapeContext* ctx) const override {
24-
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
25-
platform::errors::NotFound(
26-
"Input(X) of PixelShuffleOp should not be null."));
27-
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
28-
platform::errors::NotFound(
29-
"Output(Out) of PixelShuffleOp should not be null."));
30-
31-
auto input_dims = ctx->GetInputDim("X");
32-
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
33-
platform::errors::InvalidArgument(
34-
"Input should be a 4-D tensor of format [N, C, H, W] "
35-
"or [N, H, W, C], but got %u.",
36-
input_dims.size()));
37-
38-
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
39-
40-
const std::string data_format =
41-
ctx->Attrs().Get<std::string>("data_format");
42-
const bool channel_last = (data_format == "NHWC");
43-
44-
if (!channel_last) {
45-
PADDLE_ENFORCE_EQ(
46-
input_dims[1] % (upscale_factor * upscale_factor), 0,
47-
platform::errors::InvalidArgument(
48-
"The square of upscale_factor[%u] should divide the "
49-
"number of channel[%u]",
50-
upscale_factor * upscale_factor, input_dims[1]));
51-
} else {
52-
PADDLE_ENFORCE_EQ(
53-
input_dims[3] % (upscale_factor * upscale_factor), 0,
54-
platform::errors::InvalidArgument(
55-
"The square of upscale_factor[%u] should divide the "
56-
"number of channel[%u]",
57-
upscale_factor * upscale_factor, input_dims[3]));
58-
}
59-
auto output_dims = input_dims;
60-
output_dims[0] = input_dims[0];
61-
if (!channel_last) {
62-
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
63-
output_dims[2] = input_dims[2] * upscale_factor;
64-
output_dims[3] = input_dims[3] * upscale_factor;
65-
} else {
66-
output_dims[1] = input_dims[1] * upscale_factor;
67-
output_dims[2] = input_dims[2] * upscale_factor;
68-
output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor);
69-
}
70-
ctx->SetOutputDim("Out", output_dims);
71-
}
7225
};
7326

7427
class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -171,9 +124,13 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
171124
} // namespace paddle
172125

173126
namespace ops = paddle::operators;
127+
DELCARE_INFER_SHAPE_FUNCTOR(pixel_shuffle, PixelShuffleInferShapeFunctor,
128+
PT_INFER_META(phi::PixelShuffleInferMeta));
129+
174130
REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
175131
ops::PixelShuffleGradMaker<paddle::framework::OpDesc>,
176-
ops::PixelShuffleGradMaker<paddle::imperative::OpBase>);
132+
ops::PixelShuffleGradMaker<paddle::imperative::OpBase>,
133+
PixelShuffleInferShapeFunctor);
177134

178135
REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp);
179136

paddle/fluid/operators/size_op.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/infershape_utils.h"
1516
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/unary.h"
1619

1720
namespace paddle {
1821
namespace operators {
1922

2023
class SizeOp : public framework::OperatorWithKernel {
2124
public:
2225
using framework::OperatorWithKernel::OperatorWithKernel;
23-
24-
void InferShape(framework::InferShapeContext *ctx) const override {
25-
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Size");
26-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Size");
27-
28-
ctx->SetOutputDim("Out", {1});
29-
}
3026
};
3127

3228
class SizeOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -48,7 +44,10 @@ Return the number of elements in the input.
4844
} // namespace paddle
4945

5046
namespace ops = paddle::operators;
47+
DELCARE_INFER_SHAPE_FUNCTOR(size, SizeInferShapeFunctor,
48+
PT_INFER_META(phi::SizeInferMeta));
5149
REGISTER_OPERATOR(
5250
size, ops::SizeOp, ops::SizeOpMaker,
5351
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
54-
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
52+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
53+
SizeInferShapeFunctor);

paddle/phi/infermeta/unary.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,57 @@ void DiagInferMeta(const MetaTensor& x,
856856
}
857857
}
858858

859+
void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
860+
out->set_dtype(DataType::INT64);
861+
out->set_dims({1});
862+
}
863+
864+
void PixelShuffleInferMeta(const MetaTensor& x,
865+
int upscale_factor,
866+
const std::string& data_format,
867+
MetaTensor* out) {
868+
auto input_dims = x.dims();
869+
PADDLE_ENFORCE_EQ(input_dims.size(),
870+
4,
871+
phi::errors::InvalidArgument(
872+
"Input should be a 4-D tensor of format [N, C, H, W] "
873+
"or [N, H, W, C], but got %u.",
874+
input_dims.size()));
875+
876+
const bool channel_last = (data_format == "NHWC");
877+
878+
if (!channel_last) {
879+
PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor),
880+
0,
881+
phi::errors::InvalidArgument(
882+
"The square of upscale_factor[%u] should divide the "
883+
"number of channel[%u]",
884+
upscale_factor * upscale_factor,
885+
input_dims[1]));
886+
} else {
887+
PADDLE_ENFORCE_EQ(input_dims[3] % (upscale_factor * upscale_factor),
888+
0,
889+
phi::errors::InvalidArgument(
890+
"The square of upscale_factor[%u] should divide the "
891+
"number of channel[%u]",
892+
upscale_factor * upscale_factor,
893+
input_dims[3]));
894+
}
895+
auto output_dims = input_dims;
896+
output_dims[0] = input_dims[0];
897+
if (!channel_last) {
898+
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
899+
output_dims[2] = input_dims[2] * upscale_factor;
900+
output_dims[3] = input_dims[3] * upscale_factor;
901+
} else {
902+
output_dims[1] = input_dims[1] * upscale_factor;
903+
output_dims[2] = input_dims[2] * upscale_factor;
904+
output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor);
905+
}
906+
out->set_dtype(x.dtype());
907+
out->set_dims(output_dims);
908+
}
909+
859910
} // namespace phi
860911

861912
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);

paddle/phi/infermeta/unary.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,11 @@ void DiagInferMeta(const MetaTensor& x,
129129
float padding_value,
130130
MetaTensor* out);
131131

132+
void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
133+
134+
void PixelShuffleInferMeta(const MetaTensor& x,
135+
int upscale_factor,
136+
const std::string& data_format,
137+
MetaTensor* out);
138+
132139
} // namespace phi

0 commit comments

Comments
 (0)