We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7c18d9d commit 500a1c0Copy full SHA for 500a1c0
paddlenlp/trainer/plugins/unified_checkpoint.py
@@ -1034,7 +1034,11 @@ def get_expected_state_dict(model_to_save):
1034
elif isinstance(model_to_save, PrefixModelForCausalLM):
1035
state_dict = model_to_save.prefix_encoder.state_dict()
1036
1037
- if hasattr(model_to_save, "_tied_weights_keys") and model_to_save._tied_weights_keys is not None:
+ if (
1038
+ model_to_save.config.tie_word_embeddings
1039
+ and hasattr(model_to_save, "_tied_weights_keys")
1040
+ and model_to_save._tied_weights_keys is not None
1041
+ ):
1042
for key in model_to_save._tied_weights_keys:
1043
if key in state_dict:
1044
state_dict.pop(key)
0 commit comments