-
Notifications
You must be signed in to change notification settings - Fork 5
feat: Support conversion between dataproto and batchmeta #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
The main entry point to run the PPO algorithm | ||
""" | ||
|
||
import verl.workers.fsdp_workers as workers | ||
from verl.single_controller.base.decorator import Dispatch, register | ||
from verl.utils.transferqueue_utils import create_transferqueue_client | ||
|
||
|
||
class ActorRolloutRefWorker(workers.ActorRolloutRefWorker): | ||
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) | ||
def create_transferqueue_client(self, controller_infos, storage_infos): | ||
create_transferqueue_client( | ||
client_id=f"worker_{self.rank}", | ||
controller_infos=controller_infos, | ||
storage_infos=storage_infos, | ||
) | ||
|
||
|
||
class CriticWorker(workers.CriticWorker): | ||
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) | ||
def create_transferqueue_client(self, controller_infos, storage_infos): | ||
create_transferqueue_client( | ||
client_id=f"worker_{self.rank}", | ||
controller_infos=controller_infos, | ||
storage_infos=storage_infos, | ||
) | ||
|
||
|
||
# TODO(sgm): we may need to extract it to dp_reward_model.py | ||
class RewardModelWorker(workers.RewardModelWorker): | ||
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) | ||
def create_transferqueue_client(self, controller_infos, storage_infos): | ||
create_transferqueue_client( | ||
client_id=f"worker_{self.rank}", | ||
controller_infos=controller_infos, | ||
storage_infos=storage_infos, | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is unnecessary trailing whitespace after line 52 that should be removed. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
|
||
# ================================= Async related workers ================================= | ||
class AsyncActorRolloutRefWorker(workers.AsyncActorRolloutRefWorker): | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
The main entry point to run the PPO algorithm | ||
""" | ||
|
||
import verl.workers.megatron_workers as workers | ||
from verl.single_controller.base.decorator import Dispatch, register | ||
from verl.utils.transferqueue_utils import create_transferqueue_client | ||
|
||
|
||
class ActorRolloutRefWorker(workers.ActorRolloutRefWorker): | ||
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) | ||
def create_transferqueue_client(self, controller_infos, storage_infos): | ||
create_transferqueue_client( | ||
client_id=f"worker_{self.rank}", | ||
controller_infos=controller_infos, | ||
storage_infos=storage_infos, | ||
) | ||
|
||
|
||
class AsyncActorRolloutRefWorker(workers.AsyncActorRolloutRefWorker): | ||
pass | ||
|
||
|
||
class CriticWorker(workers.CriticWorker): | ||
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) | ||
def create_transferqueue_client(self, controller_infos, storage_infos): | ||
create_transferqueue_client( | ||
client_id=f"worker_{self.rank}", | ||
controller_infos=controller_infos, | ||
storage_infos=storage_infos, | ||
) | ||
|
||
|
||
class RewardModelWorker(workers.RewardModelWorker): | ||
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) | ||
def create_transferqueue_client(self, controller_infos, storage_infos): | ||
create_transferqueue_client( | ||
client_id=f"worker_{self.rank}", | ||
controller_infos=controller_infos, | ||
storage_infos=storage_infos, | ||
) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,24 +12,126 @@ | |||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import asyncio | ||||||||||||||||||||||||
import inspect | ||||||||||||||||||||||||
from functools import wraps | ||||||||||||||||||||||||
from typing import Any | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from verl.experimental.transfer_queue import ZMQServerInfo | ||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||
from tensordict import NonTensorData, NonTensorStack, TensorDict | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
_TRANSFER_QUEUE_CONTROLLER_INFOS = None | ||||||||||||||||||||||||
_TRANSFER_QUEUE_STORAGE_INFOS = None | ||||||||||||||||||||||||
from verl.experimental.transfer_queue import ( | ||||||||||||||||||||||||
AsyncTransferQueueClient, | ||||||||||||||||||||||||
BatchMeta, | ||||||||||||||||||||||||
ZMQServerInfo, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
from verl.protocol import DataProto | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
_TRANSFER_QUEUE_CLIENT = None | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def set_transferqueue_server_info(controller_infos: dict[Any, ZMQServerInfo], storage_infos: dict[Any, ZMQServerInfo]): | ||||||||||||||||||||||||
global _TRANSFER_QUEUE_CONTROLLER_INFOS, _TRANSFER_QUEUE_STORAGE_INFOS | ||||||||||||||||||||||||
if _TRANSFER_QUEUE_CONTROLLER_INFOS is not None and _TRANSFER_QUEUE_STORAGE_INFOS is not None: | ||||||||||||||||||||||||
return | ||||||||||||||||||||||||
_TRANSFER_QUEUE_CONTROLLER_INFOS = controller_infos | ||||||||||||||||||||||||
_TRANSFER_QUEUE_STORAGE_INFOS = storage_infos | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def create_transferqueue_client( | ||||||||||||||||||||||||
client_id: str, | ||||||||||||||||||||||||
controller_infos: dict[Any, ZMQServerInfo], | ||||||||||||||||||||||||
storage_infos: dict[Any, ZMQServerInfo], | ||||||||||||||||||||||||
) -> None: | ||||||||||||||||||||||||
global _TRANSFER_QUEUE_CLIENT | ||||||||||||||||||||||||
_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def get_transferqueue_server_info(): | ||||||||||||||||||||||||
assert _TRANSFER_QUEUE_CONTROLLER_INFOS is not None and _TRANSFER_QUEUE_STORAGE_INFOS is not None, ( | ||||||||||||||||||||||||
"TransferQueue server infos have not been set yet." | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def get_transferqueue_client() -> AsyncTransferQueueClient: | ||||||||||||||||||||||||
return _TRANSFER_QUEUE_CLIENT | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _find_batchmeta(*args, **kwargs): | ||||||||||||||||||||||||
for arg in args: | ||||||||||||||||||||||||
if isinstance(arg, BatchMeta): | ||||||||||||||||||||||||
return arg | ||||||||||||||||||||||||
for v in kwargs.values(): | ||||||||||||||||||||||||
if isinstance(v, BatchMeta): | ||||||||||||||||||||||||
return v | ||||||||||||||||||||||||
return None | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _batchmeta_to_dataproto(batchmeta: BatchMeta): | ||||||||||||||||||||||||
tensordict = asyncio.run(_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
batch = {} | ||||||||||||||||||||||||
non_tensor_batch = {} | ||||||||||||||||||||||||
batch_size = None | ||||||||||||||||||||||||
for k, v in tensordict.items(): | ||||||||||||||||||||||||
if isinstance(v, torch.Tensor): | ||||||||||||||||||||||||
batch[k] = v | ||||||||||||||||||||||||
if batch_size is None: | ||||||||||||||||||||||||
batch_size = v.shape[:1] | ||||||||||||||||||||||||
elif isinstance(v, NonTensorStack): | ||||||||||||||||||||||||
non_tensor_batch[k] = np.array([elem.data for elem in v], dtype=object) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
non_tensor_batch[k] = v | ||||||||||||||||||||||||
return DataProto( | ||||||||||||||||||||||||
batch=TensorDict(batch, batch_size=batch_size), | ||||||||||||||||||||||||
non_tensor_batch=non_tensor_batch, | ||||||||||||||||||||||||
meta_info=batchmeta.extra_info.copy(), | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
return _TRANSFER_QUEUE_CONTROLLER_INFOS, _TRANSFER_QUEUE_STORAGE_INFOS | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _dataproto_to_tensordict(data: DataProto): | ||||||||||||||||||||||||
result_dict = {} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if data.batch is not None: | ||||||||||||||||||||||||
result_dict.update(data.batch) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
batch_size = data.batch.batch_size if data.batch is not None else (len(list(data.non_tensor_batch.values())[0]),) | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is trailing whitespace at the end of line 85 that should be removed. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line will raise an IndexError if
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||
if data.non_tensor_batch is not None: | ||||||||||||||||||||||||
for k, v in data.non_tensor_batch.items(): | ||||||||||||||||||||||||
result_dict[k] = NonTensorData(data=v, batch_size=batch_size) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if data.meta_info == {} or data.meta_info is None: | ||||||||||||||||||||||||
result_dict["meta_info"] = NonTensorData(data=[None] * batch_size[0], batch_size=batch_size) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
result_dict["meta_info"] = NonTensorData(data=[data.meta_info] * batch_size[0], batch_size=batch_size) | ||||||||||||||||||||||||
return TensorDict(result_dict, batch_size=batch_size) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _update_batchmeta_with_output(output: DataProto, batchmeta: BatchMeta): | ||||||||||||||||||||||||
tensordict = _dataproto_to_tensordict(output) | ||||||||||||||||||||||||
batchmeta.add_fields(tensordict) | ||||||||||||||||||||||||
asyncio.run(_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
async def _async_update_batchmeta_with_output(output, batchmeta: BatchMeta): | ||||||||||||||||||||||||
tensordict = _dataproto_to_tensordict(output) | ||||||||||||||||||||||||
batchmeta.add_fields(tensordict) | ||||||||||||||||||||||||
await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def batchmeta_dataproto_pipe(): | ||||||||||||||||||||||||
def decorator(func): | ||||||||||||||||||||||||
@wraps(func) | ||||||||||||||||||||||||
def inner(*args, **kwargs): | ||||||||||||||||||||||||
batchmeta = _find_batchmeta(*args, **kwargs) | ||||||||||||||||||||||||
if batchmeta is None: | ||||||||||||||||||||||||
return func(*args, **kwargs) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] | ||||||||||||||||||||||||
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} | ||||||||||||||||||||||||
output = func(*args, **kwargs) | ||||||||||||||||||||||||
_update_batchmeta_with_output(output, batchmeta) | ||||||||||||||||||||||||
return batchmeta | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@wraps(func) | ||||||||||||||||||||||||
async def async_inner(*args, **kwargs): | ||||||||||||||||||||||||
batchmeta = _find_batchmeta(*args, **kwargs) | ||||||||||||||||||||||||
if batchmeta is None: | ||||||||||||||||||||||||
return await func(*args, **kwargs) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] | ||||||||||||||||||||||||
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} | ||||||||||||||||||||||||
output = await func(*args, **kwargs) | ||||||||||||||||||||||||
await _async_update_batchmeta_with_output(output, batchmeta) | ||||||||||||||||||||||||
return batchmeta | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
wrapper = async_inner if inspect.iscoroutinefunction(func) else inner | ||||||||||||||||||||||||
return wrapper | ||||||||||||||||||||||||
return decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is unnecessary trailing whitespace after line 32 that should be removed.
Copilot uses AI. Check for mistakes.