-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[RFC]Implements Group Filtering Policy Optimization #3479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request implements Group Filtering Policy Optimization (GFPO) to reduce response length inflation. The changes introduce new configuration options, core algorithm logic for filtering and sampling, and integration into the DPO trainer. My review has identified several critical issues, including incorrect configuration access, missing parameters in shell scripts, incorrect function arguments, and logical errors in loops that could lead to runtime errors or incorrect behavior. I've also pointed out a potential division-by-zero issue. I've provided suggestions to fix these problems.
recipe/dapo/dapo_ray_trainer.py
Outdated
if self.config.algorithm.filter_samples.enable: | ||
new_batch.batch["response_mask"] = compute_response_mask(new_batch) | ||
filtering_sampling_kept_traj_idxs = filtering_sampling(new_batch, | ||
metric=self.config.algorithm.filter_sample.metric, | ||
metric_name="token_level_scores", | ||
retain_count=self.config.algorithm.filter_sample.retain_count, | ||
adaptive=self.config.algorithm.filter_sample.adaptive, | ||
t_digest=t_digest, | ||
easy_count=self.config.algorithm.filter_sample.easy_count, | ||
medium_count=self.config.algorithm.filter_sample.medium_count, | ||
hard_count=self.config.algorithm.filter_sample.hard_count, | ||
very_hard_count=self.config.algorithm.filter_sample.hard_count, | ||
) | ||
new_batch = new_batch[filtering_sampling_kept_traj_idxs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a couple of issues in this block:
- There's a typo in the configuration access on line 168. It should be
self.config.algorithm.filter_sample.enable
instead ofself.config.algorithm.filter_samples.enable
. This will cause a runtime error asfilter_samples
is not defined in the configuration. - On line 179, the
very_hard_count
parameter is being assigned the value ofhard_count
from the configuration. It should be usingself.config.algorithm.filter_sample.very_hard_count
.
if self.config.algorithm.filter_samples.enable: | |
new_batch.batch["response_mask"] = compute_response_mask(new_batch) | |
filtering_sampling_kept_traj_idxs = filtering_sampling(new_batch, | |
metric=self.config.algorithm.filter_sample.metric, | |
metric_name="token_level_scores", | |
retain_count=self.config.algorithm.filter_sample.retain_count, | |
adaptive=self.config.algorithm.filter_sample.adaptive, | |
t_digest=t_digest, | |
easy_count=self.config.algorithm.filter_sample.easy_count, | |
medium_count=self.config.algorithm.filter_sample.medium_count, | |
hard_count=self.config.algorithm.filter_sample.hard_count, | |
very_hard_count=self.config.algorithm.filter_sample.hard_count, | |
) | |
new_batch = new_batch[filtering_sampling_kept_traj_idxs] | |
if self.config.algorithm.filter_sample.enable: | |
new_batch.batch["response_mask"] = compute_response_mask(new_batch) | |
filtering_sampling_kept_traj_idxs = filtering_sampling(new_batch, | |
metric=self.config.algorithm.filter_sample.metric, | |
metric_name="token_level_scores", | |
retain_count=self.config.algorithm.filter_sample.retain_count, | |
adaptive=self.config.algorithm.filter_sample.adaptive, | |
t_digest=t_digest, | |
easy_count=self.config.algorithm.filter_sample.easy_count, | |
medium_count=self.config.algorithm.filter_sample.medium_count, | |
hard_count=self.config.algorithm.filter_sample.hard_count, | |
very_hard_count=self.config.algorithm.filter_sample.very_hard_count, | |
) | |
new_batch = new_batch[filtering_sampling_kept_traj_idxs] |
recipe/dapo/run_dapo_qwen2.5_32b.sh
Outdated
algorithm.filter_sample.enable=True\ | ||
algorithm.filter_sample.metric="response length"\ | ||
algorithm.filter_sample.retain_count=8\ | ||
algorithm.filter_sample.adaptive=True\ | ||
algorithm.filter_sample.easy_count=4\ | ||
algorithm.filter_sample.medium_count=6\ | ||
algorithm.filter_sample.hard_count=8\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The algorithm.filter_sample.very_hard_count
parameter is missing. When adaptive=True
, the filtering_sampling
function requires very_hard_count
. If it's not provided, it will default to None
and cause a TypeError
when used in min(count, len(id_score))
, as count
could be None
.
algorithm.filter_sample.enable=True\ | |
algorithm.filter_sample.metric="response length"\ | |
algorithm.filter_sample.retain_count=8\ | |
algorithm.filter_sample.adaptive=True\ | |
algorithm.filter_sample.easy_count=4\ | |
algorithm.filter_sample.medium_count=6\ | |
algorithm.filter_sample.hard_count=8\ | |
algorithm.filter_sample.enable=True\ | |
algorithm.filter_sample.metric="response length"\ | |
algorithm.filter_sample.retain_count=8\ | |
algorithm.filter_sample.adaptive=True\ | |
algorithm.filter_sample.easy_count=4\ | |
algorithm.filter_sample.medium_count=6\ | |
algorithm.filter_sample.hard_count=8\ | |
algorithm.filter_sample.very_hard_count=8\ |
verl/trainer/ppo/core_algos.py
Outdated
Returns: | ||
kept_traj_idxs: the desirable responses to train on. | ||
""" | ||
id2response_and_score, id2average_reward = compute_scores(data, metric, adaptive, metric_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arguments to compute_scores
are passed positionally, but the order is incorrect. adaptive
is passed as metric_name
and metric_name
is passed as adaptive
. This will likely cause a KeyError
when trying to access data.batch[adaptive]
inside compute_scores
. You should use keyword arguments to avoid this kind of error.
id2response_and_score, id2average_reward = compute_scores(data, metric, adaptive, metric_name) | |
id2response_and_score, id2average_reward = compute_scores(data, metric=metric, metric_name=metric_name, adaptive=adaptive) |
verl/trainer/ppo/core_algos.py
Outdated
else: | ||
for id in id2response_and_score.keys(): | ||
id_score = id2response_and_score[id] | ||
for i in range(min(retain_count, len(id_score[i]))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bug in the loop condition. id_score
is a list of tuples, so id_score[i]
is a tuple and len(id_score[i])
will always be 2. The intention is likely to iterate up to retain_count
or the number of available scores for that id
, which is len(id_score)
.
for i in range(min(retain_count, len(id_score[i]))): | |
for i in range(min(retain_count, len(id_score))): |
verl/trainer/ppo/core_algos.py
Outdated
id2response_and_score[index[i]].append((i, response_length[i])) | ||
elif metric == "token efficiency": | ||
for i in range(bsz): | ||
id2response_and_score[index[i]].append((i, -reward_value[i] / response_length[i])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Division by response_length[i]
can lead to a division-by-zero error if a response is empty (i.e., response_length[i]
is 0). It's safer to add a small epsilon to the denominator to prevent this.
id2response_and_score[index[i]].append((i, -reward_value[i] / response_length[i])) | |
id2response_and_score[index[i]].append((i, -reward_value[i] / (response_length[i] + 1e-8))) |
recipe/dapo/run_dapo_qwen2.5_32b.sh
Outdated
algorithm.filter_groups.enable=${enable_filter_groups} \ | ||
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ | ||
algorithm.filter_groups.metric=${filter_groups_metric} \ | ||
algorithm.filter_sample.enable=True\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding default params about group filtering policy to existing DAPO scripts is not recommanded, how about creating a new file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I have created two new file ——run_dapo_qwen2.5_32b_w_GFPO.sh and run_dapo_qwen_3_4b_w_GFPO.
verl/trainer/ppo/core_algos.py
Outdated
for i in range(bsz): | ||
id2response_and_score[index[i]].append((i, -reward_value[i] / (response_length[i] + 10**(-8)))) | ||
else: | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest adding some instructions about this error, easy for understanding
@@ -1 +0,0 @@ | |||
0.5.0.dev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this content is removed?
# Power used for weight scaling in "pow" method | ||
weight_pow: 2.0 | ||
|
||
filter_sample: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The configurations in YAML are usually accompanied by necessary explanations to improve usability. For example, retain_count does not take effect when adaptive=True.
for i in range(bsz): | ||
id2response_and_score[index[i]].append((i, -reward_value[i] / (response_length[i] + 10**(-8)))) | ||
else: | ||
raise NotImplementedError(f"metric {metric} not supported") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the judgment logic for metric and the loop logic are redundant and can be merged into a unified processing approach. It's also good for furture extension of metrics
score_func = None
if metric == "response length":
score_func = lambda i: response_length[i]
elif metric == "token efficiency":
score_func = lambda i: -reward_value[i] / (response_length[i] + 1e-8)
else:
raise NotImplementedError(...)
for i in range(batch_size):
id2response_and_score[index[i]].append((i, score_func(i)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM on code implementation.
To prove the correctness of the implementation, I believe it is necessary to provide some experimental data and documentation.
|
|
please address comments from other reviewers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@FightingZhen @eric-haibin-lin @CLAassistant Hi, Are there any other questions? |
Okay, I have addressed all the comments raised by the other reviewers so far. |
What does this PR do?
Experimental Results
By running the script run_dapo_qwen_3_4b_w_GFPO.sh on the platform with 16 NPU (A3) , we obtained the following results(Among them, the experimental setup of DAPO is basically the same as that of GFPO, except for algorithm.filter_sample.enable=False. ):