@@ -98,24 +98,26 @@ def save_checkpoint(model: torch.nn.Module, path: str, cfg: Any):
9898
9999 if cfg .environment .use_deepspeed :
100100 if path is not None :
101- # gather model params from all ranks
102- # if hasattr(cfg.training, "lora") and cfg.training.lora:
103- # model.backbone.save_pretrained(path)
104- model .save_checkpoint (os .path .join (path , "ds_checkpoint" ))
105- if cfg .environment ._local_rank == 0 :
106- # load to cpu
107- state_dict = get_fp32_state_dict_from_zero_checkpoint (
108- os .path .join (path , "ds_checkpoint" )
101+ # gather model params from all ranks when using Deepspeed
102+ if not model .save_16bit_model (path , "checkpoint.pth" ):
103+ logger .warning (
104+ "deepspeed.save_16bit_model didn't save the model, since"
105+ " stage3_gather_16bit_weights_on_model_save=False."
106+ " Saving the full checkpoint instead"
109107 )
110- # save as normal checkpoint that can be loaded by `load_state_dict`
111- checkpoint = {"model" : state_dict }
112- torch .save (checkpoint , os .path .join (path , "checkpoint.pth" ))
113- shutil .rmtree (os .path .join (path , "ds_checkpoint" ))
108+ model .save_checkpoint (os .path .join (path , "ds_checkpoint" ))
109+ if cfg .environment ._local_rank == 0 :
110+ # load to cpu
111+ state_dict = get_fp32_state_dict_from_zero_checkpoint (
112+ os .path .join (path , "ds_checkpoint" )
113+ )
114+ # save as normal checkpoint that can be loaded by `load_state_dict`
115+ checkpoint = {"model" : state_dict }
116+ torch .save (checkpoint , os .path .join (path , "checkpoint.pth" ))
117+ shutil .rmtree (os .path .join (path , "ds_checkpoint" ))
114118 else :
115119 if cfg .environment ._local_rank == 0 :
116120 model = unwrap_model (model )
117- # if hasattr(cfg.training, "lora") and cfg.training.lora:
118- # model.backbone.save_pretrained(path)
119121 checkpoint = {"model" : model .state_dict ()}
120122 if path is not None :
121123 torch .save (checkpoint , os .path .join (path , "checkpoint.pth" ))
@@ -193,7 +195,9 @@ def load_checkpoint(
193195 if weights_path is None :
194196 weights_path = cfg .architecture .pretrained_weights
195197
196- model_weights = torch .load (weights_path , map_location = "cpu" )["model" ]
198+ model_weights = torch .load (weights_path , map_location = "cpu" )
199+ if "model" in model_weights .keys ():
200+ model_weights = model_weights ["model" ]
197201
198202 model = load_model_weights (model , model_weights , strict , cfg )
199203
@@ -224,6 +228,7 @@ def get_ds_config(cfg: Any):
224228 # zero3
225229 "stage3_prefetch_bucket_size" : cfg .environment .deepspeed_stage3_prefetch_bucket_size , # noqa: E501
226230 "stage3_param_persistence_threshold" : cfg .environment .deepspeed_stage3_param_persistence_threshold , # noqa: E501
231+ "stage3_gather_16bit_weights_on_model_save" : True ,
227232 # zero3 offload cpu
228233 # "stage3_max_live_parameters": cfg.environment.deepspeed_stage3_max_live_parameters, # noqa: E501
229234 # "stage3_max_reuse_distance": cfg.environment.deepspeed_stage3_max_reuse_distance, # noqa: E501
0 commit comments