3838from megatron .model .realm_model import ICTBertModel
3939from megatron .utils import check_adlr_autoresume_termination
4040from megatron .utils import make_data_loader
41- from megatron .utils import report_memory , flops_calculator
41+ from megatron .utils import report_memory , flops_calculator , throughput_calculator , checkpoint_throughput_calculator
4242
4343import deepspeed
4444from deepspeed .runtime .utils import see_memory_usage
@@ -106,8 +106,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
106106 valid_data_iterator , model ,
107107 iteration , False )
108108
109- if args .save and iteration != 0 :
110- save_checkpoint (iteration , model , optimizer , lr_scheduler )
109+ # if args.save and iteration != 0:
110+ # save_checkpoint(iteration, model, optimizer, lr_scheduler)
111111
112112 if args .do_test :
113113 # Run on test data.
@@ -175,8 +175,8 @@ def get_optimizer(model):
175175 weight_decay = args .weight_decay )
176176 else :
177177 # Use torch Adam instead of Fused Adam from NVIDIA which seems to have some issue.
178- # optimizer = Adam(param_groups,
179- optimizer = torch .optim .AdamW (param_groups ,
178+ optimizer = Adam (param_groups ,
179+ # optimizer = torch.optim.AdamW(param_groups,
180180 lr = args .lr ,
181181 weight_decay = args .weight_decay ,
182182 betas = (args .adam_beta1 , args .adam_beta2 ),
@@ -384,6 +384,7 @@ def add_to_logging(name):
384384 add_to_logging ('backward-clip-grad' )
385385 add_to_logging ('optimizer' )
386386 add_to_logging ('batch generator' )
387+ add_to_logging ('save checkpoint' )
387388
388389 # Tensorboard values.
389390 if writer and torch .distributed .get_rank () == 0 :
@@ -423,12 +424,14 @@ def add_to_logging(name):
423424 total_loss_dict [got_nan_key ])
424425 total_loss_dict [skipped_iters_key ] = 0
425426 total_loss_dict [got_nan_key ] = 0
427+ timers .log (timers_to_log , normalizer = args .log_interval )
426428 print_rank_0 (log_string )
427429 if report_memory_flag :
428430 report_memory ('after {} iterations' .format (iteration ))
429431 report_memory_flag = False
430- timers . log ( timers_to_log , normalizer = args . log_interval )
432+
431433 flops_calculator (model , args , elapsed_time )
434+ throughput_calculator (model , args , elapsed_time )
432435
433436 return report_memory_flag
434437
@@ -462,11 +465,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
462465 loss_scale = None
463466 if args .fp16 :
464467 loss_scale = optimizer .cur_scale if args .deepspeed else optimizer .loss_scale
465- report_memory_flag = training_log (loss_dict , total_loss_dict ,
466- optimizer .param_groups [0 ]['lr' ],
467- iteration , loss_scale ,
468- report_memory_flag , skipped_iter ,
469- model = model )
470468
471469 # Autoresume
472470 if args .adlr_autoresume and \
@@ -475,9 +473,21 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
475473 lr_scheduler )
476474
477475 # Checkpointing
478- if args .save and args .save_interval and \
479- iteration % args .save_interval == 0 :
476+ should_save_checkpoint = args .save and args .save_interval and \
477+ iteration % args .save_interval == 0
478+ timers ('save checkpoint' ).start ()
479+ if should_save_checkpoint :
480480 save_checkpoint (iteration , model , optimizer , lr_scheduler )
481+ timers ('save checkpoint' ).stop ()
482+
483+ if should_save_checkpoint :
484+ checkpoint_throughput_calculator (model , args , timers ('save checkpoint' ).elapsed (reset = False ))
485+
486+ report_memory_flag = training_log (loss_dict , total_loss_dict ,
487+ optimizer .param_groups [0 ]['lr' ],
488+ iteration , loss_scale ,
489+ report_memory_flag , skipped_iter ,
490+ model = model )
481491
482492 # Evaluation
483493 # XXX temporarily disabled for ZeRO-3
0 commit comments