@@ -136,21 +136,25 @@ def __init__(self, args):
136136 self ._process_master_weight = None
137137 self ._process_optimizer_weight = None
138138 self ._lock = None
139- self ._shared_save_path = None
140139 self ._shared_save_model_flag = None
141140 self ._shared_save_master_weight_flag = None
142141 self ._shared_save_optimizer_flag = None
143142
144143 if "async_save" in self .args .unified_checkpoint_config :
145144 self ._lock = multiprocessing .Lock ()
146145 self ._shared_save_model_path = multiprocessing .Array ("c" , 100000 )
146+ self ._shared_save_model_signal_path = multiprocessing .Array ("c" , 100000 )
147147 self ._shared_save_master_weight_path = multiprocessing .Array ("c" , 100000 )
148+ self ._shared_save_master_weight_signal_path = multiprocessing .Array ("c" , 100000 )
148149 self ._shared_save_optimizer_path = multiprocessing .Array ("c" , 100000 )
150+ self ._shared_save_optimizer_signal_path = multiprocessing .Array ("c" , 100000 )
149151 self ._shared_save_model_flag = multiprocessing .Array ("i" , 1 )
150152 self ._shared_save_master_weight_flag = multiprocessing .Array ("i" , 1 )
151153 self ._shared_save_optimizer_flag = multiprocessing .Array ("i" , 1 )
152154
153- def _file_save_async_or_sync (self , state_dict , path , is_sync = True , state_dict_type = "model_weight" ):
155+ def _file_save_async_or_sync (
156+ self , state_dict , path , signal_path = None , is_sync = True , state_dict_type = "model_weight"
157+ ):
154158 if is_sync :
155159 for k in list (state_dict .keys ()):
156160 if isinstance (state_dict [k ], paddle .Tensor ):
@@ -165,6 +169,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
165169 meta_dict = self ._meta_dict_model
166170 shared_save_flag = self ._shared_save_model_flag
167171 shared_save_path = self ._shared_save_model_path
172+ shared_save_signal_path = self ._shared_save_model_signal_path
168173 if self ._process_model_weight is None :
169174 self ._process_model_weight = multiprocessing .Process (
170175 target = self ._save_file_async_in_process ,
@@ -173,12 +178,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
173178 self ._shm_model_weight .name ,
174179 self ._shared_save_model_flag ,
175180 self ._shared_save_model_path ,
181+ self ._shared_save_model_signal_path ,
176182 self ._lock ,
177183 state_dict_type ,
178184 self .global_rank ,
179185 ),
180186 )
181187 self ._process_model_weight .start ()
188+ process = self ._process_model_weight
182189 elif state_dict_type == "master_weight" :
183190 if self ._shm_master_weight is None :
184191 self ._meta_dict_master_weight , buffer_size = create_meta_dict (state_dict )
@@ -187,6 +194,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
187194 meta_dict = self ._meta_dict_master_weight
188195 shared_save_flag = self ._shared_save_master_weight_flag
189196 shared_save_path = self ._shared_save_master_weight_path
197+ shared_save_signal_path = self ._shared_save_master_weight_signal_path
190198 if self ._process_master_weight is None :
191199 self ._process_master_weight = multiprocessing .Process (
192200 target = self ._save_file_async_in_process ,
@@ -195,6 +203,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
195203 self ._shm_master_weight .name ,
196204 self ._shared_save_master_weight_flag ,
197205 self ._shared_save_master_weight_path ,
206+ self ._shared_save_master_weight_signal_path ,
198207 self ._lock ,
199208 "model_weight"
200209 if "skip_save_model_weight" in self .args .unified_checkpoint_config
@@ -203,6 +212,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
203212 ),
204213 )
205214 self ._process_master_weight .start ()
215+ process = self ._process_master_weight
206216 elif state_dict_type == "optimizer_weight" :
207217 if self ._shm_optimizer_weight is None :
208218 self ._meta_dict_optim , buffer_size = create_meta_dict (state_dict )
@@ -211,6 +221,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
211221 meta_dict = self ._meta_dict_optim
212222 shared_save_flag = self ._shared_save_optimizer_flag
213223 shared_save_path = self ._shared_save_optimizer_path
224+ shared_save_signal_path = self ._shared_save_optimizer_signal_path
214225 if self ._process_optimizer_weight is None :
215226 self ._process_optimizer_weight = multiprocessing .Process (
216227 target = self ._save_file_async_in_process ,
@@ -219,21 +230,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
219230 self ._shm_optimizer_weight .name ,
220231 self ._shared_save_optimizer_flag ,
221232 self ._shared_save_optimizer_path ,
233+ self ._shared_save_optimizer_signal_path ,
222234 self ._lock ,
223235 state_dict_type ,
224236 self .global_rank ,
225237 ),
226238 )
227239 self ._process_optimizer_weight .start ()
240+ process = self ._process_optimizer_weight
228241
229242 while True : # wait until no process is saving.
230243 flag_value = shared_save_flag [0 ]
231244 if flag_value == 0 :
232245 break
246+ if not process .is_alive ():
247+ raise RuntimeError (f"The process that saves { state_dict_type } has been killed unexpectedly." )
233248 time .sleep (0.5 )
234249 logger .info (f"Wait for the previous save process to finish saving { state_dict_type } " )
235250 # only save model weight or save master weight, we enter this loop.
236251 self ._reset_and_update (shared_save_path , path )
252+ self ._reset_and_update (shared_save_signal_path , signal_path )
237253 _traverse_copy_to_shm (state_dict , meta_dict , shm_state_dict .buf )
238254 with self ._lock :
239255 shared_save_flag [0 ] = 1
@@ -244,6 +260,7 @@ def _save_file_async_in_process(
244260 shm_name ,
245261 shared_save_flag ,
246262 shared_save_path ,
263+ shared_save_signal_path ,
247264 lock ,
248265 state_dict_type ,
249266 global_rank ,
@@ -257,11 +274,12 @@ def _save_file_async_in_process(
257274 continue
258275 if flag_value == 1 : # need to save
259276 path = shared_save_path [:].decode ("utf-8" ).rstrip ("\x00 " )
277+ signal_path = shared_save_signal_path [:].decode ("utf-8" ).rstrip ("\x00 " )
260278 logger .info (f"Start to async save { path } " )
261279 state_dict = _read_state_dict_from_shm (meta_dict , shm ) # numpy array
262280 safe_save_file (state_dict , path , {"format" : "np" })
263281 del state_dict
264- saved_signal_path = os .path .join (os . path . dirname ( path ) , f".{ state_dict_type } .done.{ global_rank } " )
282+ saved_signal_path = os .path .join (signal_path , f".{ state_dict_type } .done.{ global_rank } " )
265283 paddle .save (global_rank , saved_signal_path )
266284 with lock :
267285 shared_save_flag [0 ] = 0
@@ -276,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value):
276294 encoded_value = new_value .encode ("utf-8" )
277295 shared_array [: len (encoded_value )] = encoded_value
278296
279- def save_unified_checkpoint (self , model , optimizer , output_dir ):
297+ def save_unified_checkpoint (self , model , optimizer , output_dir , signal_dir = None ):
280298 """save unified checkpoint
281299
282300 Args:
@@ -313,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
313331
314332 save_directory = output_dir
315333 os .makedirs (save_directory , exist_ok = True )
334+ if signal_dir is not None :
335+ os .makedirs (signal_dir , exist_ok = True ) # only for async save
316336
317337 # save model weights
318338 if not skip_save_model_weight :
@@ -325,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
325345 self ._file_save_async_or_sync (
326346 state_dict ,
327347 path = os .path .join (save_directory , shard_file ),
348+ signal_path = signal_dir ,
328349 is_sync = is_sync_save ,
329350 state_dict_type = "model_weight" ,
330351 )
@@ -394,7 +415,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
394415 if self .args .dataset_rank == 0 or self .args .use_expert_parallel :
395416 load_unified_checkpoint_locally (self .args , model , resume_from_checkpoint , safe_serialization = True )
396417
397- def save_non_merge_optimizer (self , model , optim_state_dict , master_weights , output_dir ):
418+ def save_non_merge_optimizer (self , model , optim_state_dict , master_weights , output_dir , signal_dir ):
398419 paddle .device .cuda .empty_cache ()
399420
400421 # gather global master_weights status.
@@ -446,12 +467,14 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
446467 self ._file_save_async_or_sync (
447468 optim_state_dict ,
448469 path = os .path .join (output_dir , optimizer_name ),
470+ signal_path = signal_dir ,
449471 is_sync = is_sync_save ,
450472 state_dict_type = "optimizer_weight" ,
451473 )
452474 self ._file_save_async_or_sync (
453475 master_weights ,
454476 path = os .path .join (output_dir , master_weights_name ),
477+ signal_path = signal_dir ,
455478 is_sync = is_sync_save ,
456479 state_dict_type = "master_weight" ,
457480 )
@@ -501,17 +524,19 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
501524
502525 return returned_optim_state_dict
503526
504- def save_unified_optimizer (self , model , optimizer , output_dir ):
527+ def save_unified_optimizer (self , model , optimizer , output_dir , signal_dir ):
505528 """save unified optimizer
506529
507530 Args:
508531 model (PretrainedModel): model used to get key mapping.
509532 optimizer (Optimizer): optimizer to save
510533 output_dir (str): Save directory.
534+ signal_dir (str): Asynchronous saving signal directory.
511535
512536 """
537+
513538 if paddle .distributed .get_world_size () <= 1 :
514- save_single_card_optimizer (model , optimizer , output_dir )
539+ save_single_card_optimizer (model , optimizer , output_dir ) # no need to save signal
515540 return
516541
517542 if (
@@ -530,7 +555,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
530555 optim_state_dict .pop ("LR_Scheduler" )
531556
532557 if "ignore_merge_optimizer" in self .args .unified_checkpoint_config :
533- self .save_non_merge_optimizer (model , optim_state_dict , master_weights , output_dir )
558+ self .save_non_merge_optimizer (model , optim_state_dict , master_weights , output_dir , signal_dir )
534559 return
535560
536561 # Split into naive optimizer params and master weights.
@@ -547,20 +572,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
547572 paddle .device .cuda .empty_cache ()
548573 save_directory = output_dir
549574 os .makedirs (save_directory , exist_ok = True )
575+ if signal_dir is not None :
576+ os .makedirs (signal_dir , exist_ok = True )
550577
551578 is_sync_save = True
552579 if "async_save" in self .args .unified_checkpoint_config :
553580 is_sync_save = False
554581 self ._file_save_async_or_sync (
555582 optim_state_dict ,
556583 path = os .path .join (save_directory , shard_optim_file ),
584+ signal_path = signal_dir ,
557585 is_sync = is_sync_save ,
558586 state_dict_type = "optimizer_weight" ,
559587 )
560588 if master_weight_state_dict is not None :
561589 self ._file_save_async_or_sync (
562590 master_weight_state_dict ,
563591 path = os .path .join (save_directory , shard_master_weight_file ),
592+ signal_path = signal_dir ,
564593 is_sync = is_sync_save ,
565594 state_dict_type = "master_weight" ,
566595 )
@@ -634,14 +663,20 @@ def unlink_shared_memory(self):
634663
635664 if self ._shared_save_model_flag is not None :
636665 while self ._shared_save_model_flag [0 ] > 0 : # async process is saving
666+ if not self ._process_model_weight .is_alive ():
667+ raise RuntimeError ("The process that saves model_weight has been killed unexpectedly." )
637668 time .sleep (0.5 )
638669 self ._shared_save_model_flag [0 ] = - 1
639670 if self ._shared_save_master_weight_flag is not None :
640671 while self ._shared_save_master_weight_flag [0 ] > 0 :
672+ if not self ._process_master_weight .is_alive ():
673+ raise RuntimeError ("The process that saves master_weight has been killed unexpectedly." )
641674 time .sleep (0.5 )
642675 self ._shared_save_master_weight_flag [0 ] = - 1
643676 if self ._shared_save_optimizer_flag is not None :
644677 while self ._shared_save_optimizer_flag [0 ] > 0 :
678+ if not self ._process_optimizer_weight .is_alive ():
679+ raise RuntimeError ("The process that saves optimizer_weight has been killed unexpectedly." )
645680 time .sleep (0.5 )
646681 self ._shared_save_optimizer_flag [0 ] = - 1
647682
@@ -658,7 +693,8 @@ def unlink_shared_memory(self):
658693 self ._shm_optimizer_weight .unlink ()
659694 self ._shm_optimizer_weight = None
660695
661- dist .barrier ()
696+ if paddle .distributed .get_world_size () > 1 :
697+ dist .barrier ()
662698
663699
664700def load_unified_checkpoint_locally (args , model , resume_from_checkpoint : str , safe_serialization = False ):
0 commit comments