Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions recipe/transfer_queue/fsdp_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""

import verl.workers.fsdp_workers as workers
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.transferqueue_utils import create_transferqueue_client


class ActorRolloutRefWorker(workers.ActorRolloutRefWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
def create_transferqueue_client(self, controller_infos, storage_infos):
create_transferqueue_client(
client_id=f"worker_{self.rank}",
controller_infos=controller_infos,
storage_infos=storage_infos,
)


Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is unnecessary trailing whitespace after line 32 that should be removed.

Copilot uses AI. Check for mistakes.

class CriticWorker(workers.CriticWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
def create_transferqueue_client(self, controller_infos, storage_infos):
create_transferqueue_client(
client_id=f"worker_{self.rank}",
controller_infos=controller_infos,
storage_infos=storage_infos,
)


# TODO(sgm): we may need to extract it to dp_reward_model.py
class RewardModelWorker(workers.RewardModelWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
def create_transferqueue_client(self, controller_infos, storage_infos):
create_transferqueue_client(
client_id=f"worker_{self.rank}",
controller_infos=controller_infos,
storage_infos=storage_infos,
)

Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is unnecessary trailing whitespace after line 52 that should be removed.

Copilot uses AI. Check for mistakes.


# ================================= Async related workers =================================
class AsyncActorRolloutRefWorker(workers.AsyncActorRolloutRefWorker):
pass
18 changes: 12 additions & 6 deletions recipe/transfer_queue/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def add_actor_rollout_worker(self, config):
from verl.single_controller.ray import RayWorkerGroup

if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
from .fsdp_workers import (
ActorRolloutRefWorker,
AsyncActorRolloutRefWorker,
)

actor_rollout_cls = (
AsyncActorRolloutRefWorker
Expand All @@ -123,7 +126,10 @@ def add_actor_rollout_worker(self, config):
ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == "megatron":
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
from .megatron_workers import (
ActorRolloutRefWorker,
AsyncActorRolloutRefWorker,
)

actor_rollout_cls = (
AsyncActorRolloutRefWorker
Expand All @@ -146,7 +152,7 @@ def add_critic_worker(self, config):
if config.critic.strategy in {"fsdp", "fsdp2"}:
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
if use_legacy_worker_impl in ["auto", "enable"]:
from verl.workers.fsdp_workers import CriticWorker
from .fsdp_workers import CriticWorker
elif use_legacy_worker_impl == "disable":
from verl.workers.roles import CriticWorker

Expand All @@ -155,7 +161,7 @@ def add_critic_worker(self, config):
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")

elif config.critic.strategy == "megatron":
from verl.workers.megatron_workers import CriticWorker
from .megatron_workers import CriticWorker

else:
raise NotImplementedError
Expand Down Expand Up @@ -197,9 +203,9 @@ def add_reward_model_worker(self, config):
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
if use_legacy_worker_impl in ["auto", "enable"]:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import RewardModelWorker
from .fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
from .megatron_workers import RewardModelWorker
else:
raise NotImplementedError
elif use_legacy_worker_impl == "disable":
Expand Down
54 changes: 54 additions & 0 deletions recipe/transfer_queue/megatron_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""

import verl.workers.megatron_workers as workers
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.transferqueue_utils import create_transferqueue_client


class ActorRolloutRefWorker(workers.ActorRolloutRefWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
def create_transferqueue_client(self, controller_infos, storage_infos):
create_transferqueue_client(
client_id=f"worker_{self.rank}",
controller_infos=controller_infos,
storage_infos=storage_infos,
)


class AsyncActorRolloutRefWorker(workers.AsyncActorRolloutRefWorker):
pass


class CriticWorker(workers.CriticWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
def create_transferqueue_client(self, controller_infos, storage_infos):
create_transferqueue_client(
client_id=f"worker_{self.rank}",
controller_infos=controller_infos,
storage_infos=storage_infos,
)


class RewardModelWorker(workers.RewardModelWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
def create_transferqueue_client(self, controller_infos, storage_infos):
create_transferqueue_client(
client_id=f"worker_{self.rank}",
controller_infos=controller_infos,
storage_infos=storage_infos,
)
14 changes: 9 additions & 5 deletions recipe/transfer_queue/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@
)
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
from verl.utils.transferqueue_utils import (
create_transferqueue_client,
get_transferqueue_client,
)


@dataclass
Expand Down Expand Up @@ -421,12 +425,12 @@ def _initialize_data_system(self):

# 4. create client
# each client should be allocated to exactly one controller
self.data_system_client = AsyncTransferQueueClient(
create_transferqueue_client(
client_id="Trainer",
controller_infos=self.data_system_controller_infos[0],
controller_infos=self.data_system_controller_infos,
storage_infos=self.data_system_storage_unit_infos,
)

self.data_system_client = get_transferqueue_client()
return self.data_system_client

def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
Expand Down Expand Up @@ -847,9 +851,9 @@ def init_workers(self):
self.actor_rollout_wg = all_wg["actor_rollout"]
self.actor_rollout_wg.init_model()

# set transferqueue server info for each worker group
# set transferqueue server info for each worker
for _, wg in all_wg.items():
wg.set_transferqueue_server_info(self.data_system_controller_infos, self.data_system_storage_unit_infos)
wg.create_transferqueue_client(self.data_system_controller_infos, self.data_system_storage_unit_infos)

# create async rollout manager and request scheduler
self.async_rollout_mode = False
Expand Down
3 changes: 3 additions & 0 deletions verl/single_controller/base/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,13 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki
A decorator that wraps the original function with distributed execution
configuration.
"""
from verl.utils.transferqueue_utils import batchmeta_dataproto_pipe

_check_dispatch_mode(dispatch_mode=dispatch_mode)
_check_execute_mode(execute_mode=execute_mode)

def decorator(func):
func = batchmeta_dataproto_pipe()(func)
@wraps(func)
def inner(*args, **kwargs):
if materialize_futures:
Expand Down
14 changes: 0 additions & 14 deletions verl/single_controller/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,6 @@ def _query_collect_info(self, mesh_name: str):
"""
assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
return self.__collect_dp_rank[mesh_name]

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def set_transferqueue_server_info(self, controller_infos, storage_infos):
"""Set the transfer queue server information for the worker.

Args:
controller_infos (list):
List of controller server information.
storage_infos (list):
List of storage unit server information.
"""
from verl.utils.transferqueue_utils import set_transferqueue_server_info

set_transferqueue_server_info(controller_infos, storage_infos)

@classmethod
def env_keys(cls):
Expand Down
128 changes: 115 additions & 13 deletions verl/utils/transferqueue_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,126 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import inspect
from functools import wraps
from typing import Any

from verl.experimental.transfer_queue import ZMQServerInfo
import numpy as np
import torch
from tensordict import NonTensorData, NonTensorStack, TensorDict

_TRANSFER_QUEUE_CONTROLLER_INFOS = None
_TRANSFER_QUEUE_STORAGE_INFOS = None
from verl.experimental.transfer_queue import (
AsyncTransferQueueClient,
BatchMeta,
ZMQServerInfo,
)
from verl.protocol import DataProto

_TRANSFER_QUEUE_CLIENT = None

def set_transferqueue_server_info(controller_infos: dict[Any, ZMQServerInfo], storage_infos: dict[Any, ZMQServerInfo]):
global _TRANSFER_QUEUE_CONTROLLER_INFOS, _TRANSFER_QUEUE_STORAGE_INFOS
if _TRANSFER_QUEUE_CONTROLLER_INFOS is not None and _TRANSFER_QUEUE_STORAGE_INFOS is not None:
return
_TRANSFER_QUEUE_CONTROLLER_INFOS = controller_infos
_TRANSFER_QUEUE_STORAGE_INFOS = storage_infos

def create_transferqueue_client(
client_id: str,
controller_infos: dict[Any, ZMQServerInfo],
storage_infos: dict[Any, ZMQServerInfo],
) -> None:
global _TRANSFER_QUEUE_CLIENT
_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos)

def get_transferqueue_server_info():
assert _TRANSFER_QUEUE_CONTROLLER_INFOS is not None and _TRANSFER_QUEUE_STORAGE_INFOS is not None, (
"TransferQueue server infos have not been set yet."

def get_transferqueue_client() -> AsyncTransferQueueClient:
return _TRANSFER_QUEUE_CLIENT


def _find_batchmeta(*args, **kwargs):
for arg in args:
if isinstance(arg, BatchMeta):
return arg
for v in kwargs.values():
if isinstance(v, BatchMeta):
return v
return None


def _batchmeta_to_dataproto(batchmeta: BatchMeta):
tensordict = asyncio.run(_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta))

batch = {}
non_tensor_batch = {}
batch_size = None
for k, v in tensordict.items():
if isinstance(v, torch.Tensor):
batch[k] = v
if batch_size is None:
batch_size = v.shape[:1]
elif isinstance(v, NonTensorStack):
non_tensor_batch[k] = np.array([elem.data for elem in v], dtype=object)
else:
non_tensor_batch[k] = v
return DataProto(
batch=TensorDict(batch, batch_size=batch_size),
non_tensor_batch=non_tensor_batch,
meta_info=batchmeta.extra_info.copy(),
)
return _TRANSFER_QUEUE_CONTROLLER_INFOS, _TRANSFER_QUEUE_STORAGE_INFOS


def _dataproto_to_tensordict(data: DataProto):
result_dict = {}

if data.batch is not None:
result_dict.update(data.batch)

batch_size = data.batch.batch_size if data.batch is not None else (len(list(data.non_tensor_batch.values())[0]),)
Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is trailing whitespace at the end of line 85 that should be removed.

Copilot uses AI. Check for mistakes.

Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will raise an IndexError if data.non_tensor_batch is empty. Should check if data.non_tensor_batch exists and has values before accessing [0].

Suggested change
batch_size = data.batch.batch_size if data.batch is not None else (len(list(data.non_tensor_batch.values())[0]),)
if data.batch is not None:
batch_size = data.batch.batch_size
elif data.non_tensor_batch is not None and len(data.non_tensor_batch) > 0:
first_value = next(iter(data.non_tensor_batch.values()))
if hasattr(first_value, '__len__') and len(first_value) > 0:
batch_size = (len(first_value),)
else:
batch_size = (0,)
else:
batch_size = (0,)

Copilot uses AI. Check for mistakes.

if data.non_tensor_batch is not None:
for k, v in data.non_tensor_batch.items():
result_dict[k] = NonTensorData(data=v, batch_size=batch_size)

if data.meta_info == {} or data.meta_info is None:
result_dict["meta_info"] = NonTensorData(data=[None] * batch_size[0], batch_size=batch_size)
else:
result_dict["meta_info"] = NonTensorData(data=[data.meta_info] * batch_size[0], batch_size=batch_size)
return TensorDict(result_dict, batch_size=batch_size)


def _update_batchmeta_with_output(output: DataProto, batchmeta: BatchMeta):
tensordict = _dataproto_to_tensordict(output)
batchmeta.add_fields(tensordict)
asyncio.run(_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta))


async def _async_update_batchmeta_with_output(output, batchmeta: BatchMeta):
tensordict = _dataproto_to_tensordict(output)
batchmeta.add_fields(tensordict)
await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)


def batchmeta_dataproto_pipe():
def decorator(func):
@wraps(func)
def inner(*args, **kwargs):
batchmeta = _find_batchmeta(*args, **kwargs)
if batchmeta is None:
return func(*args, **kwargs)
else:
args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}
output = func(*args, **kwargs)
_update_batchmeta_with_output(output, batchmeta)
return batchmeta

@wraps(func)
async def async_inner(*args, **kwargs):
batchmeta = _find_batchmeta(*args, **kwargs)
if batchmeta is None:
return await func(*args, **kwargs)
else:
args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}
output = await func(*args, **kwargs)
await _async_update_batchmeta_with_output(output, batchmeta)
return batchmeta

wrapper = async_inner if inspect.iscoroutinefunction(func) else inner
return wrapper
return decorator
Loading