Skip to content

Commit f1df0b1

Browse files
awan-10lekurile
andauthored
10-20x faster load checkpoint (for critic/reward model) (#675)
Co-authored-by: Lev Kurilenko <[email protected]>
1 parent f6a988e commit f1df0b1

File tree

22 files changed

+273
-28
lines changed

22 files changed

+273
-28
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
OUTPUT=$1
77
ZERO_STAGE=$2
88
if [ "$OUTPUT" == "" ]; then
9-
OUTPUT=./output_step1_llama_7b_epoch4_lr9.65e-6_test
9+
OUTPUT=./output_step1_llama2_7b
1010
fi
1111
if [ "$ZERO_STAGE" == "" ]; then
1212
ZERO_STAGE=3
1313
fi
1414
mkdir -p $OUTPUT
1515

16-
deepspeed --include="worker-1" main.py \
16+
deepspeed main.py \
1717
--data_path Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets \
1818
--data_split 2,4,4 \
19-
--model_name_or_path decapoda-research/llama-7b-hf \
19+
--model_name_or_path meta-llama/Llama-2-7b-hf \
2020
--per_device_train_batch_size 4 \
2121
--per_device_eval_batch_size 4 \
2222
--max_seq_len 512 \
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/bin/bash
2+
# Copyright (c) Microsoft Corporation.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
# DeepSpeed Team
6+
OUTPUT=$1
7+
ZERO_STAGE=$2
8+
if [ "$OUTPUT" == "" ]; then
9+
OUTPUT=./output_step1_llama2_7b_lora
10+
fi
11+
if [ "$ZERO_STAGE" == "" ]; then
12+
ZERO_STAGE=3
13+
fi
14+
mkdir -p $OUTPUT
15+
16+
deepspeed main.py \
17+
--data_path Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets \
18+
--data_split 2,4,4 \
19+
--model_name_or_path meta-llama/Llama-2-7b-hf \
20+
--per_device_train_batch_size 4 \
21+
--per_device_eval_batch_size 4 \
22+
--max_seq_len 512 \
23+
--learning_rate 9.65e-6 \
24+
--weight_decay 0. \
25+
--num_train_epochs 4 \
26+
--gradient_accumulation_steps 1 \
27+
--lr_scheduler_type cosine \
28+
--num_warmup_steps 0 \
29+
--seed 1234 \
30+
--gradient_checkpointing \
31+
--zero_stage $ZERO_STAGE \
32+
--deepspeed \
33+
--lora_dim 128 \
34+
--lora_module_name "layers." \
35+
--output_dir $OUTPUT \
36+
&> $OUTPUT/training.log
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ fi
1414
mkdir -p $OUTPUT
1515

1616
deepspeed main.py \
17-
--data_path Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets \
17+
--data_path Dahoas/rm-static \
1818
--data_split 2,4,4 \
19-
--model_name_or_path decapoda-research/llama-7b-hf \
19+
--model_name_or_path meta-llama/Llama-2-7b-hf \
2020
--per_device_train_batch_size 8 \
2121
--per_device_eval_batch_size 8 \
2222
--max_seq_len 512 \
@@ -31,5 +31,6 @@ deepspeed main.py \
3131
--gradient_checkpointing \
3232
--zero_stage $ZERO_STAGE \
3333
--deepspeed \
34+
--offload \
3435
--output_dir $OUTPUT \
3536
&> $OUTPUT/training.log
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/bin/bash
2+
# Copyright (c) Microsoft Corporation.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
# DeepSpeed Team
6+
OUTPUT=$1
7+
ZERO_STAGE=$2
8+
if [ "$OUTPUT" == "" ]; then
9+
OUTPUT=./output_step2_llama_7b_epoch1_lr9.65e-6
10+
fi
11+
if [ "$ZERO_STAGE" == "" ]; then
12+
ZERO_STAGE=3
13+
fi
14+
mkdir -p $OUTPUT
15+
16+
deepspeed main.py \
17+
--data_path Dahoas/rm-static \
18+
--data_split 2,4,4 \
19+
--model_name_or_path meta-llama/Llama-2-7b-hf \
20+
--per_device_train_batch_size 8 \
21+
--per_device_eval_batch_size 8 \
22+
--max_seq_len 512 \
23+
--learning_rate 9.65e-6 \
24+
--weight_decay 0.1 \
25+
--num_padding_at_beginning 0 \
26+
--num_train_epochs 1 \
27+
--gradient_accumulation_steps 1 \
28+
--lr_scheduler_type cosine \
29+
--num_warmup_steps 0 \
30+
--seed 1234 \
31+
--gradient_checkpointing \
32+
--zero_stage $ZERO_STAGE \
33+
--deepspeed \
34+
--offload \
35+
--lora_dim 128 \
36+
--lora_module_name "layers." \
37+
--output_dir $OUTPUT \
38+
&> $OUTPUT/training.log

0 commit comments

Comments
 (0)