Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions recipe/transfer_queue/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,20 @@

from verl import DataProto
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.experimental.transfer_queue.client import AsyncTransferQueueClient, process_zmq_server_info
from verl.experimental.transfer_queue.client import (
AsyncTransferQueueClient,
process_zmq_server_info,
)
from verl.experimental.transfer_queue.controller import TransferQueueController
from verl.experimental.transfer_queue.metadata import BatchMeta
from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit
from verl.experimental.transfer_queue.utils.utils import get_placement_group
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray import (
RayClassWithInitArgs,
RayResourcePool,
RayWorkerGroup,
)
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.config import AlgoConfig
from verl.trainer.ppo import core_algos
Expand All @@ -60,13 +67,25 @@
process_validation_metrics,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
from verl.trainer.ppo.utils import (
Role,
WorkerType,
need_critic,
need_reference_policy,
need_reward_model,
)
from verl.utils.checkpoint.checkpoint_manager import (
find_latest_ckpt_path,
should_save_ckpt_esi,
)
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.debug import marked_timer
from verl.utils.metric import reduce_metrics
from verl.utils.rollout_skip import RolloutSkip
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.seqlen_balancing import (
get_seqlen_balanced_partitions,
log_seqlen_unbalance,
)
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger

Expand Down Expand Up @@ -828,6 +847,10 @@ def init_workers(self):
self.actor_rollout_wg = all_wg["actor_rollout"]
self.actor_rollout_wg.init_model()

# set transferqueue server info for each worker group
for _, wg in all_wg.items():
wg.set_transferqueue_server_info(self.data_system_controller_infos, self.data_system_storage_unit_infos)

# create async rollout manager and request scheduler
self.async_rollout_mode = False
if self.config.actor_rollout_ref.rollout.mode == "async":
Expand Down
14 changes: 14 additions & 0 deletions verl/single_controller/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ def _query_collect_info(self, mesh_name: str):
"""
assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
return self.__collect_dp_rank[mesh_name]

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def set_transferqueue_server_info(self, controller_infos, storage_infos):
"""Set the transfer queue server information for the worker.

Args:
controller_infos (list):
List of controller server information.
storage_infos (list):
List of storage unit server information.
"""
from verl.utils.transferqueue_utils import set_transferqueue_server_info

set_transferqueue_server_info(controller_infos, storage_infos)

@classmethod
def env_keys(cls):
Expand Down
35 changes: 35 additions & 0 deletions verl/utils/transferqueue_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from transfer_queue.utils.zmq_utils import ZMQServerInfo

_TRANSFER_QUEUE_CONTROLLER_INFOS = None
_TRANSFER_QUEUE_STORAGE_INFOS = None


def set_transferqueue_server_info(controller_infos: dict[Any, ZMQServerInfo], storage_infos: dict[Any, ZMQServerInfo]):
global _TRANSFER_QUEUE_CONTROLLER_INFOS, _TRANSFER_QUEUE_STORAGE_INFOS
if _TRANSFER_QUEUE_CONTROLLER_INFOS is not None and _TRANSFER_QUEUE_STORAGE_INFOS is not None:
return
_TRANSFER_QUEUE_CONTROLLER_INFOS = controller_infos
_TRANSFER_QUEUE_STORAGE_INFOS = storage_infos


def get_transferqueue_server_info():
assert _TRANSFER_QUEUE_CONTROLLER_INFOS is not None and _TRANSFER_QUEUE_STORAGE_INFOS is not None, (
"TransferQueue server infos have not been set yet."
)
return _TRANSFER_QUEUE_CONTROLLER_INFOS, _TRANSFER_QUEUE_STORAGE_INFOS
Loading