Skip to content

Commit 44b8664

Browse files
authored
Perf statistics for save_checkpoint (#64)
* save_checkpoint perf monitoring * Disable checkpoint save on exit
1 parent a16ac9e commit 44b8664

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

Megatron-LM-v1.1.5-ZeRO3/megatron/training.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from megatron.model.realm_model import ICTBertModel
3939
from megatron.utils import check_adlr_autoresume_termination
4040
from megatron.utils import make_data_loader
41-
from megatron.utils import report_memory, flops_calculator
41+
from megatron.utils import report_memory, flops_calculator, throughput_calculator, checkpoint_throughput_calculator
4242

4343
import deepspeed
4444
from deepspeed.runtime.utils import see_memory_usage
@@ -106,8 +106,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
106106
valid_data_iterator, model,
107107
iteration, False)
108108

109-
if args.save and iteration != 0:
110-
save_checkpoint(iteration, model, optimizer, lr_scheduler)
109+
# if args.save and iteration != 0:
110+
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
111111

112112
if args.do_test:
113113
# Run on test data.
@@ -175,8 +175,8 @@ def get_optimizer(model):
175175
weight_decay=args.weight_decay)
176176
else:
177177
# Use torch Adam instead of Fused Adam from NVIDIA which seems to have some issue.
178-
#optimizer = Adam(param_groups,
179-
optimizer = torch.optim.AdamW(param_groups,
178+
optimizer = Adam(param_groups,
179+
#optimizer = torch.optim.AdamW(param_groups,
180180
lr=args.lr,
181181
weight_decay=args.weight_decay,
182182
betas=(args.adam_beta1, args.adam_beta2),
@@ -384,6 +384,7 @@ def add_to_logging(name):
384384
add_to_logging('backward-clip-grad')
385385
add_to_logging('optimizer')
386386
add_to_logging('batch generator')
387+
add_to_logging('save checkpoint')
387388

388389
# Tensorboard values.
389390
if writer and torch.distributed.get_rank() == 0:
@@ -423,12 +424,14 @@ def add_to_logging(name):
423424
total_loss_dict[got_nan_key])
424425
total_loss_dict[skipped_iters_key] = 0
425426
total_loss_dict[got_nan_key] = 0
427+
timers.log(timers_to_log, normalizer=args.log_interval)
426428
print_rank_0(log_string)
427429
if report_memory_flag:
428430
report_memory('after {} iterations'.format(iteration))
429431
report_memory_flag = False
430-
timers.log(timers_to_log, normalizer=args.log_interval)
432+
431433
flops_calculator(model, args, elapsed_time)
434+
throughput_calculator(model, args, elapsed_time)
432435

433436
return report_memory_flag
434437

@@ -462,11 +465,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
462465
loss_scale = None
463466
if args.fp16:
464467
loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale
465-
report_memory_flag = training_log(loss_dict, total_loss_dict,
466-
optimizer.param_groups[0]['lr'],
467-
iteration, loss_scale,
468-
report_memory_flag, skipped_iter,
469-
model=model)
470468

471469
# Autoresume
472470
if args.adlr_autoresume and \
@@ -475,9 +473,21 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
475473
lr_scheduler)
476474

477475
# Checkpointing
478-
if args.save and args.save_interval and \
479-
iteration % args.save_interval == 0:
476+
should_save_checkpoint = args.save and args.save_interval and \
477+
iteration % args.save_interval == 0
478+
timers('save checkpoint').start()
479+
if should_save_checkpoint:
480480
save_checkpoint(iteration, model, optimizer, lr_scheduler)
481+
timers('save checkpoint').stop()
482+
483+
if should_save_checkpoint:
484+
checkpoint_throughput_calculator(model, args, timers('save checkpoint').elapsed(reset=False))
485+
486+
report_memory_flag = training_log(loss_dict, total_loss_dict,
487+
optimizer.param_groups[0]['lr'],
488+
iteration, loss_scale,
489+
report_memory_flag, skipped_iter,
490+
model=model)
481491

482492
# Evaluation
483493
# XXX temporarily disabled for ZeRO-3

Megatron-LM-v1.1.5-ZeRO3/megatron/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,22 @@ def flops_calculator(model, args, iteration_time):
194194
effective_tera_flops_per_gpu = giga_flops_per_model_per_train_step / (iteration_time * 1000.0 * gpus_per_model)
195195

196196
print_rank_0(f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B")
197+
198+
199+
def throughput_calculator(model, args, iteration_time):
200+
gpus_per_model = torch.distributed.get_world_size(group = mpu.get_model_parallel_group())
201+
samples_per_model = args.batch_size * args.seq_length
202+
model_replica_count = torch.distributed.get_world_size() / gpus_per_model
203+
approx_parameters_in_billions = get_parameters_in_billions(model)
204+
samples_per_second = samples_per_model * model_replica_count / (iteration_time * 1000.0)
205+
206+
print_rank_0(f'Samples per second: {round(samples_per_second, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B')
207+
208+
209+
def checkpoint_throughput_calculator(model, args, latency_sec):
210+
approx_parameters_in_billions = get_parameters_in_billions(model)
211+
checkpoint_multiplier = 12 # fp16 weights (2), fp32 weights (4), fp32 momentum (4), fp32 variance (4)
212+
checkpoint_giga_bytes = approx_parameters_in_billions * checkpoint_multiplier
213+
giga_bytes_per_second = checkpoint_giga_bytes / latency_sec
214+
215+
print_rank_0(f'Checkpoint Save GB: {round(checkpoint_giga_bytes, 3)}, GB_PerSec: {round(giga_bytes_per_second, 2)}, Latency(secs): {round(latency_sec, 3)}')

0 commit comments

Comments
 (0)