@@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
433
433
conf_thres = 0.25 # TF.js NMS: confidence threshold
434
434
):
435
435
t = time .time ()
436
- include = [x .lower () for x in include ]
437
- tf_exports = list (x in include for x in ('saved_model' , 'pb' , 'tflite' , 'edgetpu' , 'tfjs' )) # TensorFlow exports
438
- file = Path (url2file (weights ) if str (weights ).startswith (('http:/' , 'https:/' )) else weights )
436
+ include = [x .lower () for x in include ] # to lowercase
437
+ formats = tuple (export_formats ()['Argument' ][1 :]) # --include arguments
438
+ flags = [x in include for x in formats ]
439
+ assert sum (flags ) == len (include ), f'ERROR: Invalid --include { include } , valid --include arguments are { formats } '
440
+ jit , onnx , xml , engine , coreml , saved_model , pb , tflite , edgetpu , tfjs = flags # export booleans
441
+ file = Path (url2file (weights ) if str (weights ).startswith (('http:/' , 'https:/' )) else weights ) # PyTorch weights
439
442
440
443
# Load PyTorch model
441
444
device = select_device (device )
@@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
475
478
# Exports
476
479
f = ['' ] * 10 # exported filenames
477
480
warnings .filterwarnings (action = 'ignore' , category = torch .jit .TracerWarning ) # suppress TracerWarning
478
- if 'torchscript' in include :
481
+ if jit :
479
482
f [0 ] = export_torchscript (model , im , file , optimize )
480
- if ' engine' in include : # TensorRT required before ONNX
483
+ if engine : # TensorRT required before ONNX
481
484
f [1 ] = export_engine (model , im , file , train , half , simplify , workspace , verbose )
482
- if ( ' onnx' in include ) or ( 'openvino' in include ) : # OpenVINO requires ONNX
485
+ if onnx or xml : # OpenVINO requires ONNX
483
486
f [2 ] = export_onnx (model , im , file , opset , train , dynamic , simplify )
484
- if 'openvino' in include :
487
+ if xml : # OpenVINO
485
488
f [3 ] = export_openvino (model , im , file )
486
- if ' coreml' in include :
489
+ if coreml :
487
490
_ , f [4 ] = export_coreml (model , im , file )
488
491
489
492
# TensorFlow Exports
490
- if any (tf_exports ):
491
- pb , tflite , edgetpu , tfjs = tf_exports [1 :]
493
+ if any ((saved_model , pb , tflite , edgetpu , tfjs )):
492
494
if int8 or edgetpu : # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
493
495
check_requirements (('flatbuffers==1.12' ,)) # required before `import tensorflow`
494
496
assert not (tflite and tfjs ), 'TFLite and TF.js models must be exported separately, please pass only one type.'
0 commit comments