Skip to content

Commit 99f8b1b

Browse files
glenn-jocherClay Januhowski
authored andcommitted
Add PyTorch Hub classification CI checks (ultralytics#9027)
* Add PyTorch Hub classification CI checks Add PyTorch Hub loading of official and custom trained classification models to CI checks. May help resolve ultralytics#8790 (comment) Signed-off-by: Glenn Jocher <[email protected]> * Update hubconf.py Signed-off-by: Glenn Jocher <[email protected]> Signed-off-by: Glenn Jocher <[email protected]>
1 parent 803b3cd commit 99f8b1b

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

.github/workflows/ci-testing.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,8 @@ jobs:
133133
python classify/predict.py --imgsz 32 --weights $b --source ../datasets/mnist2560/test/7/60.png # predict
134134
python classify/predict.py --imgsz 32 --weights $m --source data/images/bus.jpg # predict
135135
python export.py --weights $b --img 64 --imgsz 224 --include torchscript # export
136+
python - <<EOF
137+
import torch
138+
for path in '$m', '$b':
139+
model = torch.hub.load('.', 'custom', path=path, source='local')
140+
EOF

hubconf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
3030

3131
from models.common import AutoShape, DetectMultiBackend
3232
from models.experimental import attempt_load
33-
from models.yolo import Model
33+
from models.yolo import DetectionModel
3434
from utils.downloads import attempt_download
3535
from utils.general import LOGGER, check_requirements, intersect_dicts, logging
3636
from utils.torch_utils import select_device
@@ -45,13 +45,13 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
4545
if pretrained and channels == 3 and classes == 80:
4646
try:
4747
model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
48-
if autoshape:
48+
if autoshape and isinstance(model.model, DetectionModel):
4949
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
5050
except Exception:
5151
model = attempt_load(path, device=device, fuse=False) # arbitrary model
5252
else:
5353
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
54-
model = Model(cfg, channels, classes) # create model
54+
model = DetectionModel(cfg, channels, classes) # create model
5555
if pretrained:
5656
ckpt = torch.load(attempt_download(path), map_location=device) # load
5757
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32

0 commit comments

Comments
 (0)