Skip to content

Commit 500efb5

Browse files
glenn-jocherClay Januhowski
authored andcommitted
Link fuse() to AutoShape() for Hub models (ultralytics#8599)
1 parent 895ffba commit 500efb5

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

hubconf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,14 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
3636

3737
if not verbose:
3838
LOGGER.setLevel(logging.WARNING)
39-
4039
check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
4140
name = Path(name)
4241
path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
4342
try:
4443
device = select_device(device)
4544

4645
if pretrained and channels == 3 and classes == 80:
47-
model = DetectMultiBackend(path, device=device) # download/load FP32 model
46+
model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model
4847
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
4948
else:
5049
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path

models/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def forward(self, x):
305305

306306
class DetectMultiBackend(nn.Module):
307307
# YOLOv5 MultiBackend class for python inference on various backends
308-
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
308+
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
309309
# Usage:
310310
# PyTorch: weights = *.pt
311311
# TorchScript: *.torchscript
@@ -331,7 +331,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
331331
names = yaml.safe_load(f)['names']
332332

333333
if pt: # PyTorch
334-
model = attempt_load(weights if isinstance(weights, list) else w, device=device)
334+
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
335335
stride = max(int(model.stride.max()), 32) # model stride
336336
names = model.module.names if hasattr(model, 'module') else model.names # get class names
337337
model.half() if fp16 else model.float()

0 commit comments

Comments
 (0)