Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2875,12 +2875,45 @@ bool Unsqueeze_OpInferSymbolicShape(
// return true;
// }

// bool UnstackOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool UnstackOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
int rank = x_shape.size();

int axis = op->attribute<pir::Int32Attribute>("axis").data();
int num = op->attribute<pir::Int32Attribute>("num").data();

PADDLE_ENFORCE_GE(axis,
-rank,
common::errors::InvalidArgument(
"The attribute axis is out of range, it must be inside "
"[-rank, rank), where rank = %d",
rank));
PADDLE_ENFORCE_LT(axis,
rank,
common::errors::InvalidArgument(
"The attribute axis is out of range, it must be inside "
"[-rank, rank), where rank = %d",
rank));
if (axis < 0) axis += rank;

infer_context->AddEqualCstr(x_shape[axis], num);

symbol::TensorListShapeOrDataDimExprs out_list_shape_or_data;

std::vector<symbol::DimExpr> out_shape = x_shape;
out_shape.erase(out_shape.begin() + axis);

symbol::TensorShapeOrDataDimExprs out_shape_or_data =
symbol::TensorShapeOrDataDimExprs(out_shape);
for (int i = 0; i < num; i++) {
out_list_shape_or_data.push_back(out_shape_or_data);
}
infer_context->SetShapeOrDataForValue(op->result(0), out_list_shape_or_data);
return true;
}

// bool WeightQuantizeOpInferSymbolicShape(
// pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniqueConsecutive)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unsqueeze)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unsqueeze_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unfold)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unstack)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unstack)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightQuantize)

} // namespace paddle::dialect
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4981,6 +4981,7 @@
kernel :
func : unstack
backward : unstack_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : update_loss_scaling_
args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update=false)
Expand Down