Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
'prune_gate_by_capacity',
'push_sparse_v2',
'push_sparse_v2_',
'partial_concat',
'partial_send',
'partial_recv',
'partial_allgather',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,16 @@
func : partial_allgather
inplace : (x -> out)

- op : partial_concat
args : (Tensor[] x, int start_index = 0, int length = -1)
output : Tensor(out)
infer_meta :
func : PartialConcatInferMeta
kernel :
func : partial_concat
data_type : x
backward : partial_concat_grad

- op : partial_recv
args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, bool use_calc_stream = false, int num = 1, int id = 0)
output : Tensor(out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,16 @@
composite : pad_grad(x, out_grad, paddings, pad_value, x_grad)
backward : pad_double_grad

- backward_op : partial_concat_grad
forward : partial_concat (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int start_index, int length)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : PartialConcatGradInferMeta
param : [x]
kernel :
func : partial_concat_grad

- backward_op : pool2d_double_grad
forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ const std::unordered_set<std::string> LegacyOpList = {
SoftReluGradOp::name(),
MatchMatrixTensorOp::name(),
MatchMatrixTensorGradOp::name(),
PartialConcatOp::name(),
PartialConcatGradOp::name(),
NceOp::name(),
NceGradOp::name(),
LrnOp::name(),
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,15 @@
outputs :
out : Out

- op : partial_concat
backward : partial_concat_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false]

- op : partial_recv
outputs :
out : Out
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,16 @@ void NanmedianGradInferMeta(const MetaTensor& x,
x_grad->set_dtype(x.dtype());
}

void PartialConcatGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads) {
auto input_num = xs.size();
for (size_t i = 0; i < input_num; i++) {
auto x_dims = xs[i]->dims();
x_grads[i]->set_dims(x_dims);
x_grads[i]->set_dtype(xs[i]->dtype());
}
}

void NceGradInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const MetaTensor& weight,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ void NanmedianGradInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* x_grad);

void PartialConcatGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads);

void NceGradInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const MetaTensor& weight,
Expand Down
71 changes: 71 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4480,6 +4480,77 @@ void SumInferMeta(const MetaTensor& x,
SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out, config);
}

void PartialConcatInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config) {
int64_t batch_size = -1;
int64_t input_len = -1;

auto inputs_num = xs.size();
PADDLE_ENFORCE_GT(inputs_num,
0,
phi::errors::InvalidArgument(
"ShapeError: Input tensors count should > 0. But "
"received inputs' length is 0."));

// Only support two dimensions now, should be extended later
// when length is -1, need make sure all dimensions to be added are the same
for (size_t i = 0; i < inputs_num; i++) {
auto x_dim = xs[i]->dims();

PADDLE_ENFORCE_EQ(
x_dim.size(),
2,
phi::errors::InvalidArgument("Only support two dimensions input now."));

if (i == 0) {
batch_size = x_dim[0];
input_len = x_dim[1];
} else {
// each tensor's dim must eq
PADDLE_ENFORCE_EQ(x_dim[0],
batch_size,
phi::errors::InvalidArgument(
"The batch size of all inputs must be same"));
PADDLE_ENFORCE_EQ(x_dim[1],
input_len,
phi::errors::InvalidArgument(
"The input len of all inputs must be same"));
}
}

PADDLE_ENFORCE_EQ(
start_index >= -input_len && start_index < input_len,
true,
phi::errors::InvalidArgument(
"The start_index is expected to be in range of [%d, %d), but got %d",
-input_len,
input_len,
start_index));

if (start_index < 0) {
start_index += input_len;
}

if (length > 0) {
PADDLE_ENFORCE_GE(input_len,
start_index + length,
phi::errors::OutOfRange(
"start_index + length is larger than input length"));
}
Comment on lines +4600 to +4605
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么多出这部分判断?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么多出这部分判断?

check了一下,确实不需要,上面的Enforce已经覆盖了这个情况,已修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_index + length <= input_len 被哪些条件覆盖了,坦白讲我没看出来。这里其实是可以合入的,只是得下补充解释。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外PR等CI过了我就直接合入了,你先不要改,这里的改动你加点解释,如果有必要的话可以再提个PR补充。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_index

我原本以为上面的start_index >= -input_len && start_index < input_len,这个判断是判断start_index越界的,然后代入到下面start_index + length这个越界判断,已经覆盖了这个情况,刚意识到start_index + length应该等价于end_index,确实需要判断一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_index + length <= input_len 被哪些条件覆盖了,坦白讲我没看出来。这里其实是可以合入的,只是得下补充解释。

int start_index = static_cast<int>(ComputeStartIndex(

我想了一下,按理说,这段的逻辑应该也要加上这个判断,就比如说:
partial_len 如果大于0的话就直推出了partial_len * inputs_num这么多的大小,并没有考虑start_index + length <= input_len 这个条件

Copy link
Contributor Author

@cmcamdy cmcamdy Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个条件

也就是越界的话就是推大了,然后我看了一眼它的compute定义,似乎也没有考虑这个,就是这里:

memcpy(out_data + out_size * j + partial_len * i,

越界了memcpy可能拷贝到下一行的数据了


std::vector<int64_t> out_dims(2);
out_dims[0] = batch_size;
// colnum = input_num * length
out_dims[1] = (length < 0) ? input_len - start_index : length;
out_dims[1] *= inputs_num;
DDim out_dim = common::make_ddim(out_dims);
out->set_dims(out_dim);
out->set_dtype(xs[0]->dtype());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,12 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void PartialConcatInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config = MetaConfig());

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ test_one_hot_v2_op
test_one_hot_v2_op_static_build
test_overlap_add_op
test_pad3d_op
test_partial_concat_op
test_pass_quantization
test_pixel_shuffle_op
test_poisson_op
Expand Down