Skip to content

Commit 327d788

Browse files
committed
add ci
1 parent 1249c74 commit 327d788

File tree

2 files changed

+107
-12
lines changed

2 files changed

+107
-12
lines changed

llm/llama/auto_parallel/run_pretrain_3D_auto.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,21 @@ class ModelArguments:
140140
config_name: Optional[str] = field(
141141
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
142142
)
143+
vocab_size: Optional[int] = field(
144+
default=None,
145+
metadata={
146+
"help": ".Vocabulary size of the Llama model. Defines the number of different tokens that can be represented by the `inputs_ids`"
147+
},
148+
)
149+
hidden_size: Optional[int] = field(default=None, metadata={"help": "Dimension of the hidden representations."})
150+
intermediate_size: Optional[int] = field(default=None, metadata={"help": "Dimension of the MLP representations."})
151+
num_hidden_layers: Optional[int] = field(
152+
default=None, metadata={"help": "Number of hidden layers in the Transformer encoder."}
153+
)
154+
num_attention_heads: Optional[int] = field(
155+
default=None,
156+
metadata={"help": "Number of attention heads for each attention layer in the Transformer encoder."},
157+
)
143158
use_flash_attention: bool = field(
144159
default=False,
145160
metadata={"help": "use_flash_attention"},
@@ -443,6 +458,17 @@ def main():
443458
if model_args.no_recompute_layers is not None:
444459
model_args.no_recompute_layers.sort()
445460

461+
config.hidden_size = model_args.hidden_size if model_args.hidden_size is not None else config.hidden_size
462+
config.intermediate_size = (
463+
model_args.intermediate_size if model_args.intermediate_size is not None else config.intermediate_size
464+
)
465+
config.num_hidden_layers = (
466+
model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers
467+
)
468+
config.num_attention_heads = (
469+
model_args.num_attention_heads if model_args.num_attention_heads is not None else config.num_attention_heads
470+
)
471+
446472
config.use_flash_attention = model_args.use_flash_attention
447473
config.use_fused_rms_norm = model_args.use_fused_rms_norm
448474
config.fuse_attention_qkv = model_args.fuse_attention_qkv
@@ -606,8 +632,8 @@ def loss_func(loss, outputs):
606632
)
607633
tr_loss = 0
608634

609-
if global_step // training_args.gradient_accumulation_steps >= 1:
610-
sys.exit(0)
635+
if global_step // training_args.gradient_accumulation_steps >= training_args.max_steps:
636+
break
611637

612638
global_step += 1
613639

scripts/distribute/ci_case_auto.sh

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ function gpt_case_list_auto() {
4545
}
4646

4747
function llama_case_list_auto() {
48-
llama_auto_recompute_bs8_fp32_DP1-MP1-PP1
49-
llama_auto_recompute_bs16_fp32_DP2-MP1-PP1
50-
llama_auto_recompute_bs16_fp32_DP2-MP2-PP1
51-
llama_auto_recompute_bs16_fp32_DP2-MP2-PP2
52-
llama_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
48+
llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2
49+
50+
llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1
51+
llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1
52+
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1
53+
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2
54+
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
5355
}
5456

5557
function gpt_case_list_auto_pir() {
@@ -834,7 +836,7 @@ function gpt_auto_sp_acc_check() {
834836
echo "=========== $FUNCNAME run end ==========="
835837
}
836838

837-
function llama_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
839+
function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
838840
echo "=========== $FUNCNAME run begin ==========="
839841
export PYTHONPATH=$root_path/:$PYTHONPATH
840842
export FLAGS_call_stack_level=2
@@ -900,7 +902,7 @@ function llama_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
900902
echo "=========== $FUNCNAME run end ==========="
901903
}
902904

903-
function llama_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
905+
function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
904906
echo "=========== $FUNCNAME run begin ==========="
905907
export PYTHONPATH=$root_path/:$PYTHONPATH
906908
export FLAGS_call_stack_level=2
@@ -966,7 +968,7 @@ function llama_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
966968
echo "=========== $FUNCNAME run end ==========="
967969
}
968970

969-
function llama_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
971+
function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
970972
echo "=========== $FUNCNAME run begin ==========="
971973
export PYTHONPATH=$root_path/:$PYTHONPATH
972974
export FLAGS_call_stack_level=2
@@ -1032,7 +1034,7 @@ function llama_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
10321034
echo "=========== $FUNCNAME run end ==========="
10331035
}
10341036

1035-
function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
1037+
function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
10361038
echo "=========== $FUNCNAME run begin ==========="
10371039
export PYTHONPATH=$root_path/:$PYTHONPATH
10381040
export FLAGS_call_stack_level=2
@@ -1098,7 +1100,7 @@ function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
10981100
echo "=========== $FUNCNAME run end ==========="
10991101
}
11001102

1101-
function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2() {
1103+
function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2() {
11021104
echo "=========== $FUNCNAME run begin ==========="
11031105
export PYTHONPATH=$root_path/:$PYTHONPATH
11041106
export FLAGS_call_stack_level=2
@@ -1165,6 +1167,73 @@ function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2() {
11651167
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
11661168
echo "=========== $FUNCNAME run end ==========="
11671169
}
1170+
1171+
function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() {
1172+
echo "=========== $FUNCNAME run begin ==========="
1173+
export PYTHONPATH=$root_path/:$PYTHONPATH
1174+
export FLAGS_call_stack_level=3
1175+
export NVIDIA_TF32_OVERRIDE=0
1176+
1177+
task_name="llama_auto_bs16_dp2mp2pp2"
1178+
case_out_dir="output/$task_name"
1179+
case_log_dir="output/$task_name""_log"
1180+
rm -rf $case_out_dir
1181+
rm -rf $case_log_dir
1182+
1183+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_3D_auto.py \
1184+
--model_type "llama" \
1185+
--model_name_or_path "facebook/llama-7b" \
1186+
--tokenizer_name_or_path "facebook/llama-7b" \
1187+
--input_dir "./data" \
1188+
--output_dir $case_out_dir \
1189+
--split 949,50,1 \
1190+
--max_seq_length 2048 \
1191+
--hidden_size 1024 \
1192+
--intermediate_size 3072 \
1193+
--num_hidden_layers 8 \
1194+
--num_attention_heads 32 \
1195+
--per_device_train_batch_size 1 \
1196+
--per_device_eval_batch_size 2 \
1197+
--gradient_accumulation_steps 2 \
1198+
--use_flash_attention 0 \
1199+
--use_fused_rms_norm 0 \
1200+
--fp16 0 \
1201+
--fp16_opt_level "O2" \
1202+
--scale_loss 1024 \
1203+
--pipeline_parallel_degree 2 \
1204+
--tensor_parallel_degree 2 \
1205+
--sharding_parallel_degree 1 \
1206+
--learning_rate 0.0001 \
1207+
--min_learning_rate 0.00001 \
1208+
--max_steps 10 \
1209+
--save_steps 5000 \
1210+
--weight_decay 0.01 \
1211+
--warmup_ratio 0.01 \
1212+
--logging_steps 1 \
1213+
--dataloader_num_workers 1 \
1214+
--sharding "" \
1215+
--eval_steps 1000000 \
1216+
--disable_tqdm true \
1217+
--continue_training 0 \
1218+
--recompute 0 \
1219+
--do_train \
1220+
--do_eval \
1221+
--device "gpu" \
1222+
--data_impl "mmap" \
1223+
--parallel_mode "auto" \
1224+
--max_grad_norm 1.0 \
1225+
>>${log_path}/$FUNCNAME 2>&1
1226+
loss=`cat $case_log_dir/workerlog.2 | grep 'global_step 10' | awk -F '; loss' '{print $2}' | awk -F 'lr' '{print $1}'`
1227+
ips=-1
1228+
mem=-1
1229+
echo "result: loss=$loss ips=$ips mem=$mem"
1230+
loss_base=9.543781280517578
1231+
ips_base=-1
1232+
mem_base=-1
1233+
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
1234+
echo "=========== $FUNCNAME run end ==========="
1235+
}
1236+
11681237
############ case end ############
11691238

11701239
function check_result() {

0 commit comments

Comments
 (0)