Skip to content

Commit 7db97f8

Browse files
authored
fix validate in agent loop (#34)
1 parent 136c82a commit 7db97f8

File tree

3 files changed

+49
-22
lines changed

3 files changed

+49
-22
lines changed

recipe/transfer_queue/agent_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data
6767

6868
return timing
6969

70-
def create_transferqueue_client(self, controller_infos, storage_infos):
70+
def create_transferqueue_client(self, controller_infos, storage_infos, role):
7171
ray.get(
7272
[
73-
worker._create_transferqueue_client.remote(controller_infos, storage_infos)
73+
worker.create_transferqueue_client.remote(controller_infos, storage_infos, role)
7474
for worker in self.agent_loop_workers
7575
]
7676
)

recipe/transfer_queue/ray_trainer.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -748,24 +748,47 @@ def _validate(self):
748748
ground_truths = [item.get("ground_truth", None) for item in data.get("reward_model", {})]
749749
sample_gts.extend(ground_truths)
750750

751-
test_gen_meta = asyncio.run(
752-
self.val_data_system_client.async_get_meta(
753-
data_fields=[
754-
"input_ids",
755-
"attention_mask",
756-
"position_ids",
757-
"index",
758-
"tools_kwargs",
759-
"interaction_kwargs",
760-
"ability",
761-
"raw_prompt_ids",
762-
],
763-
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
764-
global_step=self.global_steps - 1, # self.global_steps start from 1
765-
get_n_samples=False,
766-
task_name="generate_sequences",
751+
if not self.async_rollout_mode:
752+
test_gen_meta = asyncio.run(
753+
self.val_data_system_client.async_get_meta(
754+
data_fields=[
755+
"input_ids",
756+
"attention_mask",
757+
"position_ids",
758+
"index",
759+
"tools_kwargs",
760+
"interaction_kwargs",
761+
"ability",
762+
"raw_prompt_ids",
763+
],
764+
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
765+
global_step=self.global_steps - 1, # self.global_steps start from 1
766+
get_n_samples=False,
767+
task_name="generate_sequences",
768+
)
769+
)
770+
else:
771+
test_gen_meta = asyncio.run(
772+
self.val_data_system_client.async_get_meta(
773+
data_fields=[
774+
"input_ids",
775+
"attention_mask",
776+
"position_ids",
777+
"index",
778+
"tools_kwargs",
779+
"interaction_kwargs",
780+
"ability",
781+
"raw_prompt_ids",
782+
"raw_prompt",
783+
"reward_model",
784+
"data_source",
785+
],
786+
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
787+
global_step=self.global_steps - 1, # self.global_steps start from 1
788+
get_n_samples=False,
789+
task_name="async_generate_sequences",
790+
)
767791
)
768-
)
769792

770793
test_gen_meta.extra_info = {
771794
"eos_token_id": self.tokenizer.eos_token_id,
@@ -1028,8 +1051,12 @@ def init_workers(self):
10281051
self.async_rollout_manager = AgentLoopManager(
10291052
config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
10301053
)
1054+
1055+
self.async_rollout_manager.create_transferqueue_client(
1056+
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
1057+
)
10311058
self.async_rollout_manager.create_transferqueue_client(
1032-
self.data_system_controller_infos, self.data_system_storage_unit_infos
1059+
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
10331060
)
10341061

10351062
def _save_checkpoint(self):

verl/experimental/agent_loop/agent_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -725,13 +725,13 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
725725
meta_info={"metrics": metrics, "reward_extra_keys": reward_extra_keys},
726726
)
727727

728-
def _create_transferqueue_client(self, controller_infos, storage_infos):
728+
def create_transferqueue_client(self, controller_infos, storage_infos, role):
729729
from verl.single_controller.ray.base import get_random_string
730730
from verl.utils.transferqueue_utils import create_transferqueue_client
731731

732732
client_name = get_random_string(length=6)
733733
create_transferqueue_client(
734-
client_id=f"worker_{client_name}",
734+
client_id=f"{role}_worker_{client_name}",
735735
controller_infos=controller_infos,
736736
storage_infos=storage_infos,
737737
)

0 commit comments

Comments
 (0)