Skip to content

Conversation

mehdiataei
Copy link
Contributor

@mehdiataei mehdiataei commented Aug 14, 2025

Description

Added JAX interop adjoint support using FFI. Also added support for static inputs (with auto-detection), vmap, multi-output with AD. Fixed some bugs related to vmap for the forward pass.

This PR also enables distributed forward and backward multi-GPU runs (see test cases) and an inverse kinematic example.

This is a new area for me, so there may be some oddities or poor decisions, but all tests pass. We may need to discuss any missing features and whether this is the way the user intends to use it.

This work is inspired by and thanks to J. Sevcik
https://gist.github.com/jaro-sevcik/893ffb891564dfc7617c8f12f3c2d1d2

Before your PR is "Ready for review"

  • All commits are signed-off to indicate that your contribution adheres to the Developer Certificate of Origin requirements
  • Necessary tests have been added
  • Documentation is up-to-date
  • Auto-generated files modified by compiling Warp and building the documentation have been updated (e.g. __init__.pyi, functions.rst)
  • Code passes formatting and linting checks with pre-commit run -a

Summary by CodeRabbit

  • New Features

    • Differentiable JAX support for Warp kernels, including AD for scalars, vectors, matrices, multi-output kernels, and per-sample gradients.
  • Documentation

    • Expanded AD/JAX interoperability guide with usage examples, VMAP options, launch-dim rules, per-sample gradient patterns, troubleshooting, and migration notes.
  • API

    • Kernel/callable wrappers accept differentiable and static-arg options to enable AD workflows.
  • Tests

    • Large set of new interop tests covering AD, vmap/pmap, multi-output and forward/backward scenarios.

@shi-eric shi-eric requested a review from nvlukasz August 14, 2025 19:13
@shi-eric shi-eric mentioned this pull request Aug 23, 2025
Copy link

@johnpjf johnpjf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this!
I left one comment, but another overall comment is whether we can get a test of this working with multiple devices? I think the older way to do this would be within a jax.pmap but there may be newer ways.

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Aug 26, 2025

Thanks for doing this! I left one comment, but another overall comment is whether we can get a test of this working with multiple devices? I think the older way to do this would be within a jax.pmap but there may be newer ways.

Thank John! I have added an example to the doc with shard_map here for the forward pass, which as far as I know is the most up-to-date way of doing this (maybe @nouiz can confirm): https://nvidia.github.io/warp/modules/interoperability.html#distributed-computation

I can add an example to show backward pass works with shard_map as well (I don't see a reason why not).

@nouiz
Copy link
Contributor

nouiz commented Aug 26, 2025

Thanks for doing this! I left one comment, but another overall comment is whether we can get a test of this working with multiple devices? I think the older way to do this would be within a jax.pmap but there may be newer ways.

Thank John! I have added an example to the doc with shard_map here for the forward pass, which as far as I know is the most up-to-date way of doing this (maybe @nouiz can confirm): https://nvidia.github.io/warp/modules/interoperability.html#distributed-computation

shard_map is the right manual way of doing sharding.

@johnpjf
Copy link

johnpjf commented Aug 26, 2025

FYI, I'm trying this on a two GPU machine and getting an error:


@wp.kernel
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
  tid = wp.tid()  # pytype: disable=module-attr
  output[tid] = 3.0 * input[tid]

class WarpDeferredTest(googletest.TestCase):

  def test_give_me_a_name(self):

    jax_triple = jax_experimental.jax_ad_kernel(triple_kernel)

    def f(x):
      return jax_triple(x)

    @jax.pmap
    def compute_grads(x):
      def loss_fn(x):
        return jnp.sum(jnp.square(f(x)[0]))

      loss, grads = jax.value_and_grad(
          loss_fn, has_aux=False)(x)
      return grads

    x = jnp.broadcast_to(jnp.arange(4, dtype=jnp.float32)[jnp.newaxis], (jax.local_device_count(), 4))
    print(compute_grads(x))



if __name__ == "__main__":
  googletest.main()

This fails with:

Module __main__ 89dc8ef load on device 'cuda:0' took 318.63 ms  (compiled)
Warp CUDA error 400: invalid resource handle (in function wp_cuda_launch_kernel, third_party/py/warp/native/warp.cu:4061)

Notably, if I hide one of the GPUs (with export CUDA_VISIBLE_DEVICES=0) so that I only have 1, the tests passes.

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Aug 26, 2025

Hey @johnpjf
Yes, if the examples are not executed with mpirun, we get that error. I’m not sure what the root cause is, but it’s probably a stream, event, function, or module created on GPU 0 being accidentally used while GPU 1 is current. Running with mpirun fixes it. I made the following test cases to check whether the grad works with shardmap, and they all pass successfully. However, I’m having trouble merging these tests with the current Warp test cases since they rely on mpirun.

import warp as wp
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather as allgather
from jax.experimental.shard_map import shard_map
from warp.jax_experimental.ffi import jax_ad_kernel

jax.distributed.initialize()


# Kernels
@wp.kernel
def multiply_by_two_kernel(a_in: wp.array(dtype=wp.float32), a_out: wp.array(dtype=wp.float32)):
    index = wp.tid()
    a_out[index] = a_in[index] * 2.0


@wp.kernel
def scale_sum_square_kernel(
    a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32), s: float, out: wp.array(dtype=wp.float32)
):
    tid = wp.tid()
    out[tid] = (a[tid] * s + b[tid]) ** 2.0


# multi-output kernel
@wp.kernel
def multi_output_kernel(
    a: wp.array(dtype=wp.float32),
    b: wp.array(dtype=wp.float32),
    s: float,
    c: wp.array(dtype=wp.float32),
    d: wp.array(dtype=wp.float32),
):
    tid = wp.tid()
    c[tid] = a[tid] ** 2.0
    d[tid] = a[tid] * b[tid] * s


# JAX callables with AD
jax_mul = jax_ad_kernel(multiply_by_two_kernel, num_outputs=1)
jax_scale = jax_ad_kernel(scale_sum_square_kernel, num_outputs=1, static_argnames=("s",))
jax_mo = jax_ad_kernel(multi_output_kernel, num_outputs=2, static_argnames=("s",))


def warp_multiply(x):
    (out,) = jax_mul(x)
    return out


def warp_scale_sum_square(a, b, s):
    (out,) = jax_scale(a, b, s)
    return out


def warp_multi_output(a, b, s):
    c, d = jax_mo(a, b, s)
    return c, d


def make_sharded_1d(mesh, n):
    sharding = jax.sharding.NamedSharding(mesh, P("x"))
    a = jnp.arange(n, dtype=jnp.float32)
    shape = (n,)
    arrays = [jax.device_put(a[idx], d) for d, idx in sharding.addressable_devices_indices_map(shape).items()]
    return jax.make_array_from_single_device_arrays(shape, sharding, arrays)


def example_mul2_backward(mesh):
    n = jax.device_count() * 5
    x = make_sharded_1d(mesh, n)

    def loss(x):
        y = shard_map(lambda v: warp_multiply(v), mesh=mesh, in_specs=(P("x"),), out_specs=P("x"), check_rep=False)(x)
        return jnp.sum(y)

    g = jax.grad(loss)(x)
    gf = allgather(g)
    if jax.process_index() == 0:
        ref = np.full(n, 2.0, dtype=np.float32)
        np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
        print("mul2: OK")


def example_scale_sum_square_backward(mesh):
    n = jax.device_count() * 6
    a = make_sharded_1d(mesh, n)
    b = make_sharded_1d(mesh, n)
    s = 1.5

    def loss(a, b):
        y = shard_map(
            lambda aa, bb: warp_scale_sum_square(aa, bb, s),
            mesh=mesh,
            in_specs=(P("x"), P("x")),
            out_specs=P("x"),
            check_rep=False,
        )(a, b)
        return jnp.sum(y)

    da, db = jax.grad(loss, argnums=(0, 1))(a, b)
    daf, dbf = allgather(da), allgather(db)
    if jax.process_index() == 0:
        a_np = np.arange(n, dtype=np.float32)
        b_np = np.arange(n, dtype=np.float32)
        ref_da = 2.0 * (a_np * s + b_np) * s
        ref_db = 2.0 * (a_np * s + b_np)
        np.testing.assert_allclose(np.asarray(daf), ref_da, rtol=1e-5, atol=1e-6)
        np.testing.assert_allclose(np.asarray(dbf), ref_db, rtol=1e-5, atol=1e-6)
        print("scale_sum_square: OK")


def example_mul2_with_psum(mesh):
    n = jax.device_count() * 4
    x = make_sharded_1d(mesh, n)

    def body(v):
        local = warp_multiply(v)
        red = jnp.sum(local)
        total = jax.lax.psum(red, "x")
        return local, total

    def loss(x):
        local, total = shard_map(body, mesh=mesh, in_specs=(P("x"),), out_specs=(P("x"), P()), check_rep=False)(x)
        return jnp.sum(local) + total

    g = jax.grad(loss)(x)
    gf = allgather(g)
    if jax.process_index() == 0:
        ref = np.full(n, 4.0, dtype=np.float32)
        np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
        print("mul2 + psum: OK")


def example_psum_mean(mesh):
    n = jax.device_count() * 3
    x = make_sharded_1d(mesh, n)

    def body(v):
        local = warp_multiply(v)
        red = jnp.sum(local)
        total = jax.lax.psum(red, "x")
        mean = total / n
        return local, mean

    def loss(x):
        local, mean = shard_map(body, mesh=mesh, in_specs=(P("x"),), out_specs=(P("x"), P()), check_rep=False)(x)
        return jnp.sum(local) + mean

    g = jax.grad(loss)(x)
    gf = allgather(g)
    if jax.process_index() == 0:
        # d/dx sum(2x) = 2; d/dx mean(2x) = 2/n per element
        ref = np.full(n, 2.0 + 2.0 / n, dtype=np.float32)
        np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
        print("mul2 + psum mean: OK")


def example_multi_output_with_psum(mesh):
    n = jax.device_count() * 5
    a = make_sharded_1d(mesh, n)
    b = make_sharded_1d(mesh, n)
    s = 2.0

    def body(aa, bb):
        c, d = warp_multi_output(aa, bb, s)
        local = c + d
        total = jax.lax.psum(jnp.sum(local), "x")
        return local, total

    def loss(a, b):
        local, total = shard_map(body, mesh=mesh, in_specs=(P("x"), P("x")), out_specs=(P("x"), P()), check_rep=False)(
            a, b
        )
        return jnp.sum(local) + total

    da, db = jax.grad(loss, argnums=(0, 1))(a, b)
    daf, dbf = allgather(da), allgather(db)
    if jax.process_index() == 0:
        a_np = np.arange(n, dtype=np.float32)
        b_np = np.arange(n, dtype=np.float32)
        base_da = 2.0 * a_np + b_np * s
        base_db = a_np * s
        ref_da = 2.0 * base_da
        ref_db = 2.0 * base_db
        np.testing.assert_allclose(np.asarray(daf), ref_da, rtol=1e-5, atol=1e-6)
        np.testing.assert_allclose(np.asarray(dbf), ref_db, rtol=1e-5, atol=1e-6)
        print("multi_output + psum: OK")


def main():
    devices = jax.devices()
    mesh = jax.sharding.Mesh(np.array(devices), "x")
    example_mul2_backward(mesh)
    example_scale_sum_square_backward(mesh)
    example_mul2_with_psum(mesh)
    example_psum_mean(mesh)
    example_multi_output_with_psum(mesh)


if __name__ == "__main__":
    main()

You can run them with
mpirun -np 2 python test.py

@johnpjf
Copy link

johnpjf commented Aug 26, 2025 via email

@mehdiataei
Copy link
Contributor Author

It is fairly common to run multi-GPU workloads with multiple processes in JAX even on a single machine, so I don’t think this is a blocking issue (and this limitation existed before this PR). I have some ideas about the root cause and a potential fix, so let me take a shot at it first.

@johnpjf
Copy link

johnpjf commented Aug 26, 2025

Yes, multi-GPU training within a pmap is essentially a requirement for all of our workloads so it would be great to get a fix for this. Thanks

@mehdiataei
Copy link
Contributor Author

@johnpjf Good news. I was able to resolve the CUDA 400 issue by launching FFI on callframe CUDA stream and preload modules across GPUs to prevent per-device race conditions. With these changes in place, your example now runs correctly without requiring the use of mpirun. Now any multi-GPU (and distributed multi-GPU with MPI) should be working similar to JAX.

This fix also made it possible to integrate both the shardmap and pmap test cases directly into test_jax.py.

Can you give it a try? Thanks.

@johnpjf
Copy link

johnpjf commented Aug 26, 2025

Thanks, but I don't think it's working for me yet:

class WarpDeferredTest(googletest.TestCase):

  def test_give_me_a_name(self):

    jax_triple = jax_experimental.jax_ad_kernel(triple_kernel)

    def f(x):
      return jax_triple(x)

    @jax.pmap
    def compute_grads(x):
      def loss_fn(x):
        return jnp.sum(jnp.square(f(x)[0]))

      loss, grads = jax.value_and_grad(
          loss_fn, has_aux=False)(x)
      return loss, grads

    x = jnp.array(((1.0, 2.0), (1.0, 2.0)))
    print('input: ', x)

    loss, grads = compute_grads(x)
    print('loss: ', loss)
    print('grads: ', grads)

Outputs:

loss:  [45.  0.]
grads:  [[18. 36.]
 [ 0.  0.]]

So it seems like something went wrong on the second GPU.

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Aug 26, 2025

Hmm I can't reproduce it:

input:  [[1. 2.]
 [1. 2.]]
Warp 1.8.1 initialized:
   CUDA Toolkit 12.9, Driver 12.9
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA RTX 6000 Ada Generation" (47 GiB, sm_89, mempool enabled)
     "cuda:1"   : "NVIDIA RTX 6000 Ada Generation" (47 GiB, sm_89, mempool enabled)
   CUDA peer access:
     Supported fully (all-directional)
   Kernel cache:
     /home/mehdi/.cache/warp/1.8.1
Module warp.jax_experimental.ffi d07651c load on device 'cuda:0' took 0.41 ms  (cached)
Module test_john 09f893d load on device 'cuda:0' took 223.74 ms  (compiled)
Module warp.jax_experimental.ffi d07651c load on device 'cuda:1' took 0.35 ms  (cached)
Module test_john 09f893d load on device 'cuda:1' took 0.24 ms  (cached)
loss:  [45. 45.]
grads:  [[18. 36.]
 [18. 36.]]
.
----------------------------------------------------------------------
Ran 1 test in 1.761s

OK

My code:

import jax
import jax.numpy as jnp
import warp.jax_experimental as jax_experimental
import unittest
import warp as wp


@wp.kernel
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
  tid = wp.tid()
  output[tid] = 3.0 * input[tid]


class WarpDeferredTest(unittest.TestCase):

  def test_give_me_a_name(self):

    jax_triple = jax_experimental.jax_ad_kernel(triple_kernel)

    def f(x):
      return jax_triple(x)

    @jax.pmap
    def compute_grads(x):
      def loss_fn(x):
        return jnp.sum(jnp.square(f(x)[0]))

      loss, grads = jax.value_and_grad(
          loss_fn, has_aux=False)(x)
      return loss, grads

    x = jnp.array(((1.0, 2.0), (1.0, 2.0)))
    print('input: ', x)

    loss, grads = compute_grads(x)
    print('loss: ', loss)
    print('grads: ', grads)

Can you delete the kernel cache folder and try again?

@johnpjf
Copy link

johnpjf commented Aug 27, 2025 via email

@johnpjf
Copy link

johnpjf commented Aug 28, 2025

Ok, turns out I have a bad GPU, pmap works great on a different dual-GPU machine, thanks!

@mehdiataei
Copy link
Contributor Author

Great! I think for JAX to work you need to have identical GPUs (maybe that's the cause)

@nouiz
Copy link
Contributor

nouiz commented Aug 28, 2025

Great! I think for JAX to work you need to have identical GPUs (maybe that's the cause)

Exact. You can't mix different kinds of GPUs. This isn't supported.

@johnpjf
Copy link

johnpjf commented Aug 28, 2025

Yeah, it wasn't that, I have identical GPUs and have been happily pmap-ing for years and now it just stopped working and giving all zeros for the other shard, even in straight JAX.

@mehdiataei
Copy link
Contributor Author

Added an inverse kinematics example using JAX and Warp, and added changes that enable direct kernel wrapping, reducing the amount of code required when using the interop.
image

@shi-eric shi-eric added this to the 1.10.0 milestone Sep 1, 2025
@btaba
Copy link
Contributor

btaba commented Sep 6, 2025

I tried this PR with this code google-deepmind/mujoco_warp#475 (comment)
and the module loading works fine. But the graph capture doesn't work with pmap. I get

E0905 22:27:46.696872   14773 pjrt_stream_executor_client.cc:3008] Execution of replica 0 failed: UNKNOWN: FFI callback error: RuntimeError: Warp CUDA error 904: capturing stream has unjoined work (in function wp_cuda_graph_pause_capture, /builds/omniverse/warp/warp/native/warp.cu:2948)                                                                                                                 
E0905 22:27:46.697047   14776 pjrt_stream_executor_client.cc:3008] Execution of replica 1 failed: UNKNOWN: FFI callback error: RuntimeError: Warp CUDA error 904: capturing stream has unjoined work (in function wp_cuda_graph_pause_capture, /builds/omniverse/warp/warp/nat
ive/warp.cu:2948)     

To repro, I modified this file directly with the implementation in this PR. Any ideas what could be going wrong?

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Sep 8, 2025

@btaba could you share more details on the WIP ? I played around with this a bit over the weekend and it looks like the issue comes from how modules are being loaded in the “third_party” FFI (in Mujoco). I made some progress but was planning to dig deeper into it this week. Is this the issue you’re working on fixing?

@btaba
Copy link
Contributor

btaba commented Sep 10, 2025

@btaba could you share more details on the WIP ? I played around with this a bit over the weekend and it looks like the issue comes from how modules are being loaded in the “third_party” FFI (in Mujoco). I made some progress but was planning to dig deeper into it this week. Is this the issue you’re working on fixing?

I'm not actively working on it, just wanted to take a stab to try to get to a semi-working state. If you have some cycles, please let us know what you find! There is currently a race condition, and data on one device is corrupt. Other than that, the code at least runs with pmap (occasionally).

I'm a bit confused by this PR as __call__ uses the default device device = wp.get_device("cuda"), why is it not running on the device from the XLA context (in my branch, I do a bit of a hack to get the actual device from the XLA buffer)? It might make sense to split this PR into one for adjoint support and one for pmap support.

Args:
kernel: The Warp kernel to wrap.
num_outputs: Number of output arrays produced by the kernel.
static_argnames: Optional iterable of argument names that should be treated
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems there is some confusion in this function between static-for-jit arguments and arguments that you don't want the gradients for. I can have a jax array as an input that is not static-for-jit but that I don't want gradients for. For example, when ray tracing, I don't want gradients of rays, but I don't want them to be static for jit (i.e. they are regular JAX Arrays).

Copy link
Contributor Author

@mehdiataei mehdiataei Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I don't think in JAX you can pass jax-arrays as static argnums....? It was the case before and I believe it should the case now as arrays are not hashable. Is this still true @nouiz?

Maybe I don’t fully understand your request. Could you provide an example of something we currently do not support?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think my comment was a little wrong, but I do think this code slightly conflates static args for jitting and for custom gradients. I don't really think this code should be concerned with jit-ing at all, static args for this might not be (and usually won't be!) static args for jit.
For specific changes:
Also, you should call these nondiff_argnums to match JAX's terminology. https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html
I also wouldn't try to be clever and use heuristics to figure what should or should not be nondiff_argnums, that could cause silent bugs where someone forgets to wrap a float in a jnp.array.
I think we should also allow for JAX input array's that shouldn't have gradients computed (i.e. as if we had wrapped them in jax.stop_gradient). I don't know what to call these or if Warp really supports not computing gradients correctly.

as static (non-differentiable) in JAX. If None, scalar (non-array) inputs
are treated as static by default.
vmap_method: How the callback transforms under jax.vmap.
launch_dim_arg_index: Index (in the kernel's argument list) of the input
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Sorry was in wrong place before)
I don't this will work for everyone, for example if I want to do something for every pixel in an RGB image, then I want the kernel to be over the height,width but not channel dimension. Also, this is different than the existing jax_kernel interface. I wonder if it's better to force providing these explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we do currently handle this when channels are represented as dtype dims (e.g., wp.vec3) and appear as the trailing axis in the JAX array. I pushed a test to demonstrate it.

Caveat: this works when channels are encoded as dtype dims (e.g., vec3) and are trailing in the JAX tensor. If someone insists on modeling channels as an extra array axis in the Warp type (e.g., wp.array3d(dtype=float)), then the kernel would launch over H×W×C. If we need to support that representation while still launching over H×W, we can extend jax_ad_kernel to accept an explicit launch_dims or launch_dims_axes parameter and update both forward/backward wrappers accordingly.

Please let me know what you think.

@nouiz @nvlukasz would be great to chime in .

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I just don't think it's a good idea to hard code that the launch dims must batch one of the input dims, you can definitely think of examples where that wouldn't be true.

@mehdiataei
Copy link
Contributor Author

@btaba could you share more details on the WIP ? I played around with this a bit over the weekend and it looks like the issue comes from how modules are being loaded in the “third_party” FFI (in Mujoco). I made some progress but was planning to dig deeper into it this week. Is this the issue you’re working on fixing?

I'm not actively working on it, just wanted to take a stab to try to get to a semi-working state. If you have some cycles, please let us know what you find! There is currently a race condition, and data on one device is corrupt. Other than that, the code at least runs with pmap (occasionally).

I'm a bit confused by this PR as __call__ uses the default device device = wp.get_device("cuda"), why is it not running on the device from the XLA context (in my branch, I do a bit of a hack to get the actual device from the XLA buffer)? It might make sense to split this PR into one for adjoint support and one for pmap support.

Thanks. I did try to fix the issue in Mujoco. I had a bit of success but it seems like there are multiple versions of warp in that repo that needs to be synced based on the same ffi implementation.

The use of wp.get_device("cuda") in the FFI path is that inside the FFI callback we bind the Warp stream to the CUDA stream that XLA passes us and rely on the current CUDA context that XLA sets for the shard. That guarantees the stream/device match across pmap replicas and avoids subtle mismatches that can occur when re-deriving the device from a JAX device handle or from a buffer pointer. Note that the stream itself comes from XLA and aot we preload everything to avoid build races.

If you try pmap with this Warp branch, it works quite well on multi-GPU devices. All the changes above were made in response to different issues that came up during the implementation.

@btaba
Copy link
Contributor

btaba commented Sep 23, 2025

@mehdiataei I'm also a bit confused, the graph_mode is no longer respected in ffi_callback

if use_xla_stream:
with wp.ScopedDevice(device):
with wp.ScopedStream(ffi_stream, sync_enter=False), wp.context.ScopedExternalStream(ffi_stream):
with _temporarily_disable_backward():
self.func(*arg_list)
else:
# Use buffer-device default stream; explicitly clear any external stream
with wp.context.ScopedExternalStream(None):
with wp.ScopedDevice(device):
if self.graph_mode == GraphMode.WARP:

and the JAX graph mode is removed?

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Sep 24, 2025

@btaba

uv venv
source .venv/bin/activate
git clone https://github.com/google-deepmind/mujoco.git
cd mujoco/mjx
uv pip install -e .
uv pip install jax["cuda12"]
uv pip install -e <ADDRESS_TO_WARP>/warp/ # Note that we need to build warp first build_lib.py
# Copy the content of ffi.py into ffi.py file in mujoco's third_party/warp
cd ../../
python <example>.py

the above worked for me can you try?

@btaba
Copy link
Contributor

btaba commented Oct 3, 2025

Thanks @mehdiataei , I just tried with the latest branch but still get zeros on the second device. I'll have to try on another machine; do you have any ideas otherwise?

@nvlukasz
Copy link
Contributor

nvlukasz commented Oct 3, 2025

Thanks @mehdiataei , I just tried with the latest branch but still get zeros on the second device. I'll have to try on another machine; do you have any ideas otherwise?

@btaba @mehdiataei, I have a pmap solution based on GH-976 from @chaserileyroberts. Thanks Chase! It's currently under review by the Warp team, but I expect it to be merged shortly.

I had to add some additional thread safety, but the Mujoco example works fine now. That should take care of pmap, so once that's merged, @mehdiataei you can modify this PR to only add the autodiff functionality.

Thank you all for your contributions and patience!

@mehdiataei
Copy link
Contributor Author

Thanks @mehdiataei , I just tried with the latest branch but still get zeros on the second device. I'll have to try on another machine; do you have any ideas otherwise?

hmm no idea. I did follow thoese steps again and it works fine with me. Is it working correctly with pure jax and pmap?

@nvlukasz That's great! Yes. I can merge those changes and add the adjoint stuff.

@nvlukasz
Copy link
Contributor

nvlukasz commented Oct 3, 2025

@nvlukasz That's great! Yes. I can merge those changes and add the adjoint stuff.

Alright, the new pmap support is now merged (599ee98)

Copy link

coderabbitai bot commented Oct 6, 2025

Warning

Rate limit exceeded

@mehdiataei has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 20 minutes and 53 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 4fee379 and 525285c.

📒 Files selected for processing (3)
  • docs/modules/interoperability.rst (1 hunks)
  • warp/jax_experimental/ffi.py (9 hunks)
  • warp/tests/interop/test_jax.py (2 hunks)
📝 Walkthrough

Walkthrough

Adds a differentiable JAX path for Warp FFI kernels: implements a custom-VJP differentiable wrapper cached separately, updates warp-aware shape/launch-dim inference and runtime handling, extends jax_kernel/jax_callable APIs with differentiable and static_argnames, expands AD-focused JAX interoperability docs, and adds extensive interop tests for AD, vmap/pmap, and static-arg scenarios.

Changes

Cohort / File(s) Summary
Docs: JAX AD integration
docs/modules/interoperability.rst
Adds comprehensive AD docs for Warp↔JAX: usage patterns for jax_kernel/jax_callable with differentiable=True and static_argnames, vmap methods, per-sample gradients, launch-dim inference, VMAP guidance, examples, error cases, and migration notes (FFI vs annotated Python).
Core JAX FFI: differentiable path & shape handling
warp/jax_experimental/ffi.py
Introduces a differentiable code path (custom VJP) with a separate _FFI_DIFF_KERNEL_REGISTRY and new DiffKernelCacheKey alias; extends jax_kernel and jax_callable signatures to accept differentiable: bool = False and static_argnames=None; implements warp-dim-aware shape extraction (get_warp_shape/warp_ndim/batch_ndim), constructs/caches forward/backward wrappers (with signature inspection), integrates with JAX autograd, disallows in_out_argnames for differentiable path, and adjusts output-type construction and caching behavior.
Tests: expanded AD/vmap/pmap interop coverage
warp/tests/interop/test_jax.py
Adds many new JAX/FFI interop tests covering differentiable kernels (scalar/vector/matrix/multi-output), JIT-of-grad cases, pmap/per-device gradients, static-arg handling, graph/FFI pathways, and conditional registration/skips based on JAX version/device.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant U as User code
    participant JK as jax_kernel / jax_callable
    participant REGd as _FFI_DIFF_KERNEL_REGISTRY
    participant REG as _FFI_KERNEL_REGISTRY
    participant WK as Warp kernel
    participant AK as Adjoint kernel

    Note over U,JK: create callable (maybe differentiable)
    U->>JK: jax_kernel(kernel, differentiable=True, static_argnames)
    JK->>REGd: lookup DiffKernelCacheKey
    alt cache miss
        JK->>JK: inspect signature & static args
        JK->>JK: build forward wrapper (launch dims, warp shapes)
        JK->>JK: build backward (custom_vjp) that calls adjoint
        JK->>REGd: cache diff wrapper
    else cache hit
        REGd-->>JK: return cached wrapper
    end
    JK-->>U: return JAX-callable

    Note over U,WK: forward call
    U->>JK: call(args)
    JK->>WK: infer warp dims, launch kernel
    WK-->>JK: outputs
    JK-->>U: outputs

    Note over U,AK: reverse pass
    U->>JK: JAX requests grad
    JK->>AK: call adjoint with cotangents (warp-aware shapes)
    AK-->>JK: input gradients
    JK-->>U: return gradients
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.65% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title “added JAX interop adjoint support using FFI” concisely and accurately reflects the primary purpose of the pull request, which is to introduce JAX automatic differentiation support for Warp kernels via an FFI-based path. It is specific to the main feature change and avoids unnecessary detail or noise.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
warp/jax_experimental/ffi.py (2)

1465-1472: Fix warp-shape inference under vmap (use trailing dims, not leading).

Current get_warp_shape() returns the first warp_ndim dims, which is wrong when leading batch dims exist. Use the last warp_ndim dims before dtype dims. This affects launch dims and output type construction.

-def get_warp_shape(arg, dims):
-    if arg.dtype_ndim > 0:
-        # vector/matrix array
-        return dims[: arg.warp_ndim]
-    else:
-        # scalar array
-        return dims
+def get_warp_shape(arg, dims):
+    # dims may include leading batch dims and trailing dtype dims
+    jax_rank = len(dims) - (arg.dtype_ndim if arg.dtype_ndim > 0 else 0)
+    if jax_rank < arg.warp_ndim:
+        raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
+    start = jax_rank - arg.warp_ndim
+    end = jax_rank
+    return tuple(dims[start:end])

324-347: Enforce stream/pointer device-ordinal agreement in FfiKernel callback.

FfiKernel.ffi_callback derives device from the XLA stream but doesn’t validate array pointer ordinals. Under pmap this risks cross-device launches. Mirror FfiCallable logic: prefer pointer ordinal; only bind the XLA stream when ordinals match.

-            # get stream and derive device from stream to be replica-local under pmap
-            stream_handle = get_stream_from_callframe(call_frame.contents)
-            try:
-                ordinal = wp.context.runtime.core.wp_cuda_stream_get_device_ordinal(stream_handle)
-                device = wp.get_cuda_device(ordinal)
-            except Exception:
-                device = wp.device_from_jax(get_jax_device())
-            stream = wp.Stream(device, cuda_stream=stream_handle)
+            # get stream and derive device/stream safely under pmap
+            stream_handle = get_stream_from_callframe(call_frame.contents)
+            try:
+                stream_ord = wp.context.runtime.core.wp_cuda_stream_get_device_ordinal(stream_handle)
+            except Exception:
+                stream_ord = -1
+            # get first array input pointer ordinal (if any)
+            ptr_ord = -1
+            for i, arg in enumerate(self.input_args):
+                if arg.is_array:
+                    buf = inputs[i].contents
+                    ptr_ord = wp.context.runtime.core.wp_cuda_pointer_get_device_ordinal(ctypes.c_void_p(buf.data))
+                    break
+            ord_to_use = ptr_ord if ptr_ord >= 0 else stream_ord
+            if ord_to_use < 0:
+                device = wp.device_from_jax(get_jax_device())
+                stream = device.stream  # fallback to device default stream
+            else:
+                device = wp.get_cuda_device(ord_to_use)
+                # Bind XLA stream only if ordinals agree; otherwise use device default stream
+                stream = (
+                    wp.Stream(device, cuda_stream=stream_handle)
+                    if (ptr_ord >= 0 and ptr_ord == stream_ord)
+                    else device.stream
+                )
warp/native/warp.cu (2)

2398-2401: Compile error: missing parentheses in if-statement

C++ requires parentheses around the condition.

Apply:

-    if check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS))
+    if (check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS)))
         return (void*) device_ptr;
     else
         return NULL;

2539-2543: Use check_cu for CUDA Driver API calls (cu)*

These lines use check_cuda with cu* functions; switch to check_cu for accurate error handling.

Apply (illustrative diffs):

-    check_cuda(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority));
+    check_cu(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority));
-    if (!check_cuda(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
+    if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))

Also applies to: 3062-3066, 3113-3116, 3221-3224, 3264-3267

♻️ Duplicate comments (1)
warp/tests/interop/test_jax.py (1)

352-375: Addresses prior feedback: added jit(grad(...)) tests

test_jax_ad_kernel_jit_of_grad_simple and ..._multi_output cover jit-of-grad usage. Looks good.

🧹 Nitpick comments (8)
warp/context.py (1)

5578-5595: Tighten exceptions and guard CUDA-only pointer-ordinal lookup

Good diagnostics, but avoid blind excepts and skip pointer-ordinal queries on non‑CUDA or null ptr. Also use getattr for ext stream safely.

Apply this diff:

-            # check device; provide detailed diagnostics on mismatch
-            if value.device != device:
-                try:
-                    ext_stream = _get_external_stream()
-                    ext_dev = ext_stream.device if ext_stream is not None else None
-                except Exception:
-                    ext_dev = None
-                try:
-                    ptr_ord = runtime.core.wp_cuda_pointer_get_device_ordinal(ctypes.c_void_p(value.ptr)) if value.ptr is not None else -1
-                except Exception:
-                    ptr_ord = -1
+            # check device; provide detailed diagnostics on mismatch
+            if value.device != device:
+                ext_stream = _get_external_stream()
+                ext_dev = getattr(ext_stream, "device", None)
+                ptr_ord = -1
+                if getattr(device, "is_cuda", False) and value.ptr is not None:
+                    try:
+                        ptr_ord = runtime.core.wp_cuda_pointer_get_device_ordinal(ctypes.c_void_p(value.ptr))
+                    except (AttributeError, OSError, ValueError):
+                        ptr_ord = -1
                 raise RuntimeError(
                     (
                         f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', "
                         f"but input array for argument '{arg_name}' is on device={value.device}. "
                         f"[debug ext_stream_device={ext_dev}, ptr_device_ordinal={ptr_ord}]"
                     )
                 )

Note: This also resolves Ruff BLE001/S110 warnings. [As per static analysis hints]

warp/examples/optim/example_inverse_kinematics_jax.py (2)

69-69: Avoid double-wrapping when creating JAX arrays.

Construct directly with JAX to skip the intermediate NumPy array.

-        self.target = jp.array(np.array((2.0, 1.0, 0.0), dtype=np.float32))
+        self.target = jp.array((2.0, 1.0, 0.0), dtype=jp.float32)

192-201: Narrow the exception and avoid silent pass when setting env vars.

Only import errors should be caught here; consider logging unexpected failures.

-    os_env = {}
-    try:
-        import os as _os
-
-        os_env = _os.environ
-        os_env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
-        os_env["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
-    except Exception:
-        pass
+    try:
+        import os as _os
+        _os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
+        _os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
+    except ImportError:
+        # JAX will use defaults; nothing to configure
+        pass
warp/jax_experimental/ffi.py (4)

30-30: Import get_stream_from_callframe explicitly.

Silences F405 and makes dependency clear.

-from .xla_ffi import *
+from .xla_ffi import *
+from .xla_ffi import get_stream_from_callframe

732-736: Remove unused call_desc.capture assignment.

This attribute is never read; use captures dict only to avoid confusion.

-                            if ffi_stream.is_capturing:
-                                with wp.ScopedCapture(external=True) as capture:
-                                    self.func(*arg_list)
-                                call_desc.capture = capture
+                            if ffi_stream.is_capturing:
+                                with wp.ScopedCapture(external=True) as capture:
+                                    self.func(*arg_list)

476-487: Allow “non‑diff” JAX arrays separately from “static” args; avoid heuristic defaults.

Current logic conflates nondifferentiable args with jit-static args and uses a heuristic to treat non-arrays as nondiff. This prevents common patterns (e.g., pass JAX arrays that shouldn’t receive grads but aren’t static for jit).

  • Add a nondiff_argnames parameter distinct from static_argnames.
  • Drop the heuristic; require explicit nondiff_argnames for clarity.
  • Keep static_argnames purely for jitting.
  • Continue to pass nondiff_argnums to custom_vjp.

Would you like a follow-up patch adding nondiff_argnames to both jax_callable and jax_ad_kernel and updating call sites?

Also applies to: 882-907


236-251: Module preloading across GPUs is good; consider caching to avoid repeated loads.

module.load(dev) is idempotent but may still cost. Cache per (module, device) to skip redundant calls during tracing.

If you want, I can wire a simple per-process cache keyed by (module.name, device.ordinal) and guard with _FFI_REGISTRY_LOCK.

Also applies to: 525-556

docs/modules/interoperability.rst (1)

956-1090: AD section reads well and matches API semantics

Examples and notes are clear (statics, VMAP options, multi-output). Consider adding a short cross-link near “JAX Foreign Function Interface (FFI)” pointing here, plus an anchor (e.g., .. _jax-ad:) for easier navigation.

📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cdad7fb and 01d087c.

📒 Files selected for processing (8)
  • docs/modules/interoperability.rst (1 hunks)
  • warp/context.py (6 hunks)
  • warp/examples/optim/example_inverse_kinematics_jax.py (1 hunks)
  • warp/jax_experimental/__init__.py (1 hunks)
  • warp/jax_experimental/ffi.py (18 hunks)
  • warp/native/warp.cu (1 hunks)
  • warp/native/warp.h (1 hunks)
  • warp/tests/interop/test_jax.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
warp/jax_experimental/__init__.py (1)
warp/jax_experimental/ffi.py (1)
  • jax_ad_kernel (1162-1445)
warp/examples/optim/example_inverse_kinematics_jax.py (6)
warp/jax_experimental/ffi.py (1)
  • jax_ad_kernel (1162-1445)
warp/sim/model.py (9)
  • ModelBuilder (1103-4821)
  • add_articulation (1398-1399)
  • joint_count (1354-1355)
  • add_body (1598-1657)
  • add_joint_revolute (1800-1888)
  • add_shape_sphere (2873-2940)
  • add_shape_box (2942-3011)
  • articulation_count (1386-1387)
  • color (4515-4555)
warp/sim/render.py (1)
  • SimRenderer (69-424)
warp/context.py (2)
  • zeros (5162-5187)
  • ones (5212-5233)
warp/jax.py (3)
  • to_jax (160-172)
  • from_jax (175-186)
  • device_to_jax (19-41)
warp/utils.py (1)
  • ScopedDevice (1194-1235)
warp/native/warp.h (2)
warp/native/warp.cpp (16)
  • int (1003-1003)
  • int (1004-1004)
  • int (1007-1007)
  • int (1010-1010)
  • int (1013-1013)
  • int (1014-1014)
  • int (1016-1016)
  • int (1017-1017)
  • int (1018-1018)
  • int (1019-1019)
  • int (1020-1020)
  • int (1021-1021)
  • int (1022-1022)
  • int (1036-1036)
  • int (1037-1037)
  • int (1041-1041)
warp/native/warp.cu (4)
  • wp_cuda_stream_get_device_ordinal (2545-2549)
  • wp_cuda_stream_get_device_ordinal (2545-2545)
  • wp_cuda_pointer_get_device_ordinal (2551-2569)
  • wp_cuda_pointer_get_device_ordinal (2551-2551)
warp/jax_experimental/ffi.py (4)
warp/jax.py (2)
  • device_from_jax (44-58)
  • get_jax_device (61-71)
warp/context.py (13)
  • load (2196-2448)
  • context (3031-3046)
  • get_cuda_device (4494-4502)
  • stream (3054-3063)
  • stream (3066-3067)
  • Stream (2690-2834)
  • ScopedExternalStream (2842-2851)
  • is_capturing (2827-2829)
  • is_capturing (3019-3028)
  • get_device (4325-4339)
  • get_device (4526-4531)
  • launch (5862-5903)
  • launch (5906-6164)
warp/jax_experimental/xla_ffi.py (1)
  • get_stream_from_callframe (554-561)
warp/native/warp.cu (4)
  • wp_cuda_stream_get_device_ordinal (2545-2549)
  • wp_cuda_stream_get_device_ordinal (2545-2545)
  • wp_cuda_pointer_get_device_ordinal (2551-2569)
  • wp_cuda_pointer_get_device_ordinal (2551-2551)
warp/native/warp.cu (3)
warp/native/cuda_util.h (2)
  • get_stream_context (155-162)
  • get_stream_context (164-167)
warp/native/warp.cpp (1)
  • wp_cuda_context_get_device_ordinal (1036-1036)
warp/native/cuda_util.cpp (2)
  • cuPointerGetAttribute_f (597-600)
  • cuPointerGetAttribute_f (597-597)
warp/tests/interop/test_jax.py (4)
warp/tests/interop/test_dlpack.py (1)
  • _jax_version (28-34)
warp/jax_experimental/ffi.py (2)
  • jax_ad_kernel (1162-1445)
  • jax_callable (825-1087)
warp/jax.py (1)
  • device_to_jax (19-41)
warp/tests/unittest_utils.py (1)
  • add_function_test (284-303)
warp/context.py (2)
warp/native/warp.cu (4)
  • wp_cuda_stream_get_device_ordinal (2545-2549)
  • wp_cuda_stream_get_device_ordinal (2545-2545)
  • wp_cuda_pointer_get_device_ordinal (2551-2569)
  • wp_cuda_pointer_get_device_ordinal (2551-2551)
warp/types.py (1)
  • is_array (2021-2023)
🪛 Ruff (0.13.3)
warp/examples/optim/example_inverse_kinematics_jax.py

148-148: Unused function argument: body_qd

(ARG001)


199-200: try-except-pass detected, consider logging the exception

(S110)


199-199: Do not catch blind exception: Exception

(BLE001)

warp/jax_experimental/ffi.py

179-181: Avoid specifying long messages outside the exception class

(TRY003)


239-239: Do not catch blind exception: Exception

(BLE001)


248-248: Avoid specifying long messages outside the exception class

(TRY003)


325-325: get_stream_from_callframe may be undefined, or defined from star imports

(F405)


329-329: Do not catch blind exception: Exception

(BLE001)


478-480: Avoid specifying long messages outside the exception class

(TRY003)


528-528: Do not catch blind exception: Exception

(BLE001)


537-538: try-except-continue detected, consider logging the exception

(S112)


537-537: Do not catch blind exception: Exception

(BLE001)


542-543: try-except-pass detected, consider logging the exception

(S110)


542-542: Do not catch blind exception: Exception

(BLE001)


544-545: try-except-pass detected, consider logging the exception

(S110)


544-544: Do not catch blind exception: Exception

(BLE001)


632-632: Abstract raise to an inner function

(TRY301)


632-632: Avoid specifying long messages outside the exception class

(TRY003)


661-663: Abstract raise to an inner function

(TRY301)


661-663: Avoid specifying long messages outside the exception class

(TRY003)


677-679: Abstract raise to an inner function

(TRY301)


677-679: Avoid specifying long messages outside the exception class

(TRY003)


683-685: Abstract raise to an inner function

(TRY301)


683-685: Avoid specifying long messages outside the exception class

(TRY003)


700-702: Abstract raise to an inner function

(TRY301)


700-702: Avoid specifying long messages outside the exception class

(TRY003)


715-717: Abstract raise to an inner function

(TRY301)


715-717: Avoid specifying long messages outside the exception class

(TRY003)


721-723: Abstract raise to an inner function

(TRY301)


721-723: Avoid specifying long messages outside the exception class

(TRY003)


927-927: Do not catch blind exception: Exception

(BLE001)


940-940: Avoid specifying long messages outside the exception class

(TRY003)


995-996: try-except-pass detected, consider logging the exception

(S110)


995-995: Do not catch blind exception: Exception

(BLE001)


1222-1222: Do not catch blind exception: Exception

(BLE001)


1285-1286: try-except-pass detected, consider logging the exception

(S110)


1285-1285: Do not catch blind exception: Exception

(BLE001)


1405-1405: Do not catch blind exception: Exception

(BLE001)


1414-1415: try-except-pass detected, consider logging the exception

(S110)


1414-1414: Do not catch blind exception: Exception

(BLE001)

warp/tests/interop/test_jax.py

312-312: Unused function argument: test

(ARG001)


348-348: assert_np_equal may be undefined, or defined from star imports

(F405)


349-349: assert_np_equal may be undefined, or defined from star imports

(F405)


353-353: Unused function argument: test

(ARG001)


388-388: assert_np_equal may be undefined, or defined from star imports

(F405)


389-389: assert_np_equal may be undefined, or defined from star imports

(F405)


393-393: Unused function argument: test

(ARG001)


429-429: assert_np_equal may be undefined, or defined from star imports

(F405)


430-430: assert_np_equal may be undefined, or defined from star imports

(F405)


434-434: Unused function argument: test

(ARG001)


455-455: Unused function argument: s

(ARG001)


477-477: assert_np_equal may be undefined, or defined from star imports

(F405)


478-478: assert_np_equal may be undefined, or defined from star imports

(F405)


482-482: Unused function argument: test

(ARG001)


511-511: assert_np_equal may be undefined, or defined from star imports

(F405)


515-515: Unused function argument: test

(ARG001)


540-540: assert_np_equal may be undefined, or defined from star imports

(F405)


544-544: Unused function argument: test

(ARG001)


572-572: assert_np_equal may be undefined, or defined from star imports

(F405)


576-576: Unused function argument: test

(ARG001)


608-608: assert_np_equal may be undefined, or defined from star imports

(F405)


609-609: assert_np_equal may be undefined, or defined from star imports

(F405)


613-613: Unused function argument: test

(ARG001)


644-644: assert_np_equal may be undefined, or defined from star imports

(F405)


645-645: assert_np_equal may be undefined, or defined from star imports

(F405)


649-649: Unused function argument: test

(ARG001)


674-674: assert_np_equal may be undefined, or defined from star imports

(F405)


678-678: Unused function argument: test

(ARG001)


711-711: assert_np_equal may be undefined, or defined from star imports

(F405)


712-712: assert_np_equal may be undefined, or defined from star imports

(F405)


716-716: Unused function argument: test

(ARG001)


748-748: assert_np_equal may be undefined, or defined from star imports

(F405)


749-749: assert_np_equal may be undefined, or defined from star imports

(F405)


753-753: Unused function argument: test

(ARG001)


785-785: assert_np_equal may be undefined, or defined from star imports

(F405)


786-786: assert_np_equal may be undefined, or defined from star imports

(F405)


790-790: Unused function argument: test

(ARG001)


822-822: assert_np_equal may be undefined, or defined from star imports

(F405)


823-823: assert_np_equal may be undefined, or defined from star imports

(F405)


827-827: Unused function argument: test

(ARG001)


860-860: assert_np_equal may be undefined, or defined from star imports

(F405)


861-861: assert_np_equal may be undefined, or defined from star imports

(F405)


865-865: Unused function argument: test

(ARG001)


883-883: Unused function argument: c

(ARG001)


934-934: assert_np_equal may be undefined, or defined from star imports

(F405)


935-935: assert_np_equal may be undefined, or defined from star imports

(F405)


938-938: Unused function argument: test

(ARG001)


962-962: assert_np_equal may be undefined, or defined from star imports

(F405)


966-966: Unused function argument: test

(ARG001)


990-990: assert_np_equal may be undefined, or defined from star imports

(F405)


1034-1034: Unused function argument: test

(ARG001)


1065-1065: assert_np_equal may be undefined, or defined from star imports

(F405)


1066-1066: assert_np_equal may be undefined, or defined from star imports

(F405)


1070-1070: Unused function argument: test

(ARG001)


1112-1112: assert_np_equal may be undefined, or defined from star imports

(F405)


1113-1113: assert_np_equal may be undefined, or defined from star imports

(F405)


1116-1116: Unused function argument: device

(ARG001)


1153-1153: Unused function argument: device

(ARG001)


1184-1184: Unused function argument: device

(ARG001)


1218-1218: Unused function argument: device

(ARG001)


1263-1263: Unused function argument: device

(ARG001)


1314-1314: Unused function argument: device

(ARG001)


1363-1363: assert_np_equal may be undefined, or defined from star imports

(F405)


1364-1364: assert_np_equal may be undefined, or defined from star imports

(F405)


1368-1368: Unused function argument: device

(ARG001)


1401-1401: Unused function argument: device

(ARG001)


1436-1436: assert_np_equal may be undefined, or defined from star imports

(F405)


1437-1437: assert_np_equal may be undefined, or defined from star imports

(F405)


1493-1493: add_function_test may be undefined, or defined from star imports

(F405)


1497-1497: add_function_test may be undefined, or defined from star imports

(F405)


1503-1503: add_function_test may be undefined, or defined from star imports

(F405)


1506-1506: add_function_test may be undefined, or defined from star imports

(F405)


1507-1507: add_function_test may be undefined, or defined from star imports

(F405)


1510-1510: add_function_test may be undefined, or defined from star imports

(F405)


1516-1516: add_function_test may be undefined, or defined from star imports

(F405)


1522-1522: add_function_test may be undefined, or defined from star imports

(F405)


1528-1528: add_function_test may be undefined, or defined from star imports

(F405)


1531-1531: add_function_test may be undefined, or defined from star imports

(F405)


1537-1537: add_function_test may be undefined, or defined from star imports

(F405)


1543-1543: add_function_test may be undefined, or defined from star imports

(F405)


1549-1549: add_function_test may be undefined, or defined from star imports

(F405)


1555-1555: add_function_test may be undefined, or defined from star imports

(F405)


1561-1561: add_function_test may be undefined, or defined from star imports

(F405)


1567-1567: add_function_test may be undefined, or defined from star imports

(F405)


1574-1574: add_function_test may be undefined, or defined from star imports

(F405)


1580-1580: add_function_test may be undefined, or defined from star imports

(F405)


1587-1587: add_function_test may be undefined, or defined from star imports

(F405)


1593-1593: add_function_test may be undefined, or defined from star imports

(F405)


1599-1599: add_function_test may be undefined, or defined from star imports

(F405)


1606-1606: add_function_test may be undefined, or defined from star imports

(F405)


1612-1612: add_function_test may be undefined, or defined from star imports

(F405)


1619-1619: add_function_test may be undefined, or defined from star imports

(F405)


1626-1626: add_function_test may be undefined, or defined from star imports

(F405)


1632-1632: add_function_test may be undefined, or defined from star imports

(F405)


1639-1639: add_function_test may be undefined, or defined from star imports

(F405)


1645-1645: add_function_test may be undefined, or defined from star imports

(F405)


1652-1652: add_function_test may be undefined, or defined from star imports

(F405)

warp/context.py

5583-5583: Do not catch blind exception: Exception

(BLE001)


5587-5587: Do not catch blind exception: Exception

(BLE001)


5968-5969: try-except-pass detected, consider logging the exception

(S110)


5968-5968: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (5)
warp/context.py (1)

3822-3826: CTypes bindings for device-ordinal helpers look correct

Signatures match native helpers; this unblocks better diagnostics and device mapping.

warp/examples/optim/example_inverse_kinematics_jax.py (1)

87-92: Good use of jax_ad_kernel with explicit output dims and launch index.

Launch dims taken from mask arg and explicit output dims resolve shape inference; aligns with adjoint support.

warp/jax_experimental/ffi.py (1)

628-756: Stream/pointer ordinal checks and batch-dim collapse logic look solid.

This should address pmap multi-GPU races and shape alignment under vmap.

Please run multi-device tests (pmap/shard_map) with mixed leading batch dims and vector dtypes to confirm no regressions.

warp/jax_experimental/__init__.py (1)

17-17: Public export looks good

Re-exporting jax_ad_kernel at the package level is convenient and aligns with docs/tests.

warp/native/warp.cu (1)

2545-2569: New device-ordinal helpers: OK

Implementations for stream/pointer → device ordinal are correct and defensive (return -1 on failure/non-device memory).

warp/context.py Outdated
Comment on lines 2837 to 2852
_EXTERNAL_STREAM_LOCAL = threading.local()

def _get_external_stream():
return getattr(_EXTERNAL_STREAM_LOCAL, "stream", None)

class ScopedExternalStream:
def __init__(self, stream: Stream | None):
self.stream = stream

def __enter__(self):
_EXTERNAL_STREAM_LOCAL.stream = self.stream
return self.stream

def __exit__(self, exc_type, exc_val, exc_tb):
_EXTERNAL_STREAM_LOCAL.stream = None

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

ScopedExternalStream is not nest-safe; restore prior TLS value on exit

exit unconditionally sets TLS to None, breaking nested usage or re-entrancy. Save previous value on enter and restore it on exit.

Apply this diff:

 class ScopedExternalStream:
     def __init__(self, stream: Stream | None):
         self.stream = stream
+        self._prev = None
 
     def __enter__(self):
-        _EXTERNAL_STREAM_LOCAL.stream = self.stream
-        return self.stream
+        self._prev = getattr(_EXTERNAL_STREAM_LOCAL, "stream", None)
+        _EXTERNAL_STREAM_LOCAL.stream = self.stream
+        return self.stream
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        _EXTERNAL_STREAM_LOCAL.stream = None
+        _EXTERNAL_STREAM_LOCAL.stream = self._prev
+        self._prev = None
🤖 Prompt for AI Agents
In warp/context.py around lines 2837 to 2852, ScopedExternalStream.__exit__
unconditionally clears the thread-local stream, which breaks nested or
re-entrant usage; modify ScopedExternalStream so __enter__ saves the previous
value (e.g., prev = _get_external_stream() or getattr from
_EXTERNAL_STREAM_LOCAL) onto the instance and __exit__ restores that saved
previous value instead of setting TLS to None, ensuring the original stream (or
absence) is reinstated when exiting the scope.

Comment on lines 148 to 149
def loss_from_fk_outputs(body_q, body_qd):
return jp.sum((body_q[self.model.body_count - 1, 0:3] - self.target) ** 2.0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused parameter in local loss helper.

body_qd is unused. Simplify signature to avoid ARG001.

-        def loss_from_fk_outputs(body_q, body_qd):
+        def loss_from_fk_outputs(body_q, _):
             return jp.sum((body_q[self.model.body_count - 1, 0:3] - self.target) ** 2.0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def loss_from_fk_outputs(body_q, body_qd):
return jp.sum((body_q[self.model.body_count - 1, 0:3] - self.target) ** 2.0)
def loss_from_fk_outputs(body_q, _):
return jp.sum((body_q[self.model.body_count - 1, 0:3] - self.target) ** 2.0)
🧰 Tools
🪛 Ruff (0.13.3)

148-148: Unused function argument: body_qd

(ARG001)

🤖 Prompt for AI Agents
In warp/examples/optim/example_inverse_kinematics_jax.py around lines 148-149,
the local helper loss_from_fk_outputs declares an unused parameter body_qd
causing an ARG001 warning; remove body_qd from the function signature and update
any internal references/call sites to call loss_from_fk_outputs(body_q) instead,
ensuring the function definition and all invocations match the new
single-parameter signature.

Comment on lines 301 to 303
WP_API int wp_cuda_stream_get_device_ordinal(void* stream);
WP_API int wp_cuda_pointer_get_device_ordinal(void* ptr);

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

❓ Verification inconclusive

Add CPU stubs to prevent link errors when CUDA is disabled

These symbols need counterparts in warp/native/warp.cpp (CPU builds). Please add no-op stubs returning -1.

Suggested stubs (in warp/native/warp.cpp):

WP_API int wp_cuda_stream_get_device_ordinal(void* /*stream*/) { return -1; }
WP_API int wp_cuda_pointer_get_device_ordinal(void* /*ptr*/) { return -1; }

🏁 Script executed:

#!/bin/bash
# Verify stubs exist in warp/native/warp.cpp
rg -nP '^\s*WP_API\s+int\s+wp_cuda_(stream|pointer)_get_device_ordinal\b' warp/native/warp.cpp || true

Length of output: 0


Add CPU stubs for device-ordinal helpers in warp/native/warp.cpp

Add no-op implementations to avoid link errors when CUDA is disabled:

WP_API int wp_cuda_stream_get_device_ordinal(void* /*stream*/) { return -1; }
WP_API int wp_cuda_pointer_get_device_ordinal(void* /*ptr*/) { return -1; }
🤖 Prompt for AI Agents
In warp/native/warp.h around lines 301 to 303, the header declares
wp_cuda_stream_get_device_ordinal and wp_cuda_pointer_get_device_ordinal but
there are no CPU-side definitions when CUDA is disabled, causing link errors;
add no-op CPU stubs in warp/native/warp.cpp that match the WP_API signatures and
return -1 (accepting the parameters but not using them) so the symbols are
defined in non-CUDA builds.

Comment on lines 753 to 1426
def test_jax_ad_kernel_vmap_expand_dims_simple(test, device):
import jax
import jax.numpy as jp

from warp.jax_experimental.ffi import jax_ad_kernel

@wp.kernel
def scale_sum_square_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float)):
tid = wp.tid()
c[tid] = (a[tid] * s + b[tid]) ** 2.0

jax_func = jax_ad_kernel(
scale_sum_square_kernel, num_outputs=1, static_argnames=("s",), vmap_method="broadcast_all"
)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Mismatch: test names say “expand_dims” but code uses vmap_method="broadcast_all"

To actually exercise expand_dims, set vmap_method="expand_dims".

-    jax_func = jax_ad_kernel(
-        scale_sum_square_kernel, num_outputs=1, static_argnames=("s",), vmap_method="broadcast_all"
-    )
+    jax_func = jax_ad_kernel(
+        scale_sum_square_kernel, num_outputs=1, static_argnames=("s",), vmap_method="expand_dims"
+    )
-    jax_func = jax_ad_kernel(add_one_2d, num_outputs=1, vmap_method="broadcast_all")
+    jax_func = jax_ad_kernel(add_one_2d, num_outputs=1, vmap_method="expand_dims")

Also applies to: 966-979

🧰 Tools
🪛 Ruff (0.13.3)

753-753: Unused function argument: test

(ARG001)

🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 753 to 767, the test name
indicates it should exercise the "expand_dims" vmap behavior but the call uses
vmap_method="broadcast_all"; change vmap_method to "expand_dims" to match the
test intent and ensure the kernel is exercised correctly; apply the same change
to the analogous test at lines 966-979.

Comment on lines 1620 to 1686
TestJax,
"test_jax_ad_kernel_pmap_multi_output",
test_jax_ad_kernel_pmap_multi_output,
devices=jax_compatible_cuda_devices,
)

add_function_test(
TestJax,
"test_jax_ad_kernel_pmap_multi_output",
test_jax_ad_kernel_pmap_multi_output,
devices=jax_compatible_cuda_devices,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate test registrations

test_jax_ad_kernel_pmap_multi_output and test_jax_callable_pmap_multi_output_forward are added twice. Remove the duplicates.

-        add_function_test(
-            TestJax,
-            "test_jax_ad_kernel_pmap_multi_output",
-            test_jax_ad_kernel_pmap_multi_output,
-            devices=jax_compatible_cuda_devices,
-        )
...
-        add_function_test(
-            TestJax,
-            "test_jax_callable_pmap_multi_output_forward",
-            test_jax_callable_pmap_multi_output_forward,
-            devices=jax_compatible_cuda_devices,
-        )

Also applies to: 1640-1657

🧰 Tools
🪛 Ruff (0.13.3)

1626-1626: add_function_test may be undefined, or defined from star imports

(F405)

🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 1620-1631 (and similarly around
1640-1657), the same tests are being registered twice (duplicate
add_function_test calls for test_jax_ad_kernel_pmap_multi_output and
test_jax_callable_pmap_multi_output_forward); remove the redundant duplicate
calls so each test is registered only once by deleting the repeated
add_function_test blocks, keeping a single registration per test and verifying
no other duplicate registrations remain in the surrounding lines.

@mehdiataei mehdiataei force-pushed the jax_ad_ffi branch 3 times, most recently from 8dc4fa7 to 10f6829 Compare October 6, 2025 17:02
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
warp/jax_experimental/ffi.py (1)

1193-1200: Fix warp-shape inference under vmap: slice from tail before dtype dims (not from head).

Current get_warp_shape() takes the first dims, which are batch dims under vmap. This breaks launch_dims/output_dims inference and can cause OOB or shape mismatches.

Apply:

-def get_warp_shape(arg, dims):
-    if arg.dtype_ndim > 0:
-        # vector/matrix array
-        return dims[: arg.warp_ndim]
-    else:
-        # scalar array
-        return dims
+def get_warp_shape(arg, dims):
+    # dims may include leading batch dims and trailing dtype dims (for vec/mat)
+    if arg.dtype_ndim > 0:
+        # vector/matrix array: pick the core warp dims immediately before dtype dims
+        core_rank = len(dims) - arg.dtype_ndim
+        start = max(0, core_rank - arg.warp_ndim)
+        end = core_rank
+        return dims[start:end]
+    else:
+        # scalar array: warp dims are the trailing dims
+        if arg.warp_ndim == 0:
+            return ()
+        return dims[-arg.warp_ndim:]

This corrects all call sites that rely on get_warp_shape(), e.g., Lines 235 and 545 when computing in-out output types and when inferring default launch/output dims.

♻️ Duplicate comments (2)
warp/jax_experimental/ffi.py (1)

750-752: Terminology/API: decouple JIT statics from nondifferentiable args; prefer ‘nondiff_argnames’.

The parameter static_argnames is used to (1) mark JIT-static args and (2) as nondiff_argnums in custom_vjp. These are conceptually distinct. Consider:

  • Rename to nondiff_argnames for the AD aspect, and optionally keep a separate jit_static_argnames (or infer JIT statics via the user’s @jit usage).
  • Update docs accordingly.

This mirrors JAX terminology (nondiff_argnums).

Also applies to: 964-966

warp/tests/interop/test_jax.py (1)

1412-1425: Test intent mismatch: functions named “expand_dims” use vmap_method="sequential".

Use vmap_method="expand_dims" to match the test names and actually exercise that path.

-    jax_func = jax_kernel(
-        scale_sum_square_kernel, num_outputs=1, differentiable=True, static_argnames=("s",), vmap_method="sequential"
-    )
+    jax_func = jax_kernel(
+        scale_sum_square_kernel, num_outputs=1, differentiable=True, static_argnames=("s",), vmap_method="expand_dims"
+    )
-    jax_func = jax_kernel(add_one_2d, num_outputs=1, differentiable=True, vmap_method="sequential")
+    jax_func = jax_kernel(add_one_2d, num_outputs=1, differentiable=True, vmap_method="expand_dims")

Also applies to: 1498-1499

🧹 Nitpick comments (1)
warp/jax_experimental/ffi.py (1)

840-846: Avoid blind except when zeroing adjoint inputs.

Catching Exception and pass can hide real issues. Keep the isinstance guard and drop the try/except.

-        try:
-            for gi in grad_in:
-                if isinstance(gi, wp.array):
-                    gi.zero_()
-        except Exception:
-            pass
+        for gi in grad_in:
+            if isinstance(gi, wp.array):
+                gi.zero_()
📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 01d087c and 683682d.

📒 Files selected for processing (4)
  • docs/modules/interoperability.rst (1 hunks)
  • warp/jax_experimental/__init__.py (1 hunks)
  • warp/jax_experimental/ffi.py (10 hunks)
  • warp/tests/interop/test_jax.py (17 hunks)
✅ Files skipped from review due to trivial changes (1)
  • docs/modules/interoperability.rst
🚧 Files skipped from review as they are similar to previous changes (1)
  • warp/jax_experimental/init.py
🧰 Additional context used
🧬 Code graph analysis (2)
warp/jax_experimental/ffi.py (2)
warp/context.py (4)
  • kernel (1069-1141)
  • func (853-883)
  • launch (6111-6152)
  • launch (6155-6387)
warp/types.py (5)
  • array (2256-3547)
  • array (4442-4448)
  • zero_ (3161-3168)
  • zero_ (3730-3731)
  • dtype (4615-4621)
warp/tests/interop/test_jax.py (4)
warp/tests/interop/test_dlpack.py (1)
  • _jax_version (28-34)
warp/jax_experimental/ffi.py (1)
  • jax_kernel (742-991)
warp/jax.py (1)
  • device_to_jax (19-41)
warp/tests/unittest_utils.py (2)
  • assert_np_equal (241-247)
  • add_function_test (284-303)
🪛 Ruff (0.13.3)
warp/jax_experimental/ffi.py

217-219: Avoid specifying long messages outside the exception class

(TRY003)


527-529: Avoid specifying long messages outside the exception class

(TRY003)


844-845: try-except-pass detected, consider logging the exception

(S110)


844-844: Do not catch blind exception: Exception

(BLE001)


940-940: Do not catch blind exception: Exception

(BLE001)


948-949: try-except-pass detected, consider logging the exception

(S110)


948-948: Do not catch blind exception: Exception

(BLE001)

warp/tests/interop/test_jax.py

445-445: Unused function argument: test

(ARG001)


470-470: Unused function argument: test

(ARG001)


499-499: Unused function argument: test

(ARG001)


528-528: Unused function argument: test

(ARG001)


548-548: Unused function argument: test

(ARG001)


571-571: Unused function argument: test

(ARG001)


663-663: Unused function argument: test

(ARG001)


696-696: Unused function argument: test

(ARG001)


729-729: Unused function argument: test

(ARG001)


850-850: Unused function argument: device

(ARG001)


882-882: Unused function argument: device

(ARG001)


929-929: Unused function argument: device

(ARG001)


1037-1037: Unused function argument: test

(ARG001)


1075-1075: assert_np_equal may be undefined, or defined from star imports

(F405)


1076-1076: assert_np_equal may be undefined, or defined from star imports

(F405)


1080-1080: Unused function argument: test

(ARG001)


1115-1115: assert_np_equal may be undefined, or defined from star imports

(F405)


1116-1116: assert_np_equal may be undefined, or defined from star imports

(F405)


1120-1120: Unused function argument: test

(ARG001)


1156-1156: assert_np_equal may be undefined, or defined from star imports

(F405)


1157-1157: assert_np_equal may be undefined, or defined from star imports

(F405)


1161-1161: Unused function argument: test

(ARG001)


1182-1182: Unused function argument: s

(ARG001)


1204-1204: assert_np_equal may be undefined, or defined from star imports

(F405)


1205-1205: assert_np_equal may be undefined, or defined from star imports

(F405)


1209-1209: Unused function argument: test

(ARG001)


1238-1238: assert_np_equal may be undefined, or defined from star imports

(F405)


1242-1242: Unused function argument: test

(ARG001)


1267-1267: assert_np_equal may be undefined, or defined from star imports

(F405)


1271-1271: Unused function argument: test

(ARG001)


1299-1299: assert_np_equal may be undefined, or defined from star imports

(F405)


1303-1303: Unused function argument: test

(ARG001)


1336-1336: assert_np_equal may be undefined, or defined from star imports

(F405)


1337-1337: assert_np_equal may be undefined, or defined from star imports

(F405)


1341-1341: Unused function argument: test

(ARG001)


1368-1368: assert_np_equal may be undefined, or defined from star imports

(F405)


1372-1372: Unused function argument: test

(ARG001)


1407-1407: assert_np_equal may be undefined, or defined from star imports

(F405)


1408-1408: assert_np_equal may be undefined, or defined from star imports

(F405)


1412-1412: Unused function argument: test

(ARG001)


1444-1444: assert_np_equal may be undefined, or defined from star imports

(F405)


1445-1445: assert_np_equal may be undefined, or defined from star imports

(F405)


1449-1449: Unused function argument: test

(ARG001)


1482-1482: assert_np_equal may be undefined, or defined from star imports

(F405)


1483-1483: assert_np_equal may be undefined, or defined from star imports

(F405)


1487-1487: Unused function argument: test

(ARG001)


1511-1511: assert_np_equal may be undefined, or defined from star imports

(F405)


1515-1515: Unused function argument: test

(ARG001)


1546-1546: assert_np_equal may be undefined, or defined from star imports

(F405)


1547-1547: assert_np_equal may be undefined, or defined from star imports

(F405)


1551-1551: Unused function argument: device

(ARG001)


1588-1588: Unused function argument: device

(ARG001)


1619-1619: Unused function argument: device

(ARG001)


1668-1668: assert_np_equal may be undefined, or defined from star imports

(F405)


1669-1669: assert_np_equal may be undefined, or defined from star imports

(F405)


1673-1673: Unused function argument: device

(ARG001)


1708-1708: assert_np_equal may be undefined, or defined from star imports

(F405)


1709-1709: assert_np_equal may be undefined, or defined from star imports

(F405)


1803-1803: add_function_test may be undefined, or defined from star imports

(F405)


1804-1804: add_function_test may be undefined, or defined from star imports

(F405)


1807-1807: add_function_test may be undefined, or defined from star imports

(F405)


1810-1810: add_function_test may be undefined, or defined from star imports

(F405)


1813-1813: add_function_test may be undefined, or defined from star imports

(F405)


1819-1819: add_function_test may be undefined, or defined from star imports

(F405)


1825-1825: add_function_test may be undefined, or defined from star imports

(F405)


1831-1831: add_function_test may be undefined, or defined from star imports

(F405)


1839-1839: add_function_test may be undefined, or defined from star imports

(F405)


1845-1845: add_function_test may be undefined, or defined from star imports

(F405)


1851-1851: add_function_test may be undefined, or defined from star imports

(F405)


1854-1854: add_function_test may be undefined, or defined from star imports

(F405)


1860-1860: add_function_test may be undefined, or defined from star imports

(F405)


1866-1866: add_function_test may be undefined, or defined from star imports

(F405)


1872-1872: add_function_test may be undefined, or defined from star imports

(F405)


1880-1880: add_function_test may be undefined, or defined from star imports

(F405)


1882-1882: add_function_test may be undefined, or defined from star imports

(F405)


1885-1885: add_function_test may be undefined, or defined from star imports

(F405)


1888-1888: add_function_test may be undefined, or defined from star imports

(F405)


1891-1891: add_function_test may be undefined, or defined from star imports

(F405)


1897-1897: add_function_test may be undefined, or defined from star imports

(F405)


1903-1903: add_function_test may be undefined, or defined from star imports

(F405)


1909-1909: add_function_test may be undefined, or defined from star imports

(F405)


1917-1917: add_function_test may be undefined, or defined from star imports

(F405)


1923-1923: add_function_test may be undefined, or defined from star imports

(F405)


1929-1929: add_function_test may be undefined, or defined from star imports

(F405)


1932-1932: add_function_test may be undefined, or defined from star imports

(F405)


1938-1938: add_function_test may be undefined, or defined from star imports

(F405)


1944-1944: add_function_test may be undefined, or defined from star imports

(F405)


1952-1952: add_function_test may be undefined, or defined from star imports

(F405)


1954-1954: add_function_test may be undefined, or defined from star imports

(F405)


1958-1958: add_function_test may be undefined, or defined from star imports

(F405)


1964-1964: add_function_test may be undefined, or defined from star imports

(F405)


1971-1971: add_function_test may be undefined, or defined from star imports

(F405)


1978-1978: add_function_test may be undefined, or defined from star imports

(F405)


1985-1985: add_function_test may be undefined, or defined from star imports

(F405)


1992-1992: add_function_test may be undefined, or defined from star imports

(F405)


1999-1999: add_function_test may be undefined, or defined from star imports

(F405)


2005-2005: add_function_test may be undefined, or defined from star imports

(F405)


2012-2012: add_function_test may be undefined, or defined from star imports

(F405)


2018-2018: add_function_test may be undefined, or defined from star imports

(F405)


2025-2025: add_function_test may be undefined, or defined from star imports

(F405)


2032-2032: add_function_test may be undefined, or defined from star imports

(F405)


2038-2038: add_function_test may be undefined, or defined from star imports

(F405)


2041-2041: test_jax_ad_kernel_launch_dim_and_output_dims may be undefined, or defined from star imports

(F405)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
warp/jax_experimental/ffi.py (2)

839-840: Fix launch dim inference to handle non-array first arguments.

Using args[0].shape assumes the first argument is always an array, which will fail if the first parameter is a scalar. This was flagged in a previous review but remains unresolved.

Apply this diff to match the non-differentiable path logic:

-    def fwd_kernel_wrapper(*args):
-        wp.launch(kernel, dim=args[0].shape, inputs=args[:num_inputs], outputs=args[num_inputs:])
+    def fwd_kernel_wrapper(*args):
+        # Derive launch dims from the first array-typed input
+        launch_dim = None
+        for idx, p in enumerate(parameters[:num_inputs]):
+            if isinstance(p.annotation, wp.array):
+                launch_dim = args[idx].shape
+                break
+        if launch_dim is None:
+            raise RuntimeError("Unable to infer launch dims: no array inputs")
+        wp.launch(kernel, dim=launch_dim, inputs=args[:num_inputs], outputs=args[num_inputs:])

867-867: Fix launch dim inference in backward wrapper.

Same issue as the forward wrapper — using inputs[0].shape assumes the first input is an array.

Apply this diff:

+        # Derive launch dims from the first array-typed input
+        launch_dim = None
+        for idx, p in enumerate(parameters[:num_inputs]):
+            if isinstance(p.annotation, wp.array):
+                launch_dim = inputs[idx].shape
+                break
+        if launch_dim is None:
+            raise RuntimeError("Unable to infer launch dims: no array inputs")
+
         wp.launch(
             kernel,
-            dim=inputs[0].shape,
+            dim=launch_dim,
             inputs=inputs,
             outputs=outputs,
             adj_inputs=grad_in,
             adj_outputs=grad_out,
             adjoint=True,
         )
🧹 Nitpick comments (6)
warp/jax_experimental/ffi.py (6)

41-41: Consider a more specific type hint for the cache key.

The key type is tuple, but the actual structure is (kernel.func, kernel.sig, num_outputs, vmap_method, tuple[str, ...]) based on usage at lines 994 and 1003. A type alias would improve clarity:

+# Type alias for differentiable kernel cache key
+DiffKernelCacheKey = tuple[Callable, tuple, int, str, tuple[str, ...]]
+
-_FFI_DIFF_KERNEL_REGISTRY: dict[tuple, Callable] = {}
+_FFI_DIFF_KERNEL_REGISTRY: dict[DiffKernelCacheKey, Callable] = {}

815-818: LGTM! Clear limitation documented.

The explicit check prevents unsupported usage of in_out_argnames in differentiable mode. Consider tracking this as a TODO or GitHub issue for future enhancement.

Do you want me to open a GitHub issue to track support for in_out_argnames in differentiable mode?


858-863: Improve error handling when zeroing gradient arrays.

The blind try-except-pass silently swallows all exceptions, which could hide real issues. At minimum, verify the object is a wp.array before calling zero_().

Apply this diff:

-        try:
-            for gi in grad_in:
-                if isinstance(gi, wp.array):
-                    gi.zero_()
-        except Exception:
-            pass
+        for gi in grad_in:
+            if isinstance(gi, wp.array):
+                try:
+                    gi.zero_()
+                except Exception as e:
+                    # Log or handle the exception appropriately
+                    wp.utils.warn(f"Failed to zero gradient array: {e}", stacklevel=2)

917-924: Simplify output handling logic.

The nested conditions for single vs. multiple outputs are fragile. JAX's custom_vjp expects fwd to return (outputs, residuals) where outputs can be a single value or tuple. Simplify by always returning a tuple:

     def fwd_function(*args):
         outputs = jax_fwd_kernel(*args)
         non_static_inputs = list(args)
         for i in reversed(static_args):
             del non_static_inputs[i]
-        if num_outputs == 1:
-            if isinstance(outputs, (list, tuple)):
-                outputs_tuple = (outputs[0],)
-            else:
-                outputs_tuple = (outputs,)
-        else:
-            outputs_tuple = tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
+        # Normalize to tuple for consistent handling
+        if num_outputs == 1:
+            outputs_tuple = (outputs,) if not isinstance(outputs, (list, tuple)) else (outputs[0],)
+        else:
+            outputs_tuple = outputs if isinstance(outputs, tuple) else tuple(outputs)
         return outputs, (tuple(non_static_inputs), outputs_tuple)

937-947: Simplify gradient output unpacking.

Similar to the forward function, this logic for unpacking grad_out_args is complex. Consider normalizing to a tuple early:

-        if num_outputs > 1:
-            if len(grad_out_args) == 1 and isinstance(grad_out_args[0], (list, tuple)):
-                grad_out_tuple = tuple(grad_out_args[0])
-            else:
-                grad_out_tuple = tuple(grad_out_args)
-        else:
-            go = grad_out_args[0]
-            if isinstance(go, (list, tuple)):
-                grad_out_tuple = (go[0],)
-            else:
-                grad_out_tuple = (go,)
+        # Normalize grad_out_args to tuple
+        if num_outputs == 1:
+            grad_out_tuple = (grad_out_args[0],)
+        else:
+            grad_out_tuple = tuple(grad_out_args) if not isinstance(grad_out_args[0], (list, tuple)) else tuple(grad_out_args[0])

958-967: Avoid bare exception handling in attribute access.

The nested try-except blocks with bare Exception catching (lines 958, 966) can hide bugs. Use specific attribute checks or document why broad catching is needed.

             ann = param_ann.get(name)
             if ann is None:
                 continue
-            try:
-                is_array_ann = isinstance(ann, wp.array)
-            except Exception:
-                is_array_ann = False
+            # Check if annotation is a warp array type
+            is_array_ann = isinstance(ann, type) and issubclass(ann, wp.array)
             if not is_array_ann:
                 continue
             dtype_ndim = 0
-            try:
-                if hasattr(ann.dtype, "_wp_scalar_type_"):
-                    dtype_ndim = len(ann.dtype._shape_)
-            except Exception:
-                pass
+            # Extract dtype ndim if it's a vector/matrix type
+            if hasattr(ann, 'dtype') and hasattr(ann.dtype, "_wp_scalar_type_"):
+                dtype_ndim = len(ann.dtype._shape_)
📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 683682d and 426d343.

📒 Files selected for processing (1)
  • warp/jax_experimental/ffi.py (15 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/jax_experimental/ffi.py (1)
warp/types.py (5)
  • strides_from_shape (2174-2185)
  • dtype (4615-4621)
  • array_t (1697-1721)
  • array (2256-3547)
  • array (4442-4448)
🪛 Ruff (0.13.3)
warp/jax_experimental/ffi.py

217-219: Avoid specifying long messages outside the exception class

(TRY003)


537-539: Avoid specifying long messages outside the exception class

(TRY003)


862-863: try-except-pass detected, consider logging the exception

(S110)


862-862: Do not catch blind exception: Exception

(BLE001)


958-958: Do not catch blind exception: Exception

(BLE001)


966-967: try-except-pass detected, consider logging the exception

(S110)


966-966: Do not catch blind exception: Exception

(BLE001)


1218-1218: Avoid specifying long messages outside the exception class

(TRY003)


1225-1225: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (10)
warp/jax_experimental/ffi.py (10)

215-225: LGTM! Correctly relaxed for vmap support.

The updated validation allows leading batch dimensions under vmap while still validating the required trailing dtype dimensions. This is the correct approach for supporting batched operations.


535-545: LGTM! Correctly relaxed for vmap support.

Same validation logic as in FfiKernel.__call__ — allows leading batch dimensions while validating trailing dtype dimensions.


1213-1226: LGTM! Correctly extracts core dimensions.

The logic properly handles both vector/matrix arrays (extracting middle dims between batch and dtype) and scalar arrays (extracting trailing dims excluding batch). The implementation aligns with the relaxed validation in __call__ methods.


1239-1239: LGTM! Aligns with batch dimension support.

Changing to ndim >= arg.jax_ndim correctly allows for exact shape match or additional leading batch dimensions, consistent with the relaxed validation.


795-813: LGTM! Non-differentiable path correctly preserved.

The logic preserves the existing behavior when differentiable=False, using the cached FfiKernel approach with proper locking.


820-836: LGTM! Signature introspection correctly implemented.

The logic properly extracts the kernel signature, identifies static arguments from static_argnames, and ensures static arguments are only considered among inputs.


875-901: LGTM! Backward signature construction is correct.

The logic properly constructs the backward wrapper signature with input/output/gradient parameters, correctly excludes static arguments from gradient inputs, and follows JAX naming conventions.


982-1009: LGTM! VJP registration and caching correctly implemented.

The logic properly:

  • Registers the custom VJP with correct nondiff_argnums
  • Applies jax.jit with static_argnames when needed
  • Uses deterministic cache keys with sorted static argument names

Note: The cache key structure matches the type hint suggestion in the earlier comment on line 41.


343-350: LGTM! Buffer shape extraction correctly handles batch and dtype dimensions.

The logic properly computes the batch dimension offset and extracts only the core warp dimensions, excluding both leading batch dimensions and trailing dtype dimensions. This is consistent with the get_warp_shape utility function.

Also applies to: 363-370


755-755: Clarify vmap_method default inconsistency
The default vmap_method differs between jax_kernel ("sequential") and jax_callable ("broadcast_all"). Confirm this is intentional and either document the rationale or align the defaults across the API.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
warp/jax_experimental/ffi.py (1)

955-1003: Consider refactoring to reduce backward function complexity.

The bwd_function handles multiple concerns: unpacking residuals, inserting static args, normalizing gradient tuples, and computing output dimensions. While functionally correct, the cognitive complexity is high.

Consider extracting helper functions:

def _unpack_residuals(bwd_args, static_args):
    nondiff_vals = list(bwd_args[:len(static_args)])
    residuals = bwd_args[len(static_args)]
    grad_out_args = bwd_args[len(static_args) + 1:]
    return nondiff_vals, residuals, grad_out_args

def _normalize_grad_out_tuple(grad_out_args, num_outputs):
    if num_outputs == 1:
        go = grad_out_args[0]
        return tuple(go) if isinstance(go, (list, tuple)) else (go,)
    if len(grad_out_args) == 1 and isinstance(grad_out_args[0], (list, tuple)):
        return tuple(grad_out_args[0])
    return tuple(grad_out_args)

def _compute_output_dims_map(parameters, differentiable_input_names, non_static_inputs, num_inputs):
    out_dims_map = {}
    param_ann = {p.name: p.annotation for p in parameters[:num_inputs]}
    for name, val in zip(differentiable_input_names, non_static_inputs):
        ann = param_ann.get(name)
        if ann is None or not isinstance(ann, wp.array):
            continue
        # ... rest of logic
    return out_dims_map

This would make bwd_function more readable by clearly separating its phases.

📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 426d343 and ede6efe.

📒 Files selected for processing (2)
  • docs/modules/interoperability.rst (1 hunks)
  • warp/jax_experimental/ffi.py (16 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/jax_experimental/ffi.py (1)
warp/types.py (5)
  • strides_from_shape (2174-2185)
  • dtype (4615-4621)
  • array_t (1697-1721)
  • array (2256-3547)
  • array (4442-4448)
🪛 Ruff (0.13.3)
warp/jax_experimental/ffi.py

220-222: Avoid specifying long messages outside the exception class

(TRY003)


540-542: Avoid specifying long messages outside the exception class

(TRY003)


867-867: Avoid specifying long messages outside the exception class

(TRY003)


893-893: Do not catch blind exception: Exception

(BLE001)


1241-1241: Avoid specifying long messages outside the exception class

(TRY003)


1248-1248: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (2)
warp/jax_experimental/ffi.py (1)

1236-1249: Shape extraction logic correctly handles vmap batch dimensions.

The get_warp_shape function properly extracts core Warp dimensions by excluding both leading batch dimensions (from vmap) and trailing dtype payload dimensions (for vector/matrix types). The implementation correctly handles the different cases for vector/matrix vs. scalar arrays.

The logic is intricate but appears correct. Consider adding an inline example in a comment to aid future maintainers:

# Example: For wp.vec3 array under vmap with JAX shape (B, N, 3):
#   - warp_ndim=1, dtype_ndim=1, dims=(B, N, 3)
#   - core_end = 3 - 1 = 2, core_begin = 2 - 1 = 1
#   - Returns dims[1:2] = (N,) — the core warp dimension
docs/modules/interoperability.rst (1)

959-1119: Comprehensive AD documentation with clear examples.

The Automatic Differentiation section provides excellent coverage of the new differentiable=True functionality, including:

  • Basic usage patterns with static arguments
  • Multiple output handling
  • Vector/matrix array support
  • VMAP integration with per-sample gradients
  • Launch dimension inference
  • Troubleshooting guidance

Minor suggestions for enhancement:

  1. Add performance note: Consider mentioning that vmap_method="broadcast_all" is more parallel than "sequential" (the new default), which could guide users toward better performance when applicable.

  2. Clarify scalar limitation: The note "Scalars must be static in JAX" could be expanded to explain why (JAX tracing vs. FFI callback requirements).

  3. Link to examples: Consider adding a reference to example files in the repo that demonstrate the AD functionality in practice.

Comment on lines 1077 to 1082
There are three options for ``vmap_method``:

- ``"sequential"``: lowers to a :func:`jax.lax.scan`; most general, least parallel.
- ``"broadcast_all"`` (default): assumes the kernel can handle broadcasted batch axes; best for data-parallel batches.
- ``"expand_dims"``: expands a new leading axis per-batched argument; use only if your kernel semantics require it.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Document the change in default vmap_method for differentiable=True.

The documentation lists three vmap_method options but doesn't explicitly state that "sequential" is now the default (changed from "broadcast_all"). Since this is a breaking change from the non-differentiable path's default, it should be clearly documented.

Consider adding a note:

.. note::
   When ``differentiable=True``, the default ``vmap_method`` is ``"sequential"`` for stability.
   For the non-differentiable path, the default is ``"broadcast_all"``. If you need maximum
   parallelism with AD, explicitly pass ``vmap_method="broadcast_all"``.
🤖 Prompt for AI Agents
In docs/modules/interoperability.rst around lines 1077 to 1082, the docstring
lists vmap_method options but doesn't state that the default changed to
"sequential" when differentiable=True; update the text to add a short note
explaining that for differentiable=True the default vmap_method is "sequential"
(changed from "broadcast_all" in the non-differentiable path), indicate that the
non-differentiable default remains "broadcast_all", and advise users to
explicitly pass vmap_method="broadcast_all" if they need maximum parallelism
with AD.

kernel,
num_outputs=1,
vmap_method="broadcast_all",
vmap_method="sequential",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify the breaking change in default vmap_method.

The default vmap_method changed from "broadcast_all" to "sequential". This could break existing code that relies on the previous default behavior. "sequential" lowers to jax.lax.scan, which is less parallel than "broadcast_all".

Consider either:

  1. Reverting to "broadcast_all" as the default for backward compatibility, or
  2. Documenting this breaking change prominently in release notes and migration guides.

Could you clarify the rationale for this change and confirm that existing users have been notified?


🏁 Script executed:

#!/bin/bash
set -e
rg -n "class FFI" -C3 warp/jax_experimental/ffi.py
rg -n "def __init__" -C5 warp/jax_experimental/ffi.py
rg -n -C2 "vmap_method" warp/jax_experimental/ffi.py
rg -n -C2 "broadcast_all" .

Length of output: 0


🏁 Script executed:

#!/bin/bash
set -e
echo "Locating ffi.py in the repository:"
fd -t f ffi.py
echo
echo "Searching for vmap_method definitions and defaults:"
rg -n "vmap_method" -C3 .
echo
echo "Searching for broadcast_all references:"
rg -n '"broadcast_all"' -C3 .

Length of output: 20460


Align vmap_method default in code and docs
The default in jax_kernel (warp/jax_experimental/ffi.py:758) was changed from "broadcast_all" to "sequential", but docs (docs/modules/interoperability.rst) still state "broadcast_all" as the default. This breaks backward compatibility. Either revert the default to "broadcast_all" or update code, docs, and migration notes to reflect "sequential".

🤖 Prompt for AI Agents
In warp/jax_experimental/ffi.py around line 758, the vmap_method default was
changed to "sequential" which conflicts with docs and breaks backward
compatibility; revert the default back to "broadcast_all" (restore the original
default value at that line) and ensure any unit tests that assume the old
behavior still pass; if the change to "sequential" was intentional instead,
update docs/modules/interoperability.rst to state "sequential" as the new
default and add a short migration note explaining the behavioral change and
rationale.

@shi-eric
Copy link
Contributor

shi-eric commented Oct 7, 2025

FYI @mehdiataei, I'm planning to make the release candidate for Warp v1.10 sometime next Friday, October 17th, so the changes in this pull request need to be merged into the main branch by then to be included in the release. Please work with @nvlukasz to resolve any outstanding issues.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
warp/tests/interop/test_jax.py (2)

1882-1914: Remove duplicate test registrations.

These test registrations duplicate those already added at lines 1803-1836. Running the same tests multiple times adds no value and may cause test collection issues.

Apply this diff to remove the duplicate block:

-        add_function_test(
-            TestJax, "test_jax_kernel_sincos", test_jax_kernel_sincos, devices=jax_compatible_cuda_devices
-        )
-        add_function_test(
-            TestJax, "test_jax_kernel_diagonal", test_jax_kernel_diagonal, devices=jax_compatible_cuda_devices
-        )
-        add_function_test(
-            TestJax, "test_jax_kernel_in_out", test_jax_kernel_in_out, devices=jax_compatible_cuda_devices
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_kernel_scale_vec_constant",
-            test_jax_kernel_scale_vec_constant,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_kernel_scale_vec_static",
-            test_jax_kernel_scale_vec_static,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_kernel_launch_dims_default",
-            test_jax_kernel_launch_dims_default,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_kernel_launch_dims_custom",
-            test_jax_kernel_launch_dims_custom,
-            devices=jax_compatible_cuda_devices,
-        )

1917-1949: Remove duplicate test registrations.

These test registrations duplicate those already added at lines 1839-1877. Delete this redundant block.

Apply this diff to remove the duplicate block:

-        # ffi.jax_callable() tests
-        add_function_test(
-            TestJax,
-            "test_jax_callable_scale_constant",
-            test_jax_callable_scale_constant,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_callable_scale_static",
-            test_jax_callable_scale_static,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax, "test_jax_callable_in_out", test_jax_callable_in_out, devices=jax_compatible_cuda_devices
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_callable_graph_cache",
-            test_jax_callable_graph_cache,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_callable_pmap_multi_output_forward",
-            test_jax_callable_pmap_multi_output_forward,
-            devices=jax_compatible_cuda_devices,
-        )
-        add_function_test(
-            TestJax,
-            "test_jax_callable_pmap_multi_stage_forward",
-            test_jax_callable_pmap_multi_stage_forward,
-            devices=jax_compatible_cuda_devices,
-        )
-
-        # ffi callback tests
-        add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ede6efe and 61c0a99.

📒 Files selected for processing (1)
  • warp/tests/interop/test_jax.py (17 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/tests/interop/test_jax.py (4)
warp/tests/interop/test_dlpack.py (1)
  • _jax_version (28-34)
warp/jax_experimental/ffi.py (1)
  • jax_kernel (755-1032)
warp/jax.py (1)
  • device_to_jax (19-41)
warp/tests/unittest_utils.py (2)
  • assert_np_equal (241-247)
  • add_function_test (284-303)
🪛 Ruff (0.13.3)
warp/tests/interop/test_jax.py

445-445: Unused function argument: test

(ARG001)


470-470: Unused function argument: test

(ARG001)


499-499: Unused function argument: test

(ARG001)


528-528: Unused function argument: test

(ARG001)


548-548: Unused function argument: test

(ARG001)


571-571: Unused function argument: test

(ARG001)


663-663: Unused function argument: test

(ARG001)


696-696: Unused function argument: test

(ARG001)


729-729: Unused function argument: test

(ARG001)


850-850: Unused function argument: device

(ARG001)


882-882: Unused function argument: device

(ARG001)


929-929: Unused function argument: device

(ARG001)


1037-1037: Unused function argument: test

(ARG001)


1075-1075: assert_np_equal may be undefined, or defined from star imports

(F405)


1076-1076: assert_np_equal may be undefined, or defined from star imports

(F405)


1080-1080: Unused function argument: test

(ARG001)


1115-1115: assert_np_equal may be undefined, or defined from star imports

(F405)


1116-1116: assert_np_equal may be undefined, or defined from star imports

(F405)


1120-1120: Unused function argument: test

(ARG001)


1156-1156: assert_np_equal may be undefined, or defined from star imports

(F405)


1157-1157: assert_np_equal may be undefined, or defined from star imports

(F405)


1161-1161: Unused function argument: test

(ARG001)


1182-1182: Unused function argument: s

(ARG001)


1204-1204: assert_np_equal may be undefined, or defined from star imports

(F405)


1205-1205: assert_np_equal may be undefined, or defined from star imports

(F405)


1209-1209: Unused function argument: test

(ARG001)


1238-1238: assert_np_equal may be undefined, or defined from star imports

(F405)


1242-1242: Unused function argument: test

(ARG001)


1267-1267: assert_np_equal may be undefined, or defined from star imports

(F405)


1271-1271: Unused function argument: test

(ARG001)


1299-1299: assert_np_equal may be undefined, or defined from star imports

(F405)


1303-1303: Unused function argument: test

(ARG001)


1336-1336: assert_np_equal may be undefined, or defined from star imports

(F405)


1337-1337: assert_np_equal may be undefined, or defined from star imports

(F405)


1341-1341: Unused function argument: test

(ARG001)


1368-1368: assert_np_equal may be undefined, or defined from star imports

(F405)


1372-1372: Unused function argument: test

(ARG001)


1407-1407: assert_np_equal may be undefined, or defined from star imports

(F405)


1408-1408: assert_np_equal may be undefined, or defined from star imports

(F405)


1412-1412: Unused function argument: test

(ARG001)


1444-1444: assert_np_equal may be undefined, or defined from star imports

(F405)


1445-1445: assert_np_equal may be undefined, or defined from star imports

(F405)


1449-1449: Unused function argument: test

(ARG001)


1482-1482: assert_np_equal may be undefined, or defined from star imports

(F405)


1483-1483: assert_np_equal may be undefined, or defined from star imports

(F405)


1487-1487: Unused function argument: test

(ARG001)


1511-1511: assert_np_equal may be undefined, or defined from star imports

(F405)


1515-1515: Unused function argument: test

(ARG001)


1546-1546: assert_np_equal may be undefined, or defined from star imports

(F405)


1547-1547: assert_np_equal may be undefined, or defined from star imports

(F405)


1551-1551: Unused function argument: device

(ARG001)


1588-1588: Unused function argument: device

(ARG001)


1619-1619: Unused function argument: device

(ARG001)


1668-1668: assert_np_equal may be undefined, or defined from star imports

(F405)


1669-1669: assert_np_equal may be undefined, or defined from star imports

(F405)


1673-1673: Unused function argument: device

(ARG001)


1708-1708: assert_np_equal may be undefined, or defined from star imports

(F405)


1709-1709: assert_np_equal may be undefined, or defined from star imports

(F405)


1803-1803: add_function_test may be undefined, or defined from star imports

(F405)


1804-1804: add_function_test may be undefined, or defined from star imports

(F405)


1807-1807: add_function_test may be undefined, or defined from star imports

(F405)


1810-1810: add_function_test may be undefined, or defined from star imports

(F405)


1813-1813: add_function_test may be undefined, or defined from star imports

(F405)


1819-1819: add_function_test may be undefined, or defined from star imports

(F405)


1825-1825: add_function_test may be undefined, or defined from star imports

(F405)


1831-1831: add_function_test may be undefined, or defined from star imports

(F405)


1839-1839: add_function_test may be undefined, or defined from star imports

(F405)


1845-1845: add_function_test may be undefined, or defined from star imports

(F405)


1851-1851: add_function_test may be undefined, or defined from star imports

(F405)


1854-1854: add_function_test may be undefined, or defined from star imports

(F405)


1860-1860: add_function_test may be undefined, or defined from star imports

(F405)


1866-1866: add_function_test may be undefined, or defined from star imports

(F405)


1872-1872: add_function_test may be undefined, or defined from star imports

(F405)


1880-1880: add_function_test may be undefined, or defined from star imports

(F405)


1882-1882: add_function_test may be undefined, or defined from star imports

(F405)


1885-1885: add_function_test may be undefined, or defined from star imports

(F405)


1888-1888: add_function_test may be undefined, or defined from star imports

(F405)


1891-1891: add_function_test may be undefined, or defined from star imports

(F405)


1897-1897: add_function_test may be undefined, or defined from star imports

(F405)


1903-1903: add_function_test may be undefined, or defined from star imports

(F405)


1909-1909: add_function_test may be undefined, or defined from star imports

(F405)


1917-1917: add_function_test may be undefined, or defined from star imports

(F405)


1923-1923: add_function_test may be undefined, or defined from star imports

(F405)


1929-1929: add_function_test may be undefined, or defined from star imports

(F405)


1932-1932: add_function_test may be undefined, or defined from star imports

(F405)


1938-1938: add_function_test may be undefined, or defined from star imports

(F405)


1944-1944: add_function_test may be undefined, or defined from star imports

(F405)


1952-1952: add_function_test may be undefined, or defined from star imports

(F405)


1954-1954: add_function_test may be undefined, or defined from star imports

(F405)


1958-1958: add_function_test may be undefined, or defined from star imports

(F405)


1964-1964: add_function_test may be undefined, or defined from star imports

(F405)


1971-1971: add_function_test may be undefined, or defined from star imports

(F405)


1978-1978: add_function_test may be undefined, or defined from star imports

(F405)


1985-1985: add_function_test may be undefined, or defined from star imports

(F405)


1992-1992: add_function_test may be undefined, or defined from star imports

(F405)


1999-1999: add_function_test may be undefined, or defined from star imports

(F405)


2005-2005: add_function_test may be undefined, or defined from star imports

(F405)


2012-2012: add_function_test may be undefined, or defined from star imports

(F405)


2018-2018: add_function_test may be undefined, or defined from star imports

(F405)


2025-2025: add_function_test may be undefined, or defined from star imports

(F405)


2032-2032: add_function_test may be undefined, or defined from star imports

(F405)

🔇 Additional comments (1)
warp/tests/interop/test_jax.py (1)

1036-1710: LGTM: Comprehensive AD test coverage.

The new automatic differentiation tests are well-structured with:

  • Clear test organization covering single/multi-output, vmap, pmap, shard_map scenarios
  • Proper data type coverage (float, vec2, mat22, 2D arrays)
  • Correct reference gradient computations for validation
  • Appropriate JAX version guards and device availability checks
  • Consistent patterns that make the test suite maintainable

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
warp/jax_experimental/ffi.py (2)

758-758: Verify vmap_method default change is intentional.

The default changed from "broadcast_all" to "sequential". A previous review comment (lines 758-758) noted this conflicts with documentation and breaks backward compatibility. "sequential" lowers to jax.lax.scan, which is less parallel than "broadcast_all".

Per past review: either revert to "broadcast_all" for backward compatibility, or update documentation (docs/modules/interoperability.rst) and add migration notes explaining the rationale for the change.


841-867: Launch dims inference remains fragile.

A previous review comment (lines 841-867) noted that starting with call_args[0].shape will fail when the first kernel argument is a scalar. While there's fallback logic, it adds unnecessary complexity.

Per past review: remove the initial args[0].shape attempt and start directly with the array-finding loop:

 def _resolve_launch_dims(call_args):
-    s = getattr(call_args[0], "shape", None)
-    if s is not None:
-        return s
     param_ann = {p.name: p.annotation for p in parameters[:num_inputs]}
     for i in range(num_inputs):

This simplifies the logic and handles scalar-first kernels correctly from the start.

🧹 Nitpick comments (1)
warp/jax_experimental/ffi.py (1)

218-228: Consider extracting shape validation to a shared helper.

The updated validation correctly allows leading batch dimensions under vmap while ensuring trailing dtype dimensions match. However, this logic is duplicated in both FfiKernel.__call__ and FfiCallable.__call__.

Consider extracting to a helper function:

def validate_array_shape(input_arg: FfiArg, input_value, allow_batch_dims: bool = True):
    """Validate array shape, optionally allowing leading batch dimensions."""
    if input_value.dtype != input_arg.jax_scalar_type:
        raise TypeError(
            f"Invalid data type for array argument '{input_arg.name}', "
            f"expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
        )
    
    if input_value.ndim < input_arg.jax_ndim:
        raise TypeError(
            f"Invalid dimensionality for array argument '{input_arg.name}', "
            f"expected at least {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
        )
    
    # Check inner (dtype) dims at the end
    for d in range(input_arg.dtype_ndim):
        if input_value.shape[-input_arg.dtype_ndim + d] != input_arg.dtype_shape[d]:
            raise TypeError(
                f"Invalid inner dimensions for array argument '{input_arg.name}', "
                f"expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim:]}"
            )

Then use validate_array_shape(input_arg, input_value) in both locations.

Also applies to: 538-548

📜 Review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 61c0a99 and 5fc4fe7.

📒 Files selected for processing (1)
  • warp/jax_experimental/ffi.py (16 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/jax_experimental/ffi.py (2)
warp/types.py (5)
  • strides_from_shape (2174-2185)
  • dtype (4615-4621)
  • array_t (1697-1721)
  • array (2256-3547)
  • array (4442-4448)
warp/context.py (4)
  • kernel (1069-1141)
  • func (853-883)
  • launch (6111-6152)
  • launch (6155-6387)
🪛 Ruff (0.13.3)
warp/jax_experimental/ffi.py

220-222: Avoid specifying long messages outside the exception class

(TRY003)


540-542: Avoid specifying long messages outside the exception class

(TRY003)


867-867: Avoid specifying long messages outside the exception class

(TRY003)


895-895: Use raise without specifying exception name

Remove exception name

(TRY201)


1242-1242: Avoid specifying long messages outside the exception class

(TRY003)


1249-1249: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (7)
warp/jax_experimental/ffi.py (7)

18-18: LGTM: Import and registry additions are well-structured.

The inspect import enables runtime signature introspection for the differentiable path. The DiffKernelCacheKey type alias clearly documents the cache key structure, and the separate _FFI_DIFF_KERNEL_REGISTRY prevents conflicts between differentiable and non-differentiable kernel paths.

Also applies to: 33-34, 44-44


238-238: LGTM: Correct usage of get_warp_shape.

The calls to get_warp_shape correctly extract core Warp dimensions from JAX shapes that may include leading batch dimensions (from vmap) and trailing dtype dimensions (from vector/matrix types).

Also applies to: 558-558


346-353: LGTM: FFI callback correctly handles batch and dtype dimensions.

The callback properly excludes leading batch dimensions (from vmap) and trailing dtype dimensions when constructing array_t for the kernel launch. The calculation batch_ndim = max(0, rank - dtype_ndim - warp_ndim) and the slice buffer.dims[batch_ndim : batch_ndim + warp_ndim] correctly extract the core Warp dimensions.

Also applies to: 366-373


763-764: LGTM: New differentiable parameters and constraints are well-documented.

The differentiable flag and static_argnames parameter enable the AD path. The restriction that in_out_argnames is not supported with differentiable=True is clearly documented in both the docstring and enforced at runtime.

Also applies to: 782-785, 818-821


907-1033: LGTM: Custom VJP implementation follows JAX conventions.

The differentiable path correctly implements JAX's custom_vjp API:

  • Forward function returns (outputs, residuals) where residuals include non-static inputs and outputs
  • Backward function reconstructs the full input/output/gradient argument list for the adjoint kernel
  • Static arguments are properly excluded from differentiation via nondiff_argnums
  • Differentiable wrappers are cached to avoid redundant construction
  • Output dimensions are correctly inferred from input array annotations

The complexity is warranted for supporting full AD semantics with static arguments, vmap, and multi-output kernels.


1070-1072: LGTM: Documentation and helper function updates are consistent.

The jax_callable docstring update aligns with jax_kernel, and the get_warp_shape refactor correctly handles:

  • Vector/matrix arrays: extracts core Warp dims positioned before trailing dtype dims
  • Scalar arrays: extracts the last warp_ndim dims (excluding leading batch dims)

The logic properly supports vmap's leading batch dimensions across both array types.

Also applies to: 1237-1250


1253-1276: LGTM: Output type inference correctly supports batch dimensions.

The updated conditions:

  • Line 1263: ndim >= arg.jax_ndim allows leading batch dimensions while validating inner dtype dimensions
  • Line 1274: ndim < arg.warp_ndim correctly requires at least the core Warp dimensions

Both changes consistently support vmap's leading batch dimensions in output arrays.

@mehdiataei mehdiataei force-pushed the jax_ad_ffi branch 3 times, most recently from 4fee379 to e6fa8d7 Compare October 8, 2025 20:56
pull bot pushed a commit to mcx/warp that referenced this pull request Oct 17, 2025
Add automatic differentiation support with jax_kernel (NVIDIAGH-515, NVIDIAGH-912)

See merge request omniverse/warp!1682
@nvlukasz
Copy link
Contributor

This is now merged in b66f253. Thank you @mehdiataei for this great contribution!

@nvlukasz nvlukasz closed this Oct 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants