|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import asyncio |
15 | 16 | import inspect
|
16 | 17 | from functools import wraps
|
17 | 18 | from typing import Any
|
18 | 19 |
|
| 20 | +import numpy as np |
| 21 | +import torch |
| 22 | +from tensordict import NonTensorData, NonTensorStack, TensorDict |
19 | 23 | from transfer_queue import BatchMeta
|
20 | 24 |
|
21 | 25 | from verl.experimental.transfer_queue import ZMQServerInfo
|
| 26 | +from verl.protocol import DataProto |
22 | 27 |
|
23 | 28 | _TRANSFER_QUEUE_CONTROLLER_INFOS = None
|
24 | 29 | _TRANSFER_QUEUE_STORAGE_INFOS = None
|
@@ -50,15 +55,55 @@ def _find_batchmeta(*args, **kwargs):
|
50 | 55 |
|
51 | 56 |
|
52 | 57 | def _batchmeta_to_dataproto(batchmeta: BatchMeta):
|
53 |
| - ... |
| 58 | + tensordict = asyncio.run(client.async_get_data(batchmeta)) |
| 59 | + |
| 60 | + batch = {} |
| 61 | + non_tensor_batch = {} |
| 62 | + batch_size = None |
| 63 | + for k, v in tensordict.items(): |
| 64 | + if isinstance(v, torch.Tensor): |
| 65 | + batch[k] = v |
| 66 | + if batch_size is None: |
| 67 | + batch_size = v.shape[:1] |
| 68 | + elif isinstance(v, NonTensorStack): |
| 69 | + non_tensor_batch[k] = np.array([elem.data for elem in v], dtype=object) |
| 70 | + else: |
| 71 | + non_tensor_batch[k] = v |
| 72 | + return DataProto( |
| 73 | + batch=TensorDict(batch, batch_size=batch_size), |
| 74 | + non_tensor_batch=non_tensor_batch, |
| 75 | + meta_info=batchmeta.extra_info.copy(), |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +def _dataproto_to_tensordict(data: DataProto): |
| 80 | + result_dict = {} |
54 | 81 |
|
| 82 | + if data.batch is not None: |
| 83 | + result_dict.update(data.batch) |
55 | 84 |
|
56 |
| -def _update_batchmeta_with_output(output, batchmeta: BatchMeta): |
57 |
| - ... |
| 85 | + batch_size = data.batch.batch_size if data.batch is not None else (len(list(data.non_tensor_batch.values())[0]),) |
| 86 | + if data.non_tensor_batch is not None: |
| 87 | + for k, v in data.non_tensor_batch.items(): |
| 88 | + result_dict[k] = NonTensorData(data=v, batch_size=batch_size) |
| 89 | + |
| 90 | + if data.meta_info == {} or data.meta_info is None: |
| 91 | + result_dict["meta_info"] = NonTensorData(data=[None] * batch_size[0], batch_size=batch_size) |
| 92 | + else: |
| 93 | + result_dict["meta_info"] = NonTensorData(data=[data.meta_info] * batch_size[0], batch_size=batch_size) |
| 94 | + return TensorDict(result_dict, batch_size=batch_size) |
| 95 | + |
| 96 | + |
| 97 | +def _update_batchmeta_with_output(output: DataProto, batchmeta: BatchMeta): |
| 98 | + tensordict = _dataproto_to_tensordict(output) |
| 99 | + batchmeta.add_fields(tensordict) |
| 100 | + asyncio.run(client.async_put(data=tensordict, metadata=batchmeta)) |
58 | 101 |
|
59 | 102 |
|
60 | 103 | async def _async_update_batchmeta_with_output(output, batchmeta: BatchMeta):
|
61 |
| - ... |
| 104 | + tensordict = _dataproto_to_tensordict(output) |
| 105 | + batchmeta.add_fields(tensordict) |
| 106 | + await client.async_put(data=tensordict, metadata=batchmeta) |
62 | 107 |
|
63 | 108 |
|
64 | 109 | def batchmeta_dataproto_pipe():
|
@@ -90,4 +135,3 @@ async def async_inner(*args, **kwargs):
|
90 | 135 | wrapper = async_inner if inspect.iscoroutinefunction(func) else inner
|
91 | 136 | return wrapper
|
92 | 137 | return decorator
|
93 |
| - |
|
0 commit comments