@@ -27,6 +27,15 @@ export llama_data_path=/llama_data
2727
2828unset CUDA_VISIBLE_DEVICES
2929
30+ function is_a100() {
31+ if [ $( nvidia-smi| grep A100| wc -l) -ne 0 ]; then
32+ echo 1
33+ else
34+ echo 0
35+ fi
36+ }
37+
38+
3039function gpt_case_list_auto() {
3140 gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1
3241 gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8
@@ -100,6 +109,11 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() {
100109 loss_base=10.507633305
101110 ips_base=3518
102111 mem_base=11750.6
112+ if [ $( is_a100) ]; then
113+ loss_base=10.530449009
114+ ips_base=16763
115+ mem_base=11750.6
116+ fi
103117 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
104118 echo " =========== $FUNCNAME run end ==========="
105119}
@@ -136,6 +150,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() {
136150 loss_base=10.570028400
137151 ips_base=35050
138152 mem_base=1988.9
153+ if [ $( is_a100) ]; then
154+ loss_base=10.559662151
155+ ips_base=83918
156+ mem_base=2022.7
157+ fi
139158 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
140159 echo " =========== $FUNCNAME run end ==========="
141160}
@@ -173,6 +192,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() {
173192 loss_base=10.570028400
174193 ips_base=35050
175194 mem_base=1988.9
195+ if [ $( is_a100) ]; then
196+ loss_base=10.559662151
197+ ips_base=83918
198+ mem_base=2022.7
199+ fi
176200 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
177201 echo " =========== $FUNCNAME run end ==========="
178202}
@@ -209,6 +233,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() {
209233 loss_base=10.700293922
210234 ips_base=32518
211235 mem_base=1535.7
236+ if [ $( is_a100) ]; then
237+ loss_base=10.679453373
238+ ips_base=79116
239+ mem_base=1488.2
240+ fi
212241 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
213242 echo " =========== $FUNCNAME run end ==========="
214243}
@@ -245,6 +274,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() {
245274 loss_base=10.672543240
246275 ips_base=18681
247276 mem_base=2135.7
277+ if [ $( is_a100) ]; then
278+ loss_base=10.651049423
279+ ips_base=41174
280+ mem_base=2064.5
281+ fi
248282 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
249283 echo " =========== $FUNCNAME run end ==========="
250284}
@@ -282,6 +316,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() {
282316 loss_base=10.672543240
283317 ips_base=18681
284318 mem_base=2135.7
319+ if [ $( is_a100) ]; then
320+ loss_base=10.651049423
321+ ips_base=41174
322+ mem_base=2064.5
323+ fi
285324 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
286325 echo " =========== $FUNCNAME run end ==========="
287326}
@@ -318,6 +357,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() {
318357 loss_base=10.720068359
319358 ips_base=15232
320359 mem_base=1999.2
360+ if [ $( is_a100) ]; then
361+ loss_base=10.657777309
362+ ips_base=30027
363+ mem_base=2002.0
364+ fi
321365 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
322366 echo " =========== $FUNCNAME run end ==========="
323367}
@@ -355,6 +399,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() {
355399 loss_base=10.720068359
356400 ips_base=15232
357401 mem_base=1999.2
402+ if [ $( is_a100) ]; then
403+ loss_base=10.657777309
404+ ips_base=30027
405+ mem_base=2002.0
406+ fi
358407 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
359408 echo " =========== $FUNCNAME run end ==========="
360409}
@@ -391,6 +440,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() {
391440 loss_base=10.720078850
392441 ips_base=15571
393442 mem_base=1999.2
443+ if [ $( is_a100) ]; then
444+ loss_base=10.657803535
445+ ips_base=29166
446+ mem_base=2002.0
447+ fi
394448 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
395449 echo " =========== $FUNCNAME run end ==========="
396450}
@@ -427,6 +481,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() {
427481 loss_base=10.681921577
428482 ips_base=13813
429483 mem_base=1747.6
484+ if [ $( is_a100) ]; then
485+ loss_base=10.662137604
486+ ips_base=24700
487+ mem_base=1750.5
488+ fi
430489 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
431490 echo " =========== $FUNCNAME run end ==========="
432491}
@@ -463,6 +522,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() {
463522 loss_base=10.579057693
464523 ips_base=19822
465524 mem_base=1709.8
525+ if [ $( is_a100) ]; then
526+ loss_base=10.586785984
527+ ips_base=42813
528+ mem_base=1743.8
529+ fi
466530 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
467531 echo " =========== $FUNCNAME run end ==========="
468532}
@@ -500,6 +564,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() {
500564 loss_base=10.579057693
501565 ips_base=19822
502566 mem_base=1709.8
567+ if [ $( is_a100) ]; then
568+ loss_base=10.586785984
569+ ips_base=42813
570+ mem_base=1743.8
571+ fi
503572 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
504573 echo " =========== $FUNCNAME run end ==========="
505574}
@@ -536,6 +605,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() {
536605 loss_base=10.579057693
537606 ips_base=20170
538607 mem_base=1709.8
608+ if [ $( is_a100) ]; then
609+ loss_base=10.586785984
610+ ips_base=42995
611+ mem_base=1743.8
612+ fi
539613 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
540614 echo " =========== $FUNCNAME run end ==========="
541615}
@@ -572,6 +646,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() {
572646 loss_base=10.585316849
573647 ips_base=15742
574648 mem_base=1591.6
649+ if [ $( is_a100) ]; then
650+ loss_base=10.555718899
651+ ips_base=34688
652+ mem_base=1625.6
653+ fi
575654 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
576655 echo " =========== $FUNCNAME run end ==========="
577656}
@@ -608,6 +687,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() {
608687 loss_base=10.672568035
609688 ips_base=19461
610689 mem_base=1384.7
690+ if [ $( is_a100) ]; then
691+ loss_base=10.651032448
692+ ips_base=42435
693+ mem_base=1377.5
694+ fi
611695 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
612696 echo " =========== $FUNCNAME run end ==========="
613697}
@@ -644,6 +728,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() {
644728 loss_base=10.672568035
645729 ips_base=19652
646730 mem_base=1384.7
731+ if [ $( is_a100) ]; then
732+ loss_base=10.651032448
733+ ips_base=43008
734+ mem_base=1377.5
735+ fi
647736 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
648737 echo " =========== $FUNCNAME run end ==========="
649738}
@@ -681,6 +770,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() {
681770 loss_base=10.672568035
682771 ips_base=19652
683772 mem_base=1384.7
773+ if [ $( is_a100) ]; then
774+ loss_base=10.651032448
775+ ips_base=43008
776+ mem_base=1377.5
777+ fi
684778 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
685779 echo " =========== $FUNCNAME run end ==========="
686780}
@@ -717,6 +811,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() {
717811 loss_base=10.696336079
718812 ips_base=16613
719813 mem_base=1280.5
814+ if [ $( is_a100) ]; then
815+ loss_base=10.705118465
816+ ips_base=37104
817+ mem_base=1217.3
818+ fi
720819 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
721820 echo " =========== $FUNCNAME run end ==========="
722821}
@@ -754,6 +853,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() {
754853 loss_base=10.696336079
755854 ips_base=16613
756855 mem_base=1280.5
856+ if [ $( is_a100) ]; then
857+ loss_base=10.705118465
858+ ips_base=37104
859+ mem_base=1217.3
860+ fi
757861 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
758862 echo " =========== $FUNCNAME run end ==========="
759863}
@@ -900,6 +1004,9 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
9001004 mem=-1
9011005 echo " result: loss=$loss ips=$ips mem=$mem "
9021006 loss_base=9.52110565
1007+ if [ $( is_a100) ]; then
1008+ loss_base=9.44003963
1009+ fi
9031010 ips_base=-1
9041011 mem_base=-1
9051012 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -966,6 +1073,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
9661073 mem=-1
9671074 echo " result: loss=$loss ips=$ips mem=$mem "
9681075 loss_base=9.42011833
1076+ if [ $( is_a100) ]; then
1077+ loss_base=9.44003963
1078+ fi
9691079 ips_base=-1
9701080 mem_base=-1
9711081 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1032,6 +1142,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
10321142 mem=-1
10331143 echo " result: loss=$loss ips=$ips mem=$mem "
10341144 loss_base=9.44299471
1145+ if [ $( is_a100) ]; then
1146+ loss_base=9.45633757
1147+ fi
10351148 ips_base=-1
10361149 mem_base=-1
10371150 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1098,6 +1211,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
10981211 mem=-1
10991212 echo " result: loss=$loss ips=$ips mem=$mem "
11001213 loss_base=9.45936012
1214+ if [ $( is_a100) ]; then
1215+ loss_base=9.46121407
1216+ fi
11011217 ips_base=-1
11021218 mem_base=-1
11031219 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1166,6 +1282,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
11661282 mem=-1
11671283 echo " result: loss=$loss ips=$ips mem=$mem "
11681284 loss_base=9.46707726
1285+ if [ $( is_a100) ]; then
1286+ loss_base=9.44474411
1287+ fi
11691288 ips_base=-1
11701289 mem_base=-1
11711290 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1235,6 +1354,9 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
12351354 mem=-1
12361355 echo " result: loss=$loss ips=$ips mem=$mem "
12371356 loss_base=10.0859375
1357+ if [ $( is_a100) ]; then
1358+ loss_base=10.125
1359+ fi
12381360 ips_base=-1
12391361 mem_base=-1
12401362 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1302,6 +1424,9 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
13021424 mem=-1
13031425 echo " result: loss=$loss ips=$ips mem=$mem "
13041426 loss_base=9.53389835
1427+ if [ $( is_a100) ]; then
1428+ loss_base=9.54253578
1429+ fi
13051430 ips_base=-1
13061431 mem_base=-1
13071432 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1369,6 +1494,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
13691494 mem=-1
13701495 echo " result: loss=$loss ips=$ips mem=$mem "
13711496 loss_base=9.39066124
1497+ if [ $( is_a100) ]; then
1498+ loss_base=9.41613197
1499+ fi
13721500 ips_base=-1
13731501 mem_base=-1
13741502 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1436,6 +1564,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
14361564 mem=-1
14371565 echo " result: loss=$loss ips=$ips mem=$mem "
14381566 loss_base=9.38235474
1567+ if [ $( is_a100) ]; then
1568+ loss_base=9.4053154
1569+ fi
14391570 ips_base=-1
14401571 mem_base=-1
14411572 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1504,6 +1635,9 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
15041635 mem=-1
15051636 echo " result: loss=$loss ips=$ips mem=$mem "
15061637 loss_base=9.38256836
1638+ if [ $( is_a100) ]; then
1639+ loss_base=9.4055137
1640+ fi
15071641 ips_base=-1
15081642 mem_base=-1
15091643 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
0 commit comments