-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【Infer Symbolic Shape No.139,140,141,142】[BUAA] Add generate_proposals,grid_sample,gru,gru_unit op #67413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
【Infer Symbolic Shape No.139,140,141,142】[BUAA] Add generate_proposals,grid_sample,gru,gru_unit op #67413
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
15e133c
fixed
Luohongzhige 2a78f90
generate
Luohongzhige a398e38
generate
Luohongzhige 0457aef
generate
Luohongzhige 7a3289c
generate
Luohongzhige b1396ba
generate
Luohongzhige b404fd5
generate
Luohongzhige 2ba9a9d
close check_symbol_infer for TestGenerateProposalsV2Op
Luohongzhige 11f4e7d
Merge branch 'develop' into cinn_2
Luohongzhige 193623e
Merge branch 'develop' into cinn_2
Luohongzhige 52182e0
Merge branch 'develop' into cinn_2
Luohongzhige b0249d4
fixed
Luohongzhige c9c2244
fixed
Luohongzhige 7835adc
fixed Gru
Luohongzhige cc500fc
Merge branch 'develop' into cinn_2
Luohongzhige c4dd4b1
fixed Gru
Luohongzhige 1b50638
Merge branch 'cinn_2' of https://github.com/Luohongzhige/Paddle into …
Luohongzhige 4e54190
Merge branch 'develop' into cinn_2
Luohongzhige File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1396,12 +1396,36 @@ bool FusedMultiTransformerOpInferSymbolicShape( | |
| return true; | ||
| } | ||
|
|
||
| // bool GenerateProposalsOpInferSymbolicShape(pir::Operation *op, | ||
| // pir::InferSymbolicShapeContext | ||
| // *infer_context) { | ||
| // // pass | ||
| // return true; | ||
| // } | ||
| bool GenerateProposalsOpInferSymbolicShape( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM for GenerateProposalsOp |
||
| pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
| symbol::DimExpr out_unknown = infer_context->GetNextSymName(); | ||
| std::vector<symbol::DimExpr> rpn_rois_shape = {out_unknown, | ||
| symbol::DimExpr(4)}; | ||
| infer_context->SetShapeOrDataForValue( | ||
| op->result(0), | ||
| symbol::ShapeOrDataDimExprs{ | ||
| symbol::TensorShapeOrDataDimExprs(rpn_rois_shape)}); | ||
|
|
||
| std::vector<symbol::DimExpr> rpn_roi_probs_shape = {out_unknown, | ||
| symbol::DimExpr(1)}; | ||
| infer_context->SetShapeOrDataForValue( | ||
| op->result(1), | ||
| symbol::ShapeOrDataDimExprs{ | ||
| symbol::TensorShapeOrDataDimExprs(rpn_roi_probs_shape)}); | ||
|
|
||
| const symbol::ShapeOrDataDimExprs &score_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
|
|
||
| std::vector<symbol::DimExpr> score_shape = score_shape_or_data.shape(); | ||
| auto rpn_rois_num_shape = std::vector<symbol::DimExpr>{score_shape[0]}; | ||
|
|
||
| infer_context->SetShapeOrDataForValue( | ||
| op->result(2), | ||
| symbol::ShapeOrDataDimExprs{ | ||
| symbol::TensorShapeOrDataDimExprs(rpn_rois_num_shape)}); | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| bool GraphKhopSamplerOpInferSymbolicShape( | ||
| pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
|
|
@@ -1559,19 +1583,135 @@ bool GraphSampleNeighborsOpInferSymbolicShape( | |
| return true; | ||
| } | ||
|
|
||
| // bool GruOpInferSymbolicShape(pir::Operation *op, | ||
| // pir::InferSymbolicShapeContext *infer_context) | ||
| // { | ||
| // // pass | ||
| // return true; | ||
| // } | ||
| bool GruOpInferSymbolicShape(pir::Operation *op, | ||
| pir::InferSymbolicShapeContext *infer_context) { | ||
| const symbol::ShapeOrDataDimExprs &input_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
| const symbol::ShapeOrDataDimExprs &weight_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(2)); | ||
| const symbol::ShapeOrDataDimExprs &hidden_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->result(3)); | ||
|
|
||
| // bool GruUnitOpInferSymbolicShape(pir::Operation *op, | ||
| // pir::InferSymbolicShapeContext | ||
| // *infer_context) { | ||
| // // pass | ||
| // return true; | ||
| // } | ||
| std::vector<symbol::DimExpr> input_shape = input_shape_or_data.shape(); | ||
| std::vector<symbol::DimExpr> weight_shape = weight_shape_or_data.shape(); | ||
| std::vector<symbol::DimExpr> hidden_shape = hidden_shape_or_data.shape(); | ||
|
|
||
| bool is_test = op->attribute<pir::BoolAttribute>("is_test").data(); | ||
|
|
||
| symbol::DimExpr input_size = input_shape[1]; | ||
| symbol::DimExpr frame_size = weight_shape[0]; | ||
|
|
||
| // Check if input_size is 3 times frame_size | ||
| infer_context->AddEqualCstr(input_size, frame_size * 3); | ||
|
|
||
| // Check if weight matrix has size [frame_size, frame_size * 3] | ||
| infer_context->AddEqualCstr(weight_shape[1], frame_size * 3); | ||
|
|
||
| if (op->operand(1)) { // Check if H0 is given | ||
| const symbol::ShapeOrDataDimExprs &h0_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
| std::vector<symbol::DimExpr> h0_shape = h0_shape_or_data.shape(); | ||
| infer_context->AddEqualCstr(h0_shape[1], frame_size); | ||
| } | ||
|
|
||
| if (op->operand(3)) { // Check if Bias is given | ||
| const symbol::ShapeOrDataDimExprs &bias_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(3)); | ||
| std::vector<symbol::DimExpr> bias_shape = bias_shape_or_data.shape(); | ||
| infer_context->AddEqualCstr(bias_shape[0], 1); | ||
| infer_context->AddEqualCstr(bias_shape[1], frame_size * 3); | ||
| } | ||
|
|
||
| if (is_test) { | ||
| symbol::TensorShapeOrDataDimExprs batch_gate_shape(input_shape); | ||
| infer_context->SetShapeOrDataForValue(op->result(0), batch_gate_shape); | ||
|
|
||
| symbol::TensorShapeOrDataDimExprs batch_reset_hidden_prev_shape( | ||
| {hidden_shape}); | ||
| infer_context->SetShapeOrDataForValue(op->result(1), | ||
| batch_reset_hidden_prev_shape); | ||
|
|
||
| symbol::TensorShapeOrDataDimExprs batch_hidden_shape({hidden_shape}); | ||
| infer_context->SetShapeOrDataForValue(op->result(2), batch_hidden_shape); | ||
| } else { | ||
| symbol::TensorShapeOrDataDimExprs batch_gate_shape(input_shape); | ||
| infer_context->SetShapeOrDataForValue(op->result(0), batch_gate_shape); | ||
|
|
||
| symbol::TensorShapeOrDataDimExprs batch_reset_hidden_prev_shape( | ||
| {input_shape[0], frame_size}); | ||
| infer_context->SetShapeOrDataForValue(op->result(1), | ||
| batch_reset_hidden_prev_shape); | ||
|
|
||
| symbol::TensorShapeOrDataDimExprs batch_hidden_shape( | ||
| {input_shape[0], frame_size}); | ||
| infer_context->SetShapeOrDataForValue(op->result(2), batch_hidden_shape); | ||
| } | ||
|
|
||
| symbol::TensorShapeOrDataDimExprs hidden_shape_output( | ||
| {input_shape[0], frame_size}); | ||
| infer_context->SetShapeOrDataForValue(op->result(3), hidden_shape_output); | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| bool GruUnitOpInferSymbolicShape( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM for GruUnitOpInferSymbolicShape |
||
| pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
| // Get symbolic shapes of the input tensors | ||
| const symbol::ShapeOrDataDimExprs &input_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
| const symbol::ShapeOrDataDimExprs &hidden_prev_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
| const symbol::ShapeOrDataDimExprs &weight_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(2)); | ||
| const symbol::ShapeOrDataDimExprs &bias_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(3)); | ||
|
|
||
| auto input_shape = input_shape_or_data.shape(); | ||
| auto hidden_prev_shape = hidden_prev_shape_or_data.shape(); | ||
| auto weight_shape = weight_shape_or_data.shape(); | ||
|
|
||
| // Validate input dimensions | ||
| symbol::DimExpr batch_size = input_shape[0]; | ||
| symbol::DimExpr input_size = input_shape[1]; | ||
| symbol::DimExpr frame_size = hidden_prev_shape[1]; | ||
| symbol::DimExpr weight_height = weight_shape[0]; | ||
| symbol::DimExpr weight_width = weight_shape[1]; | ||
|
|
||
| // Enforce dimension constraints using symbolic dimensions | ||
| infer_context->AddEqualCstr(input_size, frame_size * 3); | ||
| infer_context->AddEqualCstr(weight_height, frame_size); | ||
| infer_context->AddEqualCstr(weight_width, frame_size * 3); | ||
|
|
||
| // If bias is used, check its dimensions | ||
| if (!bias_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) { | ||
| auto bias_shape = bias_shape_or_data.shape(); | ||
| symbol::DimExpr bias_height = bias_shape[0]; | ||
| symbol::DimExpr bias_width = bias_shape[1]; | ||
| infer_context->AddEqualCstr(bias_height, 1); | ||
| infer_context->AddEqualCstr(bias_width, frame_size * 3); | ||
| } | ||
|
|
||
| // Set output dimensions | ||
| std::vector<symbol::DimExpr> gate_dims = {batch_size, frame_size * 3}; | ||
| std::vector<symbol::DimExpr> reset_hidden_prev_dims = {batch_size, | ||
| frame_size}; | ||
| std::vector<symbol::DimExpr> hidden_dims = {batch_size, frame_size}; | ||
|
|
||
| infer_context->SetShapeOrDataForValue( | ||
| op->result(0), | ||
| symbol::ShapeOrDataDimExprs{ | ||
| symbol::TensorShapeOrDataDimExprs(gate_dims)}); | ||
| infer_context->SetShapeOrDataForValue( | ||
| op->result(1), | ||
| symbol::ShapeOrDataDimExprs{ | ||
| symbol::TensorShapeOrDataDimExprs(reset_hidden_prev_dims)}); | ||
| infer_context->SetShapeOrDataForValue( | ||
| op->result(2), | ||
| symbol::ShapeOrDataDimExprs{ | ||
| symbol::TensorShapeOrDataDimExprs(hidden_dims)}); | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| bool GroupNormOpInferSymbolicShape( | ||
| pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for GridSampleOpInferSymbolicShape