Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
'sparse_momentum',
'soft_relu',
'uniform_random_batch_size_like',
'match_matrix_tensor',
]


Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,15 @@
optional: dropout1_seed, dropout2_seed, linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, ln2_mean, ln2_variance, ln1_mean, ln1_variance, ln1_out
backward: fused_feedforward_grad

- op: match_matrix_tensor
args: (Tensor x, Tensor y, Tensor w, int dim_t=1)
output: Tensor(out), Tensor(tmp)
infer_meta:
func: MatchMatrixTensorInferMeta
kernel:
func: match_matrix_tensor
backward: match_matrix_tensor_grad

- op: number_count
args: (Tensor numbers, int upper_range)
output: Tensor(out)
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,15 @@
func: fused_elemwise_add_activation_grad
optional : x, intermediate_out

- backward_op: match_matrix_tensor_grad
forward: match_matrix_tensor (Tensor x, Tensor y, Tensor w, int dim_t=1) -> Tensor(out), Tensor(tmp)
args: (Tensor x, Tensor y, Tensor w, Tensor tmp, Tensor out_grad, int dim_t=1)
output: Tensor(x_grad), Tensor(y_grad), Tensor(w_grad)
infer_meta:
func: MatchMatrixTensorGradInferMeta
kernel:
func: match_matrix_tensor_grad

- backward_op: unpool_grad
forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out)
args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ const std::unordered_set<std::string> LegacyOpList = {
RowConvOp::name(),
RowConvGradOp::name(),
SoftReluOp::name(),
SoftReluGradOp::name()};
SoftReluGradOp::name(),
MatchMatrixTensorOp::name(),
MatchMatrixTensorGradOp::name()};

enum class AttrType {
UNDEFINED = 0,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3474,6 +3474,13 @@
attrs:
pivot : pivots

- op: match_matrix_tensor
backward: match_matrix_tensor_grad
inputs:
{x : X, y : Y, w : W}
outputs:
{out : Out, tmp : Tmp}

- op: memcpy
inputs:
x: X
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/core/meta_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ class MetaTensor {

virtual bool operator!() const { return tensor_ == nullptr; }

protected:
static void unspecified_bool_true() {}

protected:
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
const LoD& lod() const;
const LoD& lod(int64_t index) const;

protected:
static void unspecified_bool_true() {}

TensorBase* tensor() const;

TensorBase* tensor_ = nullptr;
Expand Down
23 changes: 22 additions & 1 deletion paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,28 @@ void MarginCrossEntropyGradInferMeta(const MetaTensor& logits,
logits_grad->set_dtype(softmax.dtype());
}

void MatchMatrixTensorGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
const MetaTensor& tmp,
const MetaTensor& out_grad,
int dim_t,
MetaTensor* x_grad,
MetaTensor* y_grad,
MetaTensor* w_grad) {
if (x_grad != nullptr) {
x_grad->set_dims(x.dims());
x_grad->share_lod(x);
}
if (y_grad != nullptr) {
y_grad->set_dims(y.dims());
y_grad->share_lod(y);
}
if (w_grad != nullptr) {
w_grad->set_dims(w.dims());
}
}

void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
Expand Down Expand Up @@ -1297,5 +1319,4 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
value_grad->share_lod(values);
}
}

} // namespace phi
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ void MarginCrossEntropyGradInferMeta(const MetaTensor& logits,
float scale,
MetaTensor* logits_grad);

void MatchMatrixTensorGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
const MetaTensor& tmp,
const MetaTensor& out_grad,
int dim_t,
MetaTensor* x_grad,
MetaTensor* y_grad,
MetaTensor* w_grad);

void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
Expand Down
138 changes: 138 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2856,6 +2856,144 @@ void LogspaceInferMeta(const MetaTensor& start,
out->set_dtype(dtype);
}

void MatchMatrixTensorInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
int dim_t,
MetaTensor* out,
MetaTensor* tmp,
MetaConfig config) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image 这个函数应该是放在`ternary.h/.cc`中

auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(x_dims.size(),
2,
phi::errors::InvalidArgument(
"The dimensions of Input(X) should be equal to 2, "
"but received %d.",
x_dims.size()));

auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(y_dims.size(),
2,
phi::errors::InvalidArgument(
"The dimensions of Input(Y) should be equal to 2, "
"but received %d.",
y_dims.size()));

auto w_dims = w.dims();
PADDLE_ENFORCE_EQ(w_dims.size(),
3,
phi::errors::InvalidArgument(
"The dimensions of Input(W) should be equal to 3, "
"but received %d.",
w_dims.size()));

PADDLE_ENFORCE_EQ(
w_dims[0],
x_dims[1],
phi::errors::InvalidArgument(
"The first dimension of Input(W) should be equal to the second "
"dimension of Input(X). But received the first dimension of Input(W) "
"is %d, the second dimension of Input(X) is %d.",
w_dims[0],
x_dims[1]));
PADDLE_ENFORCE_EQ(
w_dims[1],
dim_t,
phi::errors::InvalidArgument(
"The second dimension of Input(W) should be equal to 'dim_t', but "
"received the second dimension of Input(W) is %d, 'dim_t' is %d.",
w_dims[1],
dim_t));
PADDLE_ENFORCE_EQ(
w_dims[2],
y_dims[1],
phi::errors::InvalidArgument(
"The last dimension of Input(W) should be equal to "
"the second dimension of Input(Y). But received the last dimension "
"of Input(W) is %d, the second dimension of Input(Y) is %d.",
w_dims[2],
y_dims[1]));

int64_t out_dim_0 = -1;
int64_t tmp_dim_0 = -1;
if (config.is_runtime) {
const auto& x_lod = x.lod();
PADDLE_ENFORCE_EQ(x_lod.empty(),
false,
phi::errors::InvalidArgument(
"The Input(X) should hold LoD information, but "
"received Input(X).lod() is empty."));
const auto& x_lod_0 = x_lod[0];
PADDLE_ENFORCE_GE(x_lod_0.size(),
2,
phi::errors::InvalidArgument(
"The dimensions of Input(X)'s LoD data should be "
"equal to 2, but received %d.",
x_lod_0.size()));
PADDLE_ENFORCE_EQ(x_dims[0],
static_cast<int64_t>(x_lod_0.back()),
phi::errors::InvalidArgument(
"The last element of Input(X)'s LoD data should be "
"equal to the first dimension of Input(X). "
"But received the last element of Input(X)'s LoD "
"data is %d, the first dimension of Input(X) is %d.",
x_lod_0.back(),
x_dims[0]));

const auto& y_lod = y.lod();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果lod信息仅在runtime时计算,是否可以将此处的计算逻辑放在kernel里?这样就不必将MetaTensor中的lod接口暴露出来。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里主要是用lod信息推导tmp和out的ddim信息,这些可以在kernel里再设置吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是运行时的处理,infermeta里写和在kernel里写是等价的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

PADDLE_ENFORCE_EQ(y_lod.empty(),
false,
phi::errors::InvalidArgument(
"The Input(Y) should hold LoD information, but "
"received Input(Y).lod() is empty."));
const auto& y_lod_0 = y_lod[0];
PADDLE_ENFORCE_GE(y_lod_0.size(),
2,
phi::errors::InvalidArgument(
"The dimensions of Input(Y)'s LoD data should be "
"equal to 2, but received %d.",
y_lod_0.size()));
PADDLE_ENFORCE_EQ(y_dims[0],
static_cast<int64_t>(y_lod_0.back()),
phi::errors::InvalidArgument(
"The last element of Input(Y)'s LoD data should be "
"equal to the first dimension of Input(Y). "
"But received the last element of Input(Y)'s LoD "
"data is %d, the first dimension of Input(Y) is %d.",
y_lod_0.back(),
y_dims[0]));

PADDLE_ENFORCE_EQ(x_lod_0.size(),
y_lod_0.size(),
phi::errors::InvalidArgument(
"The dimensions of Input(X)'s and Input(Y)'s LoD "
"data should be equal. "
"But received the dimensions of Input(X)'s LoD is "
"%d, the dimensions of Input(Y)'s LoD is %d.",
x_lod_0.size(),
y_lod_0.size()));

out_dim_0 = 0;
for (size_t i = 1; i < x_lod_0.size(); i++) {
int64_t x_len = x_lod_0[i] - x_lod_0[i - 1];
int64_t y_len = y_lod_0[i] - y_lod_0[i - 1];
out_dim_0 += (x_len * y_len);
}
out_dim_0 *= dim_t;

tmp_dim_0 = x_dims[0] * dim_t * x_dims[1];
} else {
out->share_lod(x);
}

std::vector<int64_t> out_dims_vec{out_dim_0};
out_dims_vec.push_back(1);
std::vector<int64_t> tmp_dims_vec{tmp_dim_0};
tmp_dims_vec.push_back(1);
out->set_dims(common::make_ddim(out_dims_vec));
tmp->set_dims(common::make_ddim(tmp_dims_vec));
}

void MergedAdamInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& grad,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,14 @@ void LogspaceInferMeta(const MetaTensor& start,
DataType dtype,
MetaTensor* out);

void MatchMatrixTensorInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
int dim_t,
MetaTensor* out,
MetaTensor* tmp,
MetaConfig config = MetaConfig());

void MergedAdamInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& grad,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ test_lu_op
test_lu_unpack_op
test_margin_cross_entropy_op
test_masked_select_op
test_match_matrix_tensor_op
test_matmul_bf16_mkldnn_op
test_matmul_mkldnn_op
test_matmul_v2_op
Expand Down