Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,8 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor);
}

PADDLE_ENFORCE_NE(
feed_map[feed_name].shape.size(),
0UL,
platform::errors::PreconditionNotMet(
"The input variable %s's tensor shape cannot be empty,"
"we need the variable's dtype and shape from tensor.",
feed_name.c_str()));
VLOG_IF(4, feed_map[feed_name].shape.size() == 0UL)
<< "Shape is empty, Create 0D-Tensor for " << feed_name;
}
return feed_map;
}
Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,40 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
}
}

// CINN ops in this white list support 0D-Tensor
const std::unordered_set<std::string> white_op_list{"elementwise_add"};
std::unordered_set<std::string> white_tensor_name;
// enable white_op_list only when graph_node_size = 1, which means single op
// test
int graph_node_size = 0;
for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp()) {
graph_node_size++;
VLOG(6) << "Graph has op node " << n->Op()->Type();
if (white_op_list.find(n->Op()->Type()) != white_op_list.end()) {
for (const ir::Node* var : n->inputs) {
white_tensor_name.insert(var->Var()->Name());

std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
VLOG(6) << "input var " << var->Name()
<< " dims is empty, keep it's 0D-Tensor status";
}
}
for (const ir::Node* var : n->outputs) {
white_tensor_name.insert(var->Var()->Name());

std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
VLOG(6) << "output var " << var->Name()
<< " dims is empty, keep it's 0D-Tensor status";
}
}
}
}
}
VLOG(6) << "Graph has " << graph_node_size << " op node";

for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) {
if (n->Op()->HasAttr("shape")) {
Expand Down Expand Up @@ -85,6 +119,11 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
}
if (n->IsVar()) {
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
if (graph_node_size == 1 && white_tensor_name.find(n->Var()->Name()) !=
white_tensor_name.end()) {
VLOG(6) << "Keep 0D-Tensor status of var " << n->Var()->Name();
continue;
}
std::vector<int64_t> shape = n->Var()->GetShape();
if (shape.empty()) {
shape.push_back(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ TEST(CinnZeroTensorTrickPass, basic) {
ir::Layers layers;
auto* x = layers.data("x", {});
auto* y = layers.data("y", {3, 4});
auto* add_out_0 = layers.elementwise_add(x, y, nullptr, 0);
auto* add_out_0 = layers.mul(x, y, nullptr, 0);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = ir::PassRegistry::Instance().Get("cinn_zero_tensor_trick_pass");
VLOG(3) << DebugString(graph);
Expand All @@ -43,7 +43,7 @@ TEST(CinnZeroTensorTrickPass, basic) {
shape.empty(),
false,
platform::errors::PreconditionNotMet(
"The shape of elementwise_add should not be empty after fuse"));
"The shape of mul should not be empty after fuse"));
}
}
}
Expand Down
31 changes: 23 additions & 8 deletions paddle/fluid/operators/cinn/cinn_launch_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,29 @@ void CinnLaunchContext::CheckTensorEquivalent(
// check dimension
auto cinn_tensor = GetCinnTensorOfVar(var_name);
auto cinn_dims = phi::make_ddim(cinn_tensor->shape().data());
PADDLE_ENFORCE_EQ(paddle_tensor.dims(),
cinn_dims,
platform::errors::PreconditionNotMet(
"Tensors' shape in variable(%s) are not equivalent, "
"paddle is = [%s], but cinn is = [%s].",
var_name,
paddle_tensor.dims(),
cinn_dims));
if (paddle_tensor.dims().size() == 0) {
// VLOG when paddle inputs 0D-Tensor
VLOG(4) << "Paddle inputs 0D-Tensor, CINN changes 0D-Tensor " << var_name
<< " to 1D-Tensor";
PADDLE_ENFORCE_EQ(phi::make_ddim({1}),
cinn_dims,
phi::errors::PreconditionNotMet(
"Tensor's shape of variable(%s) are not consistent, "
"paddle inputs 0D-Tensor, cinn should get 1D-Tensor "
"instead of [%s].",
var_name,
paddle_tensor.dims(),
cinn_dims));
} else {
PADDLE_ENFORCE_EQ(paddle_tensor.dims(),
cinn_dims,
phi::errors::PreconditionNotMet(
"Tensor's shape of variable(%s) are not equivalent, "
"paddle is = [%s], but cinn is = [%s].",
var_name,
paddle_tensor.dims(),
cinn_dims));
}

auto cinn_dtype =
framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ def init_input_output(self):
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.add(self.x, self.y)

def if_enable_cinn(self):
self.enable_cinn = False


class TestElementwiseAddOp_ZeroDim2(TestElementwiseAddOp_ZeroDim1):
def init_input_output(self):
Expand Down