Skip to content

Commit a5a32e4

Browse files
authored
【Infer Symbolic Shape No.123】【BUAA】Add Bincount (#67435)
1 parent e77a4b0 commit a5a32e4

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,42 @@ bool Binomial_OpInferSymbolicShape(
234234
return BinomialOpInferSymbolicShape(op, infer_context);
235235
}
236236

237-
// bool BincountOpInferSymbolicShape(pir::Operation *op,
238-
// pir::InferSymbolicShapeContext
239-
// *infer_context) {
240-
// // pass
241-
// return true;
242-
// }
237+
bool BincountOpInferSymbolicShape(
238+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
239+
const auto &x_shape_or_data =
240+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
241+
const std::vector<symbol::DimExpr> &x_dims = x_shape_or_data.shape();
242+
243+
PADDLE_ENFORCE_EQ(x_dims.size(),
244+
1,
245+
common::errors::InvalidArgument(
246+
"The 'shape' of Input(X) must be 1-D tensor. But the "
247+
"dimension of Input(X) is [%d]",
248+
x_dims.size()));
249+
250+
if (op->operand_source(1)) {
251+
const auto &weights_shape_or_data =
252+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
253+
const std::vector<symbol::DimExpr> &weights_dims =
254+
weights_shape_or_data.shape();
255+
256+
PADDLE_ENFORCE_EQ(weights_dims.size(),
257+
1,
258+
common::errors::InvalidArgument(
259+
"The 'shape' of Input(Weights) must be 1-D tensor. "
260+
"But the dimension of Input(Weights) is [%d]",
261+
weights_dims.size()));
262+
infer_context->AddEqualCstr(weights_dims[0], x_dims[0]);
263+
}
264+
265+
symbol::DimExpr out_unknown = infer_context->GetNextSymName();
266+
const std::vector<symbol::DimExpr> out_dims = {out_unknown};
267+
symbol::ShapeOrDataDimExprs output_dims{
268+
symbol::TensorShapeOrDataDimExprs(out_dims)};
269+
infer_context->SetShapeOrDataForValue(op->result(0), output_dims);
270+
271+
return true;
272+
}
243273

244274
// bool BmmOpInferSymbolicShape(pir::Operation *op,
245275
// pir::InferSymbolicShapeContext *infer_context) {

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_)
2626
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BoxClip)
2727
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial)
2828
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial_)
29-
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount)
29+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount)
3030
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bmm)
3131
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(CholeskySolve)
3232
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CtcAlign)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@
549549
kernel:
550550
func: bincount
551551
optional: weights
552+
interfaces : paddle::dialect::InferSymbolicShapeInterface
552553

553554
- op : binomial
554555
args : (Tensor count, Tensor prob)

test/legacy_test/test_bincount_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def init_test_case(self):
156156
self.Out = np.bincount(self.np_input, minlength=self.minlength)
157157

158158
def test_check_output(self):
159-
self.check_output(check_pir=True)
159+
self.check_output(check_pir=True, check_symbol_infer=False)
160160

161161

162162
class TestCase1(TestBincountOp):

0 commit comments

Comments
 (0)