Skip to content

Commit 089a3c3

Browse files
authored
[Benchmark] Fix amp level bug in some gpt tests (#9116)
1 parent 5c1779c commit 089a3c3

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp/benchmark_common/run_benchmark.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ function _set_params(){
3535
sharding_degree=${10:-"1"} # (可选)
3636
sharding_stage=${11:-"1"} # (可选)sharding case
3737
level=${12:-"o1"} # o1|o2|o3
38+
39+
if [[ $FLAGS_enable_pir_api == "1" || $FLAGS_enable_pir_api == "True" ]]; then
40+
if [ ${level} == "o3" ]; then
41+
level="o2"
42+
echo "amp level changed to o2 in pir mode"
43+
else
44+
echo "amp level is o3"
45+
fi
46+
else
47+
echo "FLAGS_enable_pir_api = 0"
48+
fi
49+
3850
local_batch_size=${13:-"8"} # (可选)本地batch size
3951
schedule_mode=${14:-"1F1B"} # (可选)schedule mode
4052
base_batch_size=$global_batch_size

legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp_pir/benchmark_common/run_benchmark.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ function _set_params(){
3535
sharding_degree=${10:-"1"} # (可选)
3636
sharding_stage=${11:-"1"} # (可选)sharding case
3737
level=${12:-"o1"} # o1|o2|o3
38+
39+
if [[ $FLAGS_enable_pir_api == "1" || $FLAGS_enable_pir_api == "True" ]]; then
40+
if [ ${level} == "o3" ]; then
41+
level="o2"
42+
echo "amp level changed to o2 in pir mode"
43+
else
44+
echo "amp level is o3"
45+
fi
46+
else
47+
echo "FLAGS_enable_pir_api = 0"
48+
fi
49+
3850
local_batch_size=${13:-"8"} # (可选)本地batch size
3951
schedule_mode=${14:-"1F1B"} # (可选)schedule mode
4052
base_batch_size=$global_batch_size

0 commit comments

Comments
 (0)