Skip to content

Commit 75f5031

Browse files
authored
[AutoParallel] add vpp align and pp amp test (#9176)
* add vpp test * Update ci_case_auto.sh * Update ci_case_auto.sh * add new test * add new test * Update ci_case_auto.sh
1 parent 1fc9429 commit 75f5031

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed

scripts/distribute/ci_case_auto.sh

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ function llama_case_list_auto() {
5858
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP
5959
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP
6060
llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1
61+
llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4
62+
llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4
6163
}
6264

6365
function llm_gpt_case_list_auto() {
@@ -1354,6 +1356,221 @@ function llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1() {
13541356
check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
13551357
}
13561358

1359+
function llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4() {
1360+
echo "=========== $FUNCNAME run begin ==========="
1361+
# Only A100 support this case.
1362+
if [ $IS_A100 -eq 0 ]; then
1363+
return
1364+
fi
1365+
export FLAGS_call_stack_level=3
1366+
export NVIDIA_TF32_OVERRIDE=0
1367+
export FLAGS_max_inplace_grad_add=3
1368+
1369+
task_name="llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1_MP1_PP4"
1370+
case_out_dir="output/$task_name"
1371+
case_log_dir="output/$task_name""_log"
1372+
loss1=0
1373+
loss2=0
1374+
use_pir=1
1375+
1376+
max_step=10
1377+
to_static=1
1378+
1379+
for pp_mode in "1F1B" "VPP"; do
1380+
export FLAGS_enable_pir_api=${use_pir}
1381+
export FLAGS_enable_pir_in_executor=${use_pir}
1382+
rm -rf $case_out_dir
1383+
rm -rf $case_log_dir
1384+
rm -rf ${log_path}/$FUNCNAME
1385+
if [ "$pp_mode" == "FThenB" ]; then
1386+
vpp_degree=1
1387+
else
1388+
vpp_degree=2
1389+
fi
1390+
1391+
python -u -m paddle.distributed.launch \
1392+
--gpus "0,1,2,3" \
1393+
--log_dir $case_log_dir \
1394+
run_pretrain_auto.py \
1395+
--model_type "llama" \
1396+
--model_name_or_path "facebook/llama-7b" \
1397+
--tokenizer_name_or_path "facebook/llama-7b" \
1398+
--input_dir "./data" \
1399+
--output_dir $case_out_dir \
1400+
--split 949,50,1 \
1401+
--weight_decay 0.01 \
1402+
--warmup_ratio 0.01 \
1403+
--warmup_steps 30 \
1404+
--max_grad_norm 0.0 \
1405+
--learning_rate 3e-05 \
1406+
--min_learning_rate 3e-06 \
1407+
--max_steps $max_step \
1408+
--logging_steps 1 \
1409+
--eval_steps 1000 \
1410+
--save_steps 50000 \
1411+
--continue_training 0 \
1412+
--do_train true \
1413+
--do_eval false \
1414+
--do_predict false \
1415+
--disable_tqdm true \
1416+
--skip_profile_timer true \
1417+
--save_total_limit 2 \
1418+
--device gpu \
1419+
--disable_tqdm true \
1420+
--dataloader_num_workers 1 \
1421+
--distributed_dataloader 0 \
1422+
--enable_auto_parallel 1 \
1423+
--per_device_train_batch_size 1 \
1424+
--gradient_accumulation_steps 4 \
1425+
--per_device_eval_batch_size 2 \
1426+
--recompute false \
1427+
--recompute_use_reentrant true \
1428+
--recompute_granularity full \
1429+
--fp16 0 \
1430+
--fp16_opt_level "O2" \
1431+
--fuse_attention_ffn true \
1432+
--fuse_attention_qkv true \
1433+
--fuse_sequence_parallel_allreduce false \
1434+
--use_flash_attention 0 \
1435+
--use_fused_rope false \
1436+
--use_fused_rms_norm 0 \
1437+
--max_seq_length 2048 \
1438+
--hidden_size 1024 \
1439+
--sep_parallel_degree 1 \
1440+
--sequence_parallel false \
1441+
--pipeline_parallel_degree 4 \
1442+
--sharding_parallel_degree 1 \
1443+
--tensor_parallel_degree 1 \
1444+
--sharding "" \
1445+
--to_static ${to_static} \
1446+
--num_hidden_layers 8 \
1447+
--data_parallel_config "gradient_sync_after_accumulate" \
1448+
--pipeline_schedule_mode $pp_mode \
1449+
--virtual_pp_degree $vpp_degree \
1450+
>>${log_path}/$FUNCNAME 2>&1
1451+
1452+
loss=$(grep "global_step: 10," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}')
1453+
if [ "$pp_mode" == "FThenB" ]; then
1454+
loss1=loss
1455+
else
1456+
loss2=loss
1457+
fi
1458+
echo "result: $pp_mode loss=$loss"
1459+
done
1460+
ips=-1
1461+
mem=-1
1462+
ips_base=-1
1463+
mem_base=-1
1464+
check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
1465+
echo "=========== $FUNCNAME run end ==========="
1466+
}
1467+
1468+
function llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4() {
1469+
echo "=========== $FUNCNAME run begin ==========="
1470+
export FLAGS_call_stack_level=3
1471+
export NVIDIA_TF32_OVERRIDE=0
1472+
export FLAGS_max_inplace_grad_add=3
1473+
1474+
task_name="llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1_MP1_PP4"
1475+
case_out_dir="output/$task_name"
1476+
case_log_dir="output/$task_name""_log"
1477+
loss1=0
1478+
loss2=0
1479+
loss1_array=()
1480+
loss2_array=()
1481+
use_pir=1
1482+
1483+
max_step=15
1484+
to_static=1
1485+
1486+
for to_static in "0" "1"; do
1487+
export FLAGS_enable_pir_api=${use_pir}
1488+
export FLAGS_enable_pir_in_executor=${use_pir}
1489+
1490+
case_out_dir="output/$task_name"
1491+
case_log_dir="output/$task_name""_log$to_static"
1492+
rm -rf $case_out_dir
1493+
rm -rf $case_log_dir
1494+
rm -rf ${log_path}/$FUNCNAME
1495+
1496+
python -u -m paddle.distributed.launch \
1497+
--gpus "0,1,2,3" \
1498+
--log_dir $case_log_dir \
1499+
run_pretrain_auto.py \
1500+
--model_type "llama" \
1501+
--model_name_or_path "facebook/llama-7b" \
1502+
--tokenizer_name_or_path "facebook/llama-7b" \
1503+
--input_dir "./data" \
1504+
--output_dir $case_out_dir \
1505+
--split 949,50,1 \
1506+
--weight_decay 0.01 \
1507+
--warmup_ratio 0.01 \
1508+
--warmup_steps 30 \
1509+
--max_grad_norm 0.0 \
1510+
--learning_rate 3e-05 \
1511+
--min_learning_rate 3e-06 \
1512+
--max_steps $max_step \
1513+
--logging_steps 1 \
1514+
--eval_steps 1000 \
1515+
--save_steps 50000 \
1516+
--continue_training 0 \
1517+
--do_train true \
1518+
--do_eval false \
1519+
--do_predict false \
1520+
--disable_tqdm true \
1521+
--skip_profile_timer true \
1522+
--save_total_limit 2 \
1523+
--device gpu \
1524+
--disable_tqdm true \
1525+
--dataloader_num_workers 1 \
1526+
--distributed_dataloader 0 \
1527+
--enable_auto_parallel 1 \
1528+
--per_device_train_batch_size 1 \
1529+
--gradient_accumulation_steps 2 \
1530+
--per_device_eval_batch_size 2 \
1531+
--recompute false \
1532+
--recompute_use_reentrant true \
1533+
--recompute_granularity full \
1534+
--bf16 true \
1535+
--fp16_opt_level "O2" \
1536+
--amp_master_grad true \
1537+
--amp_custom_black_list ["reduce_sum", "c_softmax_with_cross_entropy"] \
1538+
--amp_custom_white_list ["lookup_table", "lookup_table_v2"] \
1539+
--fuse_attention_ffn true \
1540+
--fuse_attention_qkv true \
1541+
--fuse_sequence_parallel_allreduce false \
1542+
--use_flash_attention 0 \
1543+
--use_fused_rope false \
1544+
--use_fused_rms_norm 0 \
1545+
--max_seq_length 2048 \
1546+
--hidden_size 1024 \
1547+
--sep_parallel_degree 1 \
1548+
--sequence_parallel false \
1549+
--pipeline_parallel_degree 4 \
1550+
--sharding_parallel_degree 1 \
1551+
--tensor_parallel_degree 1 \
1552+
--sharding "" \
1553+
--to_static ${to_static} \
1554+
--num_hidden_layers 8 \
1555+
--data_parallel_config "gradient_sync_after_accumulate" \
1556+
--pipeline_schedule_mode "FThenB" \
1557+
>>${log_path}/$FUNCNAME 2>&1
1558+
loss=$(grep "global_step: 15," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}')
1559+
if [ $to_static -eq 0 ]; then
1560+
loss1=($loss)
1561+
else
1562+
loss2=($loss)
1563+
fi
1564+
echo "result: to_static=$to_static loss=$loss"
1565+
done
1566+
ips=-1
1567+
mem=-1
1568+
ips_base=-1
1569+
mem_base=-1
1570+
check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
1571+
echo "=========== $FUNCNAME run end ==========="
1572+
}
1573+
13571574
function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() {
13581575
echo "=========== $FUNCNAME run begin ==========="
13591576
export PYTHONPATH=$root_path/:$PYTHONPATH

0 commit comments

Comments
 (0)