Skip to content

Commit 9186347

Browse files
authored
Rename Step 3 DS-Chat args for clarification (#698)
This PR renames some DS-Chat step 3 arguments for clarification. Args renamed: 1. --per_device_train_batch_size --> rename to --per_device_generation_batch_size 2. --per_device_mini_train_batch_size --> rename to --per_device_training_batch_size 3. --generation_batch_numbers --> rename to generation_batches
1 parent 927690e commit 9186347

File tree

16 files changed

+61
-62
lines changed

16 files changed

+61
-62
lines changed

applications/DeepSpeed-Chat/training/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ The RLHF finetuning is the most complicated step among the three step training.
4141
* ```ema checkpoint``` We observe ema checkpoint can generally bring bettr model generation quality as stated in InstructGPT.
4242
* ```PPO related hyperparameters``` PPO training has a lot of hyperparameters, see [here](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L61-L66). For now, we hard-coded them for users but you may want to adjust them for you own usage.
4343
* ```mix unsupervised training``` InstructGPT suggests to mix PPO and unsupervised training to prevent the lost of model's benchmark quality. However, when we directly apply the hyperparameter from Instruct, the model cannot converge. Therefore, we stop exploring this. However, users are encourage to test it and tune the hyperparameter for their own usage.
44-
* ```diverging issue``` We have found that it is very unstable to use different generation training batch sizes (`--per_device_train_batch_size`) and PPO training batch sizes (`--per_device_mini_batch_size`), more than one PPO training epoch (`--ppo_epochs`), or more than one generation batch size (`--generation_batch_numbers`). These all point to the same problem: we are not able to update the actor model multiple times after generating experimental data. Therefore, in all of our successful runs, we have set `per_device_train_batch_size=per_device_mini_batch_size` and `ppo_epochs=generation_batch_numbers=1`. This is unexpected for a standard RL training pipeline, and we have tried different methods to overcome this, but all have failed. One of the most likely reasons for this instability is that we found the `log_probs` and `old_log_probs` used in the `actor_loss_fn` function can quickly diverge even within two consecutive iterations, which causes the corresponding `ratio` to be huge. Setting a strict upper bound can alleviate this problem, but it cannot fully resolve the convergence issue.
44+
* ```diverging issue``` We have found that it is very unstable to use different generation training batch sizes (`--per_device_generation_batch_size`) and PPO training batch sizes (`--per_device_training_batch_size`), more than one PPO training epoch (`--ppo_epochs`), or more than one generation batch (`--generation_batches 1`). These all point to the same problem: we are not able to update the actor model multiple times after generating experimental data. Therefore, in all of our successful runs, we have set `per_device_generation_batch_size=per_device_training_batch_size` and `ppo_epochs=generation_batches=1`. This is unexpected for a standard RL training pipeline, and we have tried different methods to overcome this, but all have failed. One of the most likely reasons for this instability is that we found the `log_probs` and `old_log_probs` used in the `actor_loss_fn` function can quickly diverge even within two consecutive iterations, which causes the corresponding `ratio` to be huge. Setting a strict upper bound can alleviate this problem, but it cannot fully resolve the convergence issue.
4545

4646
### About our testing
4747
We did most of our accuracy/quality testing on OPT-1.3B (SFT and Actor model) and OPT-350m (RW and Critic model). Particularly, we used the 16 V100-32G (DGX-2 node) gpus to run our experiments.

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ an apple-to-apple comparison is critical for the machine learning community, par
88

99
We randomly select 40% training data from the six open-sourced training datasets, i.e., ``"Dahoas/rm-static", "Dahoas/full-hh-rlhf", "Dahoas/synthetic-instruct-gptj-pairwise", "yitingxie/rlhf-reward-datasets", "openai/webgpt_comparisons"``, and ``"stanfordnlp/SHP"``. The total training samples we have is 264,292. We fix the query (prompt) sequence length as **256** and generate fixed-length answer with **256** tokens. As such, the total training tokens per epoch is 135,317,504. During benchmark testing, we set the training epoch number as 1.
1010

11-
As mentioned in the instability of [RLHF Training Tutorial](./README.md#🙋-instablity-of-rlhf-training-and-others), we found that it is not stable to update the actor model multiple times using the generated data. Therefore, we set ``per_device_train_batch_size=per_device_mini_batch_size`` and ``ppo_epochs=generation_batch_numbers=1`` for all of our benchmark results. During testing, we also set an upper bound for the maximum global training tokens at 524,288 (batch size of 1024 with a sequence length of 512). This is the largest batch size we found during our exploration that provides a stable RLHF training experience. Users and practitioners may find better training hyperparameters to further increase this. Additionally, during testing, whenever the global training token batch size does not exceed our limit of 524,288, we always use the largest training batch size that does not result in an out-of-memory error to benchmark the time.
11+
As mentioned in the instability of [RLHF Training Tutorial](./README.md#🙋-instablity-of-rlhf-training-and-others), we found that it is not stable to update the actor model multiple times using the generated data. Therefore, we set ``per_device_generation_batch_size=per_device_training_batch_size`` and ``ppo_epochs=generation_batches=1`` for all of our benchmark results. During testing, we also set an upper bound for the maximum global training tokens at 524,288 (batch size of 1024 with a sequence length of 512). This is the largest batch size we found during our exploration that provides a stable RLHF training experience. Users and practitioners may find better training hyperparameters to further increase this. Additionally, during testing, whenever the global training token batch size does not exceed our limit of 524,288, we always use the largest training batch size that does not result in an out-of-memory error to benchmark the time.
1212

1313
We hope this clearly explains our benchmark settings, and please do not hesitate to contact us if you need more information. If you'd like to reproduce our performance results or make a comparison with DeepSpeed-RLHF, we would like to encourage you to leverage the same / similar settings such that the performance results are more comparable.

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ We provide most of unique arguments used in DeepSpeed RLHF other than the previo
4141
| ------------------------------------------------------------------ | -------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
4242
| --unsupervised_dataset_name and --unsupervised_dataset_config_name | Huggingface datasets standard setting to collect the data, e.g., using Wikitext-103 | When both are provided, during each PPO training, we will also add the pretraining objective. Based on InstructGPT, this will enhance the model's benchmark performance. |
4343
| --unsup_coef | Used to balance RLHF/PPO loss and the unsupervised loss | |
44-
| --per_device_train_batch_size and --per_device_mini_batch_size | The first one is the generation batch size and the second one is the PPO training batch size | Usually, the first one needs to be divisible by the second one. |
45-
| --generation_batch_numbers | Generated N batches then do PPO training | This setting is common in RL, i.e., we generate an experiment table then do RL training |
44+
| --per_device_generation_batch_size and --per_device_training_batch_size | The first one is the generation batch size and the second one is the PPO training batch size | Usually, the first one needs to be divisible by the second one. |
45+
| --generation_batches | Generated N batches then do PPO training | This setting is common in RL, i.e., we generate an experiment table then do RL training |
4646
| --ppo_epochs | For the generated experiments, how many PPO epochs we want to perform | |
4747
| --max_prompt_seq_len and --max_answer_seq_len | The length of the query and the length of the answer | |
4848
| --enable_hybrid_engine | Enable it to use DeepSpeed Hybrid Engine | This will significantly speedup your training |
@@ -69,7 +69,7 @@ Users can either use the `prompt_eval.py` script from Step 1 of the SFT process
6969

7070
RLHF is a relatively new field, and as expected, we have encountered some training instabilities during our exploration. We are sharing our findings here and actively working on solutions. We also welcome solutions from the community.
7171

72-
We have found that it is very unstable to use different generation training batch sizes (`--per_device_train_batch_size`) and PPO training batch sizes (`--per_device_mini_batch_size`), more than one PPO training epoch (`--ppo_epochs`), or more than one generation batch size (`--generation_batch_numbers`). These all point to the same problem: we are not able to update the actor model multiple times after generating experimental data. Therefore, in all of our successful runs, we have set `per_device_train_batch_size=per_device_mini_batch_size` and `ppo_epochs=generation_batch_numbers=1`. This is unexpected for a standard RL training pipeline, and we have tried different methods to overcome this, but all have failed. One of the most likely reasons for this instability is that we found the `log_probs` and `old_log_probs` used in the `actor_loss_fn` function can quickly diverge even within two consecutive iterations, which causes the corresponding `ratio` to be huge. Setting a strict upper bound can alleviate this problem, but it cannot fully resolve the convergence issue.
72+
We have found that it is very unstable to use different generation training batch sizes (`--per_device_generation_batch_size`) and PPO training batch sizes (`--per_device_training_batch_size`), more than one PPO training epoch (`--ppo_epochs`), or more than one generation batch (`--generation_batches 1`). These all point to the same problem: we are not able to update the actor model multiple times after generating experimental data. Therefore, in all of our successful runs, we have set `per_device_generation_batch_size=per_device_training_batch_size` and `ppo_epochs=generation_batches=1`. This is unexpected for a standard RL training pipeline, and we have tried different methods to overcome this, but all have failed. One of the most likely reasons for this instability is that we found the `log_probs` and `old_log_probs` used in the `actor_loss_fn` function can quickly diverge even within two consecutive iterations, which causes the corresponding `ratio` to be huge. Setting a strict upper bound can alleviate this problem, but it cannot fully resolve the convergence issue.
7373

7474
We have also found that adding unsupervised training is not easy. We tried using the coefficient (`--unsup_coef=27.8`) provided by InstructGPT, but it caused instability in the RLHF training. According to InstructGPT, unsupervised training mainly affects the model quality on standard benchmarks instead of the RLHF performance. We did not put much effort into tuning this parameter.
7575

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,20 @@ def parse_args():
109109
"OPT model has a fixed number (1) of padding tokens at the beginning of the input. We did not see this in other models but keep it as an option for now."
110110
)
111111
parser.add_argument(
112-
"--per_device_train_batch_size",
112+
"--per_device_generation_batch_size",
113113
type=int,
114114
default=16,
115115
help=
116116
"Batch size (per device) for the training dataloader and generation purpose."
117117
)
118118
parser.add_argument(
119-
"--per_device_mini_train_batch_size",
119+
"--per_device_training_batch_size",
120120
type=int,
121121
default=16,
122122
help=
123123
"Mini Batch size (per device) for the training dataloader and training purpose."
124124
)
125-
parser.add_argument("--generation_batch_numbers",
125+
parser.add_argument("--generation_batches",
126126
type=int,
127127
default=1,
128128
help="Generate x batches to go to training mode.")
@@ -387,19 +387,19 @@ def create_datasets(args, tokenizer, train_phase=3):
387387
prompt_train_dataset,
388388
collate_fn=data_collator,
389389
sampler=prompt_train_sampler,
390-
batch_size=args.per_device_train_batch_size)
390+
batch_size=args.per_device_generation_batch_size)
391391
if unsupervised_training_enabled:
392392
unsupervised_train_dataloader = DataLoader(
393393
unsupervised_train_dataset,
394394
collate_fn=default_data_collator,
395395
sampler=unsupervised_train_sampler,
396-
batch_size=args.per_device_train_batch_size)
396+
batch_size=args.per_device_generation_batch_size)
397397
else:
398398
unsupervised_train_dataloader = [None] * len(
399399
prompt_train_dataloader) # basically a dummy dataloader
400400

401401
num_update_steps_per_epoch = min(len(prompt_train_dataloader), len(unsupervised_train_dataloader)) * \
402-
(args.per_device_train_batch_size / args.per_device_mini_train_batch_size) * \
402+
(args.per_device_generation_batch_size / args.per_device_training_batch_size) * \
403403
args.ppo_epochs / args.gradient_accumulation_steps
404404
num_total_iters = int(args.num_train_epochs * num_update_steps_per_epoch)
405405

@@ -450,10 +450,10 @@ def main():
450450
trainer = ppo_trainer(rlhf_engine, args)
451451

452452
# first number is how many experience-batch to generate, second number is the training batch size, which is the micro-batch size used
453-
exp_mini_dataset = MiniDataset(args.generation_batch_numbers,
454-
args.per_device_mini_train_batch_size)
455-
unsup_mini_dataset = MiniDataset(args.generation_batch_numbers,
456-
args.per_device_mini_train_batch_size)
453+
exp_mini_dataset = MiniDataset(args.generation_batches,
454+
args.per_device_training_batch_size)
455+
unsup_mini_dataset = MiniDataset(args.generation_batches,
456+
args.per_device_training_batch_size)
457457

458458
# Train!
459459
print_rank_0("***** Running training *****", args.global_rank)
@@ -472,7 +472,7 @@ def main():
472472
unsup_dataset = unsup_mini_dataset.add(batch_unsupervised)
473473
else:
474474
unsup_dataset = unsup_mini_dataset.add(
475-
[[None] * args.per_device_train_batch_size])
475+
[[None] * args.per_device_generation_batch_size])
476476
# prompts = batch_prompt['prompt']
477477
# length = prompts.size(-1)
478478
# if length > args.max_prompt_seq_len:

0 commit comments

Comments
 (0)