Skip to content

Commit 0a4adf0

Browse files
authored
Resume with custom anchors fix (ultralytics#2361)
* Resume with custom anchors fix * Update train.py
1 parent 46e23d6 commit 0a4adf0

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
7575
with torch_distributed_zero_first(rank):
7676
attempt_download(weights) # download if not found locally
7777
ckpt = torch.load(weights, map_location=device) # load checkpoint
78-
if hyp.get('anchors'):
79-
ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
80-
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
81-
exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
78+
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
79+
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
8280
state_dict = ckpt['model'].float().state_dict() # to FP32
8381
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
8482
model.load_state_dict(state_dict, strict=False) # load
@@ -216,6 +214,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
216214
# Anchors
217215
if not opt.noautoanchor:
218216
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
217+
model.half().float() # pre-reduce anchor precision
219218

220219
# Model parameters
221220
hyp['box'] *= 3. / nl # scale to layers

0 commit comments

Comments
 (0)