Skip to content

Commit 0b2a1e2

Browse files
committed
Merge branch 'dev'
2 parents 83244e9 + 0fbe133 commit 0b2a1e2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1972
-1972
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ci:
88

99
repos:
1010
- repo: https://github.com/pre-commit/pre-commit-hooks
11-
rev: v4.5.0
11+
rev: v5.0.0
1212
hooks:
1313
- id: check-yaml
1414
- id: check-case-conflict
@@ -18,20 +18,20 @@ repos:
1818
- id: requirements-txt-fixer
1919

2020
- repo: https://github.com/PyCQA/autoflake
21-
rev: v2.0.2
21+
rev: v2.3.1
2222
hooks:
2323
- id: autoflake
2424
args: [--remove-all-unused-imports, --in-place]
2525

2626
- repo: https://github.com/PyCQA/isort
27-
rev: 5.13.2
27+
rev: 6.0.1
2828
hooks:
2929
- id: isort
3030
name: Format imports
3131
exclude: docs/
3232

3333
- repo: https://github.com/psf/black
34-
rev: 24.3.0
34+
rev: 25.1.0
3535
hooks:
3636
- id: black
3737
name: Format code

dockerfile/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ RUN DEBIAN_FRONTEND=noninteractive apt install -y tzdata
1515

1616
RUN apt-get -y install build-essential git python3-dev python3-pip libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev gdb
1717
RUN pip uninstall xgboost transformer_engine flash_attn pynvml opencv-python-headless -y
18-
RUN pip install vllm==0.7.2
18+
RUN pip install vllm==0.8.3
1919

2020
COPY docker-entrypoint.sh .
2121
RUN chmod a+x docker-entrypoint.sh

examples/scripts/experience_filter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import torch
2+
3+
def experience_filter(experience_maker,experiences):
4+
return experiences

examples/scripts/train_grpo_ray_hybrid_engine.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ ray job submit --address="http://127.0.0.1:8265" \
2626
--micro_train_batch_size 4 \
2727
--train_batch_size 128 \
2828
--micro_rollout_batch_size 8 \
29-
--rollout_batch_size 1024 \
30-
--n_samples_per_prompt 1 \
29+
--rollout_batch_size 128 \
30+
--n_samples_per_prompt 8 \
3131
--max_epochs 1 \
3232
--prompt_max_len 1024 \
3333
--max_samples 100000 \
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
set -x
2+
3+
read -r -d '' training_commands <<EOF
4+
openrlhf.cli.train_ppo \
5+
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
6+
--reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \
7+
--custom_experience_filter examples/scripts/experience_filter.py \
8+
--save_path ./checkpoint/llama-3-8b-rlhf \
9+
--save_steps -1 \
10+
--logging_steps 1 \
11+
--eval_steps -1 \
12+
--micro_train_batch_size 2 \
13+
--train_batch_size 128 \
14+
--micro_rollout_batch_size 4 \
15+
--rollout_batch_size 1024 \
16+
--max_steps 1000 \
17+
--store_extra_buffers \
18+
--max_epochs 1 \
19+
--prompt_max_len 1024 \
20+
--generate_max_len 1024 \
21+
--zero_stage 2 \
22+
--bf16 \
23+
--actor_learning_rate 5e-7 \
24+
--critic_learning_rate 9e-6 \
25+
--init_kl_coef 0.01 \
26+
--prompt_data OpenRLHF/prompt-collection-v0.1 \
27+
--input_key context_messages \
28+
--apply_chat_template \
29+
--max_samples 100000 \
30+
--normalize_reward \
31+
--adam_offload \
32+
--flash_attn \
33+
--load_checkpoint \
34+
--gradient_checkpointing
35+
EOF
36+
37+
# --packing_samples
38+
# --use_wandb [WANDB_TOKENS] or True (use wandb login command)
39+
# --remote_rm_url http://localhost:5000/get_reward
40+
41+
if [[ ${1} != "slurm" ]]; then
42+
deepspeed --module $training_commands
43+
fi
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
set -x
2+
3+
ray job submit --address="http://127.0.0.1:8265" \
4+
--runtime-env-json='{"working_dir": "/openrlhf"}' \
5+
-- python3 -m openrlhf.cli.train_ppo_ray \
6+
--ref_num_nodes 1 \
7+
--ref_num_gpus_per_node 8 \
8+
--reward_num_nodes 1 \
9+
--reward_num_gpus_per_node 8 \
10+
--actor_num_nodes 1 \
11+
--actor_num_gpus_per_node 8 \
12+
--vllm_num_engines 4 \
13+
--vllm_tensor_parallel_size 2 \
14+
--colocate_all_models \
15+
--vllm_gpu_memory_utilization 0.6 \
16+
--advantage_estimator reinforce_baseline \
17+
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
18+
--reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \
19+
--save_path /openrlhf/examples/test_scripts/final/llama3-8b-rlhf \
20+
--ckpt_path /openrlhf/examples/test_scripts/ckpt/llama3-8b-rlhf \
21+
--save_hf_ckpt \
22+
--micro_train_batch_size 4 \
23+
--train_batch_size 128 \
24+
--micro_rollout_batch_size 8 \
25+
--rollout_batch_size 128 \
26+
--n_samples_per_prompt 8 \
27+
--init_kl_coef 1e-3 \
28+
--gamma 1.0 \
29+
--use_kl_loss \
30+
--kl_estimator k2 \
31+
--max_epochs 1 \
32+
--prompt_max_len 1024 \
33+
--max_samples 100000 \
34+
--generate_max_len 1024 \
35+
--zero_stage 3 \
36+
--bf16 \
37+
--actor_learning_rate 5e-7 \
38+
--critic_learning_rate 9e-6 \
39+
--prompt_data OpenRLHF/prompt-collection-v0.1 \
40+
--input_key context_messages \
41+
--apply_chat_template \
42+
--normalize_reward \
43+
--gradient_checkpointing \
44+
--packing_samples \
45+
--vllm_sync_backend nccl \
46+
--enforce_eager \
47+
--vllm_enable_sleep \
48+
--deepspeed_enable_sleep
49+
50+
# You could also try
51+
# --use_kl_loss \
52+
# --kl_estimator k3 | k2 \
53+
54+
# also supports --advantage_estimator rloo | reinforce_baseline

openrlhf/cli/batch_inference.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ class Empty:
5454
args.dataset_probs,
5555
dummy_strategy,
5656
args.seed,
57-
return_eval=False,
5857
max_count=args.max_samples,
59-
train_split=args.dataset_split,
6058
)
6159
if args.iter is None:
6260
prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data))))
@@ -126,9 +124,7 @@ def tokenize_fn(texts):
126124
args.dataset_probs,
127125
strategy,
128126
args.seed,
129-
return_eval=False,
130127
max_count=args.max_samples,
131-
train_split=args.dataset_split,
132128
)
133129
if args.iter is None:
134130
prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data))))
@@ -229,9 +225,7 @@ def batch_rm_inference(args):
229225
args.dataset_probs,
230226
strategy,
231227
args.seed,
232-
return_eval=False,
233228
max_count=args.max_samples,
234-
train_split=args.dataset_split,
235229
)
236230
dataset = dataset.select(range(min(args.max_samples, len(dataset))))
237231
dataset = SFTDataset(
@@ -316,8 +310,7 @@ def batch_rm_inference(args):
316310

317311
# Custom dataset
318312
parser.add_argument("--dataset", type=str, default=None)
319-
parser.add_argument("--dataset_probs", type=str, default="1.0")
320-
parser.add_argument("--dataset_split", type=str, default="train")
313+
parser.add_argument("--dataset_probs", type=str, default=None)
321314
parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key")
322315
parser.add_argument("--output_key", type=str, default="output", help="JSON dataset key")
323316
parser.add_argument(

openrlhf/cli/lora_combiner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import torch
44
from peft import PeftModel
5-
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
6-
5+
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
6+
from openrlhf.models.lmm_kits.utils import get_generation_cls
77

88
def apply_lora(model_name_or_path, lora_path, output_path, is_rm, bf16):
99
print(f"Loading the base model from {model_name_or_path}")
10-
model_cls = AutoModelForCausalLM if not is_rm else AutoModelForSequenceClassification
10+
config = AutoConfig.from_pretrained(model_name_or_path)
11+
model_cls = get_generation_cls(config) if not is_rm else AutoModelForSequenceClassification
1112
base = model_cls.from_pretrained(
1213
model_name_or_path, torch_dtype=torch.bfloat16 if bf16 else "auto", low_cpu_mem_usage=True
1314
)

openrlhf/cli/train_dpo.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,15 @@ def train(args):
6060
optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2)
6161

6262
# prepare for data and dataset
63-
train_data, eval_data = blending_datasets(
63+
train_data = blending_datasets(
6464
args.dataset,
6565
args.dataset_probs,
6666
strategy,
6767
args.seed,
6868
max_count=args.max_samples,
69-
stopping_strategy="all_exhausted",
70-
train_split=args.train_split,
71-
eval_split=args.eval_split,
7269
)
70+
7371
train_data = train_data.select(range(min(args.max_samples, len(train_data))))
74-
eval_data = eval_data.select(range(min(args.max_samples, len(eval_data))))
7572
train_dataset = RewardDataset(
7673
train_data,
7774
tokenizer,
@@ -81,32 +78,40 @@ def train(args):
8178
is_dpo=True,
8279
multiple_of=args.ring_attn_size,
8380
)
84-
eval_dataset = RewardDataset(
85-
eval_data,
86-
tokenizer,
87-
args.max_len,
88-
strategy,
89-
input_template=args.input_template,
90-
is_dpo=True,
91-
multiple_of=args.ring_attn_size,
92-
)
9381

9482
# prepare dataloader
9583
train_dataloader = strategy.setup_dataloader(
9684
train_dataset,
9785
args.micro_train_batch_size,
9886
True,
9987
True,
100-
train_dataset.packing_collate_fn if args.packing_samples else train_dataset.collate_fn,
88+
train_dataset.collate_fn,
10189
)
10290

103-
eval_dataloader = strategy.setup_dataloader(
104-
eval_dataset,
105-
args.micro_train_batch_size,
106-
True,
107-
False,
108-
eval_dataset.packing_collate_fn if args.packing_samples else eval_dataset.collate_fn,
109-
)
91+
eval_dataset = None
92+
eval_dataloader = None
93+
if getattr(args, "eval_dataset", None):
94+
eval_data = blending_datasets(
95+
args.eval_dataset,
96+
None, # No probability sampling for eval datasets
97+
strategy,
98+
)
99+
eval_dataset = RewardDataset(
100+
eval_data,
101+
tokenizer,
102+
args.max_len,
103+
strategy,
104+
input_template=args.input_template,
105+
is_dpo=True,
106+
multiple_of=args.ring_attn_size,
107+
)
108+
eval_dataloader = strategy.setup_dataloader(
109+
eval_dataset,
110+
args.micro_train_batch_size,
111+
True,
112+
False,
113+
eval_dataset.collate_fn,
114+
)
110115

111116
# scheduler
112117
num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size
@@ -168,7 +173,7 @@ def train(args):
168173
parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_dpo")
169174
parser.add_argument("--max_ckpt_num", type=int, default=3)
170175
parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
171-
parser.add_argument("--universal_ckpt", action="store_true", default=False)
176+
parser.add_argument("--use_ds_universal_ckpt", action="store_true", default=False)
172177

173178
# DeepSpeed
174179
parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU")
@@ -236,10 +241,10 @@ def train(args):
236241
# Custom dataset
237242
parser.add_argument("--pretrain", type=str, default=None)
238243
parser.add_argument("--ref_pretrain", type=str, default=None)
239-
parser.add_argument("--dataset", type=str, default=None)
240-
parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets")
241-
parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset")
242-
parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset")
244+
parser.add_argument("--dataset", type=str, default=None, help="Path to the training dataset")
245+
parser.add_argument("--dataset_probs", type=str, default=None, help="Sampling probabilities for training datasets")
246+
parser.add_argument("--eval_dataset", type=str, default=None, help="Path to the evaluation dataset")
247+
parser.add_argument("--max_samples", type=int, default=1000000, help="Maximum number of samples to use")
243248

244249
parser.add_argument("--prompt_key", type=str, default=None)
245250
parser.add_argument("--chosen_key", type=str, default="chosen")
@@ -248,7 +253,6 @@ def train(args):
248253
parser.add_argument(
249254
"--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template"
250255
)
251-
parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples")
252256
parser.add_argument("--max_len", type=int, default=512)
253257

254258
# wandb parameters

0 commit comments

Comments
 (0)