Skip to content

Commit d203f2a

Browse files
committed
add test
1 parent cd5468c commit d203f2a

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

scripts/distribute/ci_case_auto.sh

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ function llama_case_list_auto() {
9494
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw
9595
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2
9696
llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
97+
llama_pir_auto_recompute_DP2_MP2_PP2
9798
llama_pir_auto_fuse_ffn_attention_qkv_MP2
9899
llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
99100
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP
@@ -743,6 +744,107 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() {
743744
echo "=========== $FUNCNAME run end ==========="
744745
}
745746

747+
function llama_pir_auto_recompute_DP2_MP2_PP2(){
748+
echo "=========== $FUNCNAME run begin ==========="
749+
export PYTHONPATH=$root_path/:$PYTHONPATH
750+
export PYTHONPATH=/paddle/Paddle/build_gpu/python/:$PYTHONPATH
751+
export FLAGS_call_stack_level=3
752+
export FLAGS_enable_pir_api=1
753+
export FLAGS_dynamic_static_unified_comm=1
754+
export FLAGS_enable_auto_parallel_align_mode=1
755+
756+
export NVIDIA_TF32_OVERRIDE=0
757+
export FLAGS_cudnn_deterministic=1
758+
export FLAGS_embedding_deterministic=1
759+
760+
task_name="llama_pir_auto_recompute_DP2_MP2_PP2"
761+
case_out_dir="output/$task_name"
762+
case_log_dir="output/$task_name""_log"
763+
764+
loss1=0
765+
loss2=0
766+
767+
for use_recompute in "0" "1"; do
768+
rm -rf $case_out_dir
769+
rm -rf $case_log_dir
770+
python -u -m paddle.distributed.launch \
771+
--gpus "0,1,2,3,4,5,6,7" \
772+
--log_dir $case_log_dir \
773+
run_pretrain_auto.py \
774+
--model_type "llama" \
775+
--model_name_or_path "facebook/llama-7b" \
776+
--tokenizer_name_or_path "facebook/llama-7b" \
777+
--input_dir "./data" \
778+
--output_dir $case_out_dir \
779+
--split 949,50,1 \
780+
--to_static true \
781+
--pipeline_parallel_degree 2 \
782+
--tensor_parallel_degree 2 \
783+
--virtual_pp_degree 1 \
784+
--pipeline_schedule_mode "VPP" \
785+
--weight_decay 0.01 \
786+
--warmup_ratio 0.01 \
787+
--max_grad_norm 0.0 \
788+
--learning_rate 3e-05 \
789+
--min_learning_rate 3e-06 \
790+
--max_steps 10 \
791+
--logging_steps 10 \
792+
--eval_steps 10000 \
793+
--save_steps 1000 \
794+
--continue_training 0 \
795+
--do_train true \
796+
--do_eval false \
797+
--do_predict false \
798+
--disable_tqdm true \
799+
--save_total_limit 2 \
800+
--device gpu \
801+
--dataloader_num_workers 4 \
802+
--distributed_dataloader 0 \
803+
--enable_auto_parallel 1 \
804+
--per_device_train_batch_size 1 \
805+
--gradient_accumulation_steps 1 \
806+
--per_device_eval_batch_size 1 \
807+
--recompute ${use_recompute} \
808+
--recompute_use_reentrant true \
809+
--recompute_granularity full \
810+
--pp_recompute_interval 0 \
811+
--bf16 true \
812+
--fp16_opt_level "O2" \
813+
--amp_master_grad true \
814+
--fuse_attention_ffn true \
815+
--fuse_attention_qkv true \
816+
--use_flash_attention false \
817+
--use_fused_rope true \
818+
--use_fused_rms_norm false \
819+
--max_seq_length 4096 \
820+
--sequence_parallel false \
821+
--sharding "stage1" \
822+
--data_parallel_config "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate " \
823+
--sharding_parallel_config "enable_stage1_overlap" \
824+
--tensor_parallel_config "enable_mp_async_allreduce" \
825+
--pipeline_parallel_config "enable_send_recv_overlap" \
826+
--auto_parallel_resume_form_hybrid_parallel true \
827+
--num_hidden_layers 4 \
828+
>>${log_path}/$FUNCNAME 2>&1
829+
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
830+
loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'`
831+
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
832+
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
833+
echo "result: loss=$loss loss_md5=$loss_md5 ips=$ips mem=$mem"
834+
if [ $use_recompute -eq 0 ];then
835+
loss1=($loss)
836+
else
837+
loss2=($loss)
838+
fi
839+
done
840+
ips=-1
841+
mem=-1
842+
ips_base=-1
843+
mem_base=-1
844+
check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
845+
echo "=========== $FUNCNAME run end ==========="
846+
}
847+
746848
function llama_pir_auto_fuse_ffn_attention_qkv_MP2() {
747849
echo "=========== $FUNCNAME run begin ==========="
748850
export PYTHONPATH=$root_path/:$PYTHONPATH

0 commit comments

Comments
 (0)