@@ -10,65 +10,18 @@ See the License for the specific language governing permissions and
1010limitations under the License. */
1111
1212#include < memory>
13+ #include " paddle/fluid/framework/infershape_utils.h"
1314#include " paddle/fluid/framework/op_registry.h"
1415#include " paddle/fluid/framework/op_version_registry.h"
16+ #include " paddle/phi/core/infermeta_utils.h"
17+ #include " paddle/phi/infermeta/unary.h"
1518
1619namespace paddle {
1720namespace operators {
1821
1922class PixelShuffleOp : public framework ::OperatorWithKernel {
2023 public:
2124 using framework::OperatorWithKernel::OperatorWithKernel;
22-
23- void InferShape (framework::InferShapeContext* ctx) const override {
24- PADDLE_ENFORCE_EQ (ctx->HasInput (" X" ), true ,
25- platform::errors::NotFound (
26- " Input(X) of PixelShuffleOp should not be null." ));
27- PADDLE_ENFORCE_EQ (ctx->HasOutput (" Out" ), true ,
28- platform::errors::NotFound (
29- " Output(Out) of PixelShuffleOp should not be null." ));
30-
31- auto input_dims = ctx->GetInputDim (" X" );
32- PADDLE_ENFORCE_EQ (input_dims.size (), 4 ,
33- platform::errors::InvalidArgument (
34- " Input should be a 4-D tensor of format [N, C, H, W] "
35- " or [N, H, W, C], but got %u." ,
36- input_dims.size ()));
37-
38- auto upscale_factor = ctx->Attrs ().Get <int >(" upscale_factor" );
39-
40- const std::string data_format =
41- ctx->Attrs ().Get <std::string>(" data_format" );
42- const bool channel_last = (data_format == " NHWC" );
43-
44- if (!channel_last) {
45- PADDLE_ENFORCE_EQ (
46- input_dims[1 ] % (upscale_factor * upscale_factor), 0 ,
47- platform::errors::InvalidArgument (
48- " The square of upscale_factor[%u] should divide the "
49- " number of channel[%u]" ,
50- upscale_factor * upscale_factor, input_dims[1 ]));
51- } else {
52- PADDLE_ENFORCE_EQ (
53- input_dims[3 ] % (upscale_factor * upscale_factor), 0 ,
54- platform::errors::InvalidArgument (
55- " The square of upscale_factor[%u] should divide the "
56- " number of channel[%u]" ,
57- upscale_factor * upscale_factor, input_dims[3 ]));
58- }
59- auto output_dims = input_dims;
60- output_dims[0 ] = input_dims[0 ];
61- if (!channel_last) {
62- output_dims[1 ] = input_dims[1 ] / (upscale_factor * upscale_factor);
63- output_dims[2 ] = input_dims[2 ] * upscale_factor;
64- output_dims[3 ] = input_dims[3 ] * upscale_factor;
65- } else {
66- output_dims[1 ] = input_dims[1 ] * upscale_factor;
67- output_dims[2 ] = input_dims[2 ] * upscale_factor;
68- output_dims[3 ] = input_dims[3 ] / (upscale_factor * upscale_factor);
69- }
70- ctx->SetOutputDim (" Out" , output_dims);
71- }
7225};
7326
7427class PixelShuffleOpMaker : public framework ::OpProtoAndCheckerMaker {
@@ -171,9 +124,13 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
171124} // namespace paddle
172125
173126namespace ops = paddle::operators;
127+ DELCARE_INFER_SHAPE_FUNCTOR (pixel_shuffle, PixelShuffleInferShapeFunctor,
128+ PT_INFER_META (phi::PixelShuffleInferMeta));
129+
174130REGISTER_OPERATOR (pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
175131 ops::PixelShuffleGradMaker<paddle::framework::OpDesc>,
176- ops::PixelShuffleGradMaker<paddle::imperative::OpBase>);
132+ ops::PixelShuffleGradMaker<paddle::imperative::OpBase>,
133+ PixelShuffleInferShapeFunctor);
177134
178135REGISTER_OPERATOR (pixel_shuffle_grad, ops::PixelShuffleGradOp);
179136
0 commit comments