Skip to content

Commit 6948d59

Browse files
authored
DS Chat Step 3 Unit Test (#677)
This PR adds a Step 3 DS Chat Unit Test.
1 parent f610094 commit 6948d59

File tree

4 files changed

+117
-3
lines changed

4 files changed

+117
-3
lines changed

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ def parse_args():
331331
"--test_stop_step",
332332
type=int,
333333
default=0,
334-
help="Training step at which to terminate training during testing.")
334+
help=
335+
"Training non-overflow step at which to terminate training during testing."
336+
)
335337

336338
parser = deepspeed.add_config_arguments(parser)
337339
args = parser.parse_args()
@@ -461,6 +463,8 @@ def main():
461463
# Train!
462464
print_rank_0("***** Running training *****", args.global_rank)
463465

466+
non_overflow_step_count = 0
467+
464468
for epoch in range(args.num_train_epochs):
465469
print_rank_0(
466470
f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Generation Batches {min(len(prompt_train_dataloader), len(unsupervised_train_dataloader))}",
@@ -547,7 +551,12 @@ def main():
547551
if args.actor_gradient_checkpointing:
548552
rlhf_engine.actor.gradient_checkpointing_disable()
549553

550-
if args.enable_test_mode and step == args.test_stop_step:
554+
actor_overflow, critic_overflow = trainer.get_overflow()
555+
556+
if not actor_overflow and not critic_overflow:
557+
non_overflow_step_count += 1
558+
559+
if args.enable_test_mode and non_overflow_step_count == args.test_stop_step:
551560
break
552561

553562
if args.enable_test_mode:

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,12 @@ def train_rlhf(self, inputs):
235235

236236
return actor_loss, critic_loss
237237

238+
def get_overflow(self):
239+
actor_overflow = self.actor_model.optimizer.overflow
240+
critic_overflow = self.critic_model.optimizer.overflow
241+
242+
return actor_overflow, critic_overflow
243+
238244
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
239245
## policy gradient loss
240246
log_ratio = (logprobs - old_logprobs) * mask

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/sweep/run_single.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ ENABLE_HYBRID_ENGINE=$5
1111
OFFLOAD=$6
1212
LORA=$7
1313
OUTPUT=$8
14+
TEST=$9
15+
TEST_STOP_STEP=${10}
1416

1517
if [ "$ACTOR_ZERO_STAGE" == "" ]; then
1618
ACTOR_ZERO_STAGE=2
@@ -40,6 +42,14 @@ else
4042
ACTOR_LORA_MODULE_NAME=""
4143
fi
4244

45+
if [ "$TEST" == true ]; then
46+
TEST="--enable_test_mode"
47+
TEST_STOP_STEP="--test_stop_step ${TEST_STOP_STEP}"
48+
else
49+
TEST=""
50+
TEST_STOP_STEP=""
51+
fi
52+
4353
mkdir -p $OUTPUT
4454

4555
Num_Padding_at_Beginning=1 # this is model related
@@ -74,7 +84,8 @@ cmd="deepspeed --num_nodes=1 main.py \
7484
--critic_zero_stage ${CRITIC_ZERO_STAGE} \
7585
--output_dir $OUTPUT \
7686
$ENABLE_HYBRID_ENGINE $OFFLOAD $UNPIN_ACTOR_PARAMETERS \
77-
$ACTOR_LORA_DIM $ACTOR_LORA_MODULE_NAME"
87+
$ACTOR_LORA_DIM $ACTOR_LORA_MODULE_NAME\
88+
$TEST $TEST_STOP_STEP"
7889

7990
echo "----------------------------- DS COMMAND -----------------------------"
8091
echo $cmd
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import pytest
7+
import os
8+
import subprocess
9+
10+
11+
def file_exists(directory_path, file_name):
12+
return os.path.isfile(os.path.join(directory_path, file_name))
13+
14+
15+
@pytest.fixture(params=["2", "3"])
16+
def zero_stage(request):
17+
return str(request.param)
18+
19+
20+
@pytest.fixture(params=["true", "false"])
21+
def hybrid_engine(request):
22+
return str(request.param)
23+
24+
25+
@pytest.fixture(params=["true", "false"])
26+
def offload(request):
27+
return str(request.param)
28+
29+
30+
@pytest.fixture(params=["true", "false"])
31+
def lora(request):
32+
return str(request.param)
33+
34+
35+
def test_ds_chat(zero_stage, hybrid_engine, offload, lora):
36+
critic_ckpt_dir = os.getenv("CRITIC_CKPT_DIR")
37+
assert critic_ckpt_dir, "Please set CRITIC_CKPT_DIR in your environment"
38+
39+
actor_model = "facebook/opt-125m"
40+
critic_model = critic_ckpt_dir
41+
output_path = "z" + zero_stage + "_he_" + hybrid_engine + "_offload_" + offload + "_lora_" + lora
42+
enable_test_mode = "true"
43+
test_stop_step = "5"
44+
params = [
45+
actor_model,
46+
critic_model,
47+
zero_stage,
48+
zero_stage,
49+
hybrid_engine,
50+
offload,
51+
lora,
52+
output_path,
53+
enable_test_mode,
54+
test_stop_step,
55+
]
56+
57+
# Skip certain combinations
58+
if zero_stage == "2" and hybrid_engine == "true" and offload == "true" and lora == "false":
59+
pytest.skip(
60+
"The combination of [actor_zero_stage==2, critic_zero_stage==2, enable_hybrid_engine=True, offload=True, lora=False] is currently unsupported due to training instability!"
61+
)
62+
63+
if zero_stage == "3" and hybrid_engine == "true" and offload == "true" and lora == "true":
64+
pytest.skip(
65+
"The combination of [actor_zero_stage==3, critic_zero_stage==3, enable_hybrid_engine=True, offload=True, lora=True] is currently unsupported due to training instability!"
66+
)
67+
68+
# cd into execution dir
69+
wd = os.getcwd()
70+
os.chdir("../step3_rlhf_finetuning")
71+
sweep_script = "training_scripts/opt/single_node/sweep/run_single.sh"
72+
73+
# Run bash script
74+
cmd = ["bash", sweep_script] + params
75+
result = subprocess.run(cmd)
76+
77+
# Assertions
78+
try:
79+
result.check_returncode()
80+
except subprocess.CalledProcessError as e:
81+
with open(os.path.join(output_path, f"{output_path}.log"), "r") as f:
82+
print(f.read())
83+
raise e
84+
85+
assert file_exists(f"{output_path}/actor/", "pytorch_model.bin"
86+
), "Actor model was not saved during step 3 training."
87+
assert file_exists(f"{output_path}/critic/", "pytorch_model.bin"
88+
), "Critic model was not saved during step 3 training."

0 commit comments

Comments
 (0)