Skip to content

Commit a884056

Browse files
authored
update client docstring (#5)
Signed-off-by: 0oshowero0 <[email protected]>
1 parent f6638aa commit a884056

File tree

1 file changed

+79
-9
lines changed

1 file changed

+79
-9
lines changed

verl/experimental/transfer_queue/client.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,40 @@ async def async_get_meta(
180180
data_fields (list[str]): List of fields to retrieve metadata for
181181
batch_size (int): Processing batch size
182182
global_step (int): Current training/processing step
183-
mode (str): Data fetch mode (TODO(hz): more details to be added)
184-
get_n_samples (bool): TODO(hz): more details to be added
183+
mode (str): Data fetch mode. 'fetch' to get ready data, 'force_fetch' to get data regardless of readiness.
184+
'insert' IS AN INTERNAL USAGE THAT SHOULD NOT BE USED BY USERS.
185+
get_n_samples (bool): If True, we arrange the samples of the same prompt in contiguous order. In 'fetch'
186+
mode, only the samples of the same prompt that are all ready will be returned.
185187
task_name (str): Optional task name associated with the request
186188
target_controller (str): ID of the target controller to send the request to
187189
socket (zmq.asyncio.Socket): ZMQ async socket for message transmission
188190
191+
Example:
192+
>>> batch_size = 4
193+
>>> current_step = 0
194+
>>> # Example 1: "fetch" a batch of metadata that has been produced
195+
>>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"],
196+
>>> batch_size=batch_size,
197+
>>> global_step=current_step,
198+
>>> mode="fetch",
199+
>>> get_n_samples=False,
200+
>>> task_name="generate_sequences",
201+
>>> ))
202+
>>> print(batch_meta.is_ready) # you should get a batch_meta with is_ready=True
203+
>>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, True, True, True]
204+
>>>
205+
>>> # Example 2: "force_fetch" a batch of metadata, ignoring their production status (but we still make
206+
>>> # sure the corresponding data has not been consumed)
207+
>>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"],
208+
>>> batch_size=batch_size,
209+
>>> global_step=current_step,
210+
>>> mode="force_fetch",
211+
>>> get_n_samples=False,
212+
>>> task_name="generate_sequences",
213+
>>> ))
214+
>>> print(batch_meta.is_ready) # you may get a batch_meta with is_ready=False
215+
>>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, False, False, True]
216+
189217
Returns:
190218
BatchMeta: Metadata object containing data structure, sample info, etc.
191219
"""
@@ -239,6 +267,30 @@ async def async_put(
239267
metadata (BatchMeta, optional): Optional metadata containing index and storage unit information
240268
global_step (int, optional): Current step (required if no metadata is provided)
241269
270+
Example:
271+
>>> batch_size = 4
272+
>>> seq_len = 16
273+
>>> current_step = 0
274+
>>> # Example 1: normal usage
275+
>>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"],
276+
>>> batch_size=batch_size,
277+
>>> global_step=current_step,
278+
>>> mode="fetch",
279+
>>> get_n_samples=False,
280+
>>> task_name="generate_sequences",
281+
>>> ))
282+
>>> batch = asyncio.run(client.async_get_data(batch_meta))
283+
>>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
284+
>>> asyncio.run(client.async_put(data=output, metadata=batch_meta))
285+
>>>
286+
>>> # Example 2: put the initial data into the system without pre-existing metadata
287+
>>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given global_step!
288+
>>> # So make sure the global_step is empty.
289+
>>> prompts = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]]))
290+
>>> prompt_batch = TensorDict({"prompts": prompts})
291+
>>> # This will create metadata in "insert" mode internally.
292+
>>> asyncio.run(client.async_put(data=prompt_batch, global_step=current_step))
293+
242294
"""
243295
if metadata is None:
244296
assert global_step is not None, "global_steps must be provided if metadata is not given"
@@ -346,13 +398,20 @@ async def async_get_data(self, metadata: BatchMeta) -> TensorDict:
346398
- "global_indexes" key: Maps each sample to its original global index.
347399
348400
Example:
349-
>>> returned_td = await async_get_data(metadata)
350-
>>> returned_td.keys()
351-
dict_keys(['prompt_token_ids', 'response_token_ids', 'global_indexes'])
352-
>>> returned_td["prompt_token_ids"].shape # Batch size 4, seq length 128
353-
torch.Size([4, 128])
354-
>>> returned_td["global_indexes"] # Preserves original global order
355-
tensor([7, 4, 6, 5])
401+
>>> batch_size = 4
402+
>>> seq_len = 16
403+
>>> current_step = 0
404+
>>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"],
405+
>>> batch_size=batch_size,
406+
>>> global_step=current_step,
407+
>>> mode="fetch",
408+
>>> get_n_samples=False,
409+
>>> task_name="generate_sequences",
410+
>>> ))
411+
>>> batch = asyncio.run(client.async_get_data(batch_meta))
412+
>>> print(batch)
413+
>>> # this is a TensorDict with fields "prompts" and "attention_mask".
414+
>>> # The order of samples in the TensorDict matches the order of global_indexes in batch_meta
356415
357416
Note:
358417
Why track `global_indexes`?
@@ -408,6 +467,7 @@ async def async_clear(self, global_step: int):
408467
409468
Args:
410469
global_step (int): The training step associated with the clear operation
470+
411471
"""
412472
try:
413473
target_controller = next(iter(self._controllers.keys()))
@@ -514,6 +574,16 @@ async def _clear_storage_unit(self, local_indexes, target_storage=None, socket=N
514574
logger.error(f"[{self.client_id}]: Error clearing storage unit {target_storage}: {str(e)}")
515575
raise
516576

577+
@dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
578+
def check_current_step_consumption(self, task_name: str, global_step: int):
579+
# TODO: Implement this method to check if all samples for the current step has been consumed
580+
pass
581+
582+
@dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
583+
def check_current_step_production(self, data_fields: list[str], global_step: int):
584+
# TODO: Implement this method to check if all samples for the current step is ready for consumption
585+
pass
586+
517587

518588
class TransferQueueClient(AsyncTransferQueueClient):
519589
def __init__(

0 commit comments

Comments
 (0)