-
Notifications
You must be signed in to change notification settings - Fork 4
[data] feat: Provide general decorator for DataProto <-> BatchMeta #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: 0oshowero0 <[email protected]>
Signed-off-by: 0oshowero0 <[email protected]>
Signed-off-by: 0oshowero0 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a general decorator for converting between DataProto and BatchMeta objects to enable TransferQueue integration. The decorator wraps functions that operate on DataProto objects so they can work with BatchMeta and the TransferQueue system.
Key changes:
- Implements
dataproto_batchmeta_conversion
decorator that handles conversion from BatchMeta to DataProto, function execution, and result conversion back to BatchMeta - Provides both synchronous and asynchronous wrappers with client-based data retrieval or mock data generation for testing
- Includes comprehensive test suite validating decorator functionality with real DataProto instances and mock TransferQueue components
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
File | Description |
---|---|
recipe/transfer_queue/dataproto_conversion.py | Core decorator implementation with conversion utilities and client integration |
recipe/transfer_queue/test_dataproto_decorator.py | Independent test script demonstrating decorator usage with DataProto and mock TransferQueue |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
data = await _batchmeta_to_dataproto_async(batch_meta, client) | ||
|
||
# Call function with DataProto | ||
result_data = await func(data, *other_args, **other_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is being awaited but may not be async. The wrapper assumes func
is async but should check if it's a coroutine function first, or handle both sync and async functions appropriately.
Copilot uses AI. Check for mistakes.
# We're in a running loop, this shouldn't happen for sync wrapper | ||
raise RuntimeError("Sync wrapper called from within async context") | ||
except RuntimeError: | ||
# No running loop, we can use asyncio.run | ||
data_dict = asyncio.run(client.async_get_data(batch_meta)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is flawed - asyncio.get_running_loop()
raises RuntimeError
when no loop is running, but the code catches ALL RuntimeError
exceptions. This could mask the intentionally raised error on line 127. Use asyncio.get_running_loop()
return value instead of exception handling.
# We're in a running loop, this shouldn't happen for sync wrapper | |
raise RuntimeError("Sync wrapper called from within async context") | |
except RuntimeError: | |
# No running loop, we can use asyncio.run | |
data_dict = asyncio.run(client.async_get_data(batch_meta)) | |
except RuntimeError: | |
# No running loop, we can use asyncio.run | |
data_dict = asyncio.run(client.async_get_data(batch_meta)) | |
else: | |
# We're in a running loop, this shouldn't happen for sync wrapper | |
raise RuntimeError("Sync wrapper called from within async context") |
Copilot uses AI. Check for mistakes.
try: | ||
return TensorDict(**tensor_dict, batch_size=len(data)) | ||
except Exception as e: | ||
logger.warning(f"TensorDict creation failed: {e}, trying fallback") | ||
# Fallback: create with batch_size parameter | ||
td = TensorDict({}, batch_size=len(data)) | ||
for key, value in tensor_dict.items(): | ||
td.set(key, value) | ||
return td |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching broad Exception
makes debugging difficult. Consider catching specific TensorDict-related exceptions or at minimum log the specific exception type and tensor_dict contents for better debugging.
Copilot uses AI. Check for mistakes.
|
||
|
||
def dataproto_batchmeta_conversion_v2(func: Optional[Callable] = None, *, transfer_queue_client: Optional[AsyncTransferQueueClient] = None): | ||
""" | ||
Alternative decorator syntax that supports both @decorator and @decorator() usage. | ||
""" | ||
def decorator(f: Callable) -> Callable: | ||
return dataproto_batchmeta_conversion(transfer_queue_client)(f) | ||
|
||
if func is not None: | ||
return decorator(func) | ||
return decorator No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _v2
function appears to be unused and provides the same functionality as the main decorator. Consider removing this duplicate implementation to reduce code complexity.
def dataproto_batchmeta_conversion_v2(func: Optional[Callable] = None, *, transfer_queue_client: Optional[AsyncTransferQueueClient] = None): | |
""" | |
Alternative decorator syntax that supports both @decorator and @decorator() usage. | |
""" | |
def decorator(f: Callable) -> Callable: | |
return dataproto_batchmeta_conversion(transfer_queue_client)(f) | |
if func is not None: | |
return decorator(func) | |
return decorator |
Copilot uses AI. Check for mistakes.
Signed-off-by: 0oshowero0 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
# Call function with DataProto | ||
result_data = await func(data, *other_args, **other_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function func
is being awaited but there's no guarantee it's a coroutine. This will fail if the wrapped function is synchronous. The check on line 101 should determine which wrapper to use, but this async wrapper shouldn't call non-async functions with await.
Copilot uses AI. Check for mistakes.
# We're in a running loop, use run_coroutine_threadsafe | ||
future = asyncio.run_coroutine_threadsafe(client.async_get_data(batch_meta), loop) | ||
data_dict = future.result(timeout=10) # 10 second timeout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using asyncio.run_coroutine_threadsafe
with the current running loop will likely cause deadlock. When there's already a running event loop, you should use await
instead of trying to run the coroutine in the same loop from a different thread context.
# We're in a running loop, use run_coroutine_threadsafe | |
future = asyncio.run_coroutine_threadsafe(client.async_get_data(batch_meta), loop) | |
data_dict = future.result(timeout=10) # 10 second timeout | |
# We're in a running event loop in this thread; cannot safely run coroutine synchronously. | |
raise RuntimeError( | |
"Cannot call _batchmeta_to_dataproto_sync when an event loop is running in this thread. " | |
"Use the async version (_batchmeta_to_dataproto_async) instead." | |
) |
Copilot uses AI. Check for mistakes.
loop = asyncio.get_running_loop() | ||
except RuntimeError: | ||
# No running loop, we can use asyncio.run | ||
asyncio.run(client.async_put(data=output_tensor_dict, metadata=batch_meta)) | ||
else: | ||
# We're in a running loop, use run_coroutine_threadsafe | ||
future = asyncio.run_coroutine_threadsafe(client.async_put(data=output_tensor_dict, metadata=batch_meta), loop) | ||
future.result(timeout=10) # 10 second timeout | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as with async_get_data - using asyncio.run_coroutine_threadsafe
with the current running loop will likely cause deadlock. This pattern is problematic when already inside an event loop.
loop = asyncio.get_running_loop() | |
except RuntimeError: | |
# No running loop, we can use asyncio.run | |
asyncio.run(client.async_put(data=output_tensor_dict, metadata=batch_meta)) | |
else: | |
# We're in a running loop, use run_coroutine_threadsafe | |
future = asyncio.run_coroutine_threadsafe(client.async_put(data=output_tensor_dict, metadata=batch_meta), loop) | |
future.result(timeout=10) # 10 second timeout | |
asyncio.get_running_loop() | |
except RuntimeError: | |
# No running loop, we can use asyncio.run | |
asyncio.run(client.async_put(data=output_tensor_dict, metadata=batch_meta)) | |
else: | |
# We're in a running event loop in this thread; cannot safely run async code synchronously. | |
raise RuntimeError( | |
"Cannot call _update_batchmeta_with_result_sync while an event loop is running in this thread. " | |
"Use _update_batchmeta_with_result_async instead." | |
) |
Copilot uses AI. Check for mistakes.
for field_name in batch_meta.field_names: | ||
if field_name == "input_ids": | ||
data_dict[field_name] = torch.randint(0, 1000, (batch_size, 10)) | ||
elif field_name == "attention_mask": | ||
data_dict[field_name] = torch.ones(batch_size, 10) | ||
elif field_name == "responses": | ||
data_dict[field_name] = torch.randint(0, 1000, (batch_size, 5)) | ||
else: | ||
# Generic mock data | ||
data_dict[field_name] = torch.ones(batch_size, 5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mock data generation logic is duplicated between sync and async versions (lines 141-150 and 170-179). This should be extracted into a separate helper function to avoid code duplication.
Copilot uses AI. Check for mistakes.
# Test with client in a separate thread to avoid event loop issues | ||
print("\n2. Testing compute_response_mask decorator with client...") | ||
try: | ||
# Run in a separate thread to avoid event loop conflicts | ||
import concurrent.futures | ||
with concurrent.futures.ThreadPoolExecutor() as executor: | ||
future = executor.submit(compute_response_mask_decorated, batch_meta, transfer_queue_client=mock_client) | ||
result_batch_meta = future.result(timeout=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using ThreadPoolExecutor to avoid event loop issues indicates a design problem with the decorator. The decorator should handle async/sync contexts properly without requiring thread workarounds in tests.
# Test with client in a separate thread to avoid event loop issues | |
print("\n2. Testing compute_response_mask decorator with client...") | |
try: | |
# Run in a separate thread to avoid event loop conflicts | |
import concurrent.futures | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future = executor.submit(compute_response_mask_decorated, batch_meta, transfer_queue_client=mock_client) | |
result_batch_meta = future.result(timeout=10) | |
# Test with client, handle async/sync context properly | |
print("\n2. Testing compute_response_mask decorator with client...") | |
try: | |
result = compute_response_mask_decorated(batch_meta, transfer_queue_client=mock_client) | |
if asyncio.iscoroutine(result): | |
result_batch_meta = await result | |
else: | |
result_batch_meta = result |
Copilot uses AI. Check for mistakes.
else: | ||
# We're in a running loop, use run_coroutine_threadsafe | ||
future = asyncio.run_coroutine_threadsafe(client.async_get_data(batch_meta), loop) | ||
data_dict = future.result(timeout=10) # 10 second timeout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 10-second timeout is a magic number that appears multiple times (lines 134, 204). This should be defined as a constant at the module level for better maintainability.
Copilot uses AI. Check for mistakes.
Signed-off-by: 0oshowero0 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
if "no running event loop" in str(e): | ||
# No running loop, we can use asyncio.run | ||
asyncio.run( | ||
client.async_put(data=output_tensor_dict, metadata=batch_meta), timeout=DEFAULT_ASYNC_TIMEOUT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The asyncio.run()
function doesn't accept a timeout
parameter. The timeout
parameter belongs to asyncio.wait_for()
. This should be wrapped in asyncio.wait_for()
similar to the async version.
client.async_put(data=output_tensor_dict, metadata=batch_meta), timeout=DEFAULT_ASYNC_TIMEOUT | |
asyncio.wait_for( | |
client.async_put(data=output_tensor_dict, metadata=batch_meta), | |
timeout=DEFAULT_ASYNC_TIMEOUT | |
) |
Copilot uses AI. Check for mistakes.
try: | ||
return TensorDict(**tensor_dict, batch_size=len(data)) | ||
except Exception as e: | ||
logger.warning(f"TensorDict creation failed: {e}, trying fallback") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable logger
is not defined in this scope. The logger should be imported at the module level or the function should use print()
for error output.
logger.warning(f"TensorDict creation failed: {e}, trying fallback") | |
print(f"TensorDict creation failed: {e}, trying fallback") |
Copilot uses AI. Check for mistakes.
Signed-off-by: 0oshowero0 <[email protected]>
Signed-off-by: 0oshowero0 <[email protected]>
What does this PR do?
As the title
Checklist Before Starting
[{modules}] {type}: {description}
(This will be checked by the CI){modules}
includefsdp
,megatron
,sglang
,vllm
,rollout
,trainer
,ci
,training_utils
,recipe
,hardware
,deployment
,ray
,worker
,single_controller
,misc
,perf
,model
,algo
,env
,tool
,ckpt
,doc
,data
,
like[megatron, fsdp, doc]
{type}
is infeat
,fix
,refactor
,chore
,test
[BREAKING]
to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batching
Test
API and Usage Example
# Add code snippet or script demonstrating how to use this
Design & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
ci-request
channel in theverl
Slack workspace. (If not accessible, please try the Feishu group (飞书群).)