@@ -180,12 +180,40 @@ async def async_get_meta(
180
180
data_fields (list[str]): List of fields to retrieve metadata for
181
181
batch_size (int): Processing batch size
182
182
global_step (int): Current training/processing step
183
- mode (str): Data fetch mode (TODO(hz): more details to be added)
184
- get_n_samples (bool): TODO(hz): more details to be added
183
+ mode (str): Data fetch mode. 'fetch' to get ready data, 'force_fetch' to get data regardless of readiness.
184
+ 'insert' IS AN INTERNAL USAGE THAT SHOULD NOT BE USED BY USERS.
185
+ get_n_samples (bool): If True, we arrange the samples of the same prompt in contiguous order. In 'fetch'
186
+ mode, only the samples of the same prompt that are all ready will be returned.
185
187
task_name (str): Optional task name associated with the request
186
188
target_controller (str): ID of the target controller to send the request to
187
189
socket (zmq.asyncio.Socket): ZMQ async socket for message transmission
188
190
191
+ Example:
192
+ >>> batch_size = 4
193
+ >>> current_step = 0
194
+ >>> # Example 1: "fetch" a batch of metadata that has been produced
195
+ >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"],
196
+ >>> batch_size=batch_size,
197
+ >>> global_step=current_step,
198
+ >>> mode="fetch",
199
+ >>> get_n_samples=False,
200
+ >>> task_name="generate_sequences",
201
+ >>> ))
202
+ >>> print(batch_meta.is_ready) # you should get a batch_meta with is_ready=True
203
+ >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, True, True, True]
204
+ >>>
205
+ >>> # Example 2: "force_fetch" a batch of metadata, ignoring their production status (but we still make
206
+ >>> # sure the corresponding data has not been consumed)
207
+ >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"],
208
+ >>> batch_size=batch_size,
209
+ >>> global_step=current_step,
210
+ >>> mode="force_fetch",
211
+ >>> get_n_samples=False,
212
+ >>> task_name="generate_sequences",
213
+ >>> ))
214
+ >>> print(batch_meta.is_ready) # you may get a batch_meta with is_ready=False
215
+ >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, False, False, True]
216
+
189
217
Returns:
190
218
BatchMeta: Metadata object containing data structure, sample info, etc.
191
219
"""
@@ -239,6 +267,30 @@ async def async_put(
239
267
metadata (BatchMeta, optional): Optional metadata containing index and storage unit information
240
268
global_step (int, optional): Current step (required if no metadata is provided)
241
269
270
+ Example:
271
+ >>> batch_size = 4
272
+ >>> seq_len = 16
273
+ >>> current_step = 0
274
+ >>> # Example 1: normal usage
275
+ >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"],
276
+ >>> batch_size=batch_size,
277
+ >>> global_step=current_step,
278
+ >>> mode="fetch",
279
+ >>> get_n_samples=False,
280
+ >>> task_name="generate_sequences",
281
+ >>> ))
282
+ >>> batch = asyncio.run(client.async_get_data(batch_meta))
283
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
284
+ >>> asyncio.run(client.async_put(data=output, metadata=batch_meta))
285
+ >>>
286
+ >>> # Example 2: put the initial data into the system without pre-existing metadata
287
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given global_step!
288
+ >>> # So make sure the global_step is empty.
289
+ >>> prompts = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]]))
290
+ >>> prompt_batch = TensorDict({"prompts": prompts})
291
+ >>> # This will create metadata in "insert" mode internally.
292
+ >>> asyncio.run(client.async_put(data=prompt_batch, global_step=current_step))
293
+
242
294
"""
243
295
if metadata is None :
244
296
assert global_step is not None , "global_steps must be provided if metadata is not given"
@@ -346,13 +398,20 @@ async def async_get_data(self, metadata: BatchMeta) -> TensorDict:
346
398
- "global_indexes" key: Maps each sample to its original global index.
347
399
348
400
Example:
349
- >>> returned_td = await async_get_data(metadata)
350
- >>> returned_td.keys()
351
- dict_keys(['prompt_token_ids', 'response_token_ids', 'global_indexes'])
352
- >>> returned_td["prompt_token_ids"].shape # Batch size 4, seq length 128
353
- torch.Size([4, 128])
354
- >>> returned_td["global_indexes"] # Preserves original global order
355
- tensor([7, 4, 6, 5])
401
+ >>> batch_size = 4
402
+ >>> seq_len = 16
403
+ >>> current_step = 0
404
+ >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"],
405
+ >>> batch_size=batch_size,
406
+ >>> global_step=current_step,
407
+ >>> mode="fetch",
408
+ >>> get_n_samples=False,
409
+ >>> task_name="generate_sequences",
410
+ >>> ))
411
+ >>> batch = asyncio.run(client.async_get_data(batch_meta))
412
+ >>> print(batch)
413
+ >>> # this is a TensorDict with fields "prompts" and "attention_mask".
414
+ >>> # The order of samples in the TensorDict matches the order of global_indexes in batch_meta
356
415
357
416
Note:
358
417
Why track `global_indexes`?
@@ -408,6 +467,7 @@ async def async_clear(self, global_step: int):
408
467
409
468
Args:
410
469
global_step (int): The training step associated with the clear operation
470
+
411
471
"""
412
472
try :
413
473
target_controller = next (iter (self ._controllers .keys ()))
@@ -514,6 +574,16 @@ async def _clear_storage_unit(self, local_indexes, target_storage=None, socket=N
514
574
logger .error (f"[{ self .client_id } ]: Error clearing storage unit { target_storage } : { str (e )} " )
515
575
raise
516
576
577
+ @dynamic_socket (target_role = TransferQueueRole .CONTROLLER , socket_name = "request_handle_socket" )
578
+ def check_current_step_consumption (self , task_name : str , global_step : int ):
579
+ # TODO: Implement this method to check if all samples for the current step has been consumed
580
+ pass
581
+
582
+ @dynamic_socket (target_role = TransferQueueRole .CONTROLLER , socket_name = "request_handle_socket" )
583
+ def check_current_step_production (self , data_fields : list [str ], global_step : int ):
584
+ # TODO: Implement this method to check if all samples for the current step is ready for consumption
585
+ pass
586
+
517
587
518
588
class TransferQueueClient (AsyncTransferQueueClient ):
519
589
def __init__ (
0 commit comments