Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
'add_n_',
'all_reduce',
'all_reduce_',
'assign_pos',
'batch_fc',
'barrier',
'c_allgather',
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@
inplace : (output -> out)
backward : assign_out__grad

- op : assign_pos
args : (Tensor x, Tensor cum_count, Tensor eff_num_len)
output : Tensor(out)
infer_meta :
func : AssignPosInferMeta
kernel :
func : assign_pos

- op : assign_value
args : (int[] shape, DataType dtype, Scalar[] values, Place place = {})
output : Tensor(out)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@
get_expected_kernel_type :
assign : GetAssignExpectedKernelType

- op : assign_pos
inputs :
{x : X}
outputs :
out : Out

- op : assign_value
outputs :
out : Out
Expand Down
19 changes: 19 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,25 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}

void AssignPosInferMeta(const MetaTensor& x,
const MetaTensor& cum_count,
const MetaTensor& eff_num_len,
MetaTensor* out) {
phi::DataType X_dtype = x.dtype();
phi::DataType cum_count_dtype = cum_count.dtype();

PADDLE_ENFORCE_EQ(cum_count_dtype,
X_dtype,
phi::errors::InvalidArgument(
"The dtype of the cum_count and X should be same"));
PADDLE_ENFORCE_EQ(cum_count_dtype,
phi::DataType::INT64,
phi::errors::InvalidArgument(
"The dtype of the cum_count_dtype, eff_num_len and "
"X should be same as int64"));
out->set_dtype(X_dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要设置ddim吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原本是空的,这里我觉得不改变现状比较好,后续有优化需求单独做

}

void BatchFCInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& step,
MetaTensor* out);

void AssignPosInferMeta(const MetaTensor& x,
const MetaTensor& cum_count,
const MetaTensor& eff_num_len,
MetaTensor* out);

void BatchFCInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_no_check_list
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
test_assign_pos_op
test_bernoulli_op
test_dirichlet_op
test_empty_op
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ test_arg_min_max_op_static_build
test_arg_min_max_v2_op
test_argsort_op
test_assign_op
test_assign_pos_op
test_assign_value_op
test_atan2_op
test_auc_op
Expand Down