-
Notifications
You must be signed in to change notification settings - Fork 375
Add device ordinal XLA plumbing #976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add device ordinal XLA plumbing #976
Conversation
Signed-off-by: Chase Riley Roberts <[email protected]>
📝 WalkthroughWalkthroughIntroduces a global lock for Warp device context, updates JAX FFI to retrieve device ordinal from callframes, preloads modules across all local GPUs, and restructures FFI call paths to resolve device/stream via ordinal. Wraps FFI execution with synchronized ScopedStream usage. Adds pmap-related JAX tests for jax_callable kernels. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant XLA as XLA Runtime Thread
participant FFI as Python FFI Shim
participant XAPI as xla_ffi (API)
participant Warp as Warp Runtime
participant Mod as Kernel Modules
Note over XLA,FFI: FFI call (compute)
XLA->>FFI: call(callback, call_frame)
FFI->>XAPI: get_device_ordinal_from_callframe(call_frame)
XAPI-->>FFI: device_ordinal
FFI->>Warp: get_cuda_device(device_ordinal), Stream(device)
par Preload (once per __call__)
FFI->>Warp: load module(s) on all jax.local_devices()
Warp-->>FFI: modules ready
end
rect rgba(200,230,255,0.25)
note right of FFI: Synchronized execution
FFI->>FFI: acquire WARP_DEVICE_CONTEXT_LOCK
FFI->>Warp: ScopedStream(stream, sync_enter=true)
alt Capturing (JAX or Warp graph)
FFI->>Warp: capture / launch graph
Warp-->>FFI: handles, refs
else Direct
FFI->>Warp: launch kernel(s)
Warp-->>FFI: completion
end
FFI->>FFI: release lock
end
FFI-->>XLA: return status
sequenceDiagram
autonumber
participant User as User Code
participant JAX as jax_callable(__call__)
participant FFI as Python FFI Shim
participant Warp as Warp Runtime
User->>JAX: create jax_callable on kernel
JAX->>FFI: prepare callable
par Preload modules on all GPUs
FFI->>Warp: for d in jax.local_devices(): load module(d)
Warp-->>FFI: modules loaded per device
end
JAX-->>User: callable ready (modules preloaded)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (6)
warp/tests/interop/test_jax.py (1)
1074-1091: Register pmap-forward tests only when JAX >= 0.5.0Prevents scheduling tests that will inevitably skip/fail on older JAX.
- add_function_test( - TestJax, - "test_jax_callable_pmap_multi_output_forward", - test_jax_callable_pmap_multi_output_forward, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, - "test_jax_callable_pmap_mul_forward", - test_jax_callable_pmap_mul_forward, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, - "test_jax_callable_pmap_multi_stage_forward", - test_jax_callable_pmap_multi_stage_forward, - devices=jax_compatible_cuda_devices, - ) + if jax.__version_info__ >= (0, 5, 0): + add_function_test( + TestJax, + "test_jax_callable_pmap_multi_output_forward", + test_jax_callable_pmap_multi_output_forward, + devices=jax_compatible_cuda_devices, + ) + add_function_test( + TestJax, + "test_jax_callable_pmap_mul_forward", + test_jax_callable_pmap_mul_forward, + devices=jax_compatible_cuda_devices, + ) + add_function_test( + TestJax, + "test_jax_callable_pmap_multi_stage_forward", + test_jax_callable_pmap_multi_stage_forward, + devices=jax_compatible_cuda_devices, + )warp/jax_experimental/ffi.py (4)
248-252: Filter preloading to GPU devices onlyAvoid unnecessary CPU module loads and potential failures on CPU-only variants by restricting to GPU devices.
- # ensure the kernel module is loaded on all local GPUs to avoid per-device build races - for d in jax.local_devices(): - dev = wp.device_from_jax(d) - self.kernel.module.load(dev) + # ensure the kernel module is loaded on all local GPUs to avoid per-device build races + for d in jax.local_devices(): + if getattr(d, "platform", "") != "gpu": + continue + dev = wp.device_from_jax(d) + self.kernel.module.load(dev)
521-526: Preload callable module on GPUs onlySame rationale as above; limits work and avoids CPU load issues.
- module = wp.get_module(self.func.__module__) - for d in jax.local_devices(): - dev = wp.device_from_jax(d) - module.load(dev) + module = wp.get_module(self.func.__module__) + for d in jax.local_devices(): + if getattr(d, "platform", "") != "gpu": + continue + dev = wp.device_from_jax(d) + module.load(dev)
597-641: ScopedStream under global lock: solid fix; consider reentrancyLocking around ScopedStream prevents concurrent overwrites of Warp's global device/stream. If nested jax_callable invocations are possible on the same thread, consider threading.RLock for reentrancy.
28-31: Silence F405 by explicitly importing the ordinal getterRuff flags get_device_ordinal_from_callframe as undefined due to star-import. Add an explicit import.
-from .xla_ffi import * +from .xla_ffi import * +# Explicitly import to satisfy static analyzers that discourage star-imports. +from .xla_ffi import get_device_ordinal_from_callframewarp/jax_experimental/xla_ffi.py (1)
596-603: Optional: add basic validation for returned device ordinalDefensive check can surface clearer errors if the API returns an invalid ordinal.
def get_device_ordinal_from_callframe(call_frame): api = call_frame.api get_device_args = XLA_FFI_DeviceOrdinal_Get_Args( ctypes.sizeof(XLA_FFI_DeviceOrdinal_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, 0 ) api.contents.XLA_FFI_DeviceOrdinal_Get(get_device_args) - return get_device_args.device_ordinal + ord = int(get_device_args.device_ordinal) + if ord < 0: + raise RuntimeError("XLA_FFI_DeviceOrdinal_Get returned an invalid device ordinal") + return ord
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
warp/jax_experimental/ffi.py(6 hunks)warp/jax_experimental/xla_ffi.py(4 hunks)warp/tests/interop/test_jax.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
warp/tests/interop/test_jax.py (3)
warp/tests/interop/test_dlpack.py (1)
_jax_version(28-34)warp/jax_experimental/ffi.py (1)
jax_callable(707-773)warp/tests/unittest_utils.py (1)
add_function_test(284-303)
warp/jax_experimental/ffi.py (4)
warp/jax.py (1)
device_from_jax(44-58)warp/context.py (4)
load(2495-2607)get_cuda_device(4719-4727)stream(3196-3205)stream(3208-3209)warp/jax_experimental/xla_ffi.py (1)
get_device_ordinal_from_callframe(596-602)warp/utils.py (2)
ScopedStream(1238-1283)ScopedCapture(1471-1500)
🪛 Ruff (0.13.1)
warp/tests/interop/test_jax.py
796-796: Unused function argument: device
(ARG001)
828-828: Unused function argument: device
(ARG001)
875-875: Unused function argument: device
(ARG001)
1074-1074: add_function_test may be undefined, or defined from star imports
(F405)
1080-1080: add_function_test may be undefined, or defined from star imports
(F405)
1086-1086: add_function_test may be undefined, or defined from star imports
(F405)
warp/jax_experimental/ffi.py
326-326: get_device_ordinal_from_callframe may be undefined, or defined from star imports
(F405)
597-597: get_device_ordinal_from_callframe may be undefined, or defined from star imports
(F405)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: build-and-test / test-warp-macos
- GitHub Check: build-and-test / build-warp-ubuntu-x86_64
- GitHub Check: build-and-test / build-warp-ubuntu-aarch64
- GitHub Check: build-and-test / build-warp-windows
🔇 Additional comments (4)
warp/jax_experimental/ffi.py (2)
30-30: Global device-context lock: LGTMUsing a single lock to serialize ScopedStream usage across FFI callbacks addresses the Warp global context race.
325-343: Correct device/stream resolution for kernel FFI pathUsing XLA device ordinal + direct wp_cuda_launch_kernel avoids touching global device/stream state. Good.
warp/jax_experimental/xla_ffi.py (2)
382-399: DeviceOrdinal_Get args wiring: LGTMStructure and CFUNCTYPE match usage pattern; returns int32 ordinal via out param.
424-451: Confirm XLA_FFI_Api field order matches headerTo avoid ABI mismatches, ensure RunId_Get and DeviceOrdinal_Get positions align with your XLA header version.
Would you double-check xla/ffi/api/c_api.h for the exact ordering of these two fields for the JAX/XLA version you target?
| @unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test") | ||
| def test_jax_callable_pmap_mul_forward(test, device): | ||
| import jax | ||
| import jax.numpy as jp | ||
|
|
||
| from warp.jax_experimental.ffi import jax_callable | ||
|
|
||
| if jax.local_device_count() < 2: | ||
| test.skipTest("requires >= 2 local devices") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Align version guard with jax_callable requirement; silence unused arg
These tests call warp.jax_experimental.ffi.jax_callable which requires JAX >= 0.5.0 via check_jax_version(). Current guard is 0.4.31, leading to failures on 0.4.31–0.4.99. Also silence the unused device arg.
-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_mul_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_mul_forward(test, _device):Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.13.1)
796-796: Unused function argument: device
(ARG001)
🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 795 to 804, the test uses
warp.jax_experimental.ffi.jax_callable which requires JAX >= 0.5.0 but the skip
guard currently checks for >= 0.4.31, causing failures on JAX 0.4.x; update the
version guard to _jax_version() >= (0, 5, 0) and silence the unused device
parameter by either renaming it to _device or adding a simple del device (or _ =
device) so the test runner does not flag an unused-argument warning.
| @unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test") | ||
| def test_jax_callable_pmap_multi_output_forward(test, device): | ||
| import jax | ||
| import jax.numpy as jp | ||
|
|
||
| from warp.jax_experimental.ffi import jax_callable | ||
|
|
||
| if jax.local_device_count() < 2: | ||
| test.skipTest("requires >= 2 local devices") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Fix version guard and unused arg here as well
-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_multi_output_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_multi_output_forward(test, _device):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test") | |
| def test_jax_callable_pmap_multi_output_forward(test, device): | |
| import jax | |
| import jax.numpy as jp | |
| from warp.jax_experimental.ffi import jax_callable | |
| if jax.local_device_count() < 2: | |
| test.skipTest("requires >= 2 local devices") | |
| @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test") | |
| def test_jax_callable_pmap_multi_output_forward(test, _device): | |
| import jax | |
| import jax.numpy as jp | |
| from warp.jax_experimental.ffi import jax_callable | |
| if jax.local_device_count() < 2: | |
| test.skipTest("requires >= 2 local devices") |
🧰 Tools
🪛 Ruff (0.13.1)
828-828: Unused function argument: device
(ARG001)
🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 827 to 836, fix the version guard
and remove the unused function arguments: ensure the skipUnless uses a properly
comparable jax version (e.g. call _jax_version() and compare to a tuple or use
packaging.version.parse if _jax_version() returns a string) and change the test
signature to a normal unittest method (remove the unused 'device' and 'test'
args) so it becomes def test_jax_callable_pmap_multi_output_forward(self):, then
replace test.skipTest(...) with self.skipTest(...) when checking
jax.local_device_count() < 2.
| @unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test") | ||
| def test_jax_callable_pmap_multi_stage_forward(test, device): | ||
| import jax | ||
| import jax.numpy as jp | ||
|
|
||
| from warp.jax_experimental.ffi import jax_callable | ||
|
|
||
| if jax.local_device_count() < 2: | ||
| test.skipTest("requires >= 2 local devices") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Fix version guard and unused arg here as well
-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_multi_stage_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_multi_stage_forward(test, _device):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test") | |
| def test_jax_callable_pmap_multi_stage_forward(test, device): | |
| import jax | |
| import jax.numpy as jp | |
| from warp.jax_experimental.ffi import jax_callable | |
| if jax.local_device_count() < 2: | |
| test.skipTest("requires >= 2 local devices") | |
| @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test") | |
| def test_jax_callable_pmap_multi_stage_forward(test, _device): | |
| import jax | |
| import jax.numpy as jp | |
| from warp.jax_experimental.ffi import jax_callable | |
| if jax.local_device_count() < 2: | |
| test.skipTest("requires >= 2 local devices") |
🧰 Tools
🪛 Ruff (0.13.1)
875-875: Unused function argument: device
(ARG001)
🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 874-883, the test function has an
unused 'test' parameter and the JAX version guard needs to be corrected; change
the function signature to remove the unused parameter (e.g., def
test_jax_callable_pmap_multi_stage_forward(self, device): if this is a unittest
method, or def test_jax_callable_pmap_multi_stage_forward(device): for a plain
test function) and ensure the decorator uses the correct version check using
_jax_version(), e.g. unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax
version too old for pmap forward test"), so the guard is consistent and no
unused arg remains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @chaserileyroberts, this is great. I will get it merged shortly.
Add JAX pmap support (NVIDIAGH-976) See merge request omniverse/warp!1614
Description
This PR fixes failures related to using Warp within a
jax.pmapor ajax.shard_map. The main issue was that JAX would initiate the Warp callback from multiple threads targeting different devices, but the device Warp would target was always the default one. To fix this, we hookupXLA_FFI_DeviceOrdinal_Gettoxla_ffi.py, and plumb that value to the warp device context so that JAX and Warp would each target the same device on each callback.Since the device Warp uses is dependent on a global context, JAX calling the FFI from multiple threads would smash this value and cause errors. To fix this, we include a threading lock before entering the context. We may consider making the device context value a thread local instead of a program global one.
Before your PR is "Ready for review"
__init__.pyi,functions.rst)pre-commit run -aSummary by CodeRabbit
New Features
Bug Fixes
Tests