Skip to content

Commit 26c824d

Browse files
authored
[0D-Tensor] Support elementwise_add (#53955)
* [0D-Tensor] Support elementwise_add * support elementwise_add ZeroDim2&3
1 parent 6fde205 commit 26c824d

File tree

5 files changed

+66
-20
lines changed

5 files changed

+66
-20
lines changed

paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,8 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
9494
feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor);
9595
}
9696

97-
PADDLE_ENFORCE_NE(
98-
feed_map[feed_name].shape.size(),
99-
0UL,
100-
platform::errors::PreconditionNotMet(
101-
"The input variable %s's tensor shape cannot be empty,"
102-
"we need the variable's dtype and shape from tensor.",
103-
feed_name.c_str()));
97+
VLOG_IF(4, feed_map[feed_name].shape.size() == 0UL)
98+
<< "Shape is empty, Create 0D-Tensor for " << feed_name;
10499
}
105100
return feed_map;
106101
}

paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,40 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
5858
}
5959
}
6060

61+
// CINN ops in this white list support 0D-Tensor
62+
const std::unordered_set<std::string> white_op_list{"elementwise_add"};
63+
std::unordered_set<std::string> white_tensor_name;
64+
// enable white_op_list only when graph_node_size = 1, which means single op
65+
// test
66+
int graph_node_size = 0;
67+
for (const ir::Node* n : graph->Nodes()) {
68+
if (n->IsOp()) {
69+
graph_node_size++;
70+
VLOG(6) << "Graph has op node " << n->Op()->Type();
71+
if (white_op_list.find(n->Op()->Type()) != white_op_list.end()) {
72+
for (const ir::Node* var : n->inputs) {
73+
white_tensor_name.insert(var->Var()->Name());
74+
75+
std::vector<int64_t> shape = var->Var()->GetShape();
76+
if (shape.empty()) {
77+
VLOG(6) << "input var " << var->Name()
78+
<< " dims is empty, keep it's 0D-Tensor status";
79+
}
80+
}
81+
for (const ir::Node* var : n->outputs) {
82+
white_tensor_name.insert(var->Var()->Name());
83+
84+
std::vector<int64_t> shape = var->Var()->GetShape();
85+
if (shape.empty()) {
86+
VLOG(6) << "output var " << var->Name()
87+
<< " dims is empty, keep it's 0D-Tensor status";
88+
}
89+
}
90+
}
91+
}
92+
}
93+
VLOG(6) << "Graph has " << graph_node_size << " op node";
94+
6195
for (const ir::Node* n : graph->Nodes()) {
6296
if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) {
6397
if (n->Op()->HasAttr("shape")) {
@@ -85,6 +119,11 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
85119
}
86120
if (n->IsVar()) {
87121
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
122+
if (graph_node_size == 1 && white_tensor_name.find(n->Var()->Name()) !=
123+
white_tensor_name.end()) {
124+
VLOG(6) << "Keep 0D-Tensor status of var " << n->Var()->Name();
125+
continue;
126+
}
88127
std::vector<int64_t> shape = n->Var()->GetShape();
89128
if (shape.empty()) {
90129
shape.push_back(1);

paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ TEST(CinnZeroTensorTrickPass, basic) {
2727
ir::Layers layers;
2828
auto* x = layers.data("x", {});
2929
auto* y = layers.data("y", {3, 4});
30-
auto* add_out_0 = layers.elementwise_add(x, y, nullptr, 0);
30+
auto* add_out_0 = layers.mul(x, y, nullptr, 0);
3131
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
3232
auto pass = ir::PassRegistry::Instance().Get("cinn_zero_tensor_trick_pass");
3333
VLOG(3) << DebugString(graph);
@@ -43,7 +43,7 @@ TEST(CinnZeroTensorTrickPass, basic) {
4343
shape.empty(),
4444
false,
4545
platform::errors::PreconditionNotMet(
46-
"The shape of elementwise_add should not be empty after fuse"));
46+
"The shape of mul should not be empty after fuse"));
4747
}
4848
}
4949
}

paddle/fluid/operators/cinn/cinn_launch_context.cc

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,29 @@ void CinnLaunchContext::CheckTensorEquivalent(
257257
// check dimension
258258
auto cinn_tensor = GetCinnTensorOfVar(var_name);
259259
auto cinn_dims = phi::make_ddim(cinn_tensor->shape().data());
260-
PADDLE_ENFORCE_EQ(paddle_tensor.dims(),
261-
cinn_dims,
262-
platform::errors::PreconditionNotMet(
263-
"Tensors' shape in variable(%s) are not equivalent, "
264-
"paddle is = [%s], but cinn is = [%s].",
265-
var_name,
266-
paddle_tensor.dims(),
267-
cinn_dims));
260+
if (paddle_tensor.dims().size() == 0) {
261+
// VLOG when paddle inputs 0D-Tensor
262+
VLOG(4) << "Paddle inputs 0D-Tensor, CINN changes 0D-Tensor " << var_name
263+
<< " to 1D-Tensor";
264+
PADDLE_ENFORCE_EQ(phi::make_ddim({1}),
265+
cinn_dims,
266+
phi::errors::PreconditionNotMet(
267+
"Tensor's shape of variable(%s) are not consistent, "
268+
"paddle inputs 0D-Tensor, cinn should get 1D-Tensor "
269+
"instead of [%s].",
270+
var_name,
271+
paddle_tensor.dims(),
272+
cinn_dims));
273+
} else {
274+
PADDLE_ENFORCE_EQ(paddle_tensor.dims(),
275+
cinn_dims,
276+
phi::errors::PreconditionNotMet(
277+
"Tensor's shape of variable(%s) are not equivalent, "
278+
"paddle is = [%s], but cinn is = [%s].",
279+
var_name,
280+
paddle_tensor.dims(),
281+
cinn_dims));
282+
}
268283

269284
auto cinn_dtype =
270285
framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type());

python/paddle/fluid/tests/unittests/test_elementwise_add_op.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ def init_input_output(self):
116116
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
117117
self.out = np.add(self.x, self.y)
118118

119-
def if_enable_cinn(self):
120-
self.enable_cinn = False
121-
122119

123120
class TestElementwiseAddOp_ZeroDim2(TestElementwiseAddOp_ZeroDim1):
124121
def init_input_output(self):

0 commit comments

Comments
 (0)