Skip to content

Commit 4fedd09

Browse files
xingmingyyjMangodadada
authored andcommitted
[auto_parallel] Add checkpoint convertor (PaddlePaddle#8847)
* Add the checkpoint conversion module
1 parent 78c2863 commit 4fedd09

File tree

3 files changed

+1160
-15
lines changed

3 files changed

+1160
-15
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
has_length,
4141
speed_metrics,
4242
)
43+
from .utils.ckpt_converter import CheckpointConverter
4344
from .utils.helper import distributed_file, distributed_isfile # nested_truncate,
4445

4546
try:
@@ -720,20 +721,16 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
720721
)
721722
)
722723

723-
ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)
724-
725-
if not os.path.isdir(ckpt_path):
726-
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
727-
728724
if self.args.to_static:
729-
opt_state_dict = {
725+
model_state_dict = {
730726
key: value
731-
for key, value in self.model_wrapped.state_dict("opt").items()
727+
for key, value in self.model_wrapped.state_dict("param").items()
732728
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
733729
}
734-
state_dict = {
735-
MODEL_NAME: self.model_wrapped.state_dict("param"),
736-
OPTIMIZER_NAME: opt_state_dict,
730+
optim_state_dict = {
731+
key: value
732+
for key, value in self.model_wrapped.state_dict("opt").items()
733+
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
737734
}
738735
else:
739736
model_state_dict = self.model_wrapped.state_dict()
@@ -746,12 +743,27 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
746743
optim_state_dict = self.optimizer.state_dict()
747744
optim_state_dict.pop("LR_Scheduler", None)
748745

749-
state_dict = {
750-
MODEL_NAME: model_state_dict,
751-
OPTIMIZER_NAME: optim_state_dict,
752-
}
746+
state_dict = {
747+
MODEL_NAME: model_state_dict,
748+
OPTIMIZER_NAME: optim_state_dict,
749+
}
753750

754-
self._load_ckpt_func(state_dict, ckpt_path)
751+
parameter_to_structured_name = {}
752+
if self.args.to_static:
753+
parameter_to_structured_name = self.model_wrapped._parameter_to_structured_name
754+
else:
755+
for state_name, state_value in self.model_wrapped.state_dict().items():
756+
parameter_to_structured_name[state_value.name] = state_name
757+
758+
if self.args.auto_parallel_resume_form_hybrid_parallel:
759+
CheckpointConverter(
760+
resume_from_checkpoint, state_dict, parameter_to_structured_name
761+
).load_from_hybrid_parallel_checkpoint()
762+
else:
763+
ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)
764+
if not os.path.isdir(ckpt_path):
765+
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
766+
self._load_ckpt_func(state_dict, ckpt_path)
755767

756768
# release memory
757769
del state_dict

paddlenlp/trainer/training_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,8 @@ class TrainingArguments:
353353
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
354354
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
355355
scripts](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples) for more details.
356+
auto_parallel_resume_form_hybrid_parallel (`bool`, *optional*):
357+
Wether hybrid paralle checkpoints be loaded in auto parallel mode.
356358
flatten_param_grads (`bool`, *optional*):
357359
Whether use flatten_param_grads method in optimizer, only used on NPU devices. Default is `False`.
358360
skip_profile_timer (`bool`, *optional*):
@@ -783,6 +785,10 @@ class TrainingArguments:
783785
default=None,
784786
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
785787
)
788+
auto_parallel_resume_form_hybrid_parallel: Optional[bool] = field(
789+
default=False,
790+
metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."},
791+
)
786792
skip_memory_metrics: bool = field(
787793
default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
788794
)

0 commit comments

Comments
 (0)