43
43
"""
44
44
45
45
import argparse
46
+ from copy import deepcopy
46
47
import json
47
48
import os
48
49
import platform
57
58
import torch .nn as nn
58
59
from torch .utils .mobile_optimizer import optimize_for_mobile
59
60
61
+ from sparseml .pytorch .utils import ModuleExporter
62
+ from sparseml .pytorch .sparsification .quantization import skip_onnx_input_quantize
63
+
60
64
FILE = Path (__file__ ).resolve ()
61
65
ROOT = FILE .parents [0 ] # YOLOv5 root directory
62
66
if str (ROOT ) not in sys .path :
63
67
sys .path .append (str (ROOT )) # add ROOT to PATH
64
68
ROOT = Path (os .path .relpath (ROOT , Path .cwd ())) # relative
65
69
66
- from models .common import Conv
70
+ from models .common import Conv , DetectMultiBackend
67
71
from models .experimental import attempt_load
68
- from models .yolo import Detect
72
+ from models .yolo import Detect , Model
69
73
from utils .activations import SiLU
70
74
from utils .datasets import LoadImages
71
75
from utils .general import (LOGGER , check_dataset , check_img_size , check_requirements , check_version , colorstr ,
72
- file_size , print_args , url2file )
73
- from utils .torch_utils import select_device
76
+ file_size , print_args , url2file , intersect_dicts )
77
+ from utils .torch_utils import select_device , torch_distributed_zero_first , is_parallel
78
+ from utils .downloads import attempt_download
79
+ from utils .sparse import SparseMLWrapper , check_download_sparsezoo_weights
80
+
74
81
75
82
76
83
def export_formats ():
@@ -118,14 +125,33 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
118
125
LOGGER .info (f'\n { prefix } starting export with onnx { onnx .__version__ } ...' )
119
126
f = file .with_suffix ('.onnx' )
120
127
121
- torch .onnx .export (model , im , f , verbose = False , opset_version = opset ,
122
- training = torch .onnx .TrainingMode .TRAINING if train else torch .onnx .TrainingMode .EVAL ,
123
- do_constant_folding = not train ,
124
- input_names = ['images' ],
125
- output_names = ['output' ],
126
- dynamic_axes = {'images' : {0 : 'batch' , 2 : 'height' , 3 : 'width' }, # shape(1,3,640,640)
127
- 'output' : {0 : 'batch' , 1 : 'anchors' } # shape(1,25200,85)
128
- } if dynamic else None )
128
+ # export through SparseML so quantized and pruned graphs can be corrected
129
+ save_dir = f .parent .absolute ()
130
+ save_name = str (f ).split (os .path .sep )[- 1 ]
131
+
132
+ # get the number of outputs so we know how to name and change dynamic axes
133
+ # nested outputs can be returned if model is exported with dynamic
134
+ def _count_outputs (outputs ):
135
+ count = 0
136
+ if isinstance (outputs , list ) or isinstance (outputs , tuple ):
137
+ for out in outputs :
138
+ count += _count_outputs (out )
139
+ else :
140
+ count += 1
141
+ return count
142
+
143
+ outputs = model (im )
144
+ num_outputs = _count_outputs (outputs )
145
+ input_names = ['input' ]
146
+ output_names = [f'out_{ i } ' for i in range (num_outputs )]
147
+ dynamic_axes = {k : {0 : 'batch' } for k in (input_names + output_names )} if dynamic else None
148
+ exporter = ModuleExporter (model , save_dir )
149
+ exporter .export_onnx (im , name = save_name , convert_qat = True ,
150
+ input_names = input_names , output_names = output_names , dynamic_axes = dynamic_axes )
151
+ try :
152
+ skip_onnx_input_quantize (f , f )
153
+ except :
154
+ pass
129
155
130
156
# Checks
131
157
model_onnx = onnx .load (f ) # load onnx model
@@ -407,14 +433,123 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
407
433
except Exception as e :
408
434
LOGGER .info (f'\n { prefix } export failure: { e } ' )
409
435
436
+ def create_checkpoint (epoch , model , optimizer , ema , sparseml_wrapper , ** kwargs ):
437
+ pickle = not sparseml_wrapper .qat_active (epoch ) # qat does not support pickled exports
438
+ ckpt_model = deepcopy (model .module if is_parallel (model ) else model ).float ()
439
+ yaml = ckpt_model .yaml
440
+ if not pickle :
441
+ ckpt_model = ckpt_model .state_dict ()
442
+
443
+ return {'epoch' : epoch ,
444
+ 'model' : ckpt_model ,
445
+ 'optimizer' : optimizer .state_dict (),
446
+ 'yaml' : yaml ,
447
+ 'hyp' : model .hyp ,
448
+ ** ema .state_dict (pickle ),
449
+ ** sparseml_wrapper .state_dict (),
450
+ ** kwargs }
451
+
452
+ def load_checkpoint (
453
+ type_ ,
454
+ weights ,
455
+ device ,
456
+ cfg = None ,
457
+ hyp = None ,
458
+ nc = None ,
459
+ data = None ,
460
+ dnn = False ,
461
+ half = False ,
462
+ recipe = None ,
463
+ resume = None ,
464
+ rank = - 1
465
+ ):
466
+ with torch_distributed_zero_first (rank ):
467
+ # download if not found locally or from sparsezoo if stub
468
+ weights = attempt_download (weights ) or check_download_sparsezoo_weights (weights )
469
+ ckpt = torch .load (weights [0 ] if isinstance (weights , list ) or isinstance (weights , tuple )
470
+ else weights , map_location = "cpu" ) # load checkpoint
471
+ start_epoch = ckpt ['epoch' ] + 1 if 'epoch' in ckpt else 0
472
+ pickled = isinstance (ckpt ['model' ], nn .Module )
473
+ train_type = type_ == 'train'
474
+ ensemble_type = type_ == 'ensemble'
475
+ val_type = type_ == 'val'
476
+
477
+ if pickled and ensemble_type :
478
+ cfg = None
479
+ if ensemble_type :
480
+ model = attempt_load (weights , map_location = device ) # load ensemble using pickled
481
+ state_dict = model .state_dict ()
482
+ elif val_type :
483
+ model = DetectMultiBackend (weights , device = device , dnn = dnn , data = data , fp16 = half )
484
+ state_dict = model .model .state_dict ()
485
+ else :
486
+ # load model from config and weights
487
+ cfg = cfg or (ckpt ['yaml' ] if 'yaml' in ckpt else None ) or \
488
+ (ckpt ['model' ].yaml if pickled else None )
489
+ model = Model (cfg , ch = 3 , nc = ckpt ['nc' ] if ('nc' in ckpt and not nc ) else nc ,
490
+ anchors = hyp .get ('anchors' ) if hyp else None ).to (device )
491
+ model_key = 'ema' if (not train_type and 'ema' in ckpt and ckpt ['ema' ]) else 'model'
492
+ state_dict = ckpt [model_key ].float ().state_dict () if pickled else ckpt [model_key ]
493
+ if val_type :
494
+ model = DetectMultiBackend (model = model , device = device , dnn = dnn , data = data , fp16 = half )
495
+
496
+ # turn gradients for params back on in case they were removed
497
+ for p in model .parameters ():
498
+ p .requires_grad = True
499
+
500
+ # load sparseml recipe for applying pruning and quantization
501
+ checkpoint_recipe = train_recipe = None
502
+ if resume :
503
+ train_recipe = ckpt ['recipe' ] if ('recipe' in ckpt ) else None
504
+ elif ckpt ['recipe' ] or recipe :
505
+ train_recipe , checkpoint_recipe = recipe , ckpt ['recipe' ]
506
+
507
+ sparseml_wrapper = SparseMLWrapper (model .model if val_type else model , checkpoint_recipe , train_recipe )
508
+ exclude_anchors = train_type and (cfg or hyp .get ('anchors' )) and not resume
509
+ loaded = False
510
+
511
+ sparseml_wrapper .apply_checkpoint_structure (float ("inf" ))
512
+ if train_type :
513
+ # intialize the recipe for training and restore the weights before if no quantized weights
514
+ quantized_state_dict = any ([name .endswith ('.zero_point' ) for name in state_dict .keys ()])
515
+ if not quantized_state_dict :
516
+ state_dict = load_state_dict (model , state_dict , train = True , exclude_anchors = exclude_anchors )
517
+ loaded = True
518
+ sparseml_wrapper .initialize (start_epoch )
519
+
520
+ if not loaded :
521
+ state_dict = load_state_dict (model , state_dict , train = train_type , exclude_anchors = exclude_anchors )
522
+
523
+ model .float ()
524
+ report = 'Transferred %g/%g items from %s' % (len (state_dict ), len (model .state_dict ()), weights )
525
+
526
+ return model , {
527
+ 'ckpt' : ckpt ,
528
+ 'state_dict' : state_dict ,
529
+ 'sparseml_wrapper' : sparseml_wrapper ,
530
+ 'report' : report ,
531
+ }
532
+
533
+
534
+ def load_state_dict (model , state_dict , train , exclude_anchors ):
535
+ # fix older state_dict names not porting to the new model setup
536
+ state_dict = {key if not key .startswith ("module." ) else key [7 :]: val for key , val in state_dict .items ()}
537
+
538
+ if train :
539
+ # load any missing weights from the model
540
+ state_dict = intersect_dicts (state_dict , model .state_dict (), exclude = ['anchor' ] if exclude_anchors else [])
541
+
542
+ model .load_state_dict (state_dict , strict = not train ) # load
543
+
544
+ return state_dict
410
545
411
546
@torch .no_grad ()
412
547
def run (data = ROOT / 'data/coco128.yaml' , # 'dataset.yaml path'
413
548
weights = ROOT / 'yolov5s.pt' , # weights path
414
549
imgsz = (640 , 640 ), # image (height, width)
415
550
batch_size = 1 , # batch size
416
551
device = 'cpu' , # cuda device, i.e. 0 or 0,1,2,3 or cpu
417
- include = ('torchscript' , ' onnx' ), # include formats
552
+ include = ('onnx' ), # include formats
418
553
half = False , # FP16 half-precision export
419
554
inplace = False , # set YOLOv5 Detect() inplace=True
420
555
train = False , # model.train() mode
@@ -430,7 +565,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
430
565
topk_per_class = 100 , # TF.js NMS: topk per class to keep
431
566
topk_all = 100 , # TF.js NMS: topk for all classes to keep
432
567
iou_thres = 0.45 , # TF.js NMS: IoU threshold
433
- conf_thres = 0.25 # TF.js NMS: confidence threshold
568
+ conf_thres = 0.25 , # TF.js NMS: confidence threshold
569
+ remove_grid = False ,
434
570
):
435
571
t = time .time ()
436
572
include = [x .lower () for x in include ] # to lowercase
@@ -443,8 +579,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
443
579
# Load PyTorch model
444
580
device = select_device (device )
445
581
assert not (device .type == 'cpu' and half ), '--half only compatible with GPU export, i.e. use --device 0'
446
- model = attempt_load (weights , map_location = device , inplace = True , fuse = True ) # load FP32 model
447
- nc , names = model .nc , model .names # number of classes, class names
582
+ model , extras = load_checkpoint (type_ = 'ensemble' , weights = weights , device = device ) # load FP32 model
583
+ sparseml_wrapper = extras ['sparseml_wrapper' ]
584
+ nc , names = extras ["ckpt" ]["nc" ], model .names # number of classes, class names
448
585
449
586
# Checks
450
587
imgsz *= 2 if len (imgsz ) == 1 else 1 # expand
@@ -469,6 +606,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
469
606
m .onnx_dynamic = dynamic
470
607
if hasattr (m , 'forward_export' ):
471
608
m .forward = m .forward_export # assign custom forward (optional)
609
+ model .model [- 1 ].export = not remove_grid # set Detect() layer grid export
472
610
473
611
for _ in range (2 ):
474
612
y = model (im ) # dry runs
@@ -541,6 +679,7 @@ def parse_opt():
541
679
parser .add_argument ('--topk-all' , type = int , default = 100 , help = 'TF.js NMS: topk for all classes to keep' )
542
680
parser .add_argument ('--iou-thres' , type = float , default = 0.45 , help = 'TF.js NMS: IoU threshold' )
543
681
parser .add_argument ('--conf-thres' , type = float , default = 0.25 , help = 'TF.js NMS: confidence threshold' )
682
+ parser .add_argument ("--remove-grid" , action = "store_true" , help = "remove export of Detect() layer grid" )
544
683
parser .add_argument ('--include' , nargs = '+' ,
545
684
default = ['torchscript' , 'onnx' ],
546
685
help = 'torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs' )
@@ -556,4 +695,4 @@ def main(opt):
556
695
557
696
if __name__ == "__main__" :
558
697
opt = parse_opt ()
559
- main (opt )
698
+ main (opt )
0 commit comments