4040 has_length ,
4141 speed_metrics ,
4242)
43+ from .utils .ckpt_converter import CheckpointConverter
4344from .utils .helper import distributed_file , distributed_isfile # nested_truncate,
4445
4546try :
@@ -720,20 +721,16 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
720721 )
721722 )
722723
723- ckpt_path = os .path .join (resume_from_checkpoint , DIST_CKPT_PATH )
724-
725- if not os .path .isdir (ckpt_path ):
726- raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } " )
727-
728724 if self .args .to_static :
729- opt_state_dict = {
725+ model_state_dict = {
730726 key : value
731- for key , value in self .model_wrapped .state_dict ("opt " ).items ()
727+ for key , value in self .model_wrapped .state_dict ("param " ).items ()
732728 if not any (keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS )
733729 }
734- state_dict = {
735- MODEL_NAME : self .model_wrapped .state_dict ("param" ),
736- OPTIMIZER_NAME : opt_state_dict ,
730+ optim_state_dict = {
731+ key : value
732+ for key , value in self .model_wrapped .state_dict ("opt" ).items ()
733+ if not any (keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS )
737734 }
738735 else :
739736 model_state_dict = self .model_wrapped .state_dict ()
@@ -746,12 +743,27 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
746743 optim_state_dict = self .optimizer .state_dict ()
747744 optim_state_dict .pop ("LR_Scheduler" , None )
748745
749- state_dict = {
750- MODEL_NAME : model_state_dict ,
751- OPTIMIZER_NAME : optim_state_dict ,
752- }
746+ state_dict = {
747+ MODEL_NAME : model_state_dict ,
748+ OPTIMIZER_NAME : optim_state_dict ,
749+ }
753750
754- self ._load_ckpt_func (state_dict , ckpt_path )
751+ parameter_to_structured_name = {}
752+ if self .args .to_static :
753+ parameter_to_structured_name = self .model_wrapped ._parameter_to_structured_name
754+ else :
755+ for state_name , state_value in self .model_wrapped .state_dict ().items ():
756+ parameter_to_structured_name [state_value .name ] = state_name
757+
758+ if self .args .auto_parallel_resume_form_hybrid_parallel :
759+ CheckpointConverter (
760+ resume_from_checkpoint , state_dict , parameter_to_structured_name
761+ ).load_from_hybrid_parallel_checkpoint ()
762+ else :
763+ ckpt_path = os .path .join (resume_from_checkpoint , DIST_CKPT_PATH )
764+ if not os .path .isdir (ckpt_path ):
765+ raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } " )
766+ self ._load_ckpt_func (state_dict , ckpt_path )
755767
756768 # release memory
757769 del state_dict
0 commit comments