@@ -247,24 +247,25 @@ def gen_metadata_and_prepare_source_state_dict(self):
247247 # Generate the optimizer states corresponding to the model weights.
248248 logger .info ("Requesting GPU memory space to concatenate tensors split by sharding1 v2." )
249249 optimizer_state_dict = {}
250- for key in cur_rank_need_load_model_state_keys :
251- for tp_rank in range (self .tp_degree ):
252- tp_rank_suffix = "_tp{:02d}" .format (tp_rank )
253- optimizer_state_dict [key + ".moment1" + tp_rank_suffix ] = paddle .zeros (
254- (param_flattened_shapes [key ],), "float32"
255- )
256- optimizer_state_dict [key + ".moment2" + tp_rank_suffix ] = paddle .zeros (
257- (param_flattened_shapes [key ],), "float32"
258- )
259- if self .optimizer_state_with_master_weights :
260- optimizer_state_dict [key + ".master_weight" + tp_rank_suffix ] = paddle .zeros (
250+ with paddle .base .dygraph .guard (place = paddle .CPUPlace ()):
251+ for key in cur_rank_need_load_model_state_keys :
252+ for tp_rank in range (self .tp_degree ):
253+ tp_rank_suffix = "_tp{:02d}" .format (tp_rank )
254+ optimizer_state_dict [key + ".moment1" + tp_rank_suffix ] = paddle .zeros (
261255 (param_flattened_shapes [key ],), "float32"
262256 )
263- # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned.
264- # Later, when these are compared with the global shape, we realize that they are replicated.
257+ optimizer_state_dict [key + ".moment2" + tp_rank_suffix ] = paddle .zeros (
258+ (param_flattened_shapes [key ],), "float32"
259+ )
260+ if self .optimizer_state_with_master_weights :
261+ optimizer_state_dict [key + ".master_weight" + tp_rank_suffix ] = paddle .zeros (
262+ (param_flattened_shapes [key ],), "float32"
263+ )
264+ # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned.
265+ # Later, when these are compared with the global shape, we realize that they are replicated.
265266
266- optimizer_state_dict [key + ".beta1_pow_acc" + tp_rank_suffix ] = paddle .zeros ((1 ,), "float32" )
267- optimizer_state_dict [key + ".beta2_pow_acc" + tp_rank_suffix ] = paddle .zeros ((1 ,), "float32" )
267+ optimizer_state_dict [key + ".beta1_pow_acc" + tp_rank_suffix ] = paddle .zeros ((1 ,), "float32" )
268+ optimizer_state_dict [key + ".beta2_pow_acc" + tp_rank_suffix ] = paddle .zeros ((1 ,), "float32" )
268269
269270 malloc_size = 0
270271 for opt_state_name , opt_state_value in optimizer_state_dict .items ():
0 commit comments