@@ -748,24 +748,47 @@ def _validate(self):
748748 ground_truths = [item .get ("ground_truth" , None ) for item in data .get ("reward_model" , {})]
749749 sample_gts .extend (ground_truths )
750750
751- test_gen_meta = asyncio .run (
752- self .val_data_system_client .async_get_meta (
753- data_fields = [
754- "input_ids" ,
755- "attention_mask" ,
756- "position_ids" ,
757- "index" ,
758- "tools_kwargs" ,
759- "interaction_kwargs" ,
760- "ability" ,
761- "raw_prompt_ids" ,
762- ],
763- batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
764- global_step = self .global_steps - 1 , # self.global_steps start from 1
765- get_n_samples = False ,
766- task_name = "generate_sequences" ,
751+ if not self .async_rollout_mode :
752+ test_gen_meta = asyncio .run (
753+ self .val_data_system_client .async_get_meta (
754+ data_fields = [
755+ "input_ids" ,
756+ "attention_mask" ,
757+ "position_ids" ,
758+ "index" ,
759+ "tools_kwargs" ,
760+ "interaction_kwargs" ,
761+ "ability" ,
762+ "raw_prompt_ids" ,
763+ ],
764+ batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
765+ global_step = self .global_steps - 1 , # self.global_steps start from 1
766+ get_n_samples = False ,
767+ task_name = "generate_sequences" ,
768+ )
769+ )
770+ else :
771+ test_gen_meta = asyncio .run (
772+ self .val_data_system_client .async_get_meta (
773+ data_fields = [
774+ "input_ids" ,
775+ "attention_mask" ,
776+ "position_ids" ,
777+ "index" ,
778+ "tools_kwargs" ,
779+ "interaction_kwargs" ,
780+ "ability" ,
781+ "raw_prompt_ids" ,
782+ "raw_prompt" ,
783+ "reward_model" ,
784+ "data_source" ,
785+ ],
786+ batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
787+ global_step = self .global_steps - 1 , # self.global_steps start from 1
788+ get_n_samples = False ,
789+ task_name = "async_generate_sequences" ,
790+ )
767791 )
768- )
769792
770793 test_gen_meta .extra_info = {
771794 "eos_token_id" : self .tokenizer .eos_token_id ,
@@ -1028,8 +1051,12 @@ def init_workers(self):
10281051 self .async_rollout_manager = AgentLoopManager (
10291052 config = self .config , worker_group = self .actor_rollout_wg , rm_wg = self .rm_wg
10301053 )
1054+
1055+ self .async_rollout_manager .create_transferqueue_client (
1056+ self .data_system_controller_infos , self .data_system_storage_unit_infos , role = "train"
1057+ )
10311058 self .async_rollout_manager .create_transferqueue_client (
1032- self .data_system_controller_infos , self .data_system_storage_unit_infos
1059+ self .val_data_system_controller_infos , self .val_data_system_storage_unit_infos , role = "val"
10331060 )
10341061
10351062 def _save_checkpoint (self ):
0 commit comments