-
Notifications
You must be signed in to change notification settings - Fork 12
Implement Automatic Mixed Precision with GradScaler to Address NaN Loss Issues #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,6 +129,7 @@ def training_loop( | |
| metrics = None, # Metrics for evaluation. | ||
| cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? | ||
| enable_tf32 = False, # Enable tf32 for A100/H100 GPUs? | ||
| enable_gradscaler = False, # Enable torch.cuda.amp.GradScaler | ||
| device = torch.device('cuda'), | ||
| ): | ||
| # Initialize. | ||
|
|
@@ -168,6 +169,14 @@ def training_loop( | |
| optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer | ||
| augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe | ||
|
|
||
| dist.print0(f'GradScaler enabled: {enable_gradscaler}') | ||
| if enable_gradscaler: | ||
| # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-gradscaler | ||
| # https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation | ||
| dist.print0(f'Setting up GradScaler...') | ||
| scaler = torch.cuda.amp.GradScaler() | ||
| dist.print0(f'Loss scaling is overwritten when GradScaler is enabled') | ||
|
|
||
| dist.print0('Setting up DDP...') | ||
| ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False) | ||
| ema = copy.deepcopy(net).eval().requires_grad_(False) | ||
|
|
@@ -197,6 +206,13 @@ def training_loop( | |
| data = torch.load(resume_state_dump, map_location=torch.device('cpu')) | ||
| misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) | ||
| optimizer.load_state_dict(data['optimizer_state']) | ||
| if enable_gradscaler: | ||
| if 'gradscaler_state' in data: | ||
| dist.print0(f'Loading GradScaler state from "{resume_state_dump}"...') | ||
| # Although not loading the state_dict of the GradScaler works well, loading it can improve reproducibility. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha. Thanks for the comments! |
||
| scaler.load_state_dict(data['gradscaler_state']) | ||
| else: | ||
| dist.print0(f'GradScaler state is not found in "{resume_state_dump}", using the default state.') | ||
| del data # conserve memory | ||
|
|
||
| # Export sample images. | ||
|
|
@@ -253,16 +269,24 @@ def update_scheduler(loss_fn): | |
|
|
||
| loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) | ||
| training_stats.report('Loss/loss', loss) | ||
| # loss.sum().mul(loss_scaling / batch_gpu_total).backward() | ||
| loss.mul(loss_scaling).mean().backward() | ||
|
|
||
| # Update weights. | ||
| # for g in optimizer.param_groups: | ||
| # g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) | ||
| for param in net.parameters(): | ||
| if param.grad is not None: | ||
| torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) | ||
| optimizer.step() | ||
| if enable_gradscaler: | ||
| scaler.scale(loss.mean()).backward() | ||
| else: | ||
| # loss.sum().mul(loss_scaling / batch_gpu_total).backward() | ||
| loss.mul(loss_scaling).mean().backward() | ||
|
|
||
| if enable_gradscaler: | ||
| # TODO Is torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) needed when using GradScaler? | ||
| scaler.step(optimizer) | ||
| scaler.update() | ||
| else: | ||
| # Update weights. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO is also unclear to me either. It seems still useful and compatible per Claude. It's fine to remove my commented code for lr rampup. |
||
| # for g in optimizer.param_groups: | ||
| # g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) | ||
| for param in net.parameters(): | ||
| if param.grad is not None: | ||
| torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) | ||
| optimizer.step() | ||
|
|
||
| # Update EMA. | ||
| if ema_halflife_kimg is not None: | ||
|
|
@@ -317,7 +341,11 @@ def update_scheduler(loss_fn): | |
|
|
||
| # Save full dump of the training state. | ||
| if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: | ||
| torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_tick:06d}.pt')) | ||
| if enable_gradscaler: | ||
| data = dict(net=net, optimizer_state=optimizer.state_dict(), gradscaler_state=scaler.state_dict()) | ||
| else: | ||
| data = dict(net=net, optimizer_state=optimizer.state_dict()) | ||
| torch.save(data, os.path.join(run_dir, f'training-state-{cur_tick:06d}.pt')) | ||
|
|
||
| # Save latest checkpoints | ||
| if (ckpt_ticks is not None) and (done or cur_tick % ckpt_ticks == 0) and cur_tick != 0: | ||
|
|
@@ -335,7 +363,11 @@ def update_scheduler(loss_fn): | |
| del data # conserve memory | ||
|
|
||
| if dist.get_rank() == 0: | ||
| torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-latest.pt')) | ||
| if enable_gradscaler: | ||
| data = dict(net=net, optimizer_state=optimizer.state_dict(), gradscaler_state=scaler.state_dict()) | ||
| else: | ||
| data = dict(net=net, optimizer_state=optimizer.state_dict()) | ||
| torch.save(data, os.path.join(run_dir, f'training-state-latest.pt')) | ||
|
|
||
| # Sample Img | ||
| if (sample_ticks is not None) and (done or cur_tick % sample_ticks == 0) and dist.get_rank() == 0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Zixiang @aiihn ,
Thanks for your neat PR!
Would it be better to use a short abbreviation like
ampas the option name? AMP already stands for Automatic Mixed Precision.