Skip to content

Commit 65affa2

Browse files
committed
feat: Support conversion between dataproto and batchmeta
1 parent eb31070 commit 65affa2

File tree

7 files changed

+248
-37
lines changed

7 files changed

+248
-37
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
The main entry point to run the PPO algorithm
16+
"""
17+
18+
import verl.workers.fsdp_workers as workers
19+
from verl.single_controller.base.decorator import Dispatch, register
20+
from verl.utils.transferqueue_utils import create_transferqueue_client
21+
22+
23+
class ActorRolloutRefWorker(workers.ActorRolloutRefWorker):
24+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
25+
def create_transferqueue_client(self, controller_infos, storage_infos):
26+
create_transferqueue_client(
27+
client_id=f"worker_{self.rank}",
28+
controller_infos=controller_infos,
29+
storage_infos=storage_infos,
30+
)
31+
32+
33+
class CriticWorker(workers.CriticWorker):
34+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
35+
def create_transferqueue_client(self, controller_infos, storage_infos):
36+
create_transferqueue_client(
37+
client_id=f"worker_{self.rank}",
38+
controller_infos=controller_infos,
39+
storage_infos=storage_infos,
40+
)
41+
42+
43+
# TODO(sgm): we may need to extract it to dp_reward_model.py
44+
class RewardModelWorker(workers.RewardModelWorker):
45+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
46+
def create_transferqueue_client(self, controller_infos, storage_infos):
47+
create_transferqueue_client(
48+
client_id=f"worker_{self.rank}",
49+
controller_infos=controller_infos,
50+
storage_infos=storage_infos,
51+
)
52+
53+
54+
# ================================= Async related workers =================================
55+
class AsyncActorRolloutRefWorker(workers.AsyncActorRolloutRefWorker):
56+
pass

recipe/transfer_queue/main_ppo.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ def add_actor_rollout_worker(self, config):
113113
from verl.single_controller.ray import RayWorkerGroup
114114

115115
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
116-
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
116+
from .fsdp_workers import (
117+
ActorRolloutRefWorker,
118+
AsyncActorRolloutRefWorker,
119+
)
117120

118121
actor_rollout_cls = (
119122
AsyncActorRolloutRefWorker
@@ -123,7 +126,10 @@ def add_actor_rollout_worker(self, config):
123126
ray_worker_group_cls = RayWorkerGroup
124127

125128
elif config.actor_rollout_ref.actor.strategy == "megatron":
126-
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
129+
from .megatron_workers import (
130+
ActorRolloutRefWorker,
131+
AsyncActorRolloutRefWorker,
132+
)
127133

128134
actor_rollout_cls = (
129135
AsyncActorRolloutRefWorker
@@ -146,7 +152,7 @@ def add_critic_worker(self, config):
146152
if config.critic.strategy in {"fsdp", "fsdp2"}:
147153
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
148154
if use_legacy_worker_impl in ["auto", "enable"]:
149-
from verl.workers.fsdp_workers import CriticWorker
155+
from .fsdp_workers import CriticWorker
150156
elif use_legacy_worker_impl == "disable":
151157
from verl.workers.roles import CriticWorker
152158

@@ -155,7 +161,7 @@ def add_critic_worker(self, config):
155161
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
156162

157163
elif config.critic.strategy == "megatron":
158-
from verl.workers.megatron_workers import CriticWorker
164+
from .megatron_workers import CriticWorker
159165

160166
else:
161167
raise NotImplementedError
@@ -197,9 +203,9 @@ def add_reward_model_worker(self, config):
197203
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
198204
if use_legacy_worker_impl in ["auto", "enable"]:
199205
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
200-
from verl.workers.fsdp_workers import RewardModelWorker
206+
from .fsdp_workers import RewardModelWorker
201207
elif config.reward_model.strategy == "megatron":
202-
from verl.workers.megatron_workers import RewardModelWorker
208+
from .megatron_workers import RewardModelWorker
203209
else:
204210
raise NotImplementedError
205211
elif use_legacy_worker_impl == "disable":
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
The main entry point to run the PPO algorithm
16+
"""
17+
18+
import verl.workers.megatron_workers as workers
19+
from verl.single_controller.base.decorator import Dispatch, register
20+
from verl.utils.transferqueue_utils import create_transferqueue_client
21+
22+
23+
class ActorRolloutRefWorker(workers.ActorRolloutRefWorker):
24+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
25+
def create_transferqueue_client(self, controller_infos, storage_infos):
26+
create_transferqueue_client(
27+
client_id=f"worker_{self.rank}",
28+
controller_infos=controller_infos,
29+
storage_infos=storage_infos,
30+
)
31+
32+
33+
class AsyncActorRolloutRefWorker(workers.AsyncActorRolloutRefWorker):
34+
pass
35+
36+
37+
class CriticWorker(workers.CriticWorker):
38+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
39+
def create_transferqueue_client(self, controller_infos, storage_infos):
40+
create_transferqueue_client(
41+
client_id=f"worker_{self.rank}",
42+
controller_infos=controller_infos,
43+
storage_infos=storage_infos,
44+
)
45+
46+
47+
class RewardModelWorker(workers.RewardModelWorker):
48+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
49+
def create_transferqueue_client(self, controller_infos, storage_infos):
50+
create_transferqueue_client(
51+
client_id=f"worker_{self.rank}",
52+
controller_infos=controller_infos,
53+
storage_infos=storage_infos,
54+
)

recipe/transfer_queue/ray_trainer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@
8888
)
8989
from verl.utils.torch_functional import masked_mean
9090
from verl.utils.tracking import ValidationGenerationsLogger
91+
from verl.utils.transferqueue_utils import (
92+
create_transferqueue_client,
93+
get_transferqueue_client,
94+
)
9195

9296

9397
@dataclass
@@ -421,12 +425,12 @@ def _initialize_data_system(self):
421425

422426
# 4. create client
423427
# each client should be allocated to exactly one controller
424-
self.data_system_client = AsyncTransferQueueClient(
428+
create_transferqueue_client(
425429
client_id="Trainer",
426-
controller_infos=self.data_system_controller_infos[0],
430+
controller_infos=self.data_system_controller_infos,
427431
storage_infos=self.data_system_storage_unit_infos,
428432
)
429-
433+
self.data_system_client = get_transferqueue_client()
430434
return self.data_system_client
431435

432436
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
@@ -847,9 +851,9 @@ def init_workers(self):
847851
self.actor_rollout_wg = all_wg["actor_rollout"]
848852
self.actor_rollout_wg.init_model()
849853

850-
# set transferqueue server info for each worker group
854+
# set transferqueue server info for each worker
851855
for _, wg in all_wg.items():
852-
wg.set_transferqueue_server_info(self.data_system_controller_infos, self.data_system_storage_unit_infos)
856+
wg.create_transferqueue_client(self.data_system_controller_infos, self.data_system_storage_unit_infos)
853857

854858
# create async rollout manager and request scheduler
855859
self.async_rollout_mode = False

verl/single_controller/base/decorator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,13 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki
429429
A decorator that wraps the original function with distributed execution
430430
configuration.
431431
"""
432+
from verl.utils.transferqueue_utils import batchmeta_dataproto_pipe
433+
432434
_check_dispatch_mode(dispatch_mode=dispatch_mode)
433435
_check_execute_mode(execute_mode=execute_mode)
434436

435437
def decorator(func):
438+
func = batchmeta_dataproto_pipe()(func)
436439
@wraps(func)
437440
def inner(*args, **kwargs):
438441
if materialize_futures:

verl/single_controller/base/worker.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,20 +129,6 @@ def _query_collect_info(self, mesh_name: str):
129129
"""
130130
assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
131131
return self.__collect_dp_rank[mesh_name]
132-
133-
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
134-
def set_transferqueue_server_info(self, controller_infos, storage_infos):
135-
"""Set the transfer queue server information for the worker.
136-
137-
Args:
138-
controller_infos (list):
139-
List of controller server information.
140-
storage_infos (list):
141-
List of storage unit server information.
142-
"""
143-
from verl.utils.transferqueue_utils import set_transferqueue_server_info
144-
145-
set_transferqueue_server_info(controller_infos, storage_infos)
146132

147133
@classmethod
148134
def env_keys(cls):

verl/utils/transferqueue_utils.py

Lines changed: 114 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,126 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
16+
import inspect
17+
from functools import wraps
1518
from typing import Any
1619

20+
import numpy as np
21+
import torch
22+
from tensordict import NonTensorData, NonTensorStack, TensorDict
23+
from transfer_queue import AsyncTransferQueueClient, BatchMeta
24+
1725
from verl.experimental.transfer_queue import ZMQServerInfo
26+
from verl.protocol import DataProto
27+
28+
# _TRANSFER_QUEUE_CONTROLLER_INFOS = None
29+
# _TRANSFER_QUEUE_STORAGE_INFOS = None
30+
_TRANSFER_QUEUE_CLIENT = None
31+
32+
33+
def create_transferqueue_client(
34+
client_id: str,
35+
controller_infos: dict[Any, ZMQServerInfo],
36+
storage_infos: dict[Any, ZMQServerInfo],
37+
) -> None:
38+
global _TRANSFER_QUEUE_CLIENT
39+
assert _TRANSFER_QUEUE_CLIENT is None, "TransferQueue client has already been created."
40+
_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos)
41+
1842

19-
_TRANSFER_QUEUE_CONTROLLER_INFOS = None
20-
_TRANSFER_QUEUE_STORAGE_INFOS = None
43+
def get_transferqueue_client() -> AsyncTransferQueueClient:
44+
return _TRANSFER_QUEUE_CLIENT
2145

2246

23-
def set_transferqueue_server_info(controller_infos: dict[Any, ZMQServerInfo], storage_infos: dict[Any, ZMQServerInfo]):
24-
global _TRANSFER_QUEUE_CONTROLLER_INFOS, _TRANSFER_QUEUE_STORAGE_INFOS
25-
if _TRANSFER_QUEUE_CONTROLLER_INFOS is not None and _TRANSFER_QUEUE_STORAGE_INFOS is not None:
26-
return
27-
_TRANSFER_QUEUE_CONTROLLER_INFOS = controller_infos
28-
_TRANSFER_QUEUE_STORAGE_INFOS = storage_infos
47+
def _find_batchmeta(*args, **kwargs):
48+
for arg in args:
49+
if isinstance(arg, BatchMeta):
50+
return arg
51+
for v in kwargs.values():
52+
if isinstance(v, BatchMeta):
53+
return v
54+
return None
2955

3056

31-
def get_transferqueue_server_info():
32-
assert _TRANSFER_QUEUE_CONTROLLER_INFOS is not None and _TRANSFER_QUEUE_STORAGE_INFOS is not None, (
33-
"TransferQueue server infos have not been set yet."
57+
def _batchmeta_to_dataproto(batchmeta: BatchMeta):
58+
tensordict = asyncio.run(_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta))
59+
60+
batch = {}
61+
non_tensor_batch = {}
62+
batch_size = None
63+
for k, v in tensordict.items():
64+
if isinstance(v, torch.Tensor):
65+
batch[k] = v
66+
if batch_size is None:
67+
batch_size = v.shape[:1]
68+
elif isinstance(v, NonTensorStack):
69+
non_tensor_batch[k] = np.array([elem.data for elem in v], dtype=object)
70+
else:
71+
non_tensor_batch[k] = v
72+
return DataProto(
73+
batch=TensorDict(batch, batch_size=batch_size),
74+
non_tensor_batch=non_tensor_batch,
75+
meta_info=batchmeta.extra_info.copy(),
3476
)
35-
return _TRANSFER_QUEUE_CONTROLLER_INFOS, _TRANSFER_QUEUE_STORAGE_INFOS
77+
78+
79+
def _dataproto_to_tensordict(data: DataProto):
80+
result_dict = {}
81+
82+
if data.batch is not None:
83+
result_dict.update(data.batch)
84+
85+
batch_size = data.batch.batch_size if data.batch is not None else (len(list(data.non_tensor_batch.values())[0]),)
86+
if data.non_tensor_batch is not None:
87+
for k, v in data.non_tensor_batch.items():
88+
result_dict[k] = NonTensorData(data=v, batch_size=batch_size)
89+
90+
if data.meta_info == {} or data.meta_info is None:
91+
result_dict["meta_info"] = NonTensorData(data=[None] * batch_size[0], batch_size=batch_size)
92+
else:
93+
result_dict["meta_info"] = NonTensorData(data=[data.meta_info] * batch_size[0], batch_size=batch_size)
94+
return TensorDict(result_dict, batch_size=batch_size)
95+
96+
97+
def _update_batchmeta_with_output(output: DataProto, batchmeta: BatchMeta):
98+
tensordict = _dataproto_to_tensordict(output)
99+
batchmeta.add_fields(tensordict)
100+
asyncio.run(_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta))
101+
102+
103+
async def _async_update_batchmeta_with_output(output, batchmeta: BatchMeta):
104+
tensordict = _dataproto_to_tensordict(output)
105+
batchmeta.add_fields(tensordict)
106+
await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)
107+
108+
109+
def batchmeta_dataproto_pipe():
110+
def decorator(func):
111+
@wraps(func)
112+
def inner(*args, **kwargs):
113+
batchmeta = _find_batchmeta(*args, **kwargs)
114+
if batchmeta is None:
115+
return func(*args, **kwargs)
116+
else:
117+
args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
118+
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}
119+
output = func(*args, **kwargs)
120+
_update_batchmeta_with_output(output, batchmeta)
121+
return batchmeta
122+
123+
@wraps(func)
124+
async def async_inner(*args, **kwargs):
125+
batchmeta = _find_batchmeta(*args, **kwargs)
126+
if batchmeta is None:
127+
return await func(*args, **kwargs)
128+
else:
129+
args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
130+
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}
131+
output = await func(*args, **kwargs)
132+
await _async_update_batchmeta_with_output(output, batchmeta)
133+
return batchmeta
134+
135+
wrapper = async_inner if inspect.iscoroutinefunction(func) else inner
136+
return wrapper
137+
return decorator

0 commit comments

Comments
 (0)