@@ -156,6 +156,8 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
156156 if args .should_save :
157157 config_to_save .save_pretrained (save_directory )
158158
159+ paddle .device .cuda .empty_cache ()
160+
159161
160162def load_unified_checkpoint (args , model , optimizer , resume_from_checkpoint : str , safe_serialization = False ) -> None :
161163 """Load potential model checkpoint
@@ -281,6 +283,7 @@ def unified_checkpoint_into_shards(
281283 Returns:
282284 tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name.
283285 """
286+ paddle .device .cuda .empty_cache ()
284287 assert hasattr (model_to_save , "config" )
285288
286289 state_dict = model_to_save .state_dict ()
@@ -311,6 +314,8 @@ def unified_checkpoint_into_shards(
311314 total_size_list ,
312315 )
313316
317+ paddle .device .cuda .empty_cache ()
318+
314319 return state_dict , shard_file , sharded_index
315320
316321
@@ -333,6 +338,8 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio
333338 optim_state_dict , shard_optim_file , sharded_optim_index = results [0 ]
334339 master_weight_state_dict , shard_master_weight_file , sharded_master_weight_index = results [1 ]
335340
341+ paddle .device .cuda .empty_cache ()
342+
336343 save_directory = output_dir
337344 os .makedirs (save_directory , exist_ok = True )
338345
@@ -514,6 +521,7 @@ def unified_optimizer_into_shards(
514521 optimizer (Optimizer): optimizer to save.
515522 safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False.
516523 """
524+ paddle .device .cuda .empty_cache ()
517525 optim_state_dict = nested_copy (optimizer .state_dict ())
518526 master_weights = None
519527 if "master_weights" in optim_state_dict .keys ():
@@ -559,12 +567,15 @@ def unified_optimizer_into_shards(
559567 tp_actions ,
560568 filter_optim_keys ,
561569 )
570+ paddle .device .cuda .empty_cache ()
571+
562572 if master_weights is not None :
563573 master_weights = merge_tensor_parallel_for_optimizer (
564574 master_weights ,
565575 tp_actions ,
566576 filter_master_keys ,
567577 )
578+ paddle .device .cuda .empty_cache ()
568579
569580 # build index json file
570581 index_optimizer_file , index_master_weight_file = {}, {}
@@ -601,6 +612,7 @@ def unified_optimizer_into_shards(
601612 else :
602613 sharded_optim_index ["master_weights" ] = False
603614
615+ paddle .device .cuda .empty_cache ()
604616 if master_weights is None :
605617 return [(optim_state_dict , shard_optimizer_file , sharded_optim_index )]
606618 else :
0 commit comments