@@ -277,8 +277,16 @@ def save_pretrained(self, directory: Optional[str] = None, **kwargs):
277
277
"""
278
278
if directory is None :
279
279
directory = os .path .join (self .config .train .checkpoint_dir , "hf_model" )
280
+
280
281
self .accelerator .wait_for_everyone ()
281
- self .accelerator .unwrap_model (self .model ).save_pretrained (directory , ** kwargs )
282
+ self .accelerator .unwrap_model (self .model ).save_pretrained (
283
+ directory ,
284
+ save_function = self .accelerator .save ,
285
+ is_main_process = self .accelerator .is_main_process ,
286
+ state_dict = self .accelerator .get_state_dict (self .model ),
287
+ ** kwargs ,
288
+ )
289
+
282
290
if self .accelerator .is_main_process :
283
291
self .tokenizer .save_pretrained (directory )
284
292
@@ -540,17 +548,24 @@ def learn(self): # noqa: C901
540
548
self .scheduler .step ()
541
549
self .iter_count += 1
542
550
543
- if self .iter_count % self .config .train .checkpoint_interval == 0 :
551
+ if (
552
+ self .iter_count % self .config .train .checkpoint_interval == 0
553
+ or self .iter_count >= self .total_steps
554
+ ):
544
555
subfolder = f"checkpoint_{ self .iter_count :0{len (str (self .total_steps ))}d} "
545
556
directory = os .path .join (self .config .train .checkpoint_dir , subfolder )
546
- self .save (directory )
557
+ logger .info (f"Saving intermediate checkpoint into { directory } " )
558
+ if self .config .train .save_optimizer :
559
+ self .save (directory )
560
+ else :
561
+ self .save_pretrained (directory )
547
562
548
563
stats ["time/forward" ] = forward_time
549
564
stats ["time/backward" ] = backward_time
550
565
for group_number , lr in enumerate (self .scheduler .get_last_lr ()):
551
566
stats [f"learning_rate_group_{ group_number } " ] = lr
552
567
553
- if self .iter_count % self .config .train .eval_interval == 0 :
568
+ if self .iter_count % self .config .train .eval_interval == 0 or self . iter_count >= self . total_steps :
554
569
results = self .evaluate ()
555
570
stats .update (results )
556
571
if ray .is_initialized ():
@@ -571,29 +586,22 @@ def learn(self): # noqa: C901
571
586
if torch .distributed .is_initialized ():
572
587
torch .distributed .all_reduce (do_save , torch .distributed .ReduceOp .MAX )
573
588
if do_save :
574
- best_path = f"{ self .config .train .checkpoint_dir } /best_checkpoint"
575
- logger .info (f"Saving the best state so far into { best_path } " )
576
- self .save (best_path )
589
+ directory = os .path .join (self .config .train .checkpoint_dir , "best_checkpoint" )
590
+ logger .info (f"Saving the best state so far into { directory } " )
591
+ if self .config .train .save_optimizer :
592
+ self .save (directory )
593
+ else :
594
+ self .save_pretrained (directory )
577
595
578
596
desc = " | " .join (f"{ k } : { v :.2f} " for k , v in stats .items () if k .startswith ("loss" ))
579
597
tbar .set_description (f"[{ desc } ]" )
580
598
tbar .update ()
581
599
582
- if self .iter_count >= self .total_steps :
583
- subfolder = f"checkpoint_{ self .iter_count :0{len (str (self .total_steps ))}d} "
584
- directory = os .path .join (self .config .train .checkpoint_dir , subfolder )
585
- results = self .evaluate ()
586
- stats .update (results )
587
-
588
- if ray .is_initialized ():
589
- session .report (filter_non_scalars (stats ), checkpoint = checkpoint )
590
- self .accelerator .log (stats , step = self .iter_count )
600
+ self .accelerator .log (stats , step = self .iter_count )
591
601
592
- self .save ( directory )
602
+ if self . iter_count >= self .total_steps :
593
603
return results
594
604
595
- self .accelerator .log (stats , step = self .iter_count )
596
-
597
605
self .post_backward_callback ()
598
606
599
607
self .post_epoch_callback ()
0 commit comments