Skip to content

Commit 67bc4e2

Browse files
authored
[Unified Checkpoint] Fix expert parallel (#9741)
1 parent 142258c commit 67bc4e2

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

paddlenlp/trainer/unified_checkpoint/load_local.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,15 @@ def load_resolved_archive_file(
282282
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
283283
returned_optim_state_dict[key_name].name = key_name
284284

285-
# master weight cast (only in remove_master_weight)
286-
if has_master_weights and state_dict_master_weight[model_weight_key].dtype != paddle.float32:
287-
state_dict_master_weight[model_weight_key] = paddle.cast(
288-
state_dict_master_weight[model_weight_key], dtype=paddle.float32
289-
)
290-
291285
if has_master_weights:
292286
for key in list(state_dict_master_weight.keys()):
293287
static_name = struct2static_name_mappings[key]
294288
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
289+
# master weight cast (only in remove_master_weight)
290+
if returned_optim_state_dict["master_weights"][static_name].dtype != paddle.float32:
291+
returned_optim_state_dict["master_weights"][static_name] = paddle.cast(
292+
returned_optim_state_dict["master_weights"][static_name], dtype=paddle.float32
293+
)
295294
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
296295

297296
return returned_optim_state_dict

0 commit comments

Comments
 (0)