@@ -1479,12 +1479,36 @@ bool FusedMultiTransformerOpInferSymbolicShape(
14791479 return true ;
14801480}
14811481
1482- // bool GenerateProposalsOpInferSymbolicShape(pir::Operation *op,
1483- // pir::InferSymbolicShapeContext
1484- // *infer_context) {
1485- // // pass
1486- // return true;
1487- // }
1482+ bool GenerateProposalsOpInferSymbolicShape (
1483+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1484+ symbol::DimExpr out_unknown = infer_context->GetNextSymName ();
1485+ std::vector<symbol::DimExpr> rpn_rois_shape = {out_unknown,
1486+ symbol::DimExpr (4 )};
1487+ infer_context->SetShapeOrDataForValue (
1488+ op->result (0 ),
1489+ symbol::ShapeOrDataDimExprs{
1490+ symbol::TensorShapeOrDataDimExprs (rpn_rois_shape)});
1491+
1492+ std::vector<symbol::DimExpr> rpn_roi_probs_shape = {out_unknown,
1493+ symbol::DimExpr (1 )};
1494+ infer_context->SetShapeOrDataForValue (
1495+ op->result (1 ),
1496+ symbol::ShapeOrDataDimExprs{
1497+ symbol::TensorShapeOrDataDimExprs (rpn_roi_probs_shape)});
1498+
1499+ const symbol::ShapeOrDataDimExprs &score_shape_or_data =
1500+ infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
1501+
1502+ std::vector<symbol::DimExpr> score_shape = score_shape_or_data.shape ();
1503+ auto rpn_rois_num_shape = std::vector<symbol::DimExpr>{score_shape[0 ]};
1504+
1505+ infer_context->SetShapeOrDataForValue (
1506+ op->result (2 ),
1507+ symbol::ShapeOrDataDimExprs{
1508+ symbol::TensorShapeOrDataDimExprs (rpn_rois_num_shape)});
1509+
1510+ return true ;
1511+ }
14881512
14891513bool GraphKhopSamplerOpInferSymbolicShape (
14901514 pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
@@ -1642,19 +1666,135 @@ bool GraphSampleNeighborsOpInferSymbolicShape(
16421666 return true ;
16431667}
16441668
1645- // bool GruOpInferSymbolicShape(pir::Operation *op,
1646- // pir::InferSymbolicShapeContext *infer_context)
1647- // {
1648- // // pass
1649- // return true;
1650- // }
1669+ bool GruOpInferSymbolicShape (pir::Operation *op,
1670+ pir::InferSymbolicShapeContext *infer_context) {
1671+ const symbol::ShapeOrDataDimExprs &input_shape_or_data =
1672+ infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
1673+ const symbol::ShapeOrDataDimExprs &weight_shape_or_data =
1674+ infer_context->GetShapeOrDataForValue (op->operand_source (2 ));
1675+ const symbol::ShapeOrDataDimExprs &hidden_shape_or_data =
1676+ infer_context->GetShapeOrDataForValue (op->result (3 ));
16511677
1652- // bool GruUnitOpInferSymbolicShape(pir::Operation *op,
1653- // pir::InferSymbolicShapeContext
1654- // *infer_context) {
1655- // // pass
1656- // return true;
1657- // }
1678+ std::vector<symbol::DimExpr> input_shape = input_shape_or_data.shape ();
1679+ std::vector<symbol::DimExpr> weight_shape = weight_shape_or_data.shape ();
1680+ std::vector<symbol::DimExpr> hidden_shape = hidden_shape_or_data.shape ();
1681+
1682+ bool is_test = op->attribute <pir::BoolAttribute>(" is_test" ).data ();
1683+
1684+ symbol::DimExpr input_size = input_shape[1 ];
1685+ symbol::DimExpr frame_size = weight_shape[0 ];
1686+
1687+ // Check if input_size is 3 times frame_size
1688+ infer_context->AddEqualCstr (input_size, frame_size * 3 );
1689+
1690+ // Check if weight matrix has size [frame_size, frame_size * 3]
1691+ infer_context->AddEqualCstr (weight_shape[1 ], frame_size * 3 );
1692+
1693+ if (op->operand (1 )) { // Check if H0 is given
1694+ const symbol::ShapeOrDataDimExprs &h0_shape_or_data =
1695+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
1696+ std::vector<symbol::DimExpr> h0_shape = h0_shape_or_data.shape ();
1697+ infer_context->AddEqualCstr (h0_shape[1 ], frame_size);
1698+ }
1699+
1700+ if (op->operand (3 )) { // Check if Bias is given
1701+ const symbol::ShapeOrDataDimExprs &bias_shape_or_data =
1702+ infer_context->GetShapeOrDataForValue (op->operand_source (3 ));
1703+ std::vector<symbol::DimExpr> bias_shape = bias_shape_or_data.shape ();
1704+ infer_context->AddEqualCstr (bias_shape[0 ], 1 );
1705+ infer_context->AddEqualCstr (bias_shape[1 ], frame_size * 3 );
1706+ }
1707+
1708+ if (is_test) {
1709+ symbol::TensorShapeOrDataDimExprs batch_gate_shape (input_shape);
1710+ infer_context->SetShapeOrDataForValue (op->result (0 ), batch_gate_shape);
1711+
1712+ symbol::TensorShapeOrDataDimExprs batch_reset_hidden_prev_shape (
1713+ {hidden_shape});
1714+ infer_context->SetShapeOrDataForValue (op->result (1 ),
1715+ batch_reset_hidden_prev_shape);
1716+
1717+ symbol::TensorShapeOrDataDimExprs batch_hidden_shape ({hidden_shape});
1718+ infer_context->SetShapeOrDataForValue (op->result (2 ), batch_hidden_shape);
1719+ } else {
1720+ symbol::TensorShapeOrDataDimExprs batch_gate_shape (input_shape);
1721+ infer_context->SetShapeOrDataForValue (op->result (0 ), batch_gate_shape);
1722+
1723+ symbol::TensorShapeOrDataDimExprs batch_reset_hidden_prev_shape (
1724+ {input_shape[0 ], frame_size});
1725+ infer_context->SetShapeOrDataForValue (op->result (1 ),
1726+ batch_reset_hidden_prev_shape);
1727+
1728+ symbol::TensorShapeOrDataDimExprs batch_hidden_shape (
1729+ {input_shape[0 ], frame_size});
1730+ infer_context->SetShapeOrDataForValue (op->result (2 ), batch_hidden_shape);
1731+ }
1732+
1733+ symbol::TensorShapeOrDataDimExprs hidden_shape_output (
1734+ {input_shape[0 ], frame_size});
1735+ infer_context->SetShapeOrDataForValue (op->result (3 ), hidden_shape_output);
1736+
1737+ return true ;
1738+ }
1739+
1740+ bool GruUnitOpInferSymbolicShape (
1741+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1742+ // Get symbolic shapes of the input tensors
1743+ const symbol::ShapeOrDataDimExprs &input_shape_or_data =
1744+ infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
1745+ const symbol::ShapeOrDataDimExprs &hidden_prev_shape_or_data =
1746+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
1747+ const symbol::ShapeOrDataDimExprs &weight_shape_or_data =
1748+ infer_context->GetShapeOrDataForValue (op->operand_source (2 ));
1749+ const symbol::ShapeOrDataDimExprs &bias_shape_or_data =
1750+ infer_context->GetShapeOrDataForValue (op->operand_source (3 ));
1751+
1752+ auto input_shape = input_shape_or_data.shape ();
1753+ auto hidden_prev_shape = hidden_prev_shape_or_data.shape ();
1754+ auto weight_shape = weight_shape_or_data.shape ();
1755+
1756+ // Validate input dimensions
1757+ symbol::DimExpr batch_size = input_shape[0 ];
1758+ symbol::DimExpr input_size = input_shape[1 ];
1759+ symbol::DimExpr frame_size = hidden_prev_shape[1 ];
1760+ symbol::DimExpr weight_height = weight_shape[0 ];
1761+ symbol::DimExpr weight_width = weight_shape[1 ];
1762+
1763+ // Enforce dimension constraints using symbolic dimensions
1764+ infer_context->AddEqualCstr (input_size, frame_size * 3 );
1765+ infer_context->AddEqualCstr (weight_height, frame_size);
1766+ infer_context->AddEqualCstr (weight_width, frame_size * 3 );
1767+
1768+ // If bias is used, check its dimensions
1769+ if (!bias_shape_or_data.isa <symbol::NullShapeOrDataDimExpr>()) {
1770+ auto bias_shape = bias_shape_or_data.shape ();
1771+ symbol::DimExpr bias_height = bias_shape[0 ];
1772+ symbol::DimExpr bias_width = bias_shape[1 ];
1773+ infer_context->AddEqualCstr (bias_height, 1 );
1774+ infer_context->AddEqualCstr (bias_width, frame_size * 3 );
1775+ }
1776+
1777+ // Set output dimensions
1778+ std::vector<symbol::DimExpr> gate_dims = {batch_size, frame_size * 3 };
1779+ std::vector<symbol::DimExpr> reset_hidden_prev_dims = {batch_size,
1780+ frame_size};
1781+ std::vector<symbol::DimExpr> hidden_dims = {batch_size, frame_size};
1782+
1783+ infer_context->SetShapeOrDataForValue (
1784+ op->result (0 ),
1785+ symbol::ShapeOrDataDimExprs{
1786+ symbol::TensorShapeOrDataDimExprs (gate_dims)});
1787+ infer_context->SetShapeOrDataForValue (
1788+ op->result (1 ),
1789+ symbol::ShapeOrDataDimExprs{
1790+ symbol::TensorShapeOrDataDimExprs (reset_hidden_prev_dims)});
1791+ infer_context->SetShapeOrDataForValue (
1792+ op->result (2 ),
1793+ symbol::ShapeOrDataDimExprs{
1794+ symbol::TensorShapeOrDataDimExprs (hidden_dims)});
1795+
1796+ return true ;
1797+ }
16581798
16591799bool GroupNormOpInferSymbolicShape (
16601800 pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
0 commit comments