Skip to content

Commit 40ed715

Browse files
HeyangQinlekurile
authored andcommitted
Mixed Precision ZeRO++ (deepspeedai#689)
Co-authored-by: Lev Kurilenko <[email protected]>
1 parent 62adcab commit 40ed715

File tree

6 files changed

+103
-26
lines changed

6 files changed

+103
-26
lines changed

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ def parse_args():
304304
parser.add_argument('--enable_ema',
305305
action='store_true',
306306
help='Enable EMA checkpoint for the model.')
307+
## Mixed Precision LoRA
308+
parser.add_argument(
309+
'--enable_mixed_precision_lora',
310+
action='store_true',
311+
help='Enable Mixed Precision LoRA for training and generation.')
307312
## Tensorboard logging
308313
parser.add_argument('--enable_tensorboard',
309314
action='store_true',
@@ -444,6 +449,13 @@ def main():
444449
num_total_iters=num_total_iters,
445450
args=args)
446451

452+
# Mixed Precision LoRA
453+
if args.enable_mixed_precision_lora:
454+
assert args.actor_lora_dim > 0, "Mixed Precision LoRA requires LoRA to be enabled"
455+
assert args.actor_zero_stage == 3, "Mixed Precision LoRA requires Zero stage 3"
456+
rlhf_engine.actor.optimizer.quantize_nontrainable_params()
457+
print_rank_0("Mixed Precision LoRA enabled")
458+
447459
args.end_of_conversation_token = "<|endoftext|>"
448460

449461
ppo_trainer = DeepSpeedPPOTrainerUnsupervised if unsupervised_training_enabled else DeepSpeedPPOTrainer

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _init_actor(self, actor_model_name_or_path):
7575
max_out_tokens=self.args.max_prompt_seq_len +
7676
self.args.max_answer_seq_len,
7777
enable_tensorboard=self.args.enable_tensorboard,
78+
enable_mixed_precision_lora=self.args.enable_mixed_precision_lora,
7879
tb_path=self.args.tensorboard_path,
7980
tb_name="step3_actor")
8081
ds_config[

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/sweep/run_single.sh

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ CRITIC_ZERO_STAGE=$4
1010
ENABLE_HYBRID_ENGINE=$5
1111
OFFLOAD=$6
1212
LORA=$7
13-
OUTPUT=$8
14-
TEST=$9
15-
TEST_STOP_STEP=${10}
13+
MIXED_PRECISION_LORA=$8
14+
OUTPUT=$9
15+
TEST=${10}
16+
TEST_STOP_STEP=${11}
1617

1718
if [ "$ACTOR_ZERO_STAGE" == "" ]; then
1819
ACTOR_ZERO_STAGE=2
@@ -42,6 +43,12 @@ else
4243
ACTOR_LORA_MODULE_NAME=""
4344
fi
4445

46+
if [ "$MIXED_PRECISION_LORA" == true ]; then
47+
MIXED_PRECISION_LORA="--enable_mixed_precision_lora"
48+
else
49+
MIXED_PRECISION_LORA=""
50+
fi
51+
4552
if [ "$TEST" == true ]; then
4653
TEST="--enable_test_mode"
4754
TEST_STOP_STEP="--test_stop_step ${TEST_STOP_STEP}"
@@ -83,7 +90,7 @@ cmd="deepspeed --num_nodes=1 main.py \
8390
--actor_zero_stage ${ACTOR_ZERO_STAGE} \
8491
--critic_zero_stage ${CRITIC_ZERO_STAGE} \
8592
--output_dir $OUTPUT \
86-
$ENABLE_HYBRID_ENGINE $OFFLOAD $UNPIN_ACTOR_PARAMETERS \
93+
$ENABLE_HYBRID_ENGINE $OFFLOAD $MIXED_PRECISION_LORA \
8794
$ACTOR_LORA_DIM $ACTOR_LORA_MODULE_NAME\
8895
$TEST $TEST_STOP_STEP"
8996

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/sweep/run_step3_sweep.sh

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,76 @@
66
ACTOR_MODEL_PATH="AdamG012/chat-opt-1.3b-sft-deepspeed"
77
CRITIC_MODEL_PATH="AdamG012/chat-opt-350m-reward-deepspeed"
88

9-
for z in {2..3}
10-
do
11-
for he in true false
9+
# Sweep switches
10+
RUN_GENERIC_SWEEP=true
11+
RUN_MPL_SWEEP=true
12+
13+
# Kill any existing Python processes
14+
pkill -9 python
15+
sleep 300
16+
17+
# Run generic sweep w/o Mixed Precision ZeRO++
18+
if [ "$RUN_GENERIC_SWEEP" == true ]; then
19+
echo "----------------------------- RUNNING GENERIC SWEEPS -----------------------------"
20+
echo ""
21+
for z in {2..3}
1222
do
13-
for offload in true false
23+
for he in true false
1424
do
15-
for lora in true false
25+
for offload in true false
1626
do
17-
cmd="bash training_scripts/opt/single_node/sweep/run_single.sh \
18-
$ACTOR_MODEL_PATH \
19-
$CRITIC_MODEL_PATH \
20-
${z} \
21-
${z} \
22-
${he} \
23-
${offload} \
24-
${lora} \
25-
z${z}_he_${he}_offload_${offload}_lora_${lora}"
26-
echo "----------------------------- CALLING SHELL SCRIPT -----------------------------"
27-
echo $cmd
28-
$cmd
29-
pkill -9 python
30-
sleep 60
31-
echo ""
27+
for lora in true false
28+
do
29+
mixed_precision_lora=false
30+
cmd="bash training_scripts/opt/single_node/sweep/run_single.sh \
31+
$ACTOR_MODEL_PATH \
32+
$CRITIC_MODEL_PATH \
33+
${z} \
34+
${z} \
35+
${he} \
36+
${offload} \
37+
${lora} \
38+
${mixed_precision_lora} \
39+
z${z}_he_${he}_offload_${offload}_lora_${lora}"
40+
echo "----------------------------- CALLING SHELL SCRIPT -----------------------------"
41+
echo $cmd
42+
$cmd
43+
pkill -9 python
44+
sleep 300
45+
echo ""
46+
done
3247
done
3348
done
3449
done
35-
done
50+
echo ""
51+
fi
52+
53+
# Run Mixed Precision ZeRO++ sweep
54+
if [ "$RUN_MPL_SWEEP" == true ]; then
55+
echo "----------------------------- RUNNING MIXED PRECISION ZERO++ SWEEPS -----------------------------"
56+
echo ""
57+
for he in true false
58+
do
59+
z=3
60+
offload=false
61+
lora=true
62+
mixed_precision_lora=true
63+
cmd="bash training_scripts/opt/single_node/sweep/run_single.sh \
64+
$ACTOR_MODEL_PATH \
65+
$CRITIC_MODEL_PATH \
66+
${z} \
67+
${z} \
68+
${he} \
69+
${offload} \
70+
${lora} \
71+
${mixed_precision_lora} \
72+
z${z}_he_${he}_offload_${offload}_lora_${lora}_mpl_${mixed_precision_lora}"
73+
echo "----------------------------- CALLING SHELL SCRIPT -----------------------------"
74+
echo $cmd
75+
$cmd
76+
pkill -9 python
77+
sleep 300
78+
echo ""
79+
done
80+
echo ""
81+
fi

applications/DeepSpeed-Chat/training/tests/test_training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@ def lora(request):
3333

3434

3535
def test_ds_chat(zero_stage, hybrid_engine, offload, lora):
36+
# Assert that critic model directory exists
3637
critic_ckpt_dir = os.getenv("CRITIC_CKPT_DIR")
3738
assert critic_ckpt_dir, "Please set CRITIC_CKPT_DIR in your environment"
3839

40+
# Setup params
3941
actor_model = "facebook/opt-125m"
4042
critic_model = critic_ckpt_dir
41-
output_path = "z" + zero_stage + "_he_" + hybrid_engine + "_offload_" + offload + "_lora_" + lora
43+
mixed_precision_lora = "false"
4244
enable_test_mode = "true"
4345
test_stop_step = "5"
46+
output_path = "z" + zero_stage + "_he_" + hybrid_engine + "_offload_" + offload + "_lora_" + lora
4447
params = [
4548
actor_model,
4649
critic_model,
@@ -49,6 +52,7 @@ def test_ds_chat(zero_stage, hybrid_engine, offload, lora):
4952
hybrid_engine,
5053
offload,
5154
lora,
55+
mixed_precision_lora,
5256
output_path,
5357
enable_test_mode,
5458
test_stop_step,

applications/DeepSpeed-Chat/training/utils/ds_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
# DeepSpeed Team
5+
6+
import torch
7+
58
GLOBAL_BATCH_SIZE = 32
69
MICRO_BATCH_SIZE = 4
710

@@ -15,6 +18,7 @@ def get_train_ds_config(offload,
1518
tp_gather_partition_size=8,
1619
max_out_tokens=512,
1720
enable_tensorboard=False,
21+
enable_mixed_precision_lora=False,
1822
tb_path="",
1923
tb_name=""):
2024

@@ -32,6 +36,9 @@ def get_train_ds_config(offload,
3236
"stage3_prefetch_bucket_size": 3e7,
3337
"memory_efficient_linear": False
3438
}
39+
if enable_mixed_precision_lora:
40+
zero_opt_dict["zero_quantized_nontrainable_weights"] = True
41+
zero_opt_dict["zero_hpz_partition_size"] = torch.cuda.device_count()
3542
return {
3643
"train_batch_size": GLOBAL_BATCH_SIZE,
3744
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,

0 commit comments

Comments
 (0)