@@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
30
30
31
31
from models .common import AutoShape , DetectMultiBackend
32
32
from models .experimental import attempt_load
33
- from models .yolo import Model
33
+ from models .yolo import DetectionModel
34
34
from utils .downloads import attempt_download
35
35
from utils .general import LOGGER , check_requirements , intersect_dicts , logging
36
36
from utils .torch_utils import select_device
@@ -45,13 +45,13 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
45
45
if pretrained and channels == 3 and classes == 80 :
46
46
try :
47
47
model = DetectMultiBackend (path , device = device , fuse = autoshape ) # detection model
48
- if autoshape :
48
+ if autoshape and isinstance ( model . model , DetectionModel ) :
49
49
model = AutoShape (model ) # for file/URI/PIL/cv2/np inputs and NMS
50
50
except Exception :
51
51
model = attempt_load (path , device = device , fuse = False ) # arbitrary model
52
52
else :
53
53
cfg = list ((Path (__file__ ).parent / 'models' ).rglob (f'{ path .stem } .yaml' ))[0 ] # model.yaml path
54
- model = Model (cfg , channels , classes ) # create model
54
+ model = DetectionModel (cfg , channels , classes ) # create model
55
55
if pretrained :
56
56
ckpt = torch .load (attempt_download (path ), map_location = device ) # load
57
57
csd = ckpt ['model' ].float ().state_dict () # checkpoint state_dict as FP32
0 commit comments