Skip to content
4 changes: 3 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,9 @@ def _load_optimizer_and_scheduler(self, checkpoint):
dist.barrier()
if self.args.use_expert_parallel:
opt_state_dict = broadcast_moe_optimizer(
opt_state_dict, broadcast_dp=not self.args.should_load_sharding_stage1_model
opt_state_dict,
model_state_dict=self.model.state_dict(),
broadcast_dp=not self.args.should_load_sharding_stage1_model,
)
else:
if not self.args.should_load_sharding_stage1_model:
Expand Down
31 changes: 27 additions & 4 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,7 @@ def broadcast_dp_optimizer(state_dict):
return state_dict


def broadcast_moe_optimizer(state_dict, broadcast_dp=True):

def broadcast_moe_optimizer(state_dict, model_state_dict=None, broadcast_dp=True):
try:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
Expand All @@ -242,7 +241,29 @@ def broadcast_moe_optimizer(state_dict, broadcast_dp=True):
except:
dp_group = None
src_rank = 0
data_parallel_rank = 0
data_parallel_rank = dist.get_rank()

def _filter_sync_optimizer_state(model_state_dict, opt_state_dict):
# get sync name
sync_vname = []
for k, v in model_state_dict.items():
if not getattr(v, "no_sync", False):
sync_vname.append(v.name)

filter_opt_state_dict = {"master_weights": {}}
filter_opt_state_dict["LR_Scheduler"] = opt_state_dict.get("LR_Scheduler", {})
for op_k, op_v in opt_state_dict.items():
if op_k not in ["master_weights", "LR_Scheduler"]:
for sync_v in sync_vname:
if op_k.startswith(sync_v):
filter_opt_state_dict[op_k] = op_v
break
elif op_k == "master_weights":
for k, v in op_v.items():
for sync_v in sync_vname:
if k.startswith(sync_v):
filter_opt_state_dict["master_weights"][k] = v
return filter_opt_state_dict

def _broadcast_moe_optimizer_state(state_dict):
# boardcast_keys
Expand Down Expand Up @@ -272,9 +293,11 @@ def _broadcast_moe_optimizer_state(state_dict):
return base_state_dict

if broadcast_dp:
base_state_dict = broadcast_dp_optimizer(state_dict)
filter_opt_state_dict = _filter_sync_optimizer_state(model_state_dict, state_dict)
base_state_dict = broadcast_dp_optimizer(filter_opt_state_dict)
else:
base_state_dict = _broadcast_moe_optimizer_state(state_dict)

if data_parallel_rank > 0:
master_weight = state_dict.pop("master_weights", {})
base_state_dict.update(state_dict)
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def merge_tensor_parallel(cls, state_dict, config) -> None:

if len(name_action_mappings) > 0:
for x in name_action_mappings.keys():
logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.")
logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.")

return state_dict_to_save

Expand Down Expand Up @@ -1322,7 +1322,7 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):
break
if key not in state_keys_map:
if not ignore_error:
logger.error(f"tensor parallel conversion: could not find name {key} in loaded state dict!")
logger.debug(f"tensor parallel conversion: could not find name {key} in loaded state dict!")
else:
state_keys_real.remove(state_keys_map[key])

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):

from paddlenlp.utils.safetensors import fast_load_file as safe_load_file

if sys.platform.startswith("cpu"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来是“cpu” ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有个合入PR我改错了

if sys.platform.startswith("win"):
from safetensors import safe_open
else:
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
Expand Down