@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15+ #include " paddle/fluid/framework/infershape_utils.h"
1516#include " paddle/fluid/framework/op_registry.h"
17+ #include " paddle/phi/core/infermeta_utils.h"
18+ #include " paddle/phi/infermeta/binary.h"
1619
1720namespace paddle {
1821namespace operators {
@@ -21,20 +24,6 @@ class GatherTreeOp : public framework::OperatorWithKernel {
2124 public:
2225 using framework::OperatorWithKernel::OperatorWithKernel;
2326
24- void InferShape (framework::InferShapeContext* ctx) const override {
25- OP_INOUT_CHECK (ctx->HasInput (" Ids" ), " Input" , " Ids" , " GatherTree" );
26- OP_INOUT_CHECK (ctx->HasInput (" Parents" ), " Input" , " Parents" , " GatherTree" );
27- OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " GatherTree" );
28-
29- auto ids_dims = ctx->GetInputDim (" Ids" );
30- auto parents_dims = ctx->GetInputDim (" Parents" );
31- PADDLE_ENFORCE_EQ (ids_dims == parents_dims, true ,
32- platform::errors::InvalidArgument (
33- " The shape of Input(Parents) must be same with the "
34- " shape of Input(Ids)." ));
35- ctx->SetOutputDim (" Out" , ids_dims);
36- }
37-
3827 protected:
3928 framework::OpKernelType GetExpectedKernelType (
4029 const framework::ExecutionContext& ctx) const override {
@@ -72,4 +61,8 @@ selected ids.
7261} // namespace paddle
7362
7463namespace ops = paddle::operators;
75- REGISTER_OPERATOR (gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker);
64+ DELCARE_INFER_SHAPE_FUNCTOR (gather_tree, GatherTreeInferShapeFunctor,
65+ PT_INFER_META (phi::GatherTreeMeta));
66+
67+ REGISTER_OPERATOR (gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker,
68+ GatherTreeInferShapeFunctor);
0 commit comments