@@ -94,6 +94,7 @@ function llama_case_list_auto() {
9494 llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw
9595 llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2
9696 llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
97+ llama_pir_auto_recompute_DP2_MP2_PP2
9798 llama_pir_auto_fuse_ffn_attention_qkv_MP2
9899 llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
99100 llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP
@@ -743,6 +744,107 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() {
743744 echo " =========== $FUNCNAME run end ==========="
744745}
745746
747+ function llama_pir_auto_recompute_DP2_MP2_PP2(){
748+ echo " =========== $FUNCNAME run begin ==========="
749+ export PYTHONPATH=$root_path /:$PYTHONPATH
750+ export PYTHONPATH=/paddle/Paddle/build_gpu/python/:$PYTHONPATH
751+ export FLAGS_call_stack_level=3
752+ export FLAGS_enable_pir_api=1
753+ export FLAGS_dynamic_static_unified_comm=1
754+ export FLAGS_enable_auto_parallel_align_mode=1
755+
756+ export NVIDIA_TF32_OVERRIDE=0
757+ export FLAGS_cudnn_deterministic=1
758+ export FLAGS_embedding_deterministic=1
759+
760+ task_name=" llama_pir_auto_recompute_DP2_MP2_PP2"
761+ case_out_dir=" output/$task_name "
762+ case_log_dir=" output/$task_name " " _log"
763+
764+ loss1=0
765+ loss2=0
766+
767+ for use_recompute in " 0" " 1" ; do
768+ rm -rf $case_out_dir
769+ rm -rf $case_log_dir
770+ python -u -m paddle.distributed.launch \
771+ --gpus " 0,1,2,3,4,5,6,7" \
772+ --log_dir $case_log_dir \
773+ run_pretrain_auto.py \
774+ --model_type " llama" \
775+ --model_name_or_path " facebook/llama-7b" \
776+ --tokenizer_name_or_path " facebook/llama-7b" \
777+ --input_dir " ./data" \
778+ --output_dir $case_out_dir \
779+ --split 949,50,1 \
780+ --to_static true \
781+ --pipeline_parallel_degree 2 \
782+ --tensor_parallel_degree 2 \
783+ --virtual_pp_degree 1 \
784+ --pipeline_schedule_mode " VPP" \
785+ --weight_decay 0.01 \
786+ --warmup_ratio 0.01 \
787+ --max_grad_norm 0.0 \
788+ --learning_rate 3e-05 \
789+ --min_learning_rate 3e-06 \
790+ --max_steps 10 \
791+ --logging_steps 10 \
792+ --eval_steps 10000 \
793+ --save_steps 1000 \
794+ --continue_training 0 \
795+ --do_train true \
796+ --do_eval false \
797+ --do_predict false \
798+ --disable_tqdm true \
799+ --save_total_limit 2 \
800+ --device gpu \
801+ --dataloader_num_workers 4 \
802+ --distributed_dataloader 0 \
803+ --enable_auto_parallel 1 \
804+ --per_device_train_batch_size 1 \
805+ --gradient_accumulation_steps 1 \
806+ --per_device_eval_batch_size 1 \
807+ --recompute ${use_recompute} \
808+ --recompute_use_reentrant true \
809+ --recompute_granularity full \
810+ --pp_recompute_interval 0 \
811+ --bf16 true \
812+ --fp16_opt_level " O2" \
813+ --amp_master_grad true \
814+ --fuse_attention_ffn true \
815+ --fuse_attention_qkv true \
816+ --use_flash_attention false \
817+ --use_fused_rope true \
818+ --use_fused_rms_norm false \
819+ --max_seq_length 4096 \
820+ --sequence_parallel false \
821+ --sharding " stage1" \
822+ --data_parallel_config " enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate " \
823+ --sharding_parallel_config " enable_stage1_overlap" \
824+ --tensor_parallel_config " enable_mp_async_allreduce" \
825+ --pipeline_parallel_config " enable_send_recv_overlap" \
826+ --auto_parallel_resume_form_hybrid_parallel true \
827+ --num_hidden_layers 4 \
828+ >> ${log_path} /$FUNCNAME 2>&1
829+ loss=` cat $case_log_dir /workerlog.0 | grep ' global_step: 10' | awk -F ' loss: ' ' {print $2}' | awk -F ' ,' ' {print $1}' `
830+ loss_md5=` cat $case_log_dir /workerlog.0 | grep ' global_step: 10' | awk -F ' loss_md5: ' ' {print $2}' | awk -F ' ,' ' {print $1}' `
831+ ips=` cat $case_log_dir /workerlog.0 | grep ' global_step: 10' | awk -F ' interval_tokens_per_second_per_device: ' ' {print $2}' | awk -F ' ,' ' {print $1}' `
832+ mem=` cat $case_log_dir /workerlog.0 | grep ' global_step: 10' | awk -F ' max_memory_reserved: ' ' {print $2}' | awk -F ' ,' ' {print $1}' `
833+ echo " result: loss=$loss loss_md5=$loss_md5 ips=$ips mem=$mem "
834+ if [ $use_recompute -eq 0 ]; then
835+ loss1=($loss )
836+ else
837+ loss2=($loss )
838+ fi
839+ done
840+ ips=-1
841+ mem=-1
842+ ips_base=-1
843+ mem_base=-1
844+ check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
845+ echo " =========== $FUNCNAME run end ==========="
846+ }
847+
746848function llama_pir_auto_fuse_ffn_attention_qkv_MP2() {
747849 echo " =========== $FUNCNAME run begin ==========="
748850 export PYTHONPATH=$root_path /:$PYTHONPATH
0 commit comments