Skip to content

Commit 5e7b060

Browse files
UnglvKitDepre-commit-ci[bot]glenn-jocher
authored andcommitted
Add tensor hooks and 10.0 gradient clipping (ultralytics#8598)
* Add tensor hooks and gradient clipping ultralytics#8578 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove retain_grad(), because its not necessary * Update train.py * Simplify * Update train.py * Update train.py * Update train.py * Update train.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <[email protected]>
1 parent fbe8dbb commit 5e7b060

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
131131
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
132132
for k, v in model.named_parameters():
133133
v.requires_grad = True # train all layers
134+
v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0.0
134135
if any(x in k for x in freeze):
135136
LOGGER.info(f'freezing {k}')
136137
v.requires_grad = False
@@ -334,8 +335,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
334335
# Backward
335336
scaler.scale(loss).backward()
336337

337-
# Optimize
338+
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
338339
if ni - last_opt_step >= accumulate:
340+
scaler.unscale_(optimizer) # unscale gradients
341+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
339342
scaler.step(optimizer) # optimizer.step
340343
scaler.update()
341344
optimizer.zero_grad()

0 commit comments

Comments
 (0)