Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 32 additions & 28 deletions warp/jax_experimental/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@

import warp as wp
from warp.codegen import get_full_arg_spec, make_full_qualified_name
from warp.jax import get_jax_device
from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp

from .xla_ffi import *

WARP_DEVICE_CONTEXT_LOCK = threading.Lock()


def check_jax_version():
# check if JAX version supports this
Expand Down Expand Up @@ -244,9 +245,10 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
input_output_aliases=self.input_output_aliases,
)

# ensure the kernel module is loaded before the callback, otherwise graph capture may fail
device = wp.device_from_jax(get_jax_device())
self.kernel.module.load(device)
# ensure the kernel module is loaded on all local GPUs to avoid per-device build races
for d in jax.local_devices():
dev = wp.device_from_jax(d)
self.kernel.module.load(dev)

# save launch data to be retrieved by callback
launch_id = self.launch_id
Expand Down Expand Up @@ -321,7 +323,7 @@ def ffi_callback(self, call_frame):
arg_refs.append(arg) # keep a reference

# get device and stream
device = wp.device_from_jax(get_jax_device())
device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents))
stream = get_stream_from_callframe(call_frame.contents)

# get kernel hooks
Expand Down Expand Up @@ -516,11 +518,11 @@ def __call__(self, *args, output_dims=None, vmap_method=None):
# has_side_effect=True, # force this function to execute even if outputs aren't used
)

# load the module
# NOTE: if the target function uses kernels from different modules, they will not be loaded here
device = wp.device_from_jax(get_jax_device())
# Preload relevant modules across local GPUs to avoid per-device build races (e.g., __main__)
module = wp.get_module(self.func.__module__)
module.load(device)
for d in jax.local_devices():
dev = wp.device_from_jax(d)
module.load(dev)

# save call data to be retrieved by callback
call_id = self.call_id
Expand Down Expand Up @@ -592,10 +594,9 @@ def ffi_callback(self, call_frame):

# early out
return

device = wp.device_from_jax(get_jax_device())
device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
device = wp.get_cuda_device(device_ordinal)
stream = wp.Stream(device, cuda_stream=cuda_stream)

# reconstruct the argument list
arg_list = []

Expand All @@ -619,23 +620,26 @@ def ffi_callback(self, call_frame):
arg_list.append(arr)

# call the Python function with reconstructed arguments
with wp.ScopedStream(stream, sync_enter=False):
if stream.is_capturing:
# capturing with JAX
with wp.ScopedCapture(external=True) as capture:
# Lock is required here to prevent wp.ScopedStreams from overwriting each other
# when XLA calls this method from multiple threads.
with WARP_DEVICE_CONTEXT_LOCK:
with wp.ScopedStream(stream, sync_enter=True):
if stream.is_capturing:
# capturing with JAX
with wp.ScopedCapture(external=True) as capture:
self.func(*arg_list)
# keep a reference to the capture object to prevent required modules getting unloaded
call_desc.capture = capture
elif self.graph_mode == GraphMode.WARP:
# capturing with WARP
with wp.ScopedCapture() as capture:
self.func(*arg_list)
wp.capture_launch(capture.graph)
# keep a reference to the capture object and reuse it with same buffers
call_desc.captures[buffer_hash] = capture
else:
# not capturing
self.func(*arg_list)
# keep a reference to the capture object to prevent required modules getting unloaded
call_desc.capture = capture
elif self.graph_mode == GraphMode.WARP:
# capturing with WARP
with wp.ScopedCapture() as capture:
self.func(*arg_list)
wp.capture_launch(capture.graph)
# keep a reference to the capture object and reuse it with same buffers
call_desc.captures[buffer_hash] = capture
else:
# not capturing
self.func(*arg_list)

except Exception as e:
print(traceback.format_exc())
Expand Down
32 changes: 32 additions & 0 deletions warp/jax_experimental/xla_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,24 @@ class XLA_FFI_Stream_Get_Args(ctypes.Structure):
XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))


# struct XLA_FFI_DeviceOrdinal_Get {
# size_t struct_size;
# XLA_FFI_Extension_Base* extension_start;
# XLA_FFI_ExecutionContext* ctx;
# int32_t device_ordinal; // out
# };
class XLA_FFI_DeviceOrdinal_Get_Args(ctypes.Structure):
_fields_ = (
("struct_size", ctypes.c_size_t),
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
("device_ordinal", ctypes.c_int32),
) # // out


XLA_FFI_DeviceOrdinal_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_DeviceOrdinal_Get_Args))


# struct XLA_FFI_Api {
# size_t struct_size;
# XLA_FFI_Extension_Base* extension_start;
Expand All @@ -402,6 +420,8 @@ class XLA_FFI_Stream_Get_Args(ctypes.Structure):
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_Create);
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetAvailable);
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_RunId_Get);
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceOrdinal_Get);
# };
class XLA_FFI_Api(ctypes.Structure):
_fields_ = (
Expand All @@ -425,6 +445,9 @@ class XLA_FFI_Api(ctypes.Structure):
("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
# TODO(chaserileyroberts): Make this return the correct value and not a c_void_p.
("XLA_FFI_RunId_Get", ctypes.c_void_p), # XLA_FFI_RunId_Get
("XLA_FFI_DeviceOrdinal_Get", XLA_FFI_DeviceOrdinal_Get), # XLA_FFI_DeviceOrdinal_Get
)


Expand Down Expand Up @@ -570,6 +593,15 @@ def get_stream_from_callframe(call_frame):
return get_stream_args.stream


def get_device_ordinal_from_callframe(call_frame):
api = call_frame.api
get_device_args = XLA_FFI_DeviceOrdinal_Get_Args(
ctypes.sizeof(XLA_FFI_DeviceOrdinal_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, 0
)
api.contents.XLA_FFI_DeviceOrdinal_Get(get_device_args)
return get_device_args.device_ordinal


_dtype_from_ffi = {
XLA_FFI_DataType.S8: wp.int8,
XLA_FFI_DataType.S16: wp.int16,
Expand Down
150 changes: 150 additions & 0 deletions warp/tests/interop/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,137 @@ def warp_func(inputs, outputs, attrs, ctx):
assert_np_equal(d, 2 * np.arange(10, dtype=np.float32).reshape((5, 2)))


@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_mul_forward(test, device):
import jax
import jax.numpy as jp

from warp.jax_experimental.ffi import jax_callable

if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")

Comment on lines +795 to +804
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Align version guard with jax_callable requirement; silence unused arg

These tests call warp.jax_experimental.ffi.jax_callable which requires JAX >= 0.5.0 via check_jax_version(). Current guard is 0.4.31, leading to failures on 0.4.31–0.4.99. Also silence the unused device arg.

-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_mul_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_mul_forward(test, _device):

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.13.1)

796-796: Unused function argument: device

(ARG001)

🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 795 to 804, the test uses
warp.jax_experimental.ffi.jax_callable which requires JAX >= 0.5.0 but the skip
guard currently checks for >= 0.4.31, causing failures on JAX 0.4.x; update the
version guard to _jax_version() >= (0, 5, 0) and silence the unused device
parameter by either renaming it to _device or adding a simple del device (or _ =
device) so the test runner does not flag an unused-argument warning.

@wp.kernel
def mul2(a: wp.array(dtype=float), out: wp.array(dtype=float)):
tid = wp.tid()
out[tid] = 2.0 * a[tid]

def mul2_py(a: wp.array(dtype=float), out: wp.array(dtype=float)):
wp.launch(mul2, dim=a.shape, inputs=[a], outputs=[out])

j = jax_callable(mul2_py, num_outputs=1)

per_device = 8
ndev = jax.local_device_count()
x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))

def per_device_fwd(v):
(y,) = j(v)
return y

y = jax.pmap(per_device_fwd)(x)
test.assertTrue(np.allclose(np.asarray(y), 2.0 * np.asarray(x), rtol=1e-5, atol=1e-6))


@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_output_forward(test, device):
import jax
import jax.numpy as jp

from warp.jax_experimental.ffi import jax_callable

if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")

Comment on lines +827 to +836
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix version guard and unused arg here as well

-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_multi_output_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_multi_output_forward(test, _device):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_output_forward(test, device):
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_callable
if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_output_forward(test, _device):
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_callable
if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")
🧰 Tools
🪛 Ruff (0.13.1)

828-828: Unused function argument: device

(ARG001)

🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 827 to 836, fix the version guard
and remove the unused function arguments: ensure the skipUnless uses a properly
comparable jax version (e.g. call _jax_version() and compare to a tuple or use
packaging.version.parse if _jax_version() returns a string) and change the test
signature to a normal unittest method (remove the unused 'device' and 'test'
args) so it becomes def test_jax_callable_pmap_multi_output_forward(self):, then
replace test.skipTest(...) with self.skipTest(...) when checking
jax.local_device_count() < 2.

@wp.kernel
def multi_out(
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
):
tid = wp.tid()
c[tid] = a[tid] + b[tid]
d[tid] = s * a[tid]

def multi_out_py(
a: wp.array(dtype=float),
b: wp.array(dtype=float),
s: float,
c: wp.array(dtype=float),
d: wp.array(dtype=float),
):
wp.launch(multi_out, dim=a.shape, inputs=[a, b, s], outputs=[c, d])

j = jax_callable(multi_out_py, num_outputs=2)

per_device = 7
ndev = jax.local_device_count()
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
b = jp.ones((ndev, per_device), dtype=jp.float32)
s = 3.0

def per_device_fwd(aa, bb):
c, d = j(aa, bb, s)
return c + d # simple combine to exercise both outputs

out = jax.pmap(per_device_fwd)(a, b)

a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
b_np = np.ones((ndev, per_device), dtype=np.float32)
ref = (a_np + b_np) + s * a_np
test.assertTrue(np.allclose(np.asarray(out), ref, rtol=1e-5, atol=1e-6))


@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_stage_forward(test, device):
import jax
import jax.numpy as jp

from warp.jax_experimental.ffi import jax_callable

if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")

Comment on lines +874 to +883
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix version guard and unused arg here as well

-@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_multi_stage_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_multi_stage_forward(test, _device):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_stage_forward(test, device):
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_callable
if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
def test_jax_callable_pmap_multi_stage_forward(test, _device):
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_callable
if jax.local_device_count() < 2:
test.skipTest("requires >= 2 local devices")
🧰 Tools
🪛 Ruff (0.13.1)

875-875: Unused function argument: device

(ARG001)

🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 874-883, the test function has an
unused 'test' parameter and the JAX version guard needs to be corrected; change
the function signature to remove the unused parameter (e.g., def
test_jax_callable_pmap_multi_stage_forward(self, device): if this is a unittest
method, or def test_jax_callable_pmap_multi_stage_forward(device): for a plain
test function) and ensure the decorator uses the correct version check using
_jax_version(), e.g. unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax
version too old for pmap forward test"), so the guard is consistent and no
unused arg remains.

@wp.kernel
def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), out: wp.array(dtype=float)):
tid = wp.tid()
out[tid] = a[tid] + b[tid]

@wp.kernel
def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)):
tid = wp.tid()
out[tid] = alpha * x[tid] + y[tid]

def multi_stage_py(
a: wp.array(dtype=float),
b: wp.array(dtype=float),
alpha: float,
tmp: wp.array(dtype=float),
out: wp.array(dtype=float),
):
wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp])
wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out])

j = jax_callable(multi_stage_py, num_outputs=2)

per_device = 9
ndev = jax.local_device_count()
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
b = jp.ones((ndev, per_device), dtype=jp.float32)
alpha = 2.5

def per_device_fwd(aa, bb):
tmp, out = j(aa, bb, alpha)
return tmp + out

combined = jax.pmap(per_device_fwd)(a, b)

a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
b_np = np.ones((ndev, per_device), dtype=np.float32)
tmp_ref = a_np + b_np
out_ref = alpha * (a_np + b_np) + b_np
ref = tmp_ref + out_ref
test.assertTrue(np.allclose(np.asarray(combined), ref, rtol=1e-5, atol=1e-6))


class TestJax(unittest.TestCase):
pass

Expand Down Expand Up @@ -940,6 +1071,25 @@ class TestJax(unittest.TestCase):
# ffi callback tests
add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)

add_function_test(
TestJax,
"test_jax_callable_pmap_multi_output_forward",
test_jax_callable_pmap_multi_output_forward,
devices=jax_compatible_cuda_devices,
)
add_function_test(
TestJax,
"test_jax_callable_pmap_mul_forward",
test_jax_callable_pmap_mul_forward,
devices=jax_compatible_cuda_devices,
)
add_function_test(
TestJax,
"test_jax_callable_pmap_multi_stage_forward",
test_jax_callable_pmap_multi_stage_forward,
devices=jax_compatible_cuda_devices,
)


except Exception as e:
print(f"Skipping Jax tests due to exception: {e}")
Expand Down
Loading