Skip to content

Commit f5a7a2f

Browse files
0x404BounharAbdelaziz
authored andcommitted
[algo] feat: add GSPO-token policy loss computation function (volcengine#2775)
### What does this PR do? This PR implements the GSPO-token policy loss calculation proposed by paper https://arxiv.org/pdf/2507.18071 ### Test <img width="1341" height="637" alt="image" src="https://github.com/user-attachments/assets/bc5e2245-b0f5-4a1f-aa7c-4c2b28d95142" /> Compared GRPO and GSPO under the same settings. GRPO uses the following script: ``` sh python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ data.train_batch_size=512 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=128 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.policy_loss.loss_mode="vanilla" \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=10 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_gspo_cmp' \ trainer.experiment_name='qwen2.5-3B-GRPO' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ trainer.total_epochs=15 $@ ``` GSPO uses the following script: ```sh python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ data.train_batch_size=512 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=128 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.policy_loss.loss_mode="gspo" \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=10 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_gspo_cmp' \ trainer.experiment_name='qwen2.5-3B-GRPO' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ trainer.total_epochs=15 $@ ``` ### API and Usage Example To use GSPO, users only need to set `actor_rollout_ref.actor.policy_loss.loss_mode` to `gspo`. ```shell python3 -m verl.trainer.main_ppo \ ... \ actor_rollout_ref.actor.policy_loss.loss_mode="gspo" \ ... ``` ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: BounharAbdelaziz <[email protected]>
1 parent 981f258 commit f5a7a2f

File tree

4 files changed

+454
-0
lines changed

4 files changed

+454
-0
lines changed

recipe/gspo/test_gspo_3b_math.sh

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#!/usr/bin/env bash
2+
#SBATCH --job-name=rl-gspo-3B
3+
#SBATCH --partition=main
4+
#SBATCH --nodes=1 # Number of nodes
5+
#SBATCH --ntasks-per-node=1 # One task per node
6+
#SBATCH --cpus-per-task=128 # cpu-cores per task
7+
#SBATCH --gres=gpu:8
8+
#SBATCH --mem=0
9+
#SBATCH --exclusive
10+
#SBATCH --time=500:00:00
11+
#SBATCH --output=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.out
12+
#SBATCH --error=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.err
13+
14+
set -xeuo pipefail
15+
16+
# activate the venv
17+
echo "Activating verl environment..."
18+
eval "$(conda shell.bash hook)"
19+
conda deactivate
20+
conda activate verl
21+
22+
# can make training faster, depends on your infrastructure
23+
export NCCL_IBEXT_DISABLE=1
24+
export NCCL_NVLS_ENABLE=1
25+
export NCCL_IB_HCA=mlx5
26+
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1
27+
28+
# Set how many GPUs we actually have on this node.
29+
export GPUS_PER_NODE=8
30+
31+
NNODES=${SLURM_JOB_NUM_NODES}
32+
export NNODES
33+
34+
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
35+
export RAY_LOGGING_LEVEL=DEBUG
36+
export HYDRA_FULL_ERROR=1
37+
export WANDB_API_KEY=... # your wandb API key
38+
39+
echo "Using $NNODES nodes for training..."
40+
41+
# ------------------------------------- Setup xp params ---------------------------------------
42+
project_name='RL-GSPO'
43+
44+
adv_estimator=grpo
45+
loss_mode=gspo
46+
loss_agg_mode="seq-mean-token-mean"
47+
MODEL_PATH=Qwen/Qwen2.5-3B-Instruct
48+
offload=false # it's a small model, offloading will just slow-down training
49+
rollout_engine=vllm
50+
rollout_mode=sync # can be async to speedup large scale xps
51+
gpu_memory_utilization=0.8
52+
reward_manager=dapo
53+
adv_estimator=grpo
54+
shuffle_dataset=true
55+
first_time_dataset_prep=true # prepare dataset
56+
57+
test_freq=10
58+
save_freq=10
59+
total_epochs=10
60+
total_training_steps=500
61+
val_before_train=false
62+
63+
use_kl_in_reward=false
64+
kl_coef=0.0
65+
use_kl_loss=false
66+
kl_loss_coef=0.0
67+
68+
clip_ratio_low=0.0003 # as recommended by the paper, see Sec. 5.1
69+
clip_ratio_high=0.0004 # as recommended by the paper, see Sec. 5.1
70+
train_batch_size=512
71+
ppo_mini_batch_size=128 # maintain 4 mini-batches as recommended by the paper, see Sec. 5.1
72+
ppo_micro_batch_size_per_gpu=8 # setup depending on your GPU memory
73+
n_resp_per_prompt=16
74+
75+
max_prompt_length=$((1024 * 2))
76+
max_response_length=$((1024 * 8))
77+
# dapo reward manager params
78+
enable_overlong_buffer=false # true
79+
overlong_buffer_len=$((1024 * 4))
80+
overlong_penalty_factor=1.0
81+
82+
# Paths and namings
83+
SFT_MODEL=$(basename $MODEL_PATH)
84+
exp_name="${loss_mode}-epslow-${clip_ratio_low}-epshigh-${clip_ratio_high}-${SFT_MODEL}-RL"
85+
CKPTS_DIR=/rl/checkpoints/experimental/4b/${loss_mode}/${exp_name}
86+
87+
# Sampling params at rollouts
88+
temperature=1.0
89+
top_p=1.0
90+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
91+
val_top_p=0.7
92+
93+
# Performance Related Parameter
94+
sp_size=1
95+
use_dynamic_bsz=true
96+
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
97+
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
98+
offload=true
99+
gen_tp=1
100+
entropy_checkpointing=true # This enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training.
101+
102+
# ------------------------------------- train/val data preparation ---------------------------------------
103+
if [ "$first_time_dataset_prep" = true ]; then
104+
echo "Preprocessing GSM8K dataset..."
105+
python examples/data_preprocess/gsm8k.py --local_dir /data/gsm8k/
106+
fi
107+
108+
gsm8k_train_path=/data/gsm8k/train.parquet
109+
gsm8k_test_path=/data/gsm8k/test.parquet
110+
111+
# set the paths
112+
train_files="['$gsm8k_train_path']"
113+
test_files="['$gsm8k_test_path']"
114+
115+
python3 -m verl.trainer.main_ppo \
116+
algorithm.adv_estimator=${adv_estimator} \
117+
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
118+
data.train_files="${train_files}" \
119+
data.val_files="${test_files}" \
120+
data.shuffle=$shuffle_dataset \
121+
data.prompt_key=prompt \
122+
data.truncation='error' \
123+
data.filter_overlong_prompts=true \
124+
data.train_batch_size=${train_batch_size} \
125+
data.max_prompt_length=${max_prompt_length} \
126+
data.max_response_length=${max_response_length} \
127+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
128+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
129+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
130+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
131+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
132+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
133+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
134+
actor_rollout_ref.model.use_remove_padding=true \
135+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
136+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
137+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
138+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
139+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
140+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
141+
actor_rollout_ref.rollout.name=vllm \
142+
actor_rollout_ref.rollout.name=${rollout_engine} \
143+
actor_rollout_ref.rollout.mode=${rollout_mode} \
144+
actor_rollout_ref.model.path="${MODEL_PATH}" \
145+
actor_rollout_ref.model.enable_gradient_checkpointing=true \
146+
actor_rollout_ref.actor.optim.lr=1e-6 \
147+
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.05 \
148+
actor_rollout_ref.actor.optim.weight_decay=0.1 \
149+
actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
150+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \
151+
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
152+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
153+
actor_rollout_ref.actor.entropy_coeff=0 \
154+
actor_rollout_ref.actor.grad_clip=1.0 \
155+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
156+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
157+
actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \
158+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
159+
actor_rollout_ref.rollout.enable_chunked_prefill=true \
160+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
161+
actor_rollout_ref.rollout.temperature=${temperature} \
162+
actor_rollout_ref.rollout.top_p=${top_p} \
163+
actor_rollout_ref.rollout.top_k=${top_k} \
164+
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
165+
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
166+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
167+
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
168+
actor_rollout_ref.rollout.val_kwargs.n=1 \
169+
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
170+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
171+
actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \
172+
reward_model.reward_manager=${reward_manager} \
173+
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
174+
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
175+
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
176+
+reward_model.reward_kwargs.overlong_buffer_cfg.log=false \
177+
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
178+
trainer.logger='["console","wandb"]' \
179+
trainer.project_name="${project_name}" \
180+
trainer.experiment_name="${exp_name}" \
181+
trainer.n_gpus_per_node="${GPUS_PER_NODE}" \
182+
trainer.nnodes="${NNODES}" \
183+
trainer.val_before_train=${val_before_train} \
184+
trainer.test_freq=${test_freq} \
185+
trainer.save_freq=${save_freq} \
186+
trainer.total_epochs=${total_epochs} \
187+
trainer.total_training_steps=${total_training_steps} \
188+
trainer.default_local_dir="${CKPTS_DIR}" \
189+
trainer.resume_mode=auto \
190+
trainer.log_val_generations=2 \
191+
$@
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
#!/usr/bin/env bash
2+
#SBATCH --job-name=rl-gspo-3B
3+
#SBATCH --partition=main
4+
#SBATCH --nodes=1 # Number of nodes
5+
#SBATCH --ntasks-per-node=1 # One task per node
6+
#SBATCH --cpus-per-task=128 # cpu-cores per task
7+
#SBATCH --gres=gpu:8
8+
#SBATCH --mem=0
9+
#SBATCH --exclusive
10+
#SBATCH --time=500:00:00
11+
#SBATCH --output=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.out
12+
#SBATCH --error=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.err
13+
14+
set -xeuo pipefail
15+
16+
# activate the venv
17+
echo "Activating verl environment..."
18+
eval "$(conda shell.bash hook)"
19+
conda deactivate
20+
conda activate verl
21+
22+
# can make training faster, depends on your infrastructure
23+
export NCCL_IBEXT_DISABLE=1
24+
export NCCL_NVLS_ENABLE=1
25+
export NCCL_IB_HCA=mlx5
26+
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1
27+
28+
# Set how many GPUs we actually have on this node.
29+
export GPUS_PER_NODE=8
30+
31+
NNODES=${SLURM_JOB_NUM_NODES}
32+
export NNODES
33+
34+
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
35+
export RAY_memory_monitor_refresh_ms=0
36+
export RAY_LOGGING_LEVEL=DEBUG
37+
export HYDRA_FULL_ERROR=1
38+
export WANDB_API_KEY=... # your wandb API key
39+
40+
# Let Ray know how many nodes to expect
41+
export RAY_NUM_NODES=$NNODES
42+
43+
echo "Using $NNODES nodes for training..."
44+
45+
# ------------------------------------- Setup xp params ---------------------------------------
46+
project_name='RL-GSPO'
47+
48+
adv_estimator=grpo
49+
loss_mode=gspo
50+
loss_agg_mode="seq-mean-token-mean"
51+
MODEL_PATH=Qwen/Qwen2.5-3B-Instruct
52+
offload=false # it's a small model, offloading will just slow-down training
53+
rollout_engine=vllm
54+
rollout_mode=sync # can be async to speedup large scale xps
55+
gpu_memory_utilization=0.8
56+
reward_manager=dapo
57+
adv_estimator=grpo
58+
shuffle_dataset=true
59+
first_time_dataset_prep=true # prepare dataset
60+
61+
test_freq=10
62+
save_freq=10
63+
total_epochs=10
64+
total_training_steps=500
65+
val_before_train=false
66+
67+
use_kl_in_reward=false
68+
kl_coef=0.0
69+
use_kl_loss=false
70+
kl_loss_coef=0.0
71+
72+
clip_ratio_low=0.0003 # as recommended by the paper, see Sec. 5.1
73+
clip_ratio_high=0.0004 # as recommended by the paper, see Sec. 5.1
74+
train_batch_size=512
75+
ppo_mini_batch_size=128 # maintain 4 mini-batches as recommended by the paper, see Sec. 5.1
76+
ppo_micro_batch_size_per_gpu=8 # setup depending on your GPU memory
77+
n_resp_per_prompt=16
78+
79+
max_prompt_length=$((1024 * 2))
80+
max_response_length=$((1024 * 8))
81+
# dapo reward manager params
82+
enable_overlong_buffer=false # true
83+
overlong_buffer_len=$((1024 * 4))
84+
overlong_penalty_factor=1.0
85+
86+
# Paths and namings
87+
SFT_MODEL=$(basename $MODEL_PATH)
88+
exp_name="${loss_mode}-epslow-${clip_ratio_low}-epshigh-${clip_ratio_high}-${SFT_MODEL}-RL"
89+
CKPTS_DIR=/rl/checkpoints/experimental/4b/${loss_mode}/${exp_name}
90+
91+
# Sampling params at rollouts
92+
temperature=1.0
93+
top_p=1.0
94+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
95+
val_top_p=0.7
96+
97+
# Performance Related Parameter
98+
sp_size=1
99+
use_dynamic_bsz=true
100+
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
101+
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
102+
offload=true
103+
gen_tp=1
104+
entropy_checkpointing=true # This enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training.
105+
106+
# ------------------------------------- train/val data preparation ---------------------------------------
107+
if [ "$first_time_dataset_prep" = true ]; then
108+
echo "Preprocessing GSM8K dataset..."
109+
python examples/data_preprocess/gsm8k.py --local_dir /data/gsm8k/
110+
fi
111+
112+
gsm8k_train_path=/data/gsm8k/train.parquet
113+
gsm8k_test_path=/data/gsm8k/test.parquet
114+
115+
# set the paths
116+
train_files="['$gsm8k_train_path']"
117+
test_files="['$gsm8k_test_path']"
118+
119+
python3 -m verl.trainer.main_ppo \
120+
algorithm.adv_estimator=${adv_estimator} \
121+
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
122+
data.train_files="${train_files}" \
123+
data.val_files="${test_files}" \
124+
data.shuffle=$shuffle_dataset \
125+
data.prompt_key=prompt \
126+
data.truncation='error' \
127+
data.filter_overlong_prompts=true \
128+
data.train_batch_size=${train_batch_size} \
129+
data.max_prompt_length=${max_prompt_length} \
130+
data.max_response_length=${max_response_length} \
131+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
132+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
133+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
134+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
135+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
136+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
137+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
138+
actor_rollout_ref.model.use_remove_padding=true \
139+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
140+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
141+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
142+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
143+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
144+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
145+
actor_rollout_ref.rollout.name=vllm \
146+
actor_rollout_ref.rollout.name=${rollout_engine} \
147+
actor_rollout_ref.rollout.mode=${rollout_mode} \
148+
actor_rollout_ref.model.path="${MODEL_PATH}" \
149+
actor_rollout_ref.model.enable_gradient_checkpointing=true \
150+
actor_rollout_ref.actor.optim.lr=1e-6 \
151+
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.05 \
152+
actor_rollout_ref.actor.optim.weight_decay=0.1 \
153+
actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
154+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \
155+
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
156+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
157+
actor_rollout_ref.actor.entropy_coeff=0 \
158+
actor_rollout_ref.actor.grad_clip=1.0 \
159+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
160+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
161+
actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \
162+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
163+
actor_rollout_ref.rollout.enable_chunked_prefill=true \
164+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
165+
actor_rollout_ref.rollout.temperature=${temperature} \
166+
actor_rollout_ref.rollout.top_p=${top_p} \
167+
actor_rollout_ref.rollout.top_k=${top_k} \
168+
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
169+
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
170+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
171+
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
172+
actor_rollout_ref.rollout.val_kwargs.n=1 \
173+
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
174+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
175+
actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \
176+
reward_model.reward_manager=${reward_manager} \
177+
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
178+
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
179+
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
180+
+reward_model.reward_kwargs.overlong_buffer_cfg.log=false \
181+
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
182+
trainer.logger='["console","wandb"]' \
183+
trainer.project_name="${project_name}" \
184+
trainer.experiment_name="${exp_name}" \
185+
trainer.n_gpus_per_node="${GPUS_PER_NODE}" \
186+
trainer.nnodes="${NNODES}" \
187+
trainer.val_before_train=${val_before_train} \
188+
trainer.test_freq=${test_freq} \
189+
trainer.save_freq=${save_freq} \
190+
trainer.total_epochs=${total_epochs} \
191+
trainer.total_training_steps=${total_training_steps} \
192+
trainer.default_local_dir="${CKPTS_DIR}" \
193+
trainer.resume_mode=auto \
194+
trainer.log_val_generations=2 \
195+
$@

0 commit comments

Comments
 (0)