Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
from models.experimental import attempt_load
from models.yolo import Model
from utils.downloads import attempt_download
from utils.general import check_requirements, set_logging
from utils.general import check_requirements, intersect_dicts, set_logging
from utils.torch_utils import select_device

file = Path(__file__).resolve()
Expand All @@ -49,9 +49,8 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
model = Model(cfg, channels, classes) # create model
if pretrained:
ckpt = torch.load(attempt_download(path), map_location=device) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
Expand Down
7 changes: 3 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,14 @@
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args,
print_mutation, strip_optimizer)
intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device,
torch_distributed_zero_first)
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first

LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
Expand Down
5 changes: 5 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def init_seeds(seed=0):
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)


def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}


def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
Expand Down
5 changes: 0 additions & 5 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,6 @@ def de_parallel(model):
return model.module if is_parallel(model) else model


def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}


def initialize_weights(model):
for m in model.modules():
t = type(m)
Expand Down