File tree Expand file tree Collapse file tree 1 file changed +5
-6
lines changed
paddlenlp/trainer/unified_checkpoint Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -282,16 +282,15 @@ def load_resolved_archive_file(
282282 returned_optim_state_dict [key_name ] = state_dict_optim .pop (key )
283283 returned_optim_state_dict [key_name ].name = key_name
284284
285- # master weight cast (only in remove_master_weight)
286- if has_master_weights and state_dict_master_weight [model_weight_key ].dtype != paddle .float32 :
287- state_dict_master_weight [model_weight_key ] = paddle .cast (
288- state_dict_master_weight [model_weight_key ], dtype = paddle .float32
289- )
290-
291285 if has_master_weights :
292286 for key in list (state_dict_master_weight .keys ()):
293287 static_name = struct2static_name_mappings [key ]
294288 returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight .pop (key )
289+ # master weight cast (only in remove_master_weight)
290+ if returned_optim_state_dict ["master_weights" ][static_name ].dtype != paddle .float32 :
291+ returned_optim_state_dict ["master_weights" ][static_name ] = paddle .cast (
292+ returned_optim_state_dict ["master_weights" ][static_name ], dtype = paddle .float32
293+ )
295294 returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
296295
297296 return returned_optim_state_dict
You can’t perform that action at this time.
0 commit comments