Skip to content

Commit 500a1c0

Browse files
committed
fix tie_word_embeddings
1 parent 7c18d9d commit 500a1c0

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,11 @@ def get_expected_state_dict(model_to_save):
10341034
elif isinstance(model_to_save, PrefixModelForCausalLM):
10351035
state_dict = model_to_save.prefix_encoder.state_dict()
10361036

1037-
if hasattr(model_to_save, "_tied_weights_keys") and model_to_save._tied_weights_keys is not None:
1037+
if (
1038+
model_to_save.config.tie_word_embeddings
1039+
and hasattr(model_to_save, "_tied_weights_keys")
1040+
and model_to_save._tied_weights_keys is not None
1041+
):
10381042
for key in model_to_save._tied_weights_keys:
10391043
if key in state_dict:
10401044
state_dict.pop(key)

0 commit comments

Comments
 (0)