Skip to content

Commit ede8cdd

Browse files
authored
[recipe] fix: DAPO using KL in reward (volcengine#3916)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Currently, the DAPO recipe will fail if we set `use_kl_in_reward` to True. This is because we need to calculate `response_mask`, `old_log_prob`, and/or `ref_log_prob` before we call `apply_kl_penalty` to rewards. Although KL divergence is removed according to DAPO paper, it would still be neccessary to get this fixed, as `apply_kl_penalty` is already there in the code, also it can be good to support this to allow exploration in non-standard ways. ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Signed-off-by: Hollow Man <[email protected]>
1 parent 5c04ddc commit ede8cdd

File tree

3 files changed

+61
-35
lines changed

3 files changed

+61
-35
lines changed

.github/workflows/e2e_dapo.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ on:
5151
# Megatron
5252
- "!verl/workers/**/megatron_*.py"
5353
- "!recipe/**"
54-
- "recipe/dapo"
54+
- "recipe/dapo/**"
5555
pull_request:
5656
branches:
5757
- main
@@ -68,7 +68,7 @@ on:
6868
# Megatron
6969
- "!verl/workers/**/megatron_*.py"
7070
# Home
71-
- "recipe/dapo"
71+
- "recipe/dapo/**"
7272
# Entrypoints
7373
- ".github/workflows/e2e_dapo.yml"
7474
- "examples/data_preprocess/gsm8k.py"

recipe/dapo/dapo_ray_trainer.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,32 @@ class RayDAPOTrainer(RayPPOTrainer):
5151
Note that this trainer runs on the driver process on a single CPU/GPU node.
5252
"""
5353

54+
def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, timing_raw: dict):
55+
batch.batch["response_mask"] = compute_response_mask(batch)
56+
57+
# recompute old_log_probs
58+
with marked_timer("old_log_prob", timing_raw, "blue"):
59+
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
60+
entropys = old_log_prob.batch["entropys"]
61+
response_masks = batch.batch["response_mask"]
62+
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
63+
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
64+
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
65+
metrics.update(old_log_prob_metrics)
66+
old_log_prob.batch.pop("entropys")
67+
batch = batch.union(old_log_prob)
68+
69+
if self.use_reference_policy:
70+
# compute reference log_prob
71+
with marked_timer("ref", timing_raw, "olive"):
72+
if not self.ref_in_actor:
73+
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
74+
else:
75+
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
76+
batch = batch.union(ref_log_prob)
77+
78+
return batch
79+
5480
def fit(self):
5581
"""
5682
The training loop of PPO.
@@ -177,6 +203,11 @@ def fit(self):
177203
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
178204
new_batch = new_batch.union(gen_batch_output)
179205

206+
if self.config.algorithm.use_kl_in_reward:
207+
# We need these metrics for apply_kl_penalty if using kl in reward
208+
new_batch = self.compute_kl_related_metrics(new_batch, metrics, timing_raw)
209+
# otherwise, we will compute those after dynamic sampling
210+
180211
with marked_timer("reward", timing_raw, "yellow"):
181212
# compute scores. Support both model and function-based.
182213
# We first compute the scores using reward model. Then, we call reward_fn to combine
@@ -269,9 +300,6 @@ def fit(self):
269300
batch = batch[:traj_bsz]
270301

271302
# === Updating ===
272-
273-
batch.batch["response_mask"] = compute_response_mask(batch)
274-
275303
# Balance the number of valid tokens across DP ranks.
276304
# NOTE: This usually changes the order of data in the `batch`,
277305
# which won't affect the advantage calculation (since it's based on uid),
@@ -283,23 +311,8 @@ def fit(self):
283311
# compute global_valid tokens
284312
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
285313

286-
# recompute old_log_probs
287-
with marked_timer("old_log_prob", timing_raw, "blue"):
288-
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
289-
entropys = old_log_prob.batch["entropys"]
290-
response_masks = batch.batch["response_mask"]
291-
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
292-
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
293-
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
294-
metrics.update(old_log_prob_metrics)
295-
old_log_prob.batch.pop("entropys")
296-
batch = batch.union(old_log_prob)
297-
298-
if self.use_reference_policy:
299-
# compute reference log_prob
300-
with marked_timer("ref", timing_raw, "olive"):
301-
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
302-
batch = batch.union(ref_log_prob)
314+
if not self.config.algorithm.use_kl_in_reward:
315+
batch = self.compute_kl_related_metrics(batch, metrics, timing_raw)
303316

304317
# compute values
305318
if self.use_critic:

recipe/entropy/entropy_ray_trainer.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,25 @@ class RayEntropyTrainer(RayPPOTrainer):
4848
Note that this trainer runs on the driver process on a single CPU/GPU node.
4949
"""
5050

51+
def compute_kl_related_metrics(self, batch: DataProto, timing_raw: dict):
52+
batch.batch["response_mask"] = compute_response_mask(batch)
53+
54+
# recompute old_log_probs
55+
with simple_timer("old_log_prob", timing_raw):
56+
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
57+
batch = batch.union(old_log_prob)
58+
59+
if self.use_reference_policy:
60+
# compute reference log_prob
61+
with simple_timer("ref", timing_raw):
62+
if not self.ref_in_actor:
63+
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
64+
else:
65+
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
66+
batch = batch.union(ref_log_prob)
67+
68+
return batch
69+
5170
def fit(self):
5271
"""
5372
The training loop of PPO.
@@ -154,6 +173,11 @@ def fit(self):
154173
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
155174
new_batch = new_batch.union(gen_batch_output)
156175

176+
if self.config.algorithm.use_kl_in_reward:
177+
# We need these metrics for apply_kl_penalty if using kl in reward
178+
new_batch = self.compute_kl_related_metrics(new_batch, timing_raw)
179+
# otherwise, we will compute those after dynamic sampling
180+
157181
with simple_timer("reward", timing_raw):
158182
# compute scores. Support both model and function-based.
159183
# We first compute the scores using reward model. Then, we call reward_fn to combine
@@ -249,9 +273,6 @@ def fit(self):
249273
batch = batch[:traj_bsz]
250274

251275
# === Updating ===
252-
253-
batch.batch["response_mask"] = compute_response_mask(batch)
254-
255276
# balance the number of valid tokens on each dp rank.
256277
# Note that this breaks the order of data inside the batch.
257278
# Please take care when you implement group based adv computation such as GRPO and rloo
@@ -261,16 +282,8 @@ def fit(self):
261282
# compute global_valid tokens
262283
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
263284

264-
# recompute old_log_probs
265-
with simple_timer("old_log_prob", timing_raw):
266-
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
267-
batch = batch.union(old_log_prob)
268-
269-
if self.use_reference_policy:
270-
# compute reference log_prob
271-
with simple_timer("ref", timing_raw):
272-
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
273-
batch = batch.union(ref_log_prob)
285+
if not self.config.algorithm.use_kl_in_reward:
286+
batch = self.compute_kl_related_metrics(batch, timing_raw)
274287

275288
# compute values
276289
if self.use_critic:

0 commit comments

Comments
 (0)