Skip to content

Commit f54b5df

Browse files
committed
move gather_tree infer shape
1 parent 8492d3b commit f54b5df

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

paddle/fluid/operators/gather_tree_op.cc

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/infershape_utils.h"
1516
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/binary.h"
1619

1720
namespace paddle {
1821
namespace operators {
@@ -21,20 +24,6 @@ class GatherTreeOp : public framework::OperatorWithKernel {
2124
public:
2225
using framework::OperatorWithKernel::OperatorWithKernel;
2326

24-
void InferShape(framework::InferShapeContext* ctx) const override {
25-
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "GatherTree");
26-
OP_INOUT_CHECK(ctx->HasInput("Parents"), "Input", "Parents", "GatherTree");
27-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GatherTree");
28-
29-
auto ids_dims = ctx->GetInputDim("Ids");
30-
auto parents_dims = ctx->GetInputDim("Parents");
31-
PADDLE_ENFORCE_EQ(ids_dims == parents_dims, true,
32-
platform::errors::InvalidArgument(
33-
"The shape of Input(Parents) must be same with the "
34-
"shape of Input(Ids)."));
35-
ctx->SetOutputDim("Out", ids_dims);
36-
}
37-
3827
protected:
3928
framework::OpKernelType GetExpectedKernelType(
4029
const framework::ExecutionContext& ctx) const override {
@@ -72,4 +61,8 @@ selected ids.
7261
} // namespace paddle
7362

7463
namespace ops = paddle::operators;
75-
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker);
64+
DELCARE_INFER_SHAPE_FUNCTOR(gather_tree, GatherTreeInferShapeFunctor,
65+
PT_INFER_META(phi::GatherTreeMeta));
66+
67+
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker,
68+
GatherTreeInferShapeFunctor);

paddle/phi/infermeta/binary.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,4 +348,17 @@ void BCELossInferMeta(const MetaTensor& input,
348348
out->share_lod(input);
349349
}
350350

351+
void GatherTreeMeta(const MetaTensor& ids,
352+
const MetaTensor& parents,
353+
MetaTensor* out) {
354+
auto ids_dims = ids.dims();
355+
auto parents_dims = parents.dims();
356+
PADDLE_ENFORCE_EQ(ids_dims == parents_dims,
357+
true,
358+
phi::errors::InvalidArgument(
359+
"The shape of Input(Parents) must be same with the "
360+
"shape of Input(Ids)."));
361+
out->set_dims(ids_dims);
362+
}
363+
351364
} // namespace phi

paddle/phi/infermeta/binary.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,8 @@ void BCELossInferMeta(const MetaTensor& input,
6868
const MetaTensor& label,
6969
MetaTensor* out,
7070
MetaConfig config = MetaConfig());
71+
72+
void GatherTreeMeta(const MetaTensor& ids,
73+
const MetaTensor& parents,
74+
MetaTensor* out);
7175
} // namespace phi

0 commit comments

Comments
 (0)