Skip to content

Commit cd77f50

Browse files
authored
intersect_dicts() in hubconf.py fix (ultralytics#5542)
1 parent 0cdc250 commit cd77f50

File tree

4 files changed

+10
-12
lines changed

4 files changed

+10
-12
lines changed

hubconf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
3030
from models.experimental import attempt_load
3131
from models.yolo import Model
3232
from utils.downloads import attempt_download
33-
from utils.general import check_requirements, set_logging
33+
from utils.general import check_requirements, intersect_dicts, set_logging
3434
from utils.torch_utils import select_device
3535

3636
file = Path(__file__).resolve()
@@ -49,9 +49,8 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
4949
model = Model(cfg, channels, classes) # create model
5050
if pretrained:
5151
ckpt = torch.load(attempt_download(path), map_location=device) # load
52-
msd = model.state_dict() # model state_dict
5352
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
54-
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
53+
csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
5554
model.load_state_dict(csd, strict=False) # load
5655
if len(ckpt['model'].names) == classes:
5756
model.names = ckpt['model'].names # set class names attribute

train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,14 @@
4343
from utils.downloads import attempt_download
4444
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
4545
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
46-
labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args,
47-
print_mutation, strip_optimizer)
46+
intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
47+
print_args, print_mutation, strip_optimizer)
4848
from utils.loggers import Loggers
4949
from utils.loggers.wandb.wandb_utils import check_wandb_resume
5050
from utils.loss import ComputeLoss
5151
from utils.metrics import fitness
5252
from utils.plots import plot_evolve, plot_labels
53-
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device,
54-
torch_distributed_zero_first)
53+
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
5554

5655
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
5756
RANK = int(os.getenv('RANK', -1))

utils/general.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def init_seeds(seed=0):
125125
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
126126

127127

128+
def intersect_dicts(da, db, exclude=()):
129+
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
130+
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
131+
132+
128133
def get_latest_run(search_dir='.'):
129134
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
130135
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)

utils/torch_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,6 @@ def de_parallel(model):
153153
return model.module if is_parallel(model) else model
154154

155155

156-
def intersect_dicts(da, db, exclude=()):
157-
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
158-
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
159-
160-
161156
def initialize_weights(model):
162157
for m in model.modules():
163158
t = type(m)

0 commit comments

Comments
 (0)