Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit a32b970

Browse files
KSGulinBenjamin
authored andcommitted
Update SparseML Integration to V6.1 (#26)
* SparseML integration * Add SparseML dependancy * Update: add missing files * Update requirements.txt * Update: sparseml-nightly support * Update: remove model versioning * Partial update for multi-stage recipes * Update: multi-stage recipe support * Update: remove sparseml dep * Fix: multi-stage recipe handeling * Fix: multi stage support * Fix: non-recipe runs * Add: legacy hyperparam files * Fix: add copy-paste to hyps * Fix: nit * apply structure fixes
1 parent 011e7df commit a32b970

20 files changed

+758
-103
lines changed

data/hyps/hyp.finetune.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Hyperparameters for VOC finetuning
2+
# python train.py --batch 64 --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50
3+
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials
4+
5+
6+
# Hyperparameter Evolution Results
7+
# Generations: 306
8+
# P R mAP.5 mAP.5:.95 box obj cls
9+
# Metrics: 0.6 0.936 0.896 0.684 0.0115 0.00805 0.00146
10+
11+
lr0: 0.0032
12+
lrf: 0.12
13+
momentum: 0.843
14+
weight_decay: 0.00036
15+
warmup_epochs: 2.0
16+
warmup_momentum: 0.5
17+
warmup_bias_lr: 0.05
18+
box: 0.0296
19+
cls: 0.243
20+
cls_pw: 0.631
21+
obj: 0.301
22+
obj_pw: 0.911
23+
iou_t: 0.2
24+
anchor_t: 2.91
25+
# anchors: 3.63
26+
fl_gamma: 0.0
27+
hsv_h: 0.0138
28+
hsv_s: 0.664
29+
hsv_v: 0.464
30+
degrees: 0.373
31+
translate: 0.245
32+
scale: 0.898
33+
shear: 0.602
34+
perspective: 0.0
35+
flipud: 0.00856
36+
fliplr: 0.5
37+
mosaic: 1.0
38+
mixup: 0.243
39+
copy_paste: 0.0

data/hyps/hyp.scratch.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Hyperparameters for COCO training from scratch
2+
# python train.py --batch 40 --cfg yolov5m.yaml --weights '' --data coco.yaml --img 640 --epochs 300
3+
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials
4+
5+
6+
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
7+
lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf)
8+
momentum: 0.937 # SGD momentum/Adam beta1
9+
weight_decay: 0.0005 # optimizer weight decay 5e-4
10+
warmup_epochs: 3.0 # warmup epochs (fractions ok)
11+
warmup_momentum: 0.8 # warmup initial momentum
12+
warmup_bias_lr: 0.1 # warmup initial bias lr
13+
box: 0.05 # box loss gain
14+
cls: 0.5 # cls loss gain
15+
cls_pw: 1.0 # cls BCELoss positive_weight
16+
obj: 1.0 # obj loss gain (scale with pixels)
17+
obj_pw: 1.0 # obj BCELoss positive_weight
18+
iou_t: 0.20 # IoU training threshold
19+
anchor_t: 4.0 # anchor-multiple threshold
20+
# anchors: 3 # anchors per output layer (0 to ignore)
21+
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
22+
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
23+
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
24+
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
25+
degrees: 0.0 # image rotation (+/- deg)
26+
translate: 0.1 # image translation (+/- fraction)
27+
scale: 0.5 # image scale (+/- gain)
28+
shear: 0.0 # image shear (+/- deg)
29+
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
30+
flipud: 0.0 # image flip up-down (probability)
31+
fliplr: 0.5 # image flip left-right (probability)
32+
mosaic: 1.0 # image mosaic (probability)
33+
mixup: 0.0 # image mixup (probability)
34+
copy_paste: 0.0

detect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
4646
from utils.plots import Annotator, colors, save_one_box
4747
from utils.torch_utils import select_device, time_sync
48+
from export import load_checkpoint
4849

4950

5051
@torch.no_grad()
@@ -89,7 +90,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
8990

9091
# Load model
9192
device = select_device(device)
92-
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
93+
model, extras = load_checkpoint(type_='val', weights=weights, device=device) # load FP32 model
9394
stride, names, pt = model.stride, model.names, model.pt
9495
imgsz = check_img_size(imgsz, s=stride) # check image size
9596

export.py

Lines changed: 156 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"""
4444

4545
import argparse
46+
from copy import deepcopy
4647
import json
4748
import os
4849
import platform
@@ -57,20 +58,26 @@
5758
import torch.nn as nn
5859
from torch.utils.mobile_optimizer import optimize_for_mobile
5960

61+
from sparseml.pytorch.utils import ModuleExporter
62+
from sparseml.pytorch.sparsification.quantization import skip_onnx_input_quantize
63+
6064
FILE = Path(__file__).resolve()
6165
ROOT = FILE.parents[0] # YOLOv5 root directory
6266
if str(ROOT) not in sys.path:
6367
sys.path.append(str(ROOT)) # add ROOT to PATH
6468
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
6569

66-
from models.common import Conv
70+
from models.common import Conv, DetectMultiBackend
6771
from models.experimental import attempt_load
68-
from models.yolo import Detect
72+
from models.yolo import Detect, Model
6973
from utils.activations import SiLU
7074
from utils.datasets import LoadImages
7175
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
72-
file_size, print_args, url2file)
73-
from utils.torch_utils import select_device
76+
file_size, print_args, url2file, intersect_dicts)
77+
from utils.torch_utils import select_device, torch_distributed_zero_first, is_parallel
78+
from utils.downloads import attempt_download
79+
from utils.sparse import SparseMLWrapper, check_download_sparsezoo_weights
80+
7481

7582

7683
def export_formats():
@@ -118,14 +125,33 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
118125
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
119126
f = file.with_suffix('.onnx')
120127

121-
torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
122-
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
123-
do_constant_folding=not train,
124-
input_names=['images'],
125-
output_names=['output'],
126-
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
127-
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
128-
} if dynamic else None)
128+
# export through SparseML so quantized and pruned graphs can be corrected
129+
save_dir = f.parent.absolute()
130+
save_name = str(f).split(os.path.sep)[-1]
131+
132+
# get the number of outputs so we know how to name and change dynamic axes
133+
# nested outputs can be returned if model is exported with dynamic
134+
def _count_outputs(outputs):
135+
count = 0
136+
if isinstance(outputs, list) or isinstance(outputs, tuple):
137+
for out in outputs:
138+
count += _count_outputs(out)
139+
else:
140+
count += 1
141+
return count
142+
143+
outputs = model(im)
144+
num_outputs = _count_outputs(outputs)
145+
input_names = ['input']
146+
output_names = [f'out_{i}' for i in range(num_outputs)]
147+
dynamic_axes = {k: {0: 'batch'} for k in (input_names + output_names)} if dynamic else None
148+
exporter = ModuleExporter(model, save_dir)
149+
exporter.export_onnx(im, name=save_name, convert_qat=True,
150+
input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
151+
try:
152+
skip_onnx_input_quantize(f, f)
153+
except:
154+
pass
129155

130156
# Checks
131157
model_onnx = onnx.load(f) # load onnx model
@@ -407,14 +433,123 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
407433
except Exception as e:
408434
LOGGER.info(f'\n{prefix} export failure: {e}')
409435

436+
def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
437+
pickle = not sparseml_wrapper.qat_active(epoch) # qat does not support pickled exports
438+
ckpt_model = deepcopy(model.module if is_parallel(model) else model).float()
439+
yaml = ckpt_model.yaml
440+
if not pickle:
441+
ckpt_model = ckpt_model.state_dict()
442+
443+
return {'epoch': epoch,
444+
'model': ckpt_model,
445+
'optimizer': optimizer.state_dict(),
446+
'yaml': yaml,
447+
'hyp': model.hyp,
448+
**ema.state_dict(pickle),
449+
**sparseml_wrapper.state_dict(),
450+
**kwargs}
451+
452+
def load_checkpoint(
453+
type_,
454+
weights,
455+
device,
456+
cfg=None,
457+
hyp=None,
458+
nc=None,
459+
data=None,
460+
dnn=False,
461+
half = False,
462+
recipe=None,
463+
resume=None,
464+
rank=-1
465+
):
466+
with torch_distributed_zero_first(rank):
467+
# download if not found locally or from sparsezoo if stub
468+
weights = attempt_download(weights) or check_download_sparsezoo_weights(weights)
469+
ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple)
470+
else weights, map_location="cpu") # load checkpoint
471+
start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0
472+
pickled = isinstance(ckpt['model'], nn.Module)
473+
train_type = type_ == 'train'
474+
ensemble_type = type_ == 'ensemble'
475+
val_type = type_ =='val'
476+
477+
if pickled and ensemble_type:
478+
cfg = None
479+
if ensemble_type:
480+
model = attempt_load(weights, map_location=device) # load ensemble using pickled
481+
state_dict = model.state_dict()
482+
elif val_type:
483+
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
484+
state_dict = model.model.state_dict()
485+
else:
486+
# load model from config and weights
487+
cfg = cfg or (ckpt['yaml'] if 'yaml' in ckpt else None) or \
488+
(ckpt['model'].yaml if pickled else None)
489+
model = Model(cfg, ch=3, nc=ckpt['nc'] if ('nc' in ckpt and not nc) else nc,
490+
anchors=hyp.get('anchors') if hyp else None).to(device)
491+
model_key = 'ema' if (not train_type and 'ema' in ckpt and ckpt['ema']) else 'model'
492+
state_dict = ckpt[model_key].float().state_dict() if pickled else ckpt[model_key]
493+
if val_type:
494+
model = DetectMultiBackend(model=model, device=device, dnn=dnn, data=data, fp16=half)
495+
496+
# turn gradients for params back on in case they were removed
497+
for p in model.parameters():
498+
p.requires_grad = True
499+
500+
# load sparseml recipe for applying pruning and quantization
501+
checkpoint_recipe = train_recipe = None
502+
if resume:
503+
train_recipe = ckpt['recipe'] if ('recipe' in ckpt) else None
504+
elif ckpt['recipe'] or recipe:
505+
train_recipe, checkpoint_recipe = recipe, ckpt['recipe']
506+
507+
sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, checkpoint_recipe, train_recipe)
508+
exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume
509+
loaded = False
510+
511+
sparseml_wrapper.apply_checkpoint_structure(float("inf"))
512+
if train_type:
513+
# intialize the recipe for training and restore the weights before if no quantized weights
514+
quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()])
515+
if not quantized_state_dict:
516+
state_dict = load_state_dict(model, state_dict, train=True, exclude_anchors=exclude_anchors)
517+
loaded = True
518+
sparseml_wrapper.initialize(start_epoch)
519+
520+
if not loaded:
521+
state_dict = load_state_dict(model, state_dict, train=train_type, exclude_anchors=exclude_anchors)
522+
523+
model.float()
524+
report = 'Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)
525+
526+
return model, {
527+
'ckpt': ckpt,
528+
'state_dict': state_dict,
529+
'sparseml_wrapper': sparseml_wrapper,
530+
'report': report,
531+
}
532+
533+
534+
def load_state_dict(model, state_dict, train, exclude_anchors):
535+
# fix older state_dict names not porting to the new model setup
536+
state_dict = {key if not key.startswith("module.") else key[7:]: val for key, val in state_dict.items()}
537+
538+
if train:
539+
# load any missing weights from the model
540+
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=['anchor'] if exclude_anchors else [])
541+
542+
model.load_state_dict(state_dict, strict=not train) # load
543+
544+
return state_dict
410545

411546
@torch.no_grad()
412547
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
413548
weights=ROOT / 'yolov5s.pt', # weights path
414549
imgsz=(640, 640), # image (height, width)
415550
batch_size=1, # batch size
416551
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
417-
include=('torchscript', 'onnx'), # include formats
552+
include=('onnx'), # include formats
418553
half=False, # FP16 half-precision export
419554
inplace=False, # set YOLOv5 Detect() inplace=True
420555
train=False, # model.train() mode
@@ -430,7 +565,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
430565
topk_per_class=100, # TF.js NMS: topk per class to keep
431566
topk_all=100, # TF.js NMS: topk for all classes to keep
432567
iou_thres=0.45, # TF.js NMS: IoU threshold
433-
conf_thres=0.25 # TF.js NMS: confidence threshold
568+
conf_thres=0.25, # TF.js NMS: confidence threshold
569+
remove_grid=False,
434570
):
435571
t = time.time()
436572
include = [x.lower() for x in include] # to lowercase
@@ -443,8 +579,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
443579
# Load PyTorch model
444580
device = select_device(device)
445581
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
446-
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
447-
nc, names = model.nc, model.names # number of classes, class names
582+
model, extras = load_checkpoint(type_='ensemble', weights=weights, device=device) # load FP32 model
583+
sparseml_wrapper = extras['sparseml_wrapper']
584+
nc, names = extras["ckpt"]["nc"], model.names # number of classes, class names
448585

449586
# Checks
450587
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
@@ -469,6 +606,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
469606
m.onnx_dynamic = dynamic
470607
if hasattr(m, 'forward_export'):
471608
m.forward = m.forward_export # assign custom forward (optional)
609+
model.model[-1].export = not remove_grid # set Detect() layer grid export
472610

473611
for _ in range(2):
474612
y = model(im) # dry runs
@@ -541,6 +679,7 @@ def parse_opt():
541679
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
542680
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
543681
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
682+
parser.add_argument("--remove-grid", action="store_true", help="remove export of Detect() layer grid")
544683
parser.add_argument('--include', nargs='+',
545684
default=['torchscript', 'onnx'],
546685
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
@@ -556,4 +695,4 @@ def main(opt):
556695

557696
if __name__ == "__main__":
558697
opt = parse_opt()
559-
main(opt)
698+
main(opt)

0 commit comments

Comments
 (0)