Skip to content

Conversation

@chaserileyroberts
Copy link
Contributor

@chaserileyroberts chaserileyroberts commented Sep 24, 2025

Description

This PR fixes failures related to using Warp within a jax.pmap or a jax.shard_map. The main issue was that JAX would initiate the Warp callback from multiple threads targeting different devices, but the device Warp would target was always the default one. To fix this, we hookup XLA_FFI_DeviceOrdinal_Get to xla_ffi.py, and plumb that value to the warp device context so that JAX and Warp would each target the same device on each callback.

Since the device Warp uses is dependent on a global context, JAX calling the FFI from multiple threads would smash this value and cause errors. To fix this, we include a threading lock before entering the context. We may consider making the device context value a thread local instead of a program global one.

Before your PR is "Ready for review"

  • [ x ] All commits are signed-off to indicate that your contribution adheres to the Developer Certificate of Origin requirements
  • [ x ] Necessary tests have been added
  • [ x ] Documentation is up-to-date
  • [ x ] Auto-generated files modified by compiling Warp and building the documentation have been updated (e.g. __init__.pyi, functions.rst)
  • [ x ] Code passes formatting and linting checks with pre-commit run -a

Summary by CodeRabbit

  • New Features

    • Improved JAX interop: pmap now supports Warp kernels with multi-output and multi-stage pipelines.
    • Automatically preloads kernel modules across all local GPUs for smoother multi-GPU execution.
    • More reliable device/stream selection during FFI calls for correct multi-GPU routing.
  • Bug Fixes

    • Resolved race conditions and intermittent failures under concurrent XLA calls via global synchronization.
    • Eliminated per-device build races by loading modules for all GPUs up front.
  • Tests

    • Added pmap coverage for single, multi-output, and multi-stage Warp kernels.

Signed-off-by: Chase Riley Roberts <[email protected]>
@coderabbitai
Copy link

coderabbitai bot commented Sep 24, 2025

📝 Walkthrough

Walkthrough

Introduces a global lock for Warp device context, updates JAX FFI to retrieve device ordinal from callframes, preloads modules across all local GPUs, and restructures FFI call paths to resolve device/stream via ordinal. Wraps FFI execution with synchronized ScopedStream usage. Adds pmap-related JAX tests for jax_callable kernels.

Changes

Cohort / File(s) Summary
FFI callflow and synchronization
warp/jax_experimental/ffi.py
Added WARP_DEVICE_CONTEXT_LOCK; preload kernel modules on all local GPUs; switch device resolution from get_jax_device to device ordinal via xla_ffi; construct Warp streams from ordinals; execute under global lock with ScopedStream; preserve capture/graph handling and references.
XLA FFI API: device ordinal
warp/jax_experimental/xla_ffi.py
Added XLA_FFI_DeviceOrdinal_Get_Args and XLA_FFI_DeviceOrdinal_Get type; extended XLA_FFI_Api with RunId and DeviceOrdinal getters; implemented get_device_ordinal_from_callframe helper mirroring Stream_Get pattern.
Tests: JAX pmap forwarding
warp/tests/interop/test_jax.py
Added three tests for jax_callable under jax.pmap: single-output, multi-output, and multi-stage kernels; gated by JAX version and device count; registered in JAX CUDA suite.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant XLA as XLA Runtime Thread
  participant FFI as Python FFI Shim
  participant XAPI as xla_ffi (API)
  participant Warp as Warp Runtime
  participant Mod as Kernel Modules

  Note over XLA,FFI: FFI call (compute)
  XLA->>FFI: call(callback, call_frame)
  FFI->>XAPI: get_device_ordinal_from_callframe(call_frame)
  XAPI-->>FFI: device_ordinal
  FFI->>Warp: get_cuda_device(device_ordinal), Stream(device)
  par Preload (once per __call__)
    FFI->>Warp: load module(s) on all jax.local_devices()
    Warp-->>FFI: modules ready
  end
  rect rgba(200,230,255,0.25)
    note right of FFI: Synchronized execution
    FFI->>FFI: acquire WARP_DEVICE_CONTEXT_LOCK
    FFI->>Warp: ScopedStream(stream, sync_enter=true)
    alt Capturing (JAX or Warp graph)
      FFI->>Warp: capture / launch graph
      Warp-->>FFI: handles, refs
    else Direct
      FFI->>Warp: launch kernel(s)
      Warp-->>FFI: completion
    end
    FFI->>FFI: release lock
  end
  FFI-->>XLA: return status
Loading
sequenceDiagram
  autonumber
  participant User as User Code
  participant JAX as jax_callable(__call__)
  participant FFI as Python FFI Shim
  participant Warp as Warp Runtime

  User->>JAX: create jax_callable on kernel
  JAX->>FFI: prepare callable
  par Preload modules on all GPUs
    FFI->>Warp: for d in jax.local_devices(): load module(d)
    Warp-->>FFI: modules loaded per device
  end
  JAX-->>User: callable ready (modules preloaded)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Add device ordinal XLA plumbing" is concise and accurately summarizes the primary change in the PR — adding XLA device-ordinal plumbing to align JAX and Warp device selection (including XLA_FFI_DeviceOrdinal_Get and related synchronization changes). It clearly conveys the main intent without unnecessary detail.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
warp/tests/interop/test_jax.py (1)

1074-1091: Register pmap-forward tests only when JAX >= 0.5.0

Prevents scheduling tests that will inevitably skip/fail on older JAX.

-        add_function_test(
-            TestJax,
-            "test_jax_callable_pmap_multi_output_forward",
-            test_jax_callable_pmap_multi_output_forward,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_callable_pmap_mul_forward",
-            test_jax_callable_pmap_mul_forward,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_callable_pmap_multi_stage_forward",
-            test_jax_callable_pmap_multi_stage_forward,
-            devices=jax_compatible_cuda_devices,
-        )
+        if jax.__version_info__ >= (0, 5, 0):
+            add_function_test(
+                TestJax,
+                "test_jax_callable_pmap_multi_output_forward",
+                test_jax_callable_pmap_multi_output_forward,
+                devices=jax_compatible_cuda_devices,
+            )
+            add_function_test(
+                TestJax,
+                "test_jax_callable_pmap_mul_forward",
+                test_jax_callable_pmap_mul_forward,
+                devices=jax_compatible_cuda_devices,
+            )
+            add_function_test(
+                TestJax,
+                "test_jax_callable_pmap_multi_stage_forward",
+                test_jax_callable_pmap_multi_stage_forward,
+                devices=jax_compatible_cuda_devices,
+            )
warp/jax_experimental/ffi.py (4)

248-252: Filter preloading to GPU devices only

Avoid unnecessary CPU module loads and potential failures on CPU-only variants by restricting to GPU devices.

-        # ensure the kernel module is loaded on all local GPUs to avoid per-device build races
-        for d in jax.local_devices():
-            dev = wp.device_from_jax(d)
-            self.kernel.module.load(dev)
+        # ensure the kernel module is loaded on all local GPUs to avoid per-device build races
+        for d in jax.local_devices():
+            if getattr(d, "platform", "") != "gpu":
+                continue
+            dev = wp.device_from_jax(d)
+            self.kernel.module.load(dev)

521-526: Preload callable module on GPUs only

Same rationale as above; limits work and avoids CPU load issues.

-        module = wp.get_module(self.func.__module__)
-        for d in jax.local_devices():
-            dev = wp.device_from_jax(d)
-            module.load(dev)
+        module = wp.get_module(self.func.__module__)
+        for d in jax.local_devices():
+            if getattr(d, "platform", "") != "gpu":
+                continue
+            dev = wp.device_from_jax(d)
+            module.load(dev)

597-641: ScopedStream under global lock: solid fix; consider reentrancy

Locking around ScopedStream prevents concurrent overwrites of Warp's global device/stream. If nested jax_callable invocations are possible on the same thread, consider threading.RLock for reentrancy.


28-31: Silence F405 by explicitly importing the ordinal getter

Ruff flags get_device_ordinal_from_callframe as undefined due to star-import. Add an explicit import.

-from .xla_ffi import *
+from .xla_ffi import *
+# Explicitly import to satisfy static analyzers that discourage star-imports.
+from .xla_ffi import get_device_ordinal_from_callframe
warp/jax_experimental/xla_ffi.py (1)

596-603: Optional: add basic validation for returned device ordinal

Defensive check can surface clearer errors if the API returns an invalid ordinal.

 def get_device_ordinal_from_callframe(call_frame):
     api = call_frame.api
     get_device_args = XLA_FFI_DeviceOrdinal_Get_Args(
         ctypes.sizeof(XLA_FFI_DeviceOrdinal_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, 0
     )
     api.contents.XLA_FFI_DeviceOrdinal_Get(get_device_args)
-    return get_device_args.device_ordinal
+    ord = int(get_device_args.device_ordinal)
+    if ord < 0:
+        raise RuntimeError("XLA_FFI_DeviceOrdinal_Get returned an invalid device ordinal")
+    return ord
📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e1f4f20 and e8f4519.

📒 Files selected for processing (3)
  • warp/jax_experimental/ffi.py (6 hunks)
  • warp/jax_experimental/xla_ffi.py (4 hunks)
  • warp/tests/interop/test_jax.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
warp/tests/interop/test_jax.py (3)
warp/tests/interop/test_dlpack.py (1)
  • _jax_version (28-34)
warp/jax_experimental/ffi.py (1)
  • jax_callable (707-773)
warp/tests/unittest_utils.py (1)
  • add_function_test (284-303)
warp/jax_experimental/ffi.py (4)
warp/jax.py (1)
  • device_from_jax (44-58)
warp/context.py (4)
  • load (2495-2607)
  • get_cuda_device (4719-4727)
  • stream (3196-3205)
  • stream (3208-3209)
warp/jax_experimental/xla_ffi.py (1)
  • get_device_ordinal_from_callframe (596-602)
warp/utils.py (2)
  • ScopedStream (1238-1283)
  • ScopedCapture (1471-1500)
🪛 Ruff (0.13.1)
warp/tests/interop/test_jax.py

796-796: Unused function argument: device

(ARG001)


828-828: Unused function argument: device

(ARG001)


875-875: Unused function argument: device

(ARG001)


1074-1074: add_function_test may be undefined, or defined from star imports

(F405)


1080-1080: add_function_test may be undefined, or defined from star imports

(F405)


1086-1086: add_function_test may be undefined, or defined from star imports

(F405)

warp/jax_experimental/ffi.py

326-326: get_device_ordinal_from_callframe may be undefined, or defined from star imports

(F405)


597-597: get_device_ordinal_from_callframe may be undefined, or defined from star imports

(F405)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: build-and-test / test-warp-macos
  • GitHub Check: build-and-test / build-warp-ubuntu-x86_64
  • GitHub Check: build-and-test / build-warp-ubuntu-aarch64
  • GitHub Check: build-and-test / build-warp-windows
🔇 Additional comments (4)
warp/jax_experimental/ffi.py (2)

30-30: Global device-context lock: LGTM

Using a single lock to serialize ScopedStream usage across FFI callbacks addresses the Warp global context race.


325-343: Correct device/stream resolution for kernel FFI path

Using XLA device ordinal + direct wp_cuda_launch_kernel avoids touching global device/stream state. Good.

warp/jax_experimental/xla_ffi.py (2)

382-399: DeviceOrdinal_Get args wiring: LGTM

Structure and CFUNCTYPE match usage pattern; returns int32 ordinal via out param.


424-451: Confirm XLA_FFI_Api field order matches header

To avoid ABI mismatches, ensure RunId_Get and DeviceOrdinal_Get positions align with your XLA header version.

Would you double-check xla/ffi/api/c_api.h for the exact ordering of these two fields for the JAX/XLA version you target?

Comment on lines +795 to +804
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_mul_forward(test, device):
import jax
import jax.numpy as jp

from warp.jax_experimental.ffi import jax_callable

if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Align version guard with jax_callable requirement; silence unused arg

These tests call warp.jax_experimental.ffi.jax_callable which requires JAX >= 0.5.0 via check_jax_version(). Current guard is 0.4.31, leading to failures on 0.4.31–0.4.99. Also silence the unused device arg.

-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_mul_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_mul_forward(test, _device):

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.13.1)

796-796: Unused function argument: device

(ARG001)

🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 795 to 804, the test uses
warp.jax_experimental.ffi.jax_callable which requires JAX >= 0.5.0 but the skip
guard currently checks for >= 0.4.31, causing failures on JAX 0.4.x; update the
version guard to _jax_version() >= (0, 5, 0) and silence the unused device
parameter by either renaming it to _device or adding a simple del device (or _ =
device) so the test runner does not flag an unused-argument warning.

Comment on lines +827 to +836
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_output_forward(test, device):
import jax
import jax.numpy as jp

from warp.jax_experimental.ffi import jax_callable

if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix version guard and unused arg here as well

-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_multi_output_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_multi_output_forward(test, _device):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_output_forward(test, device):
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_callable
if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_output_forward(test, _device):
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_callable
if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")
🧰 Tools
🪛 Ruff (0.13.1)

828-828: Unused function argument: device

(ARG001)

🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 827 to 836, fix the version guard
and remove the unused function arguments: ensure the skipUnless uses a properly
comparable jax version (e.g. call _jax_version() and compare to a tuple or use
packaging.version.parse if _jax_version() returns a string) and change the test
signature to a normal unittest method (remove the unused 'device' and 'test'
args) so it becomes def test_jax_callable_pmap_multi_output_forward(self):, then
replace test.skipTest(...) with self.skipTest(...) when checking
jax.local_device_count() < 2.

Comment on lines +874 to +883
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_stage_forward(test, device):
import jax
import jax.numpy as jp

from warp.jax_experimental.ffi import jax_callable

if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix version guard and unused arg here as well

-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_multi_stage_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_multi_stage_forward(test, _device):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_stage_forward(test, device):
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_callable
if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_stage_forward(test, _device):
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_callable
if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")
🧰 Tools
🪛 Ruff (0.13.1)

875-875: Unused function argument: device

(ARG001)

🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 874-883, the test function has an
unused 'test' parameter and the JAX version guard needs to be corrected; change
the function signature to remove the unused parameter (e.g., def
test_jax_callable_pmap_multi_stage_forward(self, device): if this is a unittest
method, or def test_jax_callable_pmap_multi_stage_forward(device): for a plain
test function) and ensure the decorator uses the correct version check using
_jax_version(), e.g. unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax
version too old for pmap forward test"), so the guard is consistent and no
unused arg remains.

@shi-eric shi-eric requested a review from nvlukasz September 24, 2025 14:31
@shi-eric shi-eric added this to the 1.10.0 milestone Sep 24, 2025
btaba added a commit to btaba/mujoco that referenced this pull request Sep 24, 2025
Copy link
Contributor

@nvlukasz nvlukasz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @chaserileyroberts, this is great. I will get it merged shortly.

btaba added a commit to btaba/mujoco that referenced this pull request Oct 3, 2025
@shi-eric shi-eric merged commit df274a7 into NVIDIA:main Oct 3, 2025
14 checks passed
shi-eric pushed a commit to shi-eric/warp that referenced this pull request Oct 3, 2025
shi-eric pushed a commit to shi-eric/warp that referenced this pull request Oct 3, 2025
Add JAX pmap support (NVIDIAGH-976)

See merge request omniverse/warp!1614
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants