Skip to content

Commit af664a9

Browse files
authored
【Infer Symbolic Shape No.202】[BUAA] Add graph_sample_neighbors op (#67731)
* tmp_save * fixed * fixed * fixed
1 parent 6c8d72d commit af664a9

File tree

3 files changed

+77
-12
lines changed

3 files changed

+77
-12
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,10 +1414,11 @@ bool GraphKhopSamplerOpInferSymbolicShape(
14141414
const symbol::ShapeOrDataDimExprs &eids_shape_or_data =
14151415
infer_context->GetShapeOrDataForValue(op->operand_source(3));
14161416

1417-
auto row_shape = row_shape_or_data.shape();
1418-
auto col_ptr_shape = col_ptr_shape_or_data.shape();
1419-
auto x_shape = x_shape_or_data.shape();
1420-
auto eids_shape = eids_shape_or_data.shape();
1417+
const std::vector<symbol::DimExpr> &row_shape = row_shape_or_data.shape();
1418+
const std::vector<symbol::DimExpr> &col_ptr_shape =
1419+
col_ptr_shape_or_data.shape();
1420+
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
1421+
const std::vector<symbol::DimExpr> &eids_shape = eids_shape_or_data.shape();
14211422

14221423
auto GKSShapeCheck = [&](const std::vector<symbol::DimExpr> &shape,
14231424
const std::string &tensor_name) {
@@ -1487,12 +1488,76 @@ bool GraphKhopSamplerOpInferSymbolicShape(
14871488
// return true;
14881489
// }
14891490

1490-
// bool GraphSampleNeighborsOpInferSymbolicShape(pir::Operation *op,
1491-
// pir::InferSymbolicShapeContext
1492-
// *infer_context) {
1493-
// // pass
1494-
// return true;
1495-
// }
1491+
bool GraphSampleNeighborsOpInferSymbolicShape(
1492+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1493+
const symbol::ShapeOrDataDimExprs &row_shape_or_data =
1494+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
1495+
const symbol::ShapeOrDataDimExprs &col_ptr_shape_or_data =
1496+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
1497+
const symbol::ShapeOrDataDimExprs &x_shape_or_data =
1498+
infer_context->GetShapeOrDataForValue(op->operand_source(2));
1499+
const symbol::ShapeOrDataDimExprs &eids_shape_or_data =
1500+
infer_context->GetShapeOrDataForValue(op->operand_source(3));
1501+
const symbol::ShapeOrDataDimExprs &perm_buffer_shape_or_data =
1502+
infer_context->GetShapeOrDataForValue(op->operand_source(4));
1503+
1504+
const std::vector<symbol::DimExpr> &row_shape = row_shape_or_data.shape();
1505+
const std::vector<symbol::DimExpr> &col_ptr_shape =
1506+
col_ptr_shape_or_data.shape();
1507+
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
1508+
const std::vector<symbol::DimExpr> &eids_shape = eids_shape_or_data.shape();
1509+
const std::vector<symbol::DimExpr> &perm_buffer_shape =
1510+
perm_buffer_shape_or_data.shape();
1511+
1512+
auto GSNShapeCheck = [&](const std::vector<symbol::DimExpr> &shape,
1513+
const std::string &tensor_name) {
1514+
if (shape.size() == 2)
1515+
infer_context->AddEqualCstr(shape[1], symbol::DimExpr{1});
1516+
else
1517+
PADDLE_ENFORCE_EQ(
1518+
shape.size(),
1519+
1,
1520+
common::errors::InvalidArgument(
1521+
"The %s should be 1D, when it is not 2D, but we get %d",
1522+
tensor_name,
1523+
shape.size()));
1524+
};
1525+
1526+
GSNShapeCheck(row_shape, "Row");
1527+
GSNShapeCheck(col_ptr_shape, "Col_Ptr");
1528+
GSNShapeCheck(x_shape, "X");
1529+
1530+
bool return_eids = op->attribute<pir::BoolAttribute>("return_eids").data();
1531+
bool flag_perm_buffer =
1532+
op->attribute<pir::BoolAttribute>("flag_perm_buffer").data();
1533+
1534+
if (return_eids) {
1535+
GSNShapeCheck(eids_shape, "Eids");
1536+
symbol::DimExpr out_unknown_2 = infer_context->GetNextSymName();
1537+
infer_context->SetShapeOrDataForValue(
1538+
op->result(2),
1539+
symbol::ShapeOrDataDimExprs{
1540+
symbol::TensorShapeOrDataDimExprs({out_unknown_2})});
1541+
} else {
1542+
infer_context->SetSymbolForValueByStaticShape(op->result(2));
1543+
}
1544+
1545+
if (flag_perm_buffer) {
1546+
GSNShapeCheck(perm_buffer_shape, "Perm_Buffer");
1547+
}
1548+
1549+
symbol::DimExpr out_unknown_0 = infer_context->GetNextSymName();
1550+
infer_context->SetShapeOrDataForValue(
1551+
op->result(0),
1552+
symbol::ShapeOrDataDimExprs{
1553+
symbol::TensorShapeOrDataDimExprs({out_unknown_0})});
1554+
infer_context->SetShapeOrDataForValue(
1555+
op->result(1),
1556+
symbol::ShapeOrDataDimExprs{
1557+
symbol::TensorShapeOrDataDimExprs({x_shape[0]})});
1558+
1559+
return true;
1560+
}
14961561

14971562
// bool GruOpInferSymbolicShape(pir::Operation *op,
14981563
// pir::InferSymbolicShapeContext *infer_context)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer)
6363
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals)
6464
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler)
6565
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphReindex)
66-
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphSampleNeighbors)
66+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphSampleNeighbors)
6767
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gru)
6868
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GruUnit)
6969
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2298,7 +2298,7 @@
22982298
func : graph_sample_neighbors
22992299
data_type : row
23002300
optional : eids, perm_buffer
2301-
# interfaces : paddle::dialect::InferSymbolicShapeInterface
2301+
interfaces : paddle::dialect::InferSymbolicShapeInterface
23022302

23032303
- op : grid_sample
23042304
args : (Tensor x, Tensor grid, str mode = "bilinear", str padding_mode = "zeros", bool align_corners = true)

0 commit comments

Comments
 (0)