Skip to content

Commit af8069d

Browse files
glenn-jochereladco
authored andcommitted
Use export_formats() in export.py (ultralytics#6705)
* Use `export_formats()` in export.py * list fix
1 parent aca9f8f commit af8069d

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

export.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
433433
conf_thres=0.25 # TF.js NMS: confidence threshold
434434
):
435435
t = time.time()
436-
include = [x.lower() for x in include]
437-
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
438-
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
436+
include = [x.lower() for x in include] # to lowercase
437+
formats = tuple(export_formats()['Argument'][1:]) # --include arguments
438+
flags = [x in include for x in formats]
439+
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'
440+
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
441+
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
439442

440443
# Load PyTorch model
441444
device = select_device(device)
@@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
475478
# Exports
476479
f = [''] * 10 # exported filenames
477480
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
478-
if 'torchscript' in include:
481+
if jit:
479482
f[0] = export_torchscript(model, im, file, optimize)
480-
if 'engine' in include: # TensorRT required before ONNX
483+
if engine: # TensorRT required before ONNX
481484
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
482-
if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
485+
if onnx or xml: # OpenVINO requires ONNX
483486
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
484-
if 'openvino' in include:
487+
if xml: # OpenVINO
485488
f[3] = export_openvino(model, im, file)
486-
if 'coreml' in include:
489+
if coreml:
487490
_, f[4] = export_coreml(model, im, file)
488491

489492
# TensorFlow Exports
490-
if any(tf_exports):
491-
pb, tflite, edgetpu, tfjs = tf_exports[1:]
493+
if any((saved_model, pb, tflite, edgetpu, tfjs)):
492494
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
493495
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
494496
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'

0 commit comments

Comments
 (0)