@@ -54,6 +54,7 @@ function llama_case_list_auto() {
5454 llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
5555
5656 llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
57+ llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
5758}
5859
5960function llm_gpt_case_list_auto() {
@@ -1062,6 +1063,165 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() {
10621063 echo " =========== $FUNCNAME run end ==========="
10631064}
10641065
1066+ function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() {
1067+ echo " =========== $FUNCNAME run begin ==========="
1068+ export PYTHONPATH=$root_path /:$PYTHONPATH
1069+ export FLAGS_call_stack_level=3
1070+ export NVIDIA_TF32_OVERRIDE=0
1071+ export FLAGS_enable_pir_api=1
1072+ export FLAGS_max_inplace_grad_add=3
1073+
1074+ echo " ---- run hybrid and save ckpt ----"
1075+ dy_task_name=" llama_hybrid_ckpt_bs2_fp32_DP2-MP1-PP1"
1076+ dy_case_out_dir=" dy_output/$dy_task_name "
1077+ dy_case_log_dir=" dy_output/$dy_task_name " " _log"
1078+ rm -rf $dy_case_out_dir
1079+ rm -rf $dy_case_log_dir
1080+
1081+ python -u -m paddle.distributed.launch \
1082+ --gpus " 0,1" \
1083+ --log_dir $dy_case_log_dir \
1084+ ../../run_pretrain.py \
1085+ --model_name_or_path " facebook/llama-7b" \
1086+ --tokenizer_name_or_path " facebook/llama-7b" \
1087+ --input_dir " ./data" \
1088+ --output_dir $dy_case_out_dir \
1089+ --split 949,50,1 \
1090+ --weight_decay 0.01 \
1091+ --warmup_ratio 0.01 \
1092+ --warmup_steps 30 \
1093+ --max_grad_norm 0.0 \
1094+ --learning_rate 3e-05 \
1095+ --min_learning_rate 3e-06 \
1096+ --max_steps 5 \
1097+ --logging_steps 1 \
1098+ --eval_steps 1000 \
1099+ --save_steps 3 \
1100+ --continue_training 0 \
1101+ --do_train true \
1102+ --do_eval false \
1103+ --do_predict false \
1104+ --disable_tqdm true \
1105+ --skip_profile_timer true \
1106+ --save_total_limit 2 \
1107+ --device gpu \
1108+ --disable_tqdm true \
1109+ --dataloader_num_workers 1 \
1110+ --distributed_dataloader 0 \
1111+ --per_device_train_batch_size 1 \
1112+ --gradient_accumulation_steps 1 \
1113+ --per_device_eval_batch_size 2 \
1114+ --recompute false \
1115+ --recompute_use_reentrant true \
1116+ --recompute_granularity full \
1117+ --pp_recompute_interval 0 \
1118+ --bf16 0 \
1119+ --fp16_opt_level " O2" \
1120+ --amp_custom_black_list " reduce_sum" " c_softmax_with_cross_entropy" \
1121+ --amp_custom_white_list " lookup_table" " lookup_table_v2" \
1122+ --amp_master_grad false \
1123+ --enable_linear_fused_grad_add false \
1124+ --fuse_attention_ffn true \
1125+ --fuse_attention_qkv false \
1126+ --fuse_sequence_parallel_allreduce false \
1127+ --use_flash_attention 0 \
1128+ --use_fused_rope false \
1129+ --use_fused_rms_norm 0 \
1130+ --max_seq_length 4096 \
1131+ --sep_parallel_degree 1 \
1132+ --sequence_parallel false \
1133+ --pipeline_parallel_degree 1 \
1134+ --sharding_parallel_degree 1 \
1135+ --tensor_parallel_degree 1 \
1136+ --virtual_pp_degree 1 \
1137+ --sharding " " \
1138+ --to_static 0 \
1139+ --num_hidden_layers 2 \
1140+ >> ${log_path} /$FUNCNAME 2>&1
1141+ dy_loss=` cat $dy_case_log_dir /workerlog.0 | grep ' global_step: 4' | awk -F ' loss: ' ' {print $2}' | awk -F ' ,' ' {print $1}' `
1142+ dy_ips=-1
1143+ dy_mem=-1
1144+ echo " hybrid result: loss=$dy_loss ips=$dy_ips mem=$dy_mem "
1145+
1146+ echo " ---- run auto parallel resueme from hybrid ckpt ----"
1147+ auto_task_name=" llama_auto_parallel_bs2_fp32_DP2-MP1-PP1"
1148+ auto_case_out_dir=" auto_output/$auto_task_name "
1149+ auto_case_log_dir=" auto_output/$auto_task_name " " _log"
1150+ rm -rf $auto_case_out_dir
1151+ rm -rf $auto_case_log_dir
1152+
1153+ python -u -m paddle.distributed.launch \
1154+ --gpus " 0,1" \
1155+ --log_dir $auto_case_log_dir \
1156+ run_pretrain_auto.py \
1157+ --model_name_or_path " facebook/llama-7b" \
1158+ --tokenizer_name_or_path " facebook/llama-7b" \
1159+ --input_dir " ./data" \
1160+ --output_dir $auto_case_out_dir \
1161+ --split 949,50,1 \
1162+ --weight_decay 0.01 \
1163+ --warmup_ratio 0.01 \
1164+ --warmup_steps 30 \
1165+ --max_grad_norm 0.0 \
1166+ --learning_rate 3e-05 \
1167+ --min_learning_rate 3e-06 \
1168+ --max_steps 4 \
1169+ --logging_steps 1 \
1170+ --eval_steps 1000 \
1171+ --save_steps 1000 \
1172+ --continue_training 0 \
1173+ --do_train true \
1174+ --do_eval false \
1175+ --do_predict false \
1176+ --disable_tqdm true \
1177+ --skip_profile_timer true \
1178+ --save_total_limit 2 \
1179+ --device gpu \
1180+ --disable_tqdm true \
1181+ --dataloader_num_workers 1 \
1182+ --distributed_dataloader 0 \
1183+ --enable_auto_parallel 1 \
1184+ --per_device_train_batch_size 1 \
1185+ --gradient_accumulation_steps 1 \
1186+ --per_device_eval_batch_size 2 \
1187+ --recompute false \
1188+ --recompute_use_reentrant true \
1189+ --recompute_granularity full \
1190+ --pp_recompute_interval 0 \
1191+ --bf16 0 \
1192+ --fp16_opt_level " O2" \
1193+ --amp_custom_black_list " reduce_sum" " c_softmax_with_cross_entropy" \
1194+ --amp_custom_white_list " lookup_table" " lookup_table_v2" \
1195+ --amp_master_grad false \
1196+ --fuse_attention_ffn true \
1197+ --fuse_attention_qkv false \
1198+ --fuse_sequence_parallel_allreduce false \
1199+ --use_flash_attention 0 \
1200+ --use_fused_rope false \
1201+ --use_fused_rms_norm 0 \
1202+ --max_seq_length 4096 \
1203+ --sep_parallel_degree 1 \
1204+ --sequence_parallel false \
1205+ --pipeline_parallel_degree 1 \
1206+ --sharding_parallel_degree 1 \
1207+ --tensor_parallel_degree 1 \
1208+ --virtual_pp_degree 1 \
1209+ --pipeline_schedule_mode " VPP" \
1210+ --sharding " " \
1211+ --to_static 1 \
1212+ --num_hidden_layers 2 \
1213+ --resume_from_checkpoint " dy_output/llama_hybrid_ckpt_bs2_fp32_DP2-MP1-PP1/checkpoint-3" \
1214+ --auto_parallel_resume_form_hybrid_parallel 1 \
1215+ >> ${log_path} /$FUNCNAME 2>&1
1216+ auto_loss=` cat $auto_case_log_dir /workerlog.0 | grep ' global_step: 4' | awk -F ' loss: ' ' {print $2}' | awk -F ' ,' ' {print $1}' `
1217+ auto_ips=-1
1218+ auto_mem=-1
1219+ echo " auto result: loss=$auto_loss ips=$auto_ips mem=$auto_mem "
1220+
1221+ check_result $FUNCNAME ${dy_loss} ${auto_loss} ${dy_ips} ${auto_ips} ${dy_mem} ${auto_mem}
1222+ echo " =========== $FUNCNAME run end ==========="
1223+ }
1224+
10651225function llm_gpt_dygraph_auto_bs8_fp32_DP2() {
10661226 echo " =========== $FUNCNAME run begin ==========="
10671227 export PYTHONPATH=$root_path /:$PYTHONPATH
0 commit comments