@@ -262,14 +262,12 @@ std::vector<std::string> OnnxConverter::Convert(
262
262
for (const auto &node : model_proto_.graph ().node ()) {
263
263
NodeAttrHelper helper (node);
264
264
const auto &op = node.op_type ();
265
- if (has_reshape && op != " Gemm" ) {
266
- throw std::invalid_argument (
267
- " Reshape can only be the last layer or precede a gemm layer "
268
- " for now" );
269
- }
270
- has_reshape = false ;
271
265
VLOG (5 ) << " Node " << node.name ();
272
266
if (op == " Conv" ) {
267
+ if (has_reshape) {
268
+ throw std::invalid_argument (" Reshape before " + op +
269
+ " is not supported" );
270
+ }
273
271
VLOG (5 ) << " Start converting Conv" ;
274
272
auto strides = helper.get (" strides" , vector<int >{1 , 1 });
275
273
auto pads = helper.get (" pads" , vector<int >{0 , 0 , 0 , 0 });
@@ -319,6 +317,10 @@ std::vector<std::string> OnnxConverter::Convert(
319
317
VLOG (5 ) << " Converting Conv completed" ;
320
318
} else if (op == " AveragePool" || op == " MaxPool" ||
321
319
op == " GlobalAveragePool" || op == " GlobalMaxPool" ) {
320
+ if (has_reshape) {
321
+ throw std::invalid_argument (" Reshape before " + op +
322
+ " is not supported" );
323
+ }
322
324
VLOG (5 ) << " Start converting Pool" ;
323
325
auto input_name = m (node.input (0 ));
324
326
auto output_name = m (node.output (0 ));
@@ -407,6 +409,10 @@ std::vector<std::string> OnnxConverter::Convert(
407
409
layers_.push_back (layer);
408
410
VLOG (5 ) << " Converting Relu completed" ;
409
411
} else if (op == " Add" ) {
412
+ if (has_reshape) {
413
+ throw std::invalid_argument (" Reshape before " + op +
414
+ " is not supported" );
415
+ }
410
416
VLOG (5 ) << " Start converting Add" ;
411
417
auto input1_name = m (node.input (0 ));
412
418
auto input2_name = m (node.input (1 ));
@@ -420,6 +426,9 @@ std::vector<std::string> OnnxConverter::Convert(
420
426
layers_.push_back (layer);
421
427
VLOG (5 ) << " Converting Add completed" ;
422
428
} else if (op == " Gemm" ) {
429
+ if (has_reshape) {
430
+ has_reshape = false ;
431
+ }
423
432
VLOG (5 ) << " Start converting Gemm" ;
424
433
auto transA = helper.get (" transA" , 0 );
425
434
auto transB = helper.get (" transB" , 0 );
@@ -478,6 +487,10 @@ std::vector<std::string> OnnxConverter::Convert(
478
487
layers_.push_back (layer);
479
488
VLOG (5 ) << " Converting Softmax completed" ;
480
489
} else if (op == " Concat" ) {
490
+ if (has_reshape) {
491
+ throw std::invalid_argument (" Reshape before " + op +
492
+ " is not supported" );
493
+ }
481
494
VLOG (5 ) << " Start converting Concat" ;
482
495
vector<std::string> concat_inputs_str;
483
496
for (const auto &onnx_input : node.input ()) {
0 commit comments