@@ -1697,12 +1697,65 @@ bool MovingAverageAbsMaxScale_OpInferSymbolicShape(
16971697// return true;
16981698// }
16991699
1700- // bool PsroiPoolOpInferSymbolicShape(pir::Operation *op,
1701- // pir::InferSymbolicShapeContext
1702- // *infer_context) {
1703- // // pass
1704- // return true;
1705- // }
1700+ bool PsroiPoolOpInferSymbolicShape (
1701+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1702+ const auto &x_shape_or_data =
1703+ infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
1704+ const std::vector<symbol::DimExpr> &input_dims = x_shape_or_data.shape ();
1705+ const auto &rois_shape_or_data =
1706+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
1707+ const std::vector<symbol::DimExpr> &rois_dims = rois_shape_or_data.shape ();
1708+ PADDLE_ENFORCE_EQ (
1709+ input_dims.size (),
1710+ 4 ,
1711+ phi::errors::InvalidArgument (" The format of input tensor is NCHW" ));
1712+ PADDLE_ENFORCE_EQ (rois_dims.size (),
1713+ 2 ,
1714+ phi::errors::InvalidArgument (
1715+ " ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
1716+ " given as [(x1, y1, x2, y2), ...]" ));
1717+ infer_context->AddEqualCstr (rois_dims[1 ], symbol::DimExpr (4 ));
1718+ if (op->operand_source (2 )) {
1719+ auto &rois_num_shape_or_data =
1720+ infer_context->GetShapeOrDataForValue (op->operand_source (2 ));
1721+ const std::vector<symbol::DimExpr> &rois_num_dims =
1722+ rois_num_shape_or_data.shape ();
1723+ PADDLE_ENFORCE_EQ (
1724+ rois_num_dims.size (),
1725+ 1 ,
1726+ phi::errors::InvalidArgument (" The second dimension of RoisNum should "
1727+ " be 1, but received dimension is %d" ,
1728+ rois_num_dims.size ()));
1729+ }
1730+ int pooled_height =
1731+ op->attribute <pir::Int32Attribute>(" pooled_height" ).data ();
1732+ int pooled_width = op->attribute <pir::Int32Attribute>(" pooled_width" ).data ();
1733+ int output_channels =
1734+ op->attribute <pir::Int32Attribute>(" output_channels" ).data ();
1735+ auto divisor =
1736+ symbol::DimExpr (output_channels * pooled_height * pooled_width);
1737+ infer_context->AddEqualCstr (input_dims[1 ], divisor);
1738+ PADDLE_ENFORCE_GT (pooled_height,
1739+ 0 ,
1740+ phi::errors::InvalidArgument (
1741+ " The pooled output height must be greater than 0" ));
1742+ PADDLE_ENFORCE_GT (pooled_width,
1743+ 0 ,
1744+ phi::errors::InvalidArgument (
1745+ " The pooled output width must be greater than 0" ));
1746+ PADDLE_ENFORCE_GT (output_channels,
1747+ 1 ,
1748+ phi::errors::InvalidArgument (
1749+ " The pooled output channels must greater than 1" ));
1750+ std::vector<symbol::DimExpr> out_dims = {
1751+ rois_dims[0 ], output_channels, pooled_height, pooled_width};
1752+
1753+ infer_context->SetShapeOrDataForValue (
1754+ op->result (0 ),
1755+ symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs (out_dims)});
1756+
1757+ return true ;
1758+ }
17061759
17071760// bool PyramidHashOpInferSymbolicShape(pir::Operation *op,
17081761// pir::InferSymbolicShapeContext
0 commit comments