Skip to content

Commit 14ad39e

Browse files
committed
update
1 parent 9988fec commit 14ad39e

File tree

4 files changed

+60
-24
lines changed

4 files changed

+60
-24
lines changed

recipe/transfer_queue/main_ppo.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ def add_actor_rollout_worker(self, config):
113113
from verl.single_controller.ray import RayWorkerGroup
114114

115115
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
116-
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
116+
from .fsdp_workers import (
117+
ActorRolloutRefWorker,
118+
AsyncActorRolloutRefWorker,
119+
)
117120

118121
actor_rollout_cls = (
119122
AsyncActorRolloutRefWorker
@@ -123,7 +126,10 @@ def add_actor_rollout_worker(self, config):
123126
ray_worker_group_cls = RayWorkerGroup
124127

125128
elif config.actor_rollout_ref.actor.strategy == "megatron":
126-
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
129+
from .megatron_workers import (
130+
ActorRolloutRefWorker,
131+
AsyncActorRolloutRefWorker,
132+
)
127133

128134
actor_rollout_cls = (
129135
AsyncActorRolloutRefWorker
@@ -197,9 +203,9 @@ def add_reward_model_worker(self, config):
197203
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
198204
if use_legacy_worker_impl in ["auto", "enable"]:
199205
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
200-
from verl.workers.fsdp_workers import RewardModelWorker
206+
from .fsdp_workers import RewardModelWorker
201207
elif config.reward_model.strategy == "megatron":
202-
from verl.workers.megatron_workers import RewardModelWorker
208+
from .megatron_workers import RewardModelWorker
203209
else:
204210
raise NotImplementedError
205211
elif use_legacy_worker_impl == "disable":

recipe/transfer_queue/ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def init_workers(self):
847847
self.actor_rollout_wg = all_wg["actor_rollout"]
848848
self.actor_rollout_wg.init_model()
849849

850-
# set transferqueue server info for each worker group
850+
# set transferqueue server info for each worker
851851
for _, wg in all_wg.items():
852852
wg.set_transferqueue_server_info(self.data_system_controller_infos, self.data_system_storage_unit_infos)
853853

verl/single_controller/base/worker.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,20 +129,6 @@ def _query_collect_info(self, mesh_name: str):
129129
"""
130130
assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
131131
return self.__collect_dp_rank[mesh_name]
132-
133-
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
134-
def set_transferqueue_server_info(self, controller_infos, storage_infos):
135-
"""Set the transfer queue server information for the worker.
136-
137-
Args:
138-
controller_infos (list):
139-
List of controller server information.
140-
storage_infos (list):
141-
List of storage unit server information.
142-
"""
143-
from verl.utils.transferqueue_utils import set_transferqueue_server_info
144-
145-
set_transferqueue_server_info(controller_infos, storage_infos)
146132

147133
@classmethod
148134
def env_keys(cls):

verl/utils/transferqueue_utils.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
import inspect
1617
from functools import wraps
1718
from typing import Any
1819

20+
import numpy as np
21+
import torch
22+
from tensordict import NonTensorData, NonTensorStack, TensorDict
1923
from transfer_queue import BatchMeta
2024

2125
from verl.experimental.transfer_queue import ZMQServerInfo
26+
from verl.protocol import DataProto
2227

2328
_TRANSFER_QUEUE_CONTROLLER_INFOS = None
2429
_TRANSFER_QUEUE_STORAGE_INFOS = None
@@ -50,15 +55,55 @@ def _find_batchmeta(*args, **kwargs):
5055

5156

5257
def _batchmeta_to_dataproto(batchmeta: BatchMeta):
53-
...
58+
tensordict = asyncio.run(client.async_get_data(batchmeta))
59+
60+
batch = {}
61+
non_tensor_batch = {}
62+
batch_size = None
63+
for k, v in tensordict.items():
64+
if isinstance(v, torch.Tensor):
65+
batch[k] = v
66+
if batch_size is None:
67+
batch_size = v.shape[:1]
68+
elif isinstance(v, NonTensorStack):
69+
non_tensor_batch[k] = np.array([elem.data for elem in v], dtype=object)
70+
else:
71+
non_tensor_batch[k] = v
72+
return DataProto(
73+
batch=TensorDict(batch, batch_size=batch_size),
74+
non_tensor_batch=non_tensor_batch,
75+
meta_info=batchmeta.extra_info.copy(),
76+
)
77+
78+
79+
def _dataproto_to_tensordict(data: DataProto):
80+
result_dict = {}
5481

82+
if data.batch is not None:
83+
result_dict.update(data.batch)
5584

56-
def _update_batchmeta_with_output(output, batchmeta: BatchMeta):
57-
...
85+
batch_size = data.batch.batch_size if data.batch is not None else (len(list(data.non_tensor_batch.values())[0]),)
86+
if data.non_tensor_batch is not None:
87+
for k, v in data.non_tensor_batch.items():
88+
result_dict[k] = NonTensorData(data=v, batch_size=batch_size)
89+
90+
if data.meta_info == {} or data.meta_info is None:
91+
result_dict["meta_info"] = NonTensorData(data=[None] * batch_size[0], batch_size=batch_size)
92+
else:
93+
result_dict["meta_info"] = NonTensorData(data=[data.meta_info] * batch_size[0], batch_size=batch_size)
94+
return TensorDict(result_dict, batch_size=batch_size)
95+
96+
97+
def _update_batchmeta_with_output(output: DataProto, batchmeta: BatchMeta):
98+
tensordict = _dataproto_to_tensordict(output)
99+
batchmeta.add_fields(tensordict)
100+
asyncio.run(client.async_put(data=tensordict, metadata=batchmeta))
58101

59102

60103
async def _async_update_batchmeta_with_output(output, batchmeta: BatchMeta):
61-
...
104+
tensordict = _dataproto_to_tensordict(output)
105+
batchmeta.add_fields(tensordict)
106+
await client.async_put(data=tensordict, metadata=batchmeta)
62107

63108

64109
def batchmeta_dataproto_pipe():
@@ -90,4 +135,3 @@ async def async_inner(*args, **kwargs):
90135
wrapper = async_inner if inspect.iscoroutinefunction(func) else inner
91136
return wrapper
92137
return decorator
93-

0 commit comments

Comments
 (0)