|
10 | 10 | import math |
11 | 11 | import numpy as np |
12 | 12 | import torch.distributed as dist |
| 13 | +import torch.nn as nn |
13 | 14 | import torch.nn.functional as F |
14 | 15 | import torch.optim as optim |
15 | 16 | import torch.optim.lr_scheduler as lr_scheduler |
@@ -80,27 +81,26 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
80 | 81 | model = Model(opt.cfg, ch=3, nc=nc).to(device) # create |
81 | 82 |
|
82 | 83 | # Freeze |
83 | | - freeze = ['', ] # parameter names to freeze (full or partial) |
84 | | - if any(freeze): |
85 | | - for k, v in model.named_parameters(): |
86 | | - if any(x in k for x in freeze): |
87 | | - print('freezing %s' % k) |
88 | | - v.requires_grad = False |
| 84 | + freeze = [] # parameter names to freeze (full or partial) |
| 85 | + for k, v in model.named_parameters(): |
| 86 | + v.requires_grad = True # train all layers |
| 87 | + if any(x in k for x in freeze): |
| 88 | + print('freezing %s' % k) |
| 89 | + v.requires_grad = False |
89 | 90 |
|
90 | 91 | # Optimizer |
91 | 92 | nbs = 64 # nominal batch size |
92 | 93 | accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing |
93 | 94 | hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay |
94 | 95 |
|
95 | 96 | pg0, pg1, pg2 = [], [], [] # optimizer parameter groups |
96 | | - for k, v in model.named_parameters(): |
97 | | - v.requires_grad = True |
98 | | - if '.bias' in k: |
99 | | - pg2.append(v) # biases |
100 | | - elif '.weight' in k and '.bn' not in k: |
101 | | - pg1.append(v) # apply weight decay |
102 | | - else: |
103 | | - pg0.append(v) # all else |
| 97 | + for k, v in model.named_modules(): |
| 98 | + if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): |
| 99 | + pg2.append(v.bias) # biases |
| 100 | + if isinstance(v, nn.BatchNorm2d): |
| 101 | + pg0.append(v.weight) # no decay |
| 102 | + elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): |
| 103 | + pg1.append(v.weight) # apply decay |
104 | 104 |
|
105 | 105 | if opt.adam: |
106 | 106 | optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum |
|
0 commit comments