Skip to content

Commit e8f4519

Browse files
Add device ordinal xla plumbing
Signed-off-by: Chase Riley Roberts <[email protected]>
1 parent e1f4f20 commit e8f4519

File tree

3 files changed

+214
-28
lines changed

3 files changed

+214
-28
lines changed

warp/jax_experimental/ffi.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323

2424
import warp as wp
2525
from warp.codegen import get_full_arg_spec, make_full_qualified_name
26-
from warp.jax import get_jax_device
2726
from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp
2827

2928
from .xla_ffi import *
3029

30+
WARP_DEVICE_CONTEXT_LOCK = threading.Lock()
31+
3132

3233
def check_jax_version():
3334
# check if JAX version supports this
@@ -244,9 +245,10 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
244245
input_output_aliases=self.input_output_aliases,
245246
)
246247

247-
# ensure the kernel module is loaded before the callback, otherwise graph capture may fail
248-
device = wp.device_from_jax(get_jax_device())
249-
self.kernel.module.load(device)
248+
# ensure the kernel module is loaded on all local GPUs to avoid per-device build races
249+
for d in jax.local_devices():
250+
dev = wp.device_from_jax(d)
251+
self.kernel.module.load(dev)
250252

251253
# save launch data to be retrieved by callback
252254
launch_id = self.launch_id
@@ -321,7 +323,7 @@ def ffi_callback(self, call_frame):
321323
arg_refs.append(arg) # keep a reference
322324

323325
# get device and stream
324-
device = wp.device_from_jax(get_jax_device())
326+
device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents))
325327
stream = get_stream_from_callframe(call_frame.contents)
326328

327329
# get kernel hooks
@@ -516,11 +518,11 @@ def __call__(self, *args, output_dims=None, vmap_method=None):
516518
# has_side_effect=True, # force this function to execute even if outputs aren't used
517519
)
518520

519-
# load the module
520-
# NOTE: if the target function uses kernels from different modules, they will not be loaded here
521-
device = wp.device_from_jax(get_jax_device())
521+
# Preload relevant modules across local GPUs to avoid per-device build races (e.g., __main__)
522522
module = wp.get_module(self.func.__module__)
523-
module.load(device)
523+
for d in jax.local_devices():
524+
dev = wp.device_from_jax(d)
525+
module.load(dev)
524526

525527
# save call data to be retrieved by callback
526528
call_id = self.call_id
@@ -592,10 +594,9 @@ def ffi_callback(self, call_frame):
592594

593595
# early out
594596
return
595-
596-
device = wp.device_from_jax(get_jax_device())
597+
device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
598+
device = wp.get_cuda_device(device_ordinal)
597599
stream = wp.Stream(device, cuda_stream=cuda_stream)
598-
599600
# reconstruct the argument list
600601
arg_list = []
601602

@@ -619,23 +620,26 @@ def ffi_callback(self, call_frame):
619620
arg_list.append(arr)
620621

621622
# call the Python function with reconstructed arguments
622-
with wp.ScopedStream(stream, sync_enter=False):
623-
if stream.is_capturing:
624-
# capturing with JAX
625-
with wp.ScopedCapture(external=True) as capture:
623+
# Lock is required here to prevent wp.ScopedStreams from overwriting each other
624+
# when XLA calls this method from multiple threads.
625+
with WARP_DEVICE_CONTEXT_LOCK:
626+
with wp.ScopedStream(stream, sync_enter=True):
627+
if stream.is_capturing:
628+
# capturing with JAX
629+
with wp.ScopedCapture(external=True) as capture:
630+
self.func(*arg_list)
631+
# keep a reference to the capture object to prevent required modules getting unloaded
632+
call_desc.capture = capture
633+
elif self.graph_mode == GraphMode.WARP:
634+
# capturing with WARP
635+
with wp.ScopedCapture() as capture:
636+
self.func(*arg_list)
637+
wp.capture_launch(capture.graph)
638+
# keep a reference to the capture object and reuse it with same buffers
639+
call_desc.captures[buffer_hash] = capture
640+
else:
641+
# not capturing
626642
self.func(*arg_list)
627-
# keep a reference to the capture object to prevent required modules getting unloaded
628-
call_desc.capture = capture
629-
elif self.graph_mode == GraphMode.WARP:
630-
# capturing with WARP
631-
with wp.ScopedCapture() as capture:
632-
self.func(*arg_list)
633-
wp.capture_launch(capture.graph)
634-
# keep a reference to the capture object and reuse it with same buffers
635-
call_desc.captures[buffer_hash] = capture
636-
else:
637-
# not capturing
638-
self.func(*arg_list)
639643

640644
except Exception as e:
641645
print(traceback.format_exc())

warp/jax_experimental/xla_ffi.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,24 @@ class XLA_FFI_Stream_Get_Args(ctypes.Structure):
379379
XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
380380

381381

382+
# struct XLA_FFI_DeviceOrdinal_Get {
383+
# size_t struct_size;
384+
# XLA_FFI_Extension_Base* extension_start;
385+
# XLA_FFI_ExecutionContext* ctx;
386+
# int32_t device_ordinal; // out
387+
# };
388+
class XLA_FFI_DeviceOrdinal_Get_Args(ctypes.Structure):
389+
_fields_ = (
390+
("struct_size", ctypes.c_size_t),
391+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
392+
("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
393+
("device_ordinal", ctypes.c_int32),
394+
) # // out
395+
396+
397+
XLA_FFI_DeviceOrdinal_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_DeviceOrdinal_Get_Args))
398+
399+
382400
# struct XLA_FFI_Api {
383401
# size_t struct_size;
384402
# XLA_FFI_Extension_Base* extension_start;
@@ -402,6 +420,8 @@ class XLA_FFI_Stream_Get_Args(ctypes.Structure):
402420
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_Create);
403421
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetAvailable);
404422
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
423+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_RunId_Get);
424+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceOrdinal_Get);
405425
# };
406426
class XLA_FFI_Api(ctypes.Structure):
407427
_fields_ = (
@@ -425,6 +445,9 @@ class XLA_FFI_Api(ctypes.Structure):
425445
("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
426446
("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
427447
("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
448+
# TODO(chaserileyroberts): Make this return the correct value and not a c_void_p.
449+
("XLA_FFI_RunId_Get", ctypes.c_void_p), # XLA_FFI_RunId_Get
450+
("XLA_FFI_DeviceOrdinal_Get", XLA_FFI_DeviceOrdinal_Get), # XLA_FFI_DeviceOrdinal_Get
428451
)
429452

430453

@@ -570,6 +593,15 @@ def get_stream_from_callframe(call_frame):
570593
return get_stream_args.stream
571594

572595

596+
def get_device_ordinal_from_callframe(call_frame):
597+
api = call_frame.api
598+
get_device_args = XLA_FFI_DeviceOrdinal_Get_Args(
599+
ctypes.sizeof(XLA_FFI_DeviceOrdinal_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, 0
600+
)
601+
api.contents.XLA_FFI_DeviceOrdinal_Get(get_device_args)
602+
return get_device_args.device_ordinal
603+
604+
573605
_dtype_from_ffi = {
574606
XLA_FFI_DataType.S8: wp.int8,
575607
XLA_FFI_DataType.S16: wp.int16,

warp/tests/interop/test_jax.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,137 @@ def warp_func(inputs, outputs, attrs, ctx):
792792
assert_np_equal(d, 2 * np.arange(10, dtype=np.float32).reshape((5, 2)))
793793

794794

795+
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
796+
def test_jax_callable_pmap_mul_forward(test, device):
797+
import jax
798+
import jax.numpy as jp
799+
800+
from warp.jax_experimental.ffi import jax_callable
801+
802+
if jax.local_device_count() < 2:
803+
test.skipTest("requires >= 2 local devices")
804+
805+
@wp.kernel
806+
def mul2(a: wp.array(dtype=float), out: wp.array(dtype=float)):
807+
tid = wp.tid()
808+
out[tid] = 2.0 * a[tid]
809+
810+
def mul2_py(a: wp.array(dtype=float), out: wp.array(dtype=float)):
811+
wp.launch(mul2, dim=a.shape, inputs=[a], outputs=[out])
812+
813+
j = jax_callable(mul2_py, num_outputs=1)
814+
815+
per_device = 8
816+
ndev = jax.local_device_count()
817+
x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
818+
819+
def per_device_fwd(v):
820+
(y,) = j(v)
821+
return y
822+
823+
y = jax.pmap(per_device_fwd)(x)
824+
test.assertTrue(np.allclose(np.asarray(y), 2.0 * np.asarray(x), rtol=1e-5, atol=1e-6))
825+
826+
827+
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
828+
def test_jax_callable_pmap_multi_output_forward(test, device):
829+
import jax
830+
import jax.numpy as jp
831+
832+
from warp.jax_experimental.ffi import jax_callable
833+
834+
if jax.local_device_count() < 2:
835+
test.skipTest("requires >= 2 local devices")
836+
837+
@wp.kernel
838+
def multi_out(
839+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
840+
):
841+
tid = wp.tid()
842+
c[tid] = a[tid] + b[tid]
843+
d[tid] = s * a[tid]
844+
845+
def multi_out_py(
846+
a: wp.array(dtype=float),
847+
b: wp.array(dtype=float),
848+
s: float,
849+
c: wp.array(dtype=float),
850+
d: wp.array(dtype=float),
851+
):
852+
wp.launch(multi_out, dim=a.shape, inputs=[a, b, s], outputs=[c, d])
853+
854+
j = jax_callable(multi_out_py, num_outputs=2)
855+
856+
per_device = 7
857+
ndev = jax.local_device_count()
858+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
859+
b = jp.ones((ndev, per_device), dtype=jp.float32)
860+
s = 3.0
861+
862+
def per_device_fwd(aa, bb):
863+
c, d = j(aa, bb, s)
864+
return c + d # simple combine to exercise both outputs
865+
866+
out = jax.pmap(per_device_fwd)(a, b)
867+
868+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
869+
b_np = np.ones((ndev, per_device), dtype=np.float32)
870+
ref = (a_np + b_np) + s * a_np
871+
test.assertTrue(np.allclose(np.asarray(out), ref, rtol=1e-5, atol=1e-6))
872+
873+
874+
@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
875+
def test_jax_callable_pmap_multi_stage_forward(test, device):
876+
import jax
877+
import jax.numpy as jp
878+
879+
from warp.jax_experimental.ffi import jax_callable
880+
881+
if jax.local_device_count() < 2:
882+
test.skipTest("requires >= 2 local devices")
883+
884+
@wp.kernel
885+
def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), out: wp.array(dtype=float)):
886+
tid = wp.tid()
887+
out[tid] = a[tid] + b[tid]
888+
889+
@wp.kernel
890+
def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)):
891+
tid = wp.tid()
892+
out[tid] = alpha * x[tid] + y[tid]
893+
894+
def multi_stage_py(
895+
a: wp.array(dtype=float),
896+
b: wp.array(dtype=float),
897+
alpha: float,
898+
tmp: wp.array(dtype=float),
899+
out: wp.array(dtype=float),
900+
):
901+
wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp])
902+
wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out])
903+
904+
j = jax_callable(multi_stage_py, num_outputs=2)
905+
906+
per_device = 9
907+
ndev = jax.local_device_count()
908+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
909+
b = jp.ones((ndev, per_device), dtype=jp.float32)
910+
alpha = 2.5
911+
912+
def per_device_fwd(aa, bb):
913+
tmp, out = j(aa, bb, alpha)
914+
return tmp + out
915+
916+
combined = jax.pmap(per_device_fwd)(a, b)
917+
918+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
919+
b_np = np.ones((ndev, per_device), dtype=np.float32)
920+
tmp_ref = a_np + b_np
921+
out_ref = alpha * (a_np + b_np) + b_np
922+
ref = tmp_ref + out_ref
923+
test.assertTrue(np.allclose(np.asarray(combined), ref, rtol=1e-5, atol=1e-6))
924+
925+
795926
class TestJax(unittest.TestCase):
796927
pass
797928

@@ -940,6 +1071,25 @@ class TestJax(unittest.TestCase):
9401071
# ffi callback tests
9411072
add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
9421073

1074+
add_function_test(
1075+
TestJax,
1076+
"test_jax_callable_pmap_multi_output_forward",
1077+
test_jax_callable_pmap_multi_output_forward,
1078+
devices=jax_compatible_cuda_devices,
1079+
)
1080+
add_function_test(
1081+
TestJax,
1082+
"test_jax_callable_pmap_mul_forward",
1083+
test_jax_callable_pmap_mul_forward,
1084+
devices=jax_compatible_cuda_devices,
1085+
)
1086+
add_function_test(
1087+
TestJax,
1088+
"test_jax_callable_pmap_multi_stage_forward",
1089+
test_jax_callable_pmap_multi_stage_forward,
1090+
devices=jax_compatible_cuda_devices,
1091+
)
1092+
9431093

9441094
except Exception as e:
9451095
print(f"Skipping Jax tests due to exception: {e}")

0 commit comments

Comments
 (0)