2727from verl import DataProto
2828
2929
30- @dataclass
31- class DynamicFilterState :
32- """State tracking for dynamic filtering during batch processing."""
30+ class DynamicFilter :
31+ """Unified class for handling dynamic filtering during training with state management."""
3332
34- num_gen_batches : int = 0
35- num_prompt_in_batch : int = 0
36- accumulated_batch : Optional [DataProto ] = None
37- reward_step : int = 0
33+ def __init__ (self , config ):
34+ """Initialize the dynamic filter.
35+
36+ Args:
37+ config: configuration from ray_trainer
38+ """
39+ # Configuration attributes
40+ self .metric = config .algorithm .filter_groups .metric
41+ self .filter_kwargs = config .algorithm .filter_groups .filter_kwargs
42+ self .custom_filter_func = None
43+ self .filter_function = config .algorithm .filter_groups .filter_function
44+
45+ # State attributes
46+ self .num_gen_batches : int = 0
47+ self .num_prompt_in_batch : int = 0
48+ self .accumulated_batch : Optional [DataProto ] = None
49+ self .reward_step : int = 0
50+
51+ assert not config .reward_model .launch_reward_fn_async , (
52+ "Dynamic filter has not supported async reward function yet."
53+ )
54+
55+ if self .filter_function :
56+ # Import custom filter function
57+ module_path , func_name = self .filter_function .rsplit ("." , 1 )
58+ module = importlib .import_module (module_path )
59+ self .custom_filter_func = getattr (module , func_name )
3860
3961 def clear (self ) -> None :
4062 """Reset all state variables for the next training step."""
41-
4263 if self .num_gen_batches > 0 :
4364 print (f"Dynamic Filter: Used { self .num_gen_batches } generation batches to complete this step" )
4465
@@ -48,6 +69,7 @@ def clear(self) -> None:
4869 self .reward_step = 0
4970
5071 def increment_reward_step (self , global_step ) -> bool :
72+ """Increment the reward step if it's less than the global step."""
5173 if self .reward_step < global_step :
5274 self .reward_step += 1
5375 return True
@@ -67,40 +89,13 @@ def accumulate_batch(self, batch: DataProto) -> None:
6789 batch if self .accumulated_batch is None else DataProto .concat ([self .accumulated_batch , batch ])
6890 )
6991
70-
71- @dataclass
72- class DynamicFilterManager :
73- """Manager class for handling dynamic filtering during training."""
74-
75- def __init__ (self , config ):
76- """Initialize the filter manager.
77-
78- Args:
79- config: configuration from ray_trainer
80- """
81- self .metric = config .algorithm .filter_groups .metric
82- self .filter_kwargs = config .algorithm .filter_groups .filter_kwargs
83- self .custom_filter_func = None
84- self .filter_function = config .algorithm .filter_groups .filter_function
85-
86- assert not config .reward_model .launch_reward_fn_async , (
87- "Dynamic filter has not supported async reward function yet."
88- )
89-
90- if self .filter_function :
91- # Import custom filter function
92- module_path , func_name = self .filter_function .rsplit ("." , 1 )
93- module = importlib .import_module (module_path )
94- self .custom_filter_func = getattr (module , func_name )
95-
9692 def process_batch_with_filtering (
97- self , batch : DataProto , dynamic_filter_state : "DynamicFilterState" , config
93+ self , batch : DataProto , config
9894 ) -> tuple [DataProto , bool ]:
9995 """Process a batch with dynamic filtering and accumulation logic.
10096
10197 Args:
10298 batch: The input batch to process
103- dynamic_filter_state: State object tracking filtering progress
10499 config: configuration from ray_trainer
105100
106101 Returns:
@@ -151,24 +146,24 @@ def process_batch_with_filtering(
151146
152147 # Filter the batch and update state
153148 filtered_batch = batch [kept_traj_idxs ]
154- dynamic_filter_state .add_prompts (kept_prompts_this_batch )
155- dynamic_filter_state .accumulate_batch (filtered_batch )
149+ self .add_prompts (kept_prompts_this_batch )
150+ self .accumulate_batch (filtered_batch )
156151
157152 # Check if we have enough prompts or reached max generation batches
158153 if (
159- dynamic_filter_state .num_prompt_in_batch < train_batch_size
160- and dynamic_filter_state .num_gen_batches < max_num_gen_batches
154+ self .num_prompt_in_batch < train_batch_size
155+ and self .num_gen_batches < max_num_gen_batches
161156 ):
162157 return None , True # Continue collecting more batches
163158
164159 # If we reached max generation batches but still don't have enough prompts,
165160 # repeat batch content to fill the deficit
166- if dynamic_filter_state .num_gen_batches >= max_num_gen_batches :
167- prompt_deficit = train_batch_size - dynamic_filter_state .num_prompt_in_batch
168- repeated_batch = dynamic_filter_state .accumulated_batch [: prompt_deficit * rollout_n ]
169- final_batch = DataProto .concat ([dynamic_filter_state .accumulated_batch , repeated_batch ])
161+ if self .num_gen_batches >= max_num_gen_batches :
162+ prompt_deficit = train_batch_size - self .num_prompt_in_batch
163+ repeated_batch = self .accumulated_batch [: prompt_deficit * rollout_n ]
164+ final_batch = DataProto .concat ([self .accumulated_batch , repeated_batch ])
170165 else :
171- final_batch = dynamic_filter_state .accumulated_batch
166+ final_batch = self .accumulated_batch
172167
173168 # Align the batch to the expected trajectory batch size
174169 traj_bsz = train_batch_size * rollout_n
0 commit comments