Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ct_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def convert(self, value, param, ctx):
@click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--tf32', help='Enable tf32 for A100/H100 training speed', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
@click.option('--enable_gradscaler', help='Enable torch.cuda.amp.GradScaler, NOTE overwritting loss_scale set by --ls', metavar='BOOL', type=bool, default=False, show_default=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Zixiang @aiihn ,

Thanks for your neat PR!

Would it be better to use a short abbreviation like amp as the option name? AMP already stands for Automatic Mixed Precision.

@click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True)
@click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True)
@click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True)
Expand Down Expand Up @@ -164,7 +165,7 @@ def main(**kwargs):
c.ema_halflife_kimg = int(opts.ema * 1000) if opts.ema is not None else opts.ema
c.ema_beta = opts.ema_beta
c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench, enable_tf32=opts.tf32)
c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench, enable_tf32=opts.tf32, enable_gradscaler=opts.enable_gradscaler)
c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump, ckpt_ticks=opts.ckpt, double_ticks=opts.double)
c.update(mid_t=opts.mid_t, metrics=opts.metrics, sample_ticks=opts.sample_every, eval_ticks=opts.eval_every)

Expand Down
56 changes: 44 additions & 12 deletions training/ct_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def training_loop(
metrics = None, # Metrics for evaluation.
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
enable_tf32 = False, # Enable tf32 for A100/H100 GPUs?
enable_gradscaler = False, # Enable torch.cuda.amp.GradScaler
device = torch.device('cuda'),
):
# Initialize.
Expand Down Expand Up @@ -168,6 +169,14 @@ def training_loop(
optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe

dist.print0(f'GradScaler enabled: {enable_gradscaler}')
if enable_gradscaler:
# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-gradscaler
# https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation
dist.print0(f'Setting up GradScaler...')
scaler = torch.cuda.amp.GradScaler()
dist.print0(f'Loss scaling is overwritten when GradScaler is enabled')

dist.print0('Setting up DDP...')
ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False)
ema = copy.deepcopy(net).eval().requires_grad_(False)
Expand Down Expand Up @@ -197,6 +206,13 @@ def training_loop(
data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
optimizer.load_state_dict(data['optimizer_state'])
if enable_gradscaler:
if 'gradscaler_state' in data:
dist.print0(f'Loading GradScaler state from "{resume_state_dump}"...')
# Although not loading the state_dict of the GradScaler works well, loading it can improve reproducibility.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Thanks for the comments!

scaler.load_state_dict(data['gradscaler_state'])
else:
dist.print0(f'GradScaler state is not found in "{resume_state_dump}", using the default state.')
del data # conserve memory

# Export sample images.
Expand Down Expand Up @@ -253,16 +269,24 @@ def update_scheduler(loss_fn):

loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
training_stats.report('Loss/loss', loss)
# loss.sum().mul(loss_scaling / batch_gpu_total).backward()
loss.mul(loss_scaling).mean().backward()

# Update weights.
# for g in optimizer.param_groups:
# g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
for param in net.parameters():
if param.grad is not None:
torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
optimizer.step()
if enable_gradscaler:
scaler.scale(loss.mean()).backward()
else:
# loss.sum().mul(loss_scaling / batch_gpu_total).backward()
loss.mul(loss_scaling).mean().backward()

if enable_gradscaler:
# TODO Is torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) needed when using GradScaler?
scaler.step(optimizer)
scaler.update()
else:
# Update weights.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO is also unclear to me either. It seems still useful and compatible per Claude.

It's fine to remove my commented code for lr rampup.

# for g in optimizer.param_groups:
# g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
for param in net.parameters():
if param.grad is not None:
torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
optimizer.step()

# Update EMA.
if ema_halflife_kimg is not None:
Expand Down Expand Up @@ -317,7 +341,11 @@ def update_scheduler(loss_fn):

# Save full dump of the training state.
if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_tick:06d}.pt'))
if enable_gradscaler:
data = dict(net=net, optimizer_state=optimizer.state_dict(), gradscaler_state=scaler.state_dict())
else:
data = dict(net=net, optimizer_state=optimizer.state_dict())
torch.save(data, os.path.join(run_dir, f'training-state-{cur_tick:06d}.pt'))

# Save latest checkpoints
if (ckpt_ticks is not None) and (done or cur_tick % ckpt_ticks == 0) and cur_tick != 0:
Expand All @@ -335,7 +363,11 @@ def update_scheduler(loss_fn):
del data # conserve memory

if dist.get_rank() == 0:
torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-latest.pt'))
if enable_gradscaler:
data = dict(net=net, optimizer_state=optimizer.state_dict(), gradscaler_state=scaler.state_dict())
else:
data = dict(net=net, optimizer_state=optimizer.state_dict())
torch.save(data, os.path.join(run_dir, f'training-state-latest.pt'))

# Sample Img
if (sample_ticks is not None) and (done or cur_tick % sample_ticks == 0) and dist.get_rank() == 0:
Expand Down