Skip to content

Commit bca78dd

Browse files
authored
【Infer Symbolic Shape BUAA No.164】Add psroi_pool op (#67361)
* 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * 【Infer Symbolic Shape BUAA No.164】Add psroi_pool op * Update multiary_infer_sym.h * test
1 parent 022dfc6 commit bca78dd

File tree

3 files changed

+61
-7
lines changed

3 files changed

+61
-7
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,12 +1697,65 @@ bool MovingAverageAbsMaxScale_OpInferSymbolicShape(
16971697
// return true;
16981698
// }
16991699

1700-
// bool PsroiPoolOpInferSymbolicShape(pir::Operation *op,
1701-
// pir::InferSymbolicShapeContext
1702-
// *infer_context) {
1703-
// // pass
1704-
// return true;
1705-
// }
1700+
bool PsroiPoolOpInferSymbolicShape(
1701+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1702+
const auto &x_shape_or_data =
1703+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
1704+
const std::vector<symbol::DimExpr> &input_dims = x_shape_or_data.shape();
1705+
const auto &rois_shape_or_data =
1706+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
1707+
const std::vector<symbol::DimExpr> &rois_dims = rois_shape_or_data.shape();
1708+
PADDLE_ENFORCE_EQ(
1709+
input_dims.size(),
1710+
4,
1711+
phi::errors::InvalidArgument("The format of input tensor is NCHW"));
1712+
PADDLE_ENFORCE_EQ(rois_dims.size(),
1713+
2,
1714+
phi::errors::InvalidArgument(
1715+
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
1716+
"given as [(x1, y1, x2, y2), ...]"));
1717+
infer_context->AddEqualCstr(rois_dims[1], symbol::DimExpr(4));
1718+
if (op->operand_source(2)) {
1719+
auto &rois_num_shape_or_data =
1720+
infer_context->GetShapeOrDataForValue(op->operand_source(2));
1721+
const std::vector<symbol::DimExpr> &rois_num_dims =
1722+
rois_num_shape_or_data.shape();
1723+
PADDLE_ENFORCE_EQ(
1724+
rois_num_dims.size(),
1725+
1,
1726+
phi::errors::InvalidArgument("The second dimension of RoisNum should "
1727+
"be 1, but received dimension is %d",
1728+
rois_num_dims.size()));
1729+
}
1730+
int pooled_height =
1731+
op->attribute<pir::Int32Attribute>("pooled_height").data();
1732+
int pooled_width = op->attribute<pir::Int32Attribute>("pooled_width").data();
1733+
int output_channels =
1734+
op->attribute<pir::Int32Attribute>("output_channels").data();
1735+
auto divisor =
1736+
symbol::DimExpr(output_channels * pooled_height * pooled_width);
1737+
infer_context->AddEqualCstr(input_dims[1], divisor);
1738+
PADDLE_ENFORCE_GT(pooled_height,
1739+
0,
1740+
phi::errors::InvalidArgument(
1741+
"The pooled output height must be greater than 0"));
1742+
PADDLE_ENFORCE_GT(pooled_width,
1743+
0,
1744+
phi::errors::InvalidArgument(
1745+
"The pooled output width must be greater than 0"));
1746+
PADDLE_ENFORCE_GT(output_channels,
1747+
1,
1748+
phi::errors::InvalidArgument(
1749+
"The pooled output channels must greater than 1"));
1750+
std::vector<symbol::DimExpr> out_dims = {
1751+
rois_dims[0], output_channels, pooled_height, pooled_width};
1752+
1753+
infer_context->SetShapeOrDataForValue(
1754+
op->result(0),
1755+
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)});
1756+
1757+
return true;
1758+
}
17061759

17071760
// bool PyramidHashOpInferSymbolicShape(pir::Operation *op,
17081761
// pir::InferSymbolicShapeContext

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale)
8888
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale_)
8989
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp)
9090
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nce)
91-
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PsroiPool)
91+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PsroiPool)
9292
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PyramidHash)
9393
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(QuantizeLinear)
9494
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(QuantizeLinear_)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3677,6 +3677,7 @@
36773677
data_type : x
36783678
optional : boxes_num
36793679
backward : psroi_pool_grad
3680+
interfaces : paddle::dialect::InferSymbolicShapeInterface
36803681

36813682
- op : put_along_axis
36823683
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign", bool include_self = true)

0 commit comments

Comments
 (0)