Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,62 @@ bool FusedSoftmaxMaskOpInferSymbolicShape(
return true;
}

// bool GridSampleOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool GridSampleOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for GridSampleOpInferSymbolicShape

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
auto x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
auto grid_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape();

PADDLE_ENFORCE_GE(x_shape.size(),
4,
common::errors::InvalidArgument(
"Input(X) of GridSampleOp should be 4-D Tensor, but "
"received X dimension size(%d)",
x_shape.size()));
PADDLE_ENFORCE_LE(x_shape.size(),
5,
common::errors::InvalidArgument(
"Input(X) of GridSampleOp should be 4-D Tensor, but "
"received X dimension size(%d)",
x_shape.size()));
PADDLE_ENFORCE_GE(grid_shape.size(),
4,
common::errors::InvalidArgument(
"Input(Grid) of GridSampleOp should be 4-D Tensor, "
"but received Grid dimension size(%d)",
grid_shape.size()));
PADDLE_ENFORCE_LE(grid_shape.size(),
5,
common::errors::InvalidArgument(
"Input(Grid) of GridSampleOp should be 4-D Tensor, "
"but received Grid dimension size(%d)",
grid_shape.size()));

if (grid_shape.size() == 4) {
infer_context->AddEqualCstr(grid_shape[3], symbol::DimExpr(2));
}
if (grid_shape.size() == 5) {
infer_context->AddEqualCstr(grid_shape[4], symbol::DimExpr(3));
}

infer_context->AddEqualCstr(grid_shape[0], x_shape[0]);

std::vector<symbol::DimExpr> out_shape;
if (grid_shape.size() == 4) {
out_shape = {x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]};
} else {
out_shape = {
x_shape[0], x_shape[1], grid_shape[1], grid_shape[2], grid_shape[3]};
}

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

return true;
}

bool GatherOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedSoftmaxMask)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GridSample)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GridSample)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherNd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherTree)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1396,12 +1396,36 @@ bool FusedMultiTransformerOpInferSymbolicShape(
return true;
}

// bool GenerateProposalsOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool GenerateProposalsOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for GenerateProposalsOp

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
symbol::DimExpr out_unknown = infer_context->GetNextSymName();
std::vector<symbol::DimExpr> rpn_rois_shape = {out_unknown,
symbol::DimExpr(4)};
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(rpn_rois_shape)});

std::vector<symbol::DimExpr> rpn_roi_probs_shape = {out_unknown,
symbol::DimExpr(1)};
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(rpn_roi_probs_shape)});

const symbol::ShapeOrDataDimExprs &score_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));

std::vector<symbol::DimExpr> score_shape = score_shape_or_data.shape();
auto rpn_rois_num_shape = std::vector<symbol::DimExpr>{score_shape[0]};

infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(rpn_rois_num_shape)});

return true;
}

bool GraphKhopSamplerOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down Expand Up @@ -1559,19 +1583,135 @@ bool GraphSampleNeighborsOpInferSymbolicShape(
return true;
}

// bool GruOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context)
// {
// // pass
// return true;
// }
bool GruOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &input_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const symbol::ShapeOrDataDimExprs &weight_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const symbol::ShapeOrDataDimExprs &hidden_shape_or_data =
infer_context->GetShapeOrDataForValue(op->result(3));

// bool GruUnitOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
std::vector<symbol::DimExpr> input_shape = input_shape_or_data.shape();
std::vector<symbol::DimExpr> weight_shape = weight_shape_or_data.shape();
std::vector<symbol::DimExpr> hidden_shape = hidden_shape_or_data.shape();

bool is_test = op->attribute<pir::BoolAttribute>("is_test").data();

symbol::DimExpr input_size = input_shape[1];
symbol::DimExpr frame_size = weight_shape[0];

// Check if input_size is 3 times frame_size
infer_context->AddEqualCstr(input_size, frame_size * 3);

// Check if weight matrix has size [frame_size, frame_size * 3]
infer_context->AddEqualCstr(weight_shape[1], frame_size * 3);

if (op->operand(1)) { // Check if H0 is given
const symbol::ShapeOrDataDimExprs &h0_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
std::vector<symbol::DimExpr> h0_shape = h0_shape_or_data.shape();
infer_context->AddEqualCstr(h0_shape[1], frame_size);
}

if (op->operand(3)) { // Check if Bias is given
const symbol::ShapeOrDataDimExprs &bias_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(3));
std::vector<symbol::DimExpr> bias_shape = bias_shape_or_data.shape();
infer_context->AddEqualCstr(bias_shape[0], 1);
infer_context->AddEqualCstr(bias_shape[1], frame_size * 3);
}

if (is_test) {
symbol::TensorShapeOrDataDimExprs batch_gate_shape(input_shape);
infer_context->SetShapeOrDataForValue(op->result(0), batch_gate_shape);

symbol::TensorShapeOrDataDimExprs batch_reset_hidden_prev_shape(
{hidden_shape});
infer_context->SetShapeOrDataForValue(op->result(1),
batch_reset_hidden_prev_shape);

symbol::TensorShapeOrDataDimExprs batch_hidden_shape({hidden_shape});
infer_context->SetShapeOrDataForValue(op->result(2), batch_hidden_shape);
} else {
symbol::TensorShapeOrDataDimExprs batch_gate_shape(input_shape);
infer_context->SetShapeOrDataForValue(op->result(0), batch_gate_shape);

symbol::TensorShapeOrDataDimExprs batch_reset_hidden_prev_shape(
{input_shape[0], frame_size});
infer_context->SetShapeOrDataForValue(op->result(1),
batch_reset_hidden_prev_shape);

symbol::TensorShapeOrDataDimExprs batch_hidden_shape(
{input_shape[0], frame_size});
infer_context->SetShapeOrDataForValue(op->result(2), batch_hidden_shape);
}

symbol::TensorShapeOrDataDimExprs hidden_shape_output(
{input_shape[0], frame_size});
infer_context->SetShapeOrDataForValue(op->result(3), hidden_shape_output);

return true;
}

bool GruUnitOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for GruUnitOpInferSymbolicShape

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// Get symbolic shapes of the input tensors
const symbol::ShapeOrDataDimExprs &input_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const symbol::ShapeOrDataDimExprs &hidden_prev_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const symbol::ShapeOrDataDimExprs &weight_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const symbol::ShapeOrDataDimExprs &bias_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(3));

auto input_shape = input_shape_or_data.shape();
auto hidden_prev_shape = hidden_prev_shape_or_data.shape();
auto weight_shape = weight_shape_or_data.shape();

// Validate input dimensions
symbol::DimExpr batch_size = input_shape[0];
symbol::DimExpr input_size = input_shape[1];
symbol::DimExpr frame_size = hidden_prev_shape[1];
symbol::DimExpr weight_height = weight_shape[0];
symbol::DimExpr weight_width = weight_shape[1];

// Enforce dimension constraints using symbolic dimensions
infer_context->AddEqualCstr(input_size, frame_size * 3);
infer_context->AddEqualCstr(weight_height, frame_size);
infer_context->AddEqualCstr(weight_width, frame_size * 3);

// If bias is used, check its dimensions
if (!bias_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
auto bias_shape = bias_shape_or_data.shape();
symbol::DimExpr bias_height = bias_shape[0];
symbol::DimExpr bias_width = bias_shape[1];
infer_context->AddEqualCstr(bias_height, 1);
infer_context->AddEqualCstr(bias_width, frame_size * 3);
}

// Set output dimensions
std::vector<symbol::DimExpr> gate_dims = {batch_size, frame_size * 3};
std::vector<symbol::DimExpr> reset_hidden_prev_dims = {batch_size,
frame_size};
std::vector<symbol::DimExpr> hidden_dims = {batch_size, frame_size};

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(gate_dims)});
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(reset_hidden_prev_dims)});
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(hidden_dims)});

return true;
}

bool GroupNormOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphReindex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphSampleNeighbors)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gru)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GruUnit)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gru)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GruUnit)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(InstanceNorm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lerp)
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2277,6 +2277,7 @@
func : generate_proposals
data_type : anchors
optional : rpn_rois_num
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : graph_khop_sampler
args : (Tensor row, Tensor colptr, Tensor x, Tensor eids, int[] sample_sizes, bool return_eids)
Expand Down Expand Up @@ -2310,6 +2311,7 @@
func : grid_sample
data_type : x
backward : grid_sample_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : group_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon = 1e-5, int groups = -1, str data_format = "NCHW")
Expand All @@ -2336,6 +2338,7 @@
optional: h0, bias
intermediate: batch_gate, batch_reset_hidden_prev, batch_hidden
backward: gru_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : gru_unit
args: (Tensor input, Tensor hidden_prev, Tensor weight, Tensor bias, int activation
Expand All @@ -2348,6 +2351,7 @@
optional: bias
intermediate: gate, reset_hidden_prev
backward: gru_unit_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : gumbel_softmax
args : (Tensor x, float temperature = 1.0, bool hard = false, int axis = -1)
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/test_generate_proposals_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def set_data(self):
}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_symbol_infer=False)

def setUp(self):
self.op_type = "generate_proposals_v2"
Expand Down