@@ -28,6 +28,15 @@ export llm_gpt_case_path=$root_path/llm/gpt-3/auto_parallel
2828
2929unset CUDA_VISIBLE_DEVICES
3030
31+ function is_a100() {
32+ if [ $( nvidia-smi| grep A100| wc -l) -ne 0 ]; then
33+ echo 1
34+ else
35+ echo 0
36+ fi
37+ }
38+
39+
3140function gpt_case_list_auto() {
3241 gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1
3342 gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8
@@ -108,6 +117,11 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() {
108117 loss_base=10.507633305
109118 ips_base=3518
110119 mem_base=11750.6
120+ if [ $( is_a100) ]; then
121+ loss_base=10.530449009
122+ ips_base=16763
123+ mem_base=11750.6
124+ fi
111125 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
112126 echo " =========== $FUNCNAME run end ==========="
113127}
@@ -144,6 +158,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() {
144158 loss_base=10.570028400
145159 ips_base=35050
146160 mem_base=1988.9
161+ if [ $( is_a100) ]; then
162+ loss_base=10.559662151
163+ ips_base=83918
164+ mem_base=2022.7
165+ fi
147166 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
148167 echo " =========== $FUNCNAME run end ==========="
149168}
@@ -181,6 +200,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() {
181200 loss_base=10.570028400
182201 ips_base=35050
183202 mem_base=1988.9
203+ if [ $( is_a100) ]; then
204+ loss_base=10.559662151
205+ ips_base=83918
206+ mem_base=2022.7
207+ fi
184208 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
185209 echo " =========== $FUNCNAME run end ==========="
186210}
@@ -217,6 +241,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() {
217241 loss_base=10.700293922
218242 ips_base=32518
219243 mem_base=1535.7
244+ if [ $( is_a100) ]; then
245+ loss_base=10.679453373
246+ ips_base=79116
247+ mem_base=1488.2
248+ fi
220249 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
221250 echo " =========== $FUNCNAME run end ==========="
222251}
@@ -253,6 +282,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() {
253282 loss_base=10.672543240
254283 ips_base=18681
255284 mem_base=2135.7
285+ if [ $( is_a100) ]; then
286+ loss_base=10.651049423
287+ ips_base=41174
288+ mem_base=2064.5
289+ fi
256290 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
257291 echo " =========== $FUNCNAME run end ==========="
258292}
@@ -290,6 +324,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() {
290324 loss_base=10.672543240
291325 ips_base=18681
292326 mem_base=2135.7
327+ if [ $( is_a100) ]; then
328+ loss_base=10.651049423
329+ ips_base=41174
330+ mem_base=2064.5
331+ fi
293332 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
294333 echo " =========== $FUNCNAME run end ==========="
295334}
@@ -326,6 +365,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() {
326365 loss_base=10.720068359
327366 ips_base=15232
328367 mem_base=1999.2
368+ if [ $( is_a100) ]; then
369+ loss_base=10.657777309
370+ ips_base=30027
371+ mem_base=2002.0
372+ fi
329373 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
330374 echo " =========== $FUNCNAME run end ==========="
331375}
@@ -363,6 +407,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() {
363407 loss_base=10.720068359
364408 ips_base=15232
365409 mem_base=1999.2
410+ if [ $( is_a100) ]; then
411+ loss_base=10.657777309
412+ ips_base=30027
413+ mem_base=2002.0
414+ fi
366415 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
367416 echo " =========== $FUNCNAME run end ==========="
368417}
@@ -399,6 +448,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() {
399448 loss_base=10.720078850
400449 ips_base=15571
401450 mem_base=1999.2
451+ if [ $( is_a100) ]; then
452+ loss_base=10.657803535
453+ ips_base=29166
454+ mem_base=2002.0
455+ fi
402456 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
403457 echo " =========== $FUNCNAME run end ==========="
404458}
@@ -435,6 +489,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() {
435489 loss_base=10.681921577
436490 ips_base=13813
437491 mem_base=1747.6
492+ if [ $( is_a100) ]; then
493+ loss_base=10.662137604
494+ ips_base=24700
495+ mem_base=1750.5
496+ fi
438497 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
439498 echo " =========== $FUNCNAME run end ==========="
440499}
@@ -471,6 +530,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() {
471530 loss_base=10.579057693
472531 ips_base=19822
473532 mem_base=1709.8
533+ if [ $( is_a100) ]; then
534+ loss_base=10.586785984
535+ ips_base=42813
536+ mem_base=1743.8
537+ fi
474538 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
475539 echo " =========== $FUNCNAME run end ==========="
476540}
@@ -508,6 +572,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() {
508572 loss_base=10.579057693
509573 ips_base=19822
510574 mem_base=1709.8
575+ if [ $( is_a100) ]; then
576+ loss_base=10.586785984
577+ ips_base=42813
578+ mem_base=1743.8
579+ fi
511580 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
512581 echo " =========== $FUNCNAME run end ==========="
513582}
@@ -544,6 +613,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() {
544613 loss_base=10.579057693
545614 ips_base=20170
546615 mem_base=1709.8
616+ if [ $( is_a100) ]; then
617+ loss_base=10.586785984
618+ ips_base=42995
619+ mem_base=1743.8
620+ fi
547621 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
548622 echo " =========== $FUNCNAME run end ==========="
549623}
@@ -580,6 +654,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() {
580654 loss_base=10.585316849
581655 ips_base=15742
582656 mem_base=1591.6
657+ if [ $( is_a100) ]; then
658+ loss_base=10.555718899
659+ ips_base=34688
660+ mem_base=1625.6
661+ fi
583662 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
584663 echo " =========== $FUNCNAME run end ==========="
585664}
@@ -616,6 +695,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() {
616695 loss_base=10.672568035
617696 ips_base=19461
618697 mem_base=1384.7
698+ if [ $( is_a100) ]; then
699+ loss_base=10.651032448
700+ ips_base=42435
701+ mem_base=1377.5
702+ fi
619703 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
620704 echo " =========== $FUNCNAME run end ==========="
621705}
@@ -652,6 +736,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() {
652736 loss_base=10.672568035
653737 ips_base=19652
654738 mem_base=1384.7
739+ if [ $( is_a100) ]; then
740+ loss_base=10.651032448
741+ ips_base=43008
742+ mem_base=1377.5
743+ fi
655744 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
656745 echo " =========== $FUNCNAME run end ==========="
657746}
@@ -689,6 +778,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() {
689778 loss_base=10.672568035
690779 ips_base=19652
691780 mem_base=1384.7
781+ if [ $( is_a100) ]; then
782+ loss_base=10.651032448
783+ ips_base=43008
784+ mem_base=1377.5
785+ fi
692786 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
693787 echo " =========== $FUNCNAME run end ==========="
694788}
@@ -725,6 +819,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() {
725819 loss_base=10.696336079
726820 ips_base=16613
727821 mem_base=1280.5
822+ if [ $( is_a100) ]; then
823+ loss_base=10.705118465
824+ ips_base=37104
825+ mem_base=1217.3
826+ fi
728827 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
729828 echo " =========== $FUNCNAME run end ==========="
730829}
@@ -762,6 +861,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() {
762861 loss_base=10.696336079
763862 ips_base=16613
764863 mem_base=1280.5
864+ if [ $( is_a100) ]; then
865+ loss_base=10.705118465
866+ ips_base=37104
867+ mem_base=1217.3
868+ fi
765869 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
766870 echo " =========== $FUNCNAME run end ==========="
767871}
@@ -908,6 +1012,9 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
9081012 mem=-1
9091013 echo " result: loss=$loss ips=$ips mem=$mem "
9101014 loss_base=9.52110565
1015+ if [ $( is_a100) ]; then
1016+ loss_base=9.44003963
1017+ fi
9111018 ips_base=-1
9121019 mem_base=-1
9131020 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -974,6 +1081,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
9741081 mem=-1
9751082 echo " result: loss=$loss ips=$ips mem=$mem "
9761083 loss_base=9.42011833
1084+ if [ $( is_a100) ]; then
1085+ loss_base=9.44003963
1086+ fi
9771087 ips_base=-1
9781088 mem_base=-1
9791089 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1040,6 +1150,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
10401150 mem=-1
10411151 echo " result: loss=$loss ips=$ips mem=$mem "
10421152 loss_base=9.44299471
1153+ if [ $( is_a100) ]; then
1154+ loss_base=9.45633757
1155+ fi
10431156 ips_base=-1
10441157 mem_base=-1
10451158 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1106,6 +1219,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
11061219 mem=-1
11071220 echo " result: loss=$loss ips=$ips mem=$mem "
11081221 loss_base=9.45936012
1222+ if [ $( is_a100) ]; then
1223+ loss_base=9.46121407
1224+ fi
11091225 ips_base=-1
11101226 mem_base=-1
11111227 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1174,6 +1290,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
11741290 mem=-1
11751291 echo " result: loss=$loss ips=$ips mem=$mem "
11761292 loss_base=9.46707726
1293+ if [ $( is_a100) ]; then
1294+ loss_base=9.44474411
1295+ fi
11771296 ips_base=-1
11781297 mem_base=-1
11791298 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1243,6 +1362,9 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
12431362 mem=-1
12441363 echo " result: loss=$loss ips=$ips mem=$mem "
12451364 loss_base=10.0859375
1365+ if [ $( is_a100) ]; then
1366+ loss_base=10.125
1367+ fi
12461368 ips_base=-1
12471369 mem_base=-1
12481370 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1310,6 +1432,9 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
13101432 mem=-1
13111433 echo " result: loss=$loss ips=$ips mem=$mem "
13121434 loss_base=9.53389835
1435+ if [ $( is_a100) ]; then
1436+ loss_base=9.54253578
1437+ fi
13131438 ips_base=-1
13141439 mem_base=-1
13151440 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1377,6 +1502,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
13771502 mem=-1
13781503 echo " result: loss=$loss ips=$ips mem=$mem "
13791504 loss_base=9.39066124
1505+ if [ $( is_a100) ]; then
1506+ loss_base=9.41613197
1507+ fi
13801508 ips_base=-1
13811509 mem_base=-1
13821510 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1444,6 +1572,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
14441572 mem=-1
14451573 echo " result: loss=$loss ips=$ips mem=$mem "
14461574 loss_base=9.38235474
1575+ if [ $( is_a100) ]; then
1576+ loss_base=9.4053154
1577+ fi
14471578 ips_base=-1
14481579 mem_base=-1
14491580 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1512,6 +1643,9 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
15121643 mem=-1
15131644 echo " result: loss=$loss ips=$ips mem=$mem "
15141645 loss_base=9.38256836
1646+ if [ $( is_a100) ]; then
1647+ loss_base=9.4055137
1648+ fi
15151649 ips_base=-1
15161650 mem_base=-1
15171651 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
0 commit comments