Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1414,10 +1414,11 @@ bool GraphKhopSamplerOpInferSymbolicShape(
const symbol::ShapeOrDataDimExprs &eids_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(3));

auto row_shape = row_shape_or_data.shape();
auto col_ptr_shape = col_ptr_shape_or_data.shape();
auto x_shape = x_shape_or_data.shape();
auto eids_shape = eids_shape_or_data.shape();
const std::vector<symbol::DimExpr> &row_shape = row_shape_or_data.shape();
const std::vector<symbol::DimExpr> &col_ptr_shape =
col_ptr_shape_or_data.shape();
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
const std::vector<symbol::DimExpr> &eids_shape = eids_shape_or_data.shape();

auto GKSShapeCheck = [&](const std::vector<symbol::DimExpr> &shape,
const std::string &tensor_name) {
Expand Down Expand Up @@ -1487,12 +1488,76 @@ bool GraphKhopSamplerOpInferSymbolicShape(
// return true;
// }

// bool GraphSampleNeighborsOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool GraphSampleNeighborsOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &row_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const symbol::ShapeOrDataDimExprs &col_ptr_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const symbol::ShapeOrDataDimExprs &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const symbol::ShapeOrDataDimExprs &eids_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(3));
const symbol::ShapeOrDataDimExprs &perm_buffer_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(4));

const std::vector<symbol::DimExpr> &row_shape = row_shape_or_data.shape();
const std::vector<symbol::DimExpr> &col_ptr_shape =
col_ptr_shape_or_data.shape();
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
const std::vector<symbol::DimExpr> &eids_shape = eids_shape_or_data.shape();
const std::vector<symbol::DimExpr> &perm_buffer_shape =
perm_buffer_shape_or_data.shape();

auto GSNShapeCheck = [&](const std::vector<symbol::DimExpr> &shape,
const std::string &tensor_name) {
if (shape.size() == 2)
infer_context->AddEqualCstr(shape[1], symbol::DimExpr{1});
else
PADDLE_ENFORCE_EQ(
shape.size(),
1,
common::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
shape.size()));
};

GSNShapeCheck(row_shape, "Row");
GSNShapeCheck(col_ptr_shape, "Col_Ptr");
GSNShapeCheck(x_shape, "X");

bool return_eids = op->attribute<pir::BoolAttribute>("return_eids").data();
bool flag_perm_buffer =
op->attribute<pir::BoolAttribute>("flag_perm_buffer").data();

if (return_eids) {
GSNShapeCheck(eids_shape, "Eids");
symbol::DimExpr out_unknown_2 = infer_context->GetNextSymName();
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs({out_unknown_2})});
} else {
infer_context->SetSymbolForValueByStaticShape(op->result(2));
}

if (flag_perm_buffer) {
GSNShapeCheck(perm_buffer_shape, "Perm_Buffer");
}

symbol::DimExpr out_unknown_0 = infer_context->GetNextSymName();
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs({out_unknown_0})});
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs({x_shape[0]})});

return true;
}

// bool GruOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphReindex)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphSampleNeighbors)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphSampleNeighbors)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gru)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GruUnit)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2296,7 +2296,7 @@
func : graph_sample_neighbors
data_type : row
optional : eids, perm_buffer
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : grid_sample
args : (Tensor x, Tensor grid, str mode = "bilinear", str padding_mode = "zeros", bool align_corners = true)
Expand Down