Skip to content

Commit 5c84d6e

Browse files
Change optimizer parameters group method (ultralytics#1239)
* Change optimizer parameters group method * Add torch nn * Change isinstance method(torch.Tensor to nn.Parameter) * parameter freeze fix, PEP8 reformat * freeze bug fix Co-authored-by: Glenn Jocher <[email protected]>
1 parent 2ae4127 commit 5c84d6e

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

train.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import math
1111
import numpy as np
1212
import torch.distributed as dist
13+
import torch.nn as nn
1314
import torch.nn.functional as F
1415
import torch.optim as optim
1516
import torch.optim.lr_scheduler as lr_scheduler
@@ -80,27 +81,26 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
8081
model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
8182

8283
# Freeze
83-
freeze = ['', ] # parameter names to freeze (full or partial)
84-
if any(freeze):
85-
for k, v in model.named_parameters():
86-
if any(x in k for x in freeze):
87-
print('freezing %s' % k)
88-
v.requires_grad = False
84+
freeze = [] # parameter names to freeze (full or partial)
85+
for k, v in model.named_parameters():
86+
v.requires_grad = True # train all layers
87+
if any(x in k for x in freeze):
88+
print('freezing %s' % k)
89+
v.requires_grad = False
8990

9091
# Optimizer
9192
nbs = 64 # nominal batch size
9293
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
9394
hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
9495

9596
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
96-
for k, v in model.named_parameters():
97-
v.requires_grad = True
98-
if '.bias' in k:
99-
pg2.append(v) # biases
100-
elif '.weight' in k and '.bn' not in k:
101-
pg1.append(v) # apply weight decay
102-
else:
103-
pg0.append(v) # all else
97+
for k, v in model.named_modules():
98+
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
99+
pg2.append(v.bias) # biases
100+
if isinstance(v, nn.BatchNorm2d):
101+
pg0.append(v.weight) # no decay
102+
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
103+
pg1.append(v.weight) # apply decay
104104

105105
if opt.adam:
106106
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum

0 commit comments

Comments
 (0)