Skip to content

Commit 09d434a

Browse files
DrRyanHuangHermitSun
authored andcommitted
【PIR OpTest Fix No.31】 fix test_number_count_op (PaddlePaddle#60055)
* add numbercount * add number_count to NEED_GEN_STATIC_ONLY_APIS * revert ops_api_gen * add NEED_GEN_STATIC_ONLY_APIS to number_count * numbers => Numbers * numbers => Numbers in py * numbers => Numbers * fix data_type err && Numbers|numbers err * Number revert
1 parent 4922fc4 commit 09d434a

File tree

7 files changed

+28
-1
lines changed

7 files changed

+28
-1
lines changed

paddle/fluid/operators/number_count_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker {
4646
void Make() override {
4747
AddInput("numbers", "(Tensor) The input gate index tensor.");
4848
AddOutput("Out", "(Tensor) The output number count tensor.");
49-
AddAttr<int>("upper_range", "int), The number of different numbers.");
49+
AddAttr<int>("upper_range", "(int), The number of different numbers.");
5050

5151
AddComment(R"DOC(number_count Operator.count numbers.)DOC");
5252
}

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
'get_tensor_from_selected_rows',
102102
'print',
103103
'sequence_mask',
104+
'number_count',
104105
]
105106

106107
NO_NEED_GEN_STATIC_ONLY_APIS = [

paddle/fluid/pir/dialect/operator/ir/ops.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,15 @@
14281428
optional: dropout1_seed, dropout2_seed, linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, ln2_mean, ln2_variance, ln1_mean, ln1_variance, ln1_out
14291429
backward: fused_feedforward_grad
14301430

1431+
- op: number_count
1432+
args: (Tensor numbers, int upper_range)
1433+
output: Tensor(out)
1434+
infer_meta:
1435+
func: NumberCountInferMeta
1436+
kernel:
1437+
func: number_count
1438+
data_type: numbers
1439+
14311440
- op: sparse_momentum
14321441
args: (Tensor param, Tensor grad, Tensor velocity, Tensor index, Tensor learning_rate, Tensor master_param,float mu, Scalar axis=0, bool use_nesterov=false,str regularization_method="", float regularization_coeff=0.0f, bool multi_precision=false, float rescale_grad=1.0f)
14331442
output: Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out)

paddle/phi/api/yaml/op_compat.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3472,6 +3472,12 @@
34723472
outputs :
34733473
out : Out
34743474

3475+
- op: number_count
3476+
inputs :
3477+
{numbers: numbers}
3478+
outputs :
3479+
out : Out
3480+
34753481
- op: read_from_array
34763482
inputs:
34773483
array : X

paddle/phi/infermeta/unary.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5313,6 +5313,12 @@ void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out) {
53135313
out->set_strides(x.strides());
53145314
}
53155315

5316+
void NumberCountInferMeta(const MetaTensor& x,
5317+
int upper_range,
5318+
MetaTensor* out) {
5319+
out->share_meta(x);
5320+
}
5321+
53165322
} // namespace phi
53175323

53185324
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);

paddle/phi/infermeta/unary.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,10 @@ void UnStackInferMeta(const MetaTensor& x,
789789
int num,
790790
std::vector<MetaTensor*> outs);
791791

792+
void NumberCountInferMeta(const MetaTensor& x,
793+
int upper_range,
794+
MetaTensor* out);
795+
792796
void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out);
793797

794798
} // namespace phi

test/white_list/pir_op_test_white_list

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ test_nms_op
219219
test_nn_functional_hot_op
220220
test_nonzero_api
221221
test_norm_op
222+
test_number_count_op
222223
test_numel_op
223224
test_one_hot_v2_op
224225
test_one_hot_v2_op_static_build

0 commit comments

Comments
 (0)