Skip to content

Commit 778e318

Browse files
committed
resolve comment
1 parent 36fc97b commit 778e318

File tree

6 files changed

+67
-74
lines changed

6 files changed

+67
-74
lines changed

examples/sglang_multiturn/run_qwen3_4b_dapo_multiturn.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ python3 -m verl.trainer.main_ppo \
4949
data.train_files=$HOME/data/BytedTsinghua-SIA/DAPO-Math-17k \
5050
data.val_files=$HOME/data/Maxwell-Jia/AIME_2024 \
5151
data.return_raw_chat=True \
52-
data.train_batch_size=32 \
52+
data.train_batch_size=16 \
5353
data.max_prompt_length=2048 \
5454
data.max_response_length=16384 \
5555
data.filter_overlong_prompts=True \
@@ -68,7 +68,7 @@ python3 -m verl.trainer.main_ppo \
6868
actor_rollout_ref.actor.clip_ratio_c=10.0 \
6969
actor_rollout_ref.actor.optim.lr=1e-6 \
7070
actor_rollout_ref.actor.use_dynamic_bsz=False \
71-
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
71+
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
7272
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
7373
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \
7474
actor_rollout_ref.rollout.name=sglang \
@@ -84,17 +84,18 @@ python3 -m verl.trainer.main_ppo \
8484
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=16 \
8585
actor_rollout_ref.rollout.multi_turn.tool_config_path=$PROJECT_DIR/recipe/retool/sandbox_fusion_tool_config.yaml \
8686
actor_rollout_ref.rollout.multi_turn.format=hermes \
87-
actor_rollout_ref.rollout.n=8 \
87+
actor_rollout_ref.rollout.n=4 \
8888
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
8989
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
9090
actor_rollout_ref.rollout.val_kwargs.n=30 \
9191
algorithm.filter_groups.enable=True \
92+
algorithm.filter_groups.max_num_gen_batches=2 \
9293
trainer.logger=['console','wandb'] \
9394
trainer.project_name=sglang-dapo-multiturn \
9495
trainer.experiment_name=qwen3_4b_sft_dapo_multiturn \
9596
trainer.n_gpus_per_node=8 \
9697
trainer.log_val_generations=20 \
97-
trainer.val_before_train=True \
98+
trainer.val_before_train=False \
9899
trainer.nnodes=1 \
99100
trainer.save_freq=-1 \
100101
trainer.test_freq=20 \

verl/trainer/ppo/ray_trainer.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,18 @@
5050
compute_timing_metrics,
5151
process_validation_metrics,
5252
)
53-
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
53+
from verl.trainer.ppo.reward import compute_reward, compute_reward_async, extract_reward_extra_infos
5454
from verl.trainer.ppo.utils import (
5555
Role,
5656
WorkerType,
57-
extract_reward_extra_infos,
5857
need_critic,
5958
need_reference_policy,
6059
need_reward_model,
6160
)
6261
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
6362
from verl.utils.config import omega_conf_to_dataclass
6463
from verl.utils.debug import marked_timer
65-
from verl.utils.filtering import DynamicFilterState
66-
from verl.utils.filtering.dynamic_filtering import DynamicFilterManager
64+
from verl.utils.filtering.dynamic_filtering import DynamicFilter
6765
from verl.utils.metric import reduce_metrics
6866
from verl.utils.rollout_skip import RolloutSkip
6967
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
@@ -362,9 +360,8 @@ def __init__(
362360
if self.config.algorithm.use_kl_in_reward:
363361
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
364362

365-
# initialize dynamic filter manager
366-
self.dynamic_filter_manager = (
367-
DynamicFilterManager(config=self.config)
363+
self.dynamic_filter = (
364+
DynamicFilter(config=self.config)
368365
if self.config.algorithm.filter_groups and self.config.algorithm.filter_groups.enable
369366
else None
370367
)
@@ -959,15 +956,16 @@ def fit(self):
959956
else False
960957
)
961958
next_step_profile = False
962-
dynamic_filter_state = DynamicFilterState()
959+
963960

964961
for epoch in range(self.config.trainer.total_epochs):
965962
for batch_dict in self.train_dataloader:
966963
metrics = {}
967964
timing_raw = {}
968965

969966
# dynamic filter
970-
dynamic_filter_state.increment_gen_batches()
967+
if self.dynamic_filter:
968+
self.dynamic_filter.increment_gen_batches()
971969

972970
with marked_timer("start_profile", timing_raw):
973971
self._start_profiling(
@@ -1051,16 +1049,15 @@ def fit(self):
10511049
)
10521050

10531051
# Compute reward metrics
1054-
if dynamic_filter_state.increment_reward_step(self.global_steps):
1052+
if self.dynamic_filter and self.dynamic_filter.increment_reward_step(self.global_steps):
10551053
reward_metrics = compute_reward_metrics(batch)
10561054
metrics.update(reward_metrics)
10571055

10581056
# Apply dynamic filtering after reward computation
1059-
if self.dynamic_filter_manager:
1057+
if self.dynamic_filter:
10601058
# Apply dynamic filtering and handle batch accumulation
1061-
processed_batch, should_continue = self.dynamic_filter_manager.process_batch_with_filtering(
1059+
processed_batch, should_continue = self.dynamic_filter.process_batch_with_filtering(
10621060
batch,
1063-
dynamic_filter_state,
10641061
self.config,
10651062
)
10661063

@@ -1279,7 +1276,8 @@ def fit(self):
12791276
self.global_steps += 1
12801277

12811278
# Reset dynamic filter state for next training step
1282-
dynamic_filter_state.clear()
1279+
if self.dynamic_filter:
1280+
self.dynamic_filter.clear()
12831281

12841282
if (
12851283
hasattr(self.config.actor_rollout_ref.actor, "profiler")

verl/trainer/ppo/reward.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,11 @@ def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn
186186
)
187187

188188
return compute_reward(data, reward_fn)
189+
190+
def extract_reward_extra_infos(batch: DataProto, reward_extra_info_keys: list[str]) -> dict[str, list]:
191+
"""Extract reward extra info from batch.non_tensor_batch for dump_generations."""
192+
reward_extra_infos_dict = {}
193+
for key in reward_extra_info_keys:
194+
reward_extra_infos_dict[key] = batch.non_tensor_batch[key]
195+
196+
return reward_extra_infos_dict

verl/trainer/ppo/utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,3 @@ def need_critic(config: DictConfig) -> bool:
6464
stacklevel=2,
6565
)
6666
return False
67-
68-
69-
def extract_reward_extra_infos(batch: DataProto, reward_extra_info_keys: list[str]) -> dict[str, list]:
70-
"""Extract reward extra info from batch.non_tensor_batch for dump_generations."""
71-
reward_extra_infos_dict = {}
72-
for key in reward_extra_info_keys:
73-
reward_extra_infos_dict[key] = batch.non_tensor_batch[key]
74-
75-
return reward_extra_infos_dict

verl/utils/filtering/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@
1818
# Paper: https://arxiv.org/abs/2503.14476
1919
# - This implementation references the ReTool implementation: recipe/retool/ in VERL codebase
2020

21-
from .dynamic_filtering import DynamicFilterState
21+
from .dynamic_filtering import DynamicFilter, keep_mixed_reward
2222

23-
__all__ = ["DynamicFilterState"]
23+
__all__ = ["DynamicFilter", "keep_mixed_reward"]

verl/utils/filtering/dynamic_filtering.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,39 @@
2727
from verl import DataProto
2828

2929

30-
@dataclass
31-
class DynamicFilterState:
32-
"""State tracking for dynamic filtering during batch processing."""
30+
class DynamicFilter:
31+
"""Unified class for handling dynamic filtering during training with state management."""
3332

34-
num_gen_batches: int = 0
35-
num_prompt_in_batch: int = 0
36-
accumulated_batch: Optional[DataProto] = None
37-
reward_step: int = 0
33+
def __init__(self, config):
34+
"""Initialize the dynamic filter.
35+
36+
Args:
37+
config: configuration from ray_trainer
38+
"""
39+
# Configuration attributes
40+
self.metric = config.algorithm.filter_groups.metric
41+
self.filter_kwargs = config.algorithm.filter_groups.filter_kwargs
42+
self.custom_filter_func = None
43+
self.filter_function = config.algorithm.filter_groups.filter_function
44+
45+
# State attributes
46+
self.num_gen_batches: int = 0
47+
self.num_prompt_in_batch: int = 0
48+
self.accumulated_batch: Optional[DataProto] = None
49+
self.reward_step: int = 0
50+
51+
assert not config.reward_model.launch_reward_fn_async, (
52+
"Dynamic filter has not supported async reward function yet."
53+
)
54+
55+
if self.filter_function:
56+
# Import custom filter function
57+
module_path, func_name = self.filter_function.rsplit(".", 1)
58+
module = importlib.import_module(module_path)
59+
self.custom_filter_func = getattr(module, func_name)
3860

3961
def clear(self) -> None:
4062
"""Reset all state variables for the next training step."""
41-
4263
if self.num_gen_batches > 0:
4364
print(f"Dynamic Filter: Used {self.num_gen_batches} generation batches to complete this step")
4465

@@ -48,6 +69,7 @@ def clear(self) -> None:
4869
self.reward_step = 0
4970

5071
def increment_reward_step(self, global_step) -> bool:
72+
"""Increment the reward step if it's less than the global step."""
5173
if self.reward_step < global_step:
5274
self.reward_step += 1
5375
return True
@@ -67,40 +89,13 @@ def accumulate_batch(self, batch: DataProto) -> None:
6789
batch if self.accumulated_batch is None else DataProto.concat([self.accumulated_batch, batch])
6890
)
6991

70-
71-
@dataclass
72-
class DynamicFilterManager:
73-
"""Manager class for handling dynamic filtering during training."""
74-
75-
def __init__(self, config):
76-
"""Initialize the filter manager.
77-
78-
Args:
79-
config: configuration from ray_trainer
80-
"""
81-
self.metric = config.algorithm.filter_groups.metric
82-
self.filter_kwargs = config.algorithm.filter_groups.filter_kwargs
83-
self.custom_filter_func = None
84-
self.filter_function = config.algorithm.filter_groups.filter_function
85-
86-
assert not config.reward_model.launch_reward_fn_async, (
87-
"Dynamic filter has not supported async reward function yet."
88-
)
89-
90-
if self.filter_function:
91-
# Import custom filter function
92-
module_path, func_name = self.filter_function.rsplit(".", 1)
93-
module = importlib.import_module(module_path)
94-
self.custom_filter_func = getattr(module, func_name)
95-
9692
def process_batch_with_filtering(
97-
self, batch: DataProto, dynamic_filter_state: "DynamicFilterState", config
93+
self, batch: DataProto, config
9894
) -> tuple[DataProto, bool]:
9995
"""Process a batch with dynamic filtering and accumulation logic.
10096
10197
Args:
10298
batch: The input batch to process
103-
dynamic_filter_state: State object tracking filtering progress
10499
config: configuration from ray_trainer
105100
106101
Returns:
@@ -151,24 +146,24 @@ def process_batch_with_filtering(
151146

152147
# Filter the batch and update state
153148
filtered_batch = batch[kept_traj_idxs]
154-
dynamic_filter_state.add_prompts(kept_prompts_this_batch)
155-
dynamic_filter_state.accumulate_batch(filtered_batch)
149+
self.add_prompts(kept_prompts_this_batch)
150+
self.accumulate_batch(filtered_batch)
156151

157152
# Check if we have enough prompts or reached max generation batches
158153
if (
159-
dynamic_filter_state.num_prompt_in_batch < train_batch_size
160-
and dynamic_filter_state.num_gen_batches < max_num_gen_batches
154+
self.num_prompt_in_batch < train_batch_size
155+
and self.num_gen_batches < max_num_gen_batches
161156
):
162157
return None, True # Continue collecting more batches
163158

164159
# If we reached max generation batches but still don't have enough prompts,
165160
# repeat batch content to fill the deficit
166-
if dynamic_filter_state.num_gen_batches >= max_num_gen_batches:
167-
prompt_deficit = train_batch_size - dynamic_filter_state.num_prompt_in_batch
168-
repeated_batch = dynamic_filter_state.accumulated_batch[: prompt_deficit * rollout_n]
169-
final_batch = DataProto.concat([dynamic_filter_state.accumulated_batch, repeated_batch])
161+
if self.num_gen_batches >= max_num_gen_batches:
162+
prompt_deficit = train_batch_size - self.num_prompt_in_batch
163+
repeated_batch = self.accumulated_batch[: prompt_deficit * rollout_n]
164+
final_batch = DataProto.concat([self.accumulated_batch, repeated_batch])
170165
else:
171-
final_batch = dynamic_filter_state.accumulated_batch
166+
final_batch = self.accumulated_batch
172167

173168
# Align the batch to the expected trajectory batch size
174169
traj_bsz = train_batch_size * rollout_n

0 commit comments

Comments
 (0)