@@ -1414,10 +1414,11 @@ bool GraphKhopSamplerOpInferSymbolicShape(
14141414 const symbol::ShapeOrDataDimExprs &eids_shape_or_data =
14151415 infer_context->GetShapeOrDataForValue (op->operand_source (3 ));
14161416
1417- auto row_shape = row_shape_or_data.shape ();
1418- auto col_ptr_shape = col_ptr_shape_or_data.shape ();
1419- auto x_shape = x_shape_or_data.shape ();
1420- auto eids_shape = eids_shape_or_data.shape ();
1417+ const std::vector<symbol::DimExpr> &row_shape = row_shape_or_data.shape ();
1418+ const std::vector<symbol::DimExpr> &col_ptr_shape =
1419+ col_ptr_shape_or_data.shape ();
1420+ const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape ();
1421+ const std::vector<symbol::DimExpr> &eids_shape = eids_shape_or_data.shape ();
14211422
14221423 auto GKSShapeCheck = [&](const std::vector<symbol::DimExpr> &shape,
14231424 const std::string &tensor_name) {
@@ -1487,12 +1488,76 @@ bool GraphKhopSamplerOpInferSymbolicShape(
14871488// return true;
14881489// }
14891490
1490- // bool GraphSampleNeighborsOpInferSymbolicShape(pir::Operation *op,
1491- // pir::InferSymbolicShapeContext
1492- // *infer_context) {
1493- // // pass
1494- // return true;
1495- // }
1491+ bool GraphSampleNeighborsOpInferSymbolicShape (
1492+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1493+ const symbol::ShapeOrDataDimExprs &row_shape_or_data =
1494+ infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
1495+ const symbol::ShapeOrDataDimExprs &col_ptr_shape_or_data =
1496+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
1497+ const symbol::ShapeOrDataDimExprs &x_shape_or_data =
1498+ infer_context->GetShapeOrDataForValue (op->operand_source (2 ));
1499+ const symbol::ShapeOrDataDimExprs &eids_shape_or_data =
1500+ infer_context->GetShapeOrDataForValue (op->operand_source (3 ));
1501+ const symbol::ShapeOrDataDimExprs &perm_buffer_shape_or_data =
1502+ infer_context->GetShapeOrDataForValue (op->operand_source (4 ));
1503+
1504+ const std::vector<symbol::DimExpr> &row_shape = row_shape_or_data.shape ();
1505+ const std::vector<symbol::DimExpr> &col_ptr_shape =
1506+ col_ptr_shape_or_data.shape ();
1507+ const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape ();
1508+ const std::vector<symbol::DimExpr> &eids_shape = eids_shape_or_data.shape ();
1509+ const std::vector<symbol::DimExpr> &perm_buffer_shape =
1510+ perm_buffer_shape_or_data.shape ();
1511+
1512+ auto GSNShapeCheck = [&](const std::vector<symbol::DimExpr> &shape,
1513+ const std::string &tensor_name) {
1514+ if (shape.size () == 2 )
1515+ infer_context->AddEqualCstr (shape[1 ], symbol::DimExpr{1 });
1516+ else
1517+ PADDLE_ENFORCE_EQ (
1518+ shape.size (),
1519+ 1 ,
1520+ common::errors::InvalidArgument (
1521+ " The %s should be 1D, when it is not 2D, but we get %d" ,
1522+ tensor_name,
1523+ shape.size ()));
1524+ };
1525+
1526+ GSNShapeCheck (row_shape, " Row" );
1527+ GSNShapeCheck (col_ptr_shape, " Col_Ptr" );
1528+ GSNShapeCheck (x_shape, " X" );
1529+
1530+ bool return_eids = op->attribute <pir::BoolAttribute>(" return_eids" ).data ();
1531+ bool flag_perm_buffer =
1532+ op->attribute <pir::BoolAttribute>(" flag_perm_buffer" ).data ();
1533+
1534+ if (return_eids) {
1535+ GSNShapeCheck (eids_shape, " Eids" );
1536+ symbol::DimExpr out_unknown_2 = infer_context->GetNextSymName ();
1537+ infer_context->SetShapeOrDataForValue (
1538+ op->result (2 ),
1539+ symbol::ShapeOrDataDimExprs{
1540+ symbol::TensorShapeOrDataDimExprs ({out_unknown_2})});
1541+ } else {
1542+ infer_context->SetSymbolForValueByStaticShape (op->result (2 ));
1543+ }
1544+
1545+ if (flag_perm_buffer) {
1546+ GSNShapeCheck (perm_buffer_shape, " Perm_Buffer" );
1547+ }
1548+
1549+ symbol::DimExpr out_unknown_0 = infer_context->GetNextSymName ();
1550+ infer_context->SetShapeOrDataForValue (
1551+ op->result (0 ),
1552+ symbol::ShapeOrDataDimExprs{
1553+ symbol::TensorShapeOrDataDimExprs ({out_unknown_0})});
1554+ infer_context->SetShapeOrDataForValue (
1555+ op->result (1 ),
1556+ symbol::ShapeOrDataDimExprs{
1557+ symbol::TensorShapeOrDataDimExprs ({x_shape[0 ]})});
1558+
1559+ return true ;
1560+ }
14961561
14971562// bool GruOpInferSymbolicShape(pir::Operation *op,
14981563// pir::InferSymbolicShapeContext *infer_context)
0 commit comments