Skip to content

Commit 78ac9f2

Browse files
zhangbo9674Mangodadada
authored andcommitted
Refine checkpoint converter (PaddlePaddle#9001)
* refine
1 parent 1a3f05c commit 78ac9f2

File tree

3 files changed

+187
-14
lines changed

3 files changed

+187
-14
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
722722
)
723723

724724
if self.args.to_static:
725+
if self.model_wrapped._mode is None:
726+
self.model_wrapped.train()
725727
model_state_dict = {
726728
key: value
727729
for key, value in self.model_wrapped.state_dict("param").items()
@@ -757,7 +759,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
757759

758760
if self.args.auto_parallel_resume_form_hybrid_parallel:
759761
CheckpointConverter(
760-
resume_from_checkpoint, state_dict, parameter_to_structured_name
762+
resume_from_checkpoint, state_dict, parameter_to_structured_name, self.args
761763
).load_from_hybrid_parallel_checkpoint()
762764
else:
763765
ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)

paddlenlp/trainer/utils/ckpt_converter.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,22 @@
3333
MODEL_WEIGHT_SUFFIX = ".pdparams"
3434
OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"
3535
SCHEDULER_NAME = "scheduler.pdparams"
36+
SCALAR_NAME = "scalar.pdparams"
3637
MODEL_META_FILE_NAME = "model_meta.json"
3738
OPTIMIZER_STATE_NAME_SUFFIX = [".moment1", ".moment2", ".beta1_pow_acc", ".beta2_pow_acc", ".master_weight"]
3839
MODEL_STATE_FILE_MIN_SIZE = 512
3940

4041

4142
class CheckpointConverter:
42-
def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, patch_dict=None):
43+
def __init__(
44+
self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, trainging_args=None, patch_dict=None
45+
):
4346
self.use_dist = True if paddle.distributed.get_world_size() > 1 else False
4447
self.path = hybrid_parallel_ckpt_path
48+
49+
if trainging_args.ignore_load_lr_and_optim:
50+
state_dict.pop("optimizer")
51+
4552
self.auto_parallel_state_dict = self.flatten_state_dict(state_dict)
4653
self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name)
4754
model_state_global_shape = {}
@@ -74,9 +81,9 @@ def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structure
7481
for k, v in self.auto_parallel_state_dict.items():
7582
if k in self.patch_dict:
7683
del_keys.append(k)
77-
7884
for k in del_keys:
7985
self.auto_parallel_state_dict[self.patch_dict[k]] = self.auto_parallel_state_dict[k]
86+
for k in del_keys:
8087
self.auto_parallel_state_dict.pop(k)
8188

8289
flags = [
@@ -896,25 +903,26 @@ def rename(old_name, parameter_to_structured_name):
896903
return renamed_state_dict
897904

898905
def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict):
899-
900906
name_mapping = {}
901907
suffix_bucket = {}
902908
assert len(optimizer_state_dict) % len(model_state_keys) == 0
903909
for suffix in OPTIMIZER_STATE_NAME_SUFFIX:
904910
suffix_bucket[suffix] = []
905-
for satte_name, satte_value in optimizer_state_dict.items():
906-
if "moment1" in satte_name:
907-
suffix_bucket[".moment1"].append(satte_name)
908-
elif "moment2" in satte_name:
909-
suffix_bucket[".moment2"].append(satte_name)
910-
elif "beta1_pow_acc" in satte_name:
911-
suffix_bucket[".beta1_pow_acc"].append(satte_name)
912-
elif "beta2_pow_acc" in satte_name:
913-
suffix_bucket[".beta2_pow_acc"].append(satte_name)
911+
for opt_name, opt_value in optimizer_state_dict.items():
912+
if "moment1" in opt_name:
913+
suffix_bucket[".moment1"].append(opt_name)
914+
elif "moment2" in opt_name:
915+
suffix_bucket[".moment2"].append(opt_name)
916+
elif "beta1_pow_acc" in opt_name:
917+
suffix_bucket[".beta1_pow_acc"].append(opt_name)
918+
elif "beta2_pow_acc" in opt_name:
919+
suffix_bucket[".beta2_pow_acc"].append(opt_name)
914920
else:
915-
suffix_bucket[".master_weight"].append(satte_name)
921+
suffix_bucket[".master_weight"].append(opt_name)
916922

917923
for suffix, old_names in suffix_bucket.items():
924+
if len(old_names) == 0:
925+
continue
918926
assert len(old_names) == len(model_state_keys)
919927
for i in range(len(old_names)):
920928
name_mapping[old_names[i]] = model_state_keys[i] + suffix
@@ -1011,6 +1019,9 @@ def get_local_checkpoint_file_names(self):
10111019
cur_rank_optimizer_state_file_names.append(file_name)
10121020
if SCHEDULER_NAME in cur_rank_model_state_file_names:
10131021
cur_rank_model_state_file_names.remove(SCHEDULER_NAME)
1022+
if SCALAR_NAME in cur_rank_model_state_file_names:
1023+
cur_rank_model_state_file_names.remove(SCALAR_NAME)
1024+
10141025
return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names
10151026

10161027
def get_distribution_rank_from_file_name(self, file_name):

scripts/distribute/ci_case_auto.sh

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ function llama_case_list_auto() {
5454
llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
5555

5656
llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
57+
llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
5758
}
5859

5960
function llm_gpt_case_list_auto() {
@@ -1062,6 +1063,165 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() {
10621063
echo "=========== $FUNCNAME run end ==========="
10631064
}
10641065

1066+
function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() {
1067+
echo "=========== $FUNCNAME run begin ==========="
1068+
export PYTHONPATH=$root_path/:$PYTHONPATH
1069+
export FLAGS_call_stack_level=3
1070+
export NVIDIA_TF32_OVERRIDE=0
1071+
export FLAGS_enable_pir_api=1
1072+
export FLAGS_max_inplace_grad_add=3
1073+
1074+
echo "---- run hybrid and save ckpt ----"
1075+
dy_task_name="llama_hybrid_ckpt_bs2_fp32_DP2-MP1-PP1"
1076+
dy_case_out_dir="dy_output/$dy_task_name"
1077+
dy_case_log_dir="dy_output/$dy_task_name""_log"
1078+
rm -rf $dy_case_out_dir
1079+
rm -rf $dy_case_log_dir
1080+
1081+
python -u -m paddle.distributed.launch \
1082+
--gpus "0,1" \
1083+
--log_dir $dy_case_log_dir \
1084+
../../run_pretrain.py \
1085+
--model_name_or_path "facebook/llama-7b" \
1086+
--tokenizer_name_or_path "facebook/llama-7b" \
1087+
--input_dir "./data" \
1088+
--output_dir $dy_case_out_dir \
1089+
--split 949,50,1 \
1090+
--weight_decay 0.01 \
1091+
--warmup_ratio 0.01 \
1092+
--warmup_steps 30 \
1093+
--max_grad_norm 0.0 \
1094+
--learning_rate 3e-05 \
1095+
--min_learning_rate 3e-06 \
1096+
--max_steps 5 \
1097+
--logging_steps 1 \
1098+
--eval_steps 1000 \
1099+
--save_steps 3 \
1100+
--continue_training 0 \
1101+
--do_train true \
1102+
--do_eval false \
1103+
--do_predict false \
1104+
--disable_tqdm true \
1105+
--skip_profile_timer true \
1106+
--save_total_limit 2 \
1107+
--device gpu \
1108+
--disable_tqdm true \
1109+
--dataloader_num_workers 1 \
1110+
--distributed_dataloader 0 \
1111+
--per_device_train_batch_size 1 \
1112+
--gradient_accumulation_steps 1 \
1113+
--per_device_eval_batch_size 2 \
1114+
--recompute false \
1115+
--recompute_use_reentrant true \
1116+
--recompute_granularity full \
1117+
--pp_recompute_interval 0 \
1118+
--bf16 0 \
1119+
--fp16_opt_level "O2" \
1120+
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
1121+
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
1122+
--amp_master_grad false \
1123+
--enable_linear_fused_grad_add false \
1124+
--fuse_attention_ffn true \
1125+
--fuse_attention_qkv false \
1126+
--fuse_sequence_parallel_allreduce false \
1127+
--use_flash_attention 0 \
1128+
--use_fused_rope false \
1129+
--use_fused_rms_norm 0 \
1130+
--max_seq_length 4096 \
1131+
--sep_parallel_degree 1 \
1132+
--sequence_parallel false \
1133+
--pipeline_parallel_degree 1 \
1134+
--sharding_parallel_degree 1 \
1135+
--tensor_parallel_degree 1 \
1136+
--virtual_pp_degree 1 \
1137+
--sharding "" \
1138+
--to_static 0 \
1139+
--num_hidden_layers 2 \
1140+
>>${log_path}/$FUNCNAME 2>&1
1141+
dy_loss=`cat $dy_case_log_dir/workerlog.0 | grep 'global_step: 4' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1142+
dy_ips=-1
1143+
dy_mem=-1
1144+
echo "hybrid result: loss=$dy_loss ips=$dy_ips mem=$dy_mem"
1145+
1146+
echo "---- run auto parallel resueme from hybrid ckpt ----"
1147+
auto_task_name="llama_auto_parallel_bs2_fp32_DP2-MP1-PP1"
1148+
auto_case_out_dir="auto_output/$auto_task_name"
1149+
auto_case_log_dir="auto_output/$auto_task_name""_log"
1150+
rm -rf $auto_case_out_dir
1151+
rm -rf $auto_case_log_dir
1152+
1153+
python -u -m paddle.distributed.launch \
1154+
--gpus "0,1" \
1155+
--log_dir $auto_case_log_dir \
1156+
run_pretrain_auto.py \
1157+
--model_name_or_path "facebook/llama-7b" \
1158+
--tokenizer_name_or_path "facebook/llama-7b" \
1159+
--input_dir "./data" \
1160+
--output_dir $auto_case_out_dir \
1161+
--split 949,50,1 \
1162+
--weight_decay 0.01 \
1163+
--warmup_ratio 0.01 \
1164+
--warmup_steps 30 \
1165+
--max_grad_norm 0.0 \
1166+
--learning_rate 3e-05 \
1167+
--min_learning_rate 3e-06 \
1168+
--max_steps 4 \
1169+
--logging_steps 1 \
1170+
--eval_steps 1000 \
1171+
--save_steps 1000 \
1172+
--continue_training 0 \
1173+
--do_train true \
1174+
--do_eval false \
1175+
--do_predict false \
1176+
--disable_tqdm true \
1177+
--skip_profile_timer true \
1178+
--save_total_limit 2 \
1179+
--device gpu \
1180+
--disable_tqdm true \
1181+
--dataloader_num_workers 1 \
1182+
--distributed_dataloader 0 \
1183+
--enable_auto_parallel 1 \
1184+
--per_device_train_batch_size 1 \
1185+
--gradient_accumulation_steps 1 \
1186+
--per_device_eval_batch_size 2 \
1187+
--recompute false \
1188+
--recompute_use_reentrant true \
1189+
--recompute_granularity full \
1190+
--pp_recompute_interval 0 \
1191+
--bf16 0 \
1192+
--fp16_opt_level "O2" \
1193+
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
1194+
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
1195+
--amp_master_grad false \
1196+
--fuse_attention_ffn true \
1197+
--fuse_attention_qkv false \
1198+
--fuse_sequence_parallel_allreduce false \
1199+
--use_flash_attention 0 \
1200+
--use_fused_rope false \
1201+
--use_fused_rms_norm 0 \
1202+
--max_seq_length 4096 \
1203+
--sep_parallel_degree 1 \
1204+
--sequence_parallel false \
1205+
--pipeline_parallel_degree 1 \
1206+
--sharding_parallel_degree 1 \
1207+
--tensor_parallel_degree 1 \
1208+
--virtual_pp_degree 1 \
1209+
--pipeline_schedule_mode "VPP" \
1210+
--sharding "" \
1211+
--to_static 1 \
1212+
--num_hidden_layers 2 \
1213+
--resume_from_checkpoint "dy_output/llama_hybrid_ckpt_bs2_fp32_DP2-MP1-PP1/checkpoint-3" \
1214+
--auto_parallel_resume_form_hybrid_parallel 1 \
1215+
>>${log_path}/$FUNCNAME 2>&1
1216+
auto_loss=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 4' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1217+
auto_ips=-1
1218+
auto_mem=-1
1219+
echo "auto result: loss=$auto_loss ips=$auto_ips mem=$auto_mem"
1220+
1221+
check_result $FUNCNAME ${dy_loss} ${auto_loss} ${dy_ips} ${auto_ips} ${dy_mem} ${auto_mem}
1222+
echo "=========== $FUNCNAME run end ==========="
1223+
}
1224+
10651225
function llm_gpt_dygraph_auto_bs8_fp32_DP2() {
10661226
echo "=========== $FUNCNAME run begin ==========="
10671227
export PYTHONPATH=$root_path/:$PYTHONPATH

0 commit comments

Comments
 (0)