Skip to content

Commit c766981

Browse files
committed
Enhance the support for reshape
1 parent 9e375c9 commit c766981

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

tools/onnx2bnn/OnnxConverter.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,12 @@ std::vector<std::string> OnnxConverter::Convert(
262262
for (const auto &node : model_proto_.graph().node()) {
263263
NodeAttrHelper helper(node);
264264
const auto &op = node.op_type();
265-
if (has_reshape && op != "Gemm") {
266-
throw std::invalid_argument(
267-
"Reshape can only be the last layer or precede a gemm layer "
268-
"for now");
269-
}
270-
has_reshape = false;
271265
VLOG(5) << "Node " << node.name();
272266
if (op == "Conv") {
267+
if (has_reshape) {
268+
throw std::invalid_argument("Reshape before " + op +
269+
" is not supported");
270+
}
273271
VLOG(5) << "Start converting Conv";
274272
auto strides = helper.get("strides", vector<int>{1, 1});
275273
auto pads = helper.get("pads", vector<int>{0, 0, 0, 0});
@@ -319,6 +317,10 @@ std::vector<std::string> OnnxConverter::Convert(
319317
VLOG(5) << "Converting Conv completed";
320318
} else if (op == "AveragePool" || op == "MaxPool" ||
321319
op == "GlobalAveragePool" || op == "GlobalMaxPool") {
320+
if (has_reshape) {
321+
throw std::invalid_argument("Reshape before " + op +
322+
" is not supported");
323+
}
322324
VLOG(5) << "Start converting Pool";
323325
auto input_name = m(node.input(0));
324326
auto output_name = m(node.output(0));
@@ -407,6 +409,10 @@ std::vector<std::string> OnnxConverter::Convert(
407409
layers_.push_back(layer);
408410
VLOG(5) << "Converting Relu completed";
409411
} else if (op == "Add") {
412+
if (has_reshape) {
413+
throw std::invalid_argument("Reshape before " + op +
414+
" is not supported");
415+
}
410416
VLOG(5) << "Start converting Add";
411417
auto input1_name = m(node.input(0));
412418
auto input2_name = m(node.input(1));
@@ -420,6 +426,9 @@ std::vector<std::string> OnnxConverter::Convert(
420426
layers_.push_back(layer);
421427
VLOG(5) << "Converting Add completed";
422428
} else if (op == "Gemm") {
429+
if (has_reshape) {
430+
has_reshape = false;
431+
}
423432
VLOG(5) << "Start converting Gemm";
424433
auto transA = helper.get("transA", 0);
425434
auto transB = helper.get("transB", 0);
@@ -478,6 +487,10 @@ std::vector<std::string> OnnxConverter::Convert(
478487
layers_.push_back(layer);
479488
VLOG(5) << "Converting Softmax completed";
480489
} else if (op == "Concat") {
490+
if (has_reshape) {
491+
throw std::invalid_argument("Reshape before " + op +
492+
" is not supported");
493+
}
481494
VLOG(5) << "Start converting Concat";
482495
vector<std::string> concat_inputs_str;
483496
for (const auto &onnx_input : node.input()) {

0 commit comments

Comments
 (0)