-
Notifications
You must be signed in to change notification settings - Fork 373
added JAX interop adjoint support using FFI #912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this!
I left one comment, but another overall comment is whether we can get a test of this working with multiple devices? I think the older way to do this would be within a jax.pmap
but there may be newer ways.
Thank John! I have added an example to the doc with shard_map here for the forward pass, which as far as I know is the most up-to-date way of doing this (maybe @nouiz can confirm): https://nvidia.github.io/warp/modules/interoperability.html#distributed-computation I can add an example to show backward pass works with shard_map as well (I don't see a reason why not). |
shard_map is the right manual way of doing sharding. |
FYI, I'm trying this on a two GPU machine and getting an error:
This fails with:
Notably, if I hide one of the GPUs (with |
Hey @johnpjf import warp as wp
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather as allgather
from jax.experimental.shard_map import shard_map
from warp.jax_experimental.ffi import jax_ad_kernel
jax.distributed.initialize()
# Kernels
@wp.kernel
def multiply_by_two_kernel(a_in: wp.array(dtype=wp.float32), a_out: wp.array(dtype=wp.float32)):
index = wp.tid()
a_out[index] = a_in[index] * 2.0
@wp.kernel
def scale_sum_square_kernel(
a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32), s: float, out: wp.array(dtype=wp.float32)
):
tid = wp.tid()
out[tid] = (a[tid] * s + b[tid]) ** 2.0
# multi-output kernel
@wp.kernel
def multi_output_kernel(
a: wp.array(dtype=wp.float32),
b: wp.array(dtype=wp.float32),
s: float,
c: wp.array(dtype=wp.float32),
d: wp.array(dtype=wp.float32),
):
tid = wp.tid()
c[tid] = a[tid] ** 2.0
d[tid] = a[tid] * b[tid] * s
# JAX callables with AD
jax_mul = jax_ad_kernel(multiply_by_two_kernel, num_outputs=1)
jax_scale = jax_ad_kernel(scale_sum_square_kernel, num_outputs=1, static_argnames=("s",))
jax_mo = jax_ad_kernel(multi_output_kernel, num_outputs=2, static_argnames=("s",))
def warp_multiply(x):
(out,) = jax_mul(x)
return out
def warp_scale_sum_square(a, b, s):
(out,) = jax_scale(a, b, s)
return out
def warp_multi_output(a, b, s):
c, d = jax_mo(a, b, s)
return c, d
def make_sharded_1d(mesh, n):
sharding = jax.sharding.NamedSharding(mesh, P("x"))
a = jnp.arange(n, dtype=jnp.float32)
shape = (n,)
arrays = [jax.device_put(a[idx], d) for d, idx in sharding.addressable_devices_indices_map(shape).items()]
return jax.make_array_from_single_device_arrays(shape, sharding, arrays)
def example_mul2_backward(mesh):
n = jax.device_count() * 5
x = make_sharded_1d(mesh, n)
def loss(x):
y = shard_map(lambda v: warp_multiply(v), mesh=mesh, in_specs=(P("x"),), out_specs=P("x"), check_rep=False)(x)
return jnp.sum(y)
g = jax.grad(loss)(x)
gf = allgather(g)
if jax.process_index() == 0:
ref = np.full(n, 2.0, dtype=np.float32)
np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
print("mul2: OK")
def example_scale_sum_square_backward(mesh):
n = jax.device_count() * 6
a = make_sharded_1d(mesh, n)
b = make_sharded_1d(mesh, n)
s = 1.5
def loss(a, b):
y = shard_map(
lambda aa, bb: warp_scale_sum_square(aa, bb, s),
mesh=mesh,
in_specs=(P("x"), P("x")),
out_specs=P("x"),
check_rep=False,
)(a, b)
return jnp.sum(y)
da, db = jax.grad(loss, argnums=(0, 1))(a, b)
daf, dbf = allgather(da), allgather(db)
if jax.process_index() == 0:
a_np = np.arange(n, dtype=np.float32)
b_np = np.arange(n, dtype=np.float32)
ref_da = 2.0 * (a_np * s + b_np) * s
ref_db = 2.0 * (a_np * s + b_np)
np.testing.assert_allclose(np.asarray(daf), ref_da, rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(np.asarray(dbf), ref_db, rtol=1e-5, atol=1e-6)
print("scale_sum_square: OK")
def example_mul2_with_psum(mesh):
n = jax.device_count() * 4
x = make_sharded_1d(mesh, n)
def body(v):
local = warp_multiply(v)
red = jnp.sum(local)
total = jax.lax.psum(red, "x")
return local, total
def loss(x):
local, total = shard_map(body, mesh=mesh, in_specs=(P("x"),), out_specs=(P("x"), P()), check_rep=False)(x)
return jnp.sum(local) + total
g = jax.grad(loss)(x)
gf = allgather(g)
if jax.process_index() == 0:
ref = np.full(n, 4.0, dtype=np.float32)
np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
print("mul2 + psum: OK")
def example_psum_mean(mesh):
n = jax.device_count() * 3
x = make_sharded_1d(mesh, n)
def body(v):
local = warp_multiply(v)
red = jnp.sum(local)
total = jax.lax.psum(red, "x")
mean = total / n
return local, mean
def loss(x):
local, mean = shard_map(body, mesh=mesh, in_specs=(P("x"),), out_specs=(P("x"), P()), check_rep=False)(x)
return jnp.sum(local) + mean
g = jax.grad(loss)(x)
gf = allgather(g)
if jax.process_index() == 0:
# d/dx sum(2x) = 2; d/dx mean(2x) = 2/n per element
ref = np.full(n, 2.0 + 2.0 / n, dtype=np.float32)
np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
print("mul2 + psum mean: OK")
def example_multi_output_with_psum(mesh):
n = jax.device_count() * 5
a = make_sharded_1d(mesh, n)
b = make_sharded_1d(mesh, n)
s = 2.0
def body(aa, bb):
c, d = warp_multi_output(aa, bb, s)
local = c + d
total = jax.lax.psum(jnp.sum(local), "x")
return local, total
def loss(a, b):
local, total = shard_map(body, mesh=mesh, in_specs=(P("x"), P("x")), out_specs=(P("x"), P()), check_rep=False)(
a, b
)
return jnp.sum(local) + total
da, db = jax.grad(loss, argnums=(0, 1))(a, b)
daf, dbf = allgather(da), allgather(db)
if jax.process_index() == 0:
a_np = np.arange(n, dtype=np.float32)
b_np = np.arange(n, dtype=np.float32)
base_da = 2.0 * a_np + b_np * s
base_db = a_np * s
ref_da = 2.0 * base_da
ref_db = 2.0 * base_db
np.testing.assert_allclose(np.asarray(daf), ref_da, rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(np.asarray(dbf), ref_db, rtol=1e-5, atol=1e-6)
print("multi_output + psum: OK")
def main():
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), "x")
example_mul2_backward(mesh)
example_scale_sum_square_backward(mesh)
example_mul2_with_psum(mesh)
example_psum_mean(mesh)
example_multi_output_with_psum(mesh)
if __name__ == "__main__":
main() You can run them with |
Ok, is the intent to fix this multi-GPU issue before checking in?
…On Tue, Aug 26, 2025 at 2:54 PM Mehdi Ataei ***@***.***> wrote:
*mehdiataei* left a comment (NVIDIA/warp#912)
<#912 (comment)>
Hey @johnpjf <https://github.com/johnpjf>
Yes, if the examples are not executed with mpirun, we get that error. I’m
not sure what the root cause is, but it’s probably a stream, event,
function, or module created on GPU 0 being accidentally used while GPU 1 is
current. Running with mpirun fixes it. I made the following test cases to
check whether the grad works with shardmap, and they all pass successfully.
However, I’m having trouble merging these tests with the current Warp test
cases since they rely on mpirun.
`import warp as wp
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather as allgather
from jax.experimental.shard_map import shard_map
from warp.jax_experimental.ffi import jax_ad_kernel
jax.distributed.initialize()
Kernels
@wp.kernel
def multiply_by_two_kernel(a_in: wp.array(dtype=wp.float32), a_out:
wp.array(dtype=wp.float32)):
index = wp.tid()
a_out[index] = a_in[index] * 2.0
@wp.kernel
def scale_sum_square_kernel(
a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32), s: float,
out: wp.array(dtype=wp.float32)
):
tid = wp.tid()
out[tid] = (a[tid] * s + b[tid]) ** 2.0
multi-output kernel
@wp.kernel
def multi_output_kernel(
a: wp.array(dtype=wp.float32),
b: wp.array(dtype=wp.float32),
s: float,
c: wp.array(dtype=wp.float32),
d: wp.array(dtype=wp.float32),
):
tid = wp.tid()
c[tid] = a[tid] ** 2.0
d[tid] = a[tid] * b[tid] * s
JAX callables with AD
jax_mul = jax_ad_kernel(multiply_by_two_kernel, num_outputs=1)
jax_scale = jax_ad_kernel(scale_sum_square_kernel, num_outputs=1,
static_argnames=("s",))
jax_mo = jax_ad_kernel(multi_output_kernel, num_outputs=2,
static_argnames=("s",))
def warp_multiply(x):
(out,) = jax_mul(x)
return out
def warp_scale_sum_square(a, b, s):
(out,) = jax_scale(a, b, s)
return out
def warp_multi_output(a, b, s):
c, d = jax_mo(a, b, s)
return c, d
def make_sharded_1d(mesh, n):
sharding = jax.sharding.NamedSharding(mesh, P("x"))
a = jnp.arange(n, dtype=jnp.float32)
shape = (n,)
arrays = [jax.device_put(a[idx], d) for d, idx in
sharding.addressable_devices_indices_map(shape).items()]
return jax.make_array_from_single_device_arrays(shape, sharding, arrays)
def example_mul2_backward(mesh):
n = jax.device_count() * 5
x = make_sharded_1d(mesh, n)
def loss(x):
y = shard_map(lambda v: warp_multiply(v), mesh=mesh, in_specs=(P("x"),), out_specs=P("x"), check_rep=False)(x)
return jnp.sum(y)
g = jax.grad(loss)(x)
gf = allgather(g)
if jax.process_index() == 0:
ref = np.full(n, 2.0, dtype=np.float32)
np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
print("mul2: OK")
def example_scale_sum_square_backward(mesh):
n = jax.device_count() * 6
a = make_sharded_1d(mesh, n)
b = make_sharded_1d(mesh, n)
s = 1.5
def loss(a, b):
y = shard_map(
lambda aa, bb: warp_scale_sum_square(aa, bb, s),
mesh=mesh,
in_specs=(P("x"), P("x")),
out_specs=P("x"),
check_rep=False,
)(a, b)
return jnp.sum(y)
da, db = jax.grad(loss, argnums=(0, 1))(a, b)
daf, dbf = allgather(da), allgather(db)
if jax.process_index() == 0:
a_np = np.arange(n, dtype=np.float32)
b_np = np.arange(n, dtype=np.float32)
ref_da = 2.0 * (a_np * s + b_np) * s
ref_db = 2.0 * (a_np * s + b_np)
np.testing.assert_allclose(np.asarray(daf), ref_da, rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(np.asarray(dbf), ref_db, rtol=1e-5, atol=1e-6)
print("scale_sum_square: OK")
def example_mul2_with_psum(mesh):
n = jax.device_count() * 4
x = make_sharded_1d(mesh, n)
def body(v):
local = warp_multiply(v)
red = jnp.sum(local)
total = jax.lax.psum(red, "x")
return local, total
def loss(x):
local, total = shard_map(body, mesh=mesh, in_specs=(P("x"),), out_specs=(P("x"), P()), check_rep=False)(x)
return jnp.sum(local) + total
g = jax.grad(loss)(x)
gf = allgather(g)
if jax.process_index() == 0:
ref = np.full(n, 4.0, dtype=np.float32)
np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
print("mul2 + psum: OK")
def example_psum_mean(mesh):
n = jax.device_count() * 3
x = make_sharded_1d(mesh, n)
def body(v):
local = warp_multiply(v)
red = jnp.sum(local)
total = jax.lax.psum(red, "x")
mean = total / n
return local, mean
def loss(x):
local, mean = shard_map(body, mesh=mesh, in_specs=(P("x"),), out_specs=(P("x"), P()), check_rep=False)(x)
return jnp.sum(local) + mean
g = jax.grad(loss)(x)
gf = allgather(g)
if jax.process_index() == 0:
# d/dx sum(2x) = 2; d/dx mean(2x) = 2/n per element
ref = np.full(n, 2.0 + 2.0 / n, dtype=np.float32)
np.testing.assert_allclose(np.asarray(gf), ref, rtol=1e-5, atol=1e-6)
print("mul2 + psum mean: OK")
def example_multi_output_with_psum(mesh):
n = jax.device_count() * 5
a = make_sharded_1d(mesh, n)
b = make_sharded_1d(mesh, n)
s = 2.0
def body(aa, bb):
c, d = warp_multi_output(aa, bb, s)
local = c + d
total = jax.lax.psum(jnp.sum(local), "x")
return local, total
def loss(a, b):
local, total = shard_map(body, mesh=mesh, in_specs=(P("x"), P("x")), out_specs=(P("x"), P()), check_rep=False)(
a, b
)
return jnp.sum(local) + total
da, db = jax.grad(loss, argnums=(0, 1))(a, b)
daf, dbf = allgather(da), allgather(db)
if jax.process_index() == 0:
a_np = np.arange(n, dtype=np.float32)
b_np = np.arange(n, dtype=np.float32)
base_da = 2.0 * a_np + b_np * s
base_db = a_np * s
ref_da = 2.0 * base_da
ref_db = 2.0 * base_db
np.testing.assert_allclose(np.asarray(daf), ref_da, rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(np.asarray(dbf), ref_db, rtol=1e-5, atol=1e-6)
print("multi_output + psum: OK")
def main():
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), "x")
example_mul2_backward(mesh)
example_scale_sum_square_backward(mesh)
example_mul2_with_psum(mesh)
example_psum_mean(mesh)
example_multi_output_with_psum(mesh)
if *name* == "*main*":
main()
`
You can run them with
mpirun -np 2 python test.py
—
Reply to this email directly, view it on GitHub
<#912 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AHLB5C5Q5BAELOBHTAPMFST3PTJPXAVCNFSM6AAAAACD5LNHJWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTEMRVHA2DINZUGI>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
It is fairly common to run multi-GPU workloads with multiple processes in JAX even on a single machine, so I don’t think this is a blocking issue (and this limitation existed before this PR). I have some ideas about the root cause and a potential fix, so let me take a shot at it first. |
Yes, multi-GPU training within a pmap is essentially a requirement for all of our workloads so it would be great to get a fix for this. Thanks |
@johnpjf Good news. I was able to resolve the CUDA 400 issue by launching FFI on callframe CUDA stream and preload modules across GPUs to prevent per-device race conditions. With these changes in place, your example now runs correctly without requiring the use of This fix also made it possible to integrate both the Can you give it a try? Thanks. |
Thanks, but I don't think it's working for me yet:
Outputs:
So it seems like something went wrong on the second GPU. |
Hmm I can't reproduce it:
My code: import jax
import jax.numpy as jnp
import warp.jax_experimental as jax_experimental
import unittest
import warp as wp
@wp.kernel
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
tid = wp.tid()
output[tid] = 3.0 * input[tid]
class WarpDeferredTest(unittest.TestCase):
def test_give_me_a_name(self):
jax_triple = jax_experimental.jax_ad_kernel(triple_kernel)
def f(x):
return jax_triple(x)
@jax.pmap
def compute_grads(x):
def loss_fn(x):
return jnp.sum(jnp.square(f(x)[0]))
loss, grads = jax.value_and_grad(
loss_fn, has_aux=False)(x)
return loss, grads
x = jnp.array(((1.0, 2.0), (1.0, 2.0)))
print('input: ', x)
loss, grads = compute_grads(x)
print('loss: ', loss)
print('grads: ', grads) Can you delete the kernel cache folder and try again? |
Tried that, but didn't fix it. Looks like we have a copy of warp in our
monorepo that's from 8/7, is it possible that your change depends on
something since then. I'll try to get ours updated and try again tomorrow.
Thanks again for working on this!
…On Tue, Aug 26, 2025 at 4:56 PM Mehdi Ataei ***@***.***> wrote:
*mehdiataei* left a comment (NVIDIA/warp#912)
<#912 (comment)>
Hmm I can't reproduce it:
image.png (view on web)
<https://github.com/user-attachments/assets/de64be2b-95b6-4f27-8e6c-3b400764aa7c>
Can you delete the kernel cache folder and try again?
—
Reply to this email directly, view it on GitHub
<#912 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AHLB5C6YUW7SS4NDXANYS4D3PTX2RAVCNFSM6AAAAACD5LNHJWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTEMRWGE2TCNRYHA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Ok, turns out I have a bad GPU, pmap works great on a different dual-GPU machine, thanks! |
Great! I think for JAX to work you need to have identical GPUs (maybe that's the cause) |
Exact. You can't mix different kinds of GPUs. This isn't supported. |
Yeah, it wasn't that, I have identical GPUs and have been happily pmap-ing for years and now it just stopped working and giving all zeros for the other shard, even in straight JAX. |
I tried this PR with this code google-deepmind/mujoco_warp#475 (comment)
To repro, I modified this file directly with the implementation in this PR. Any ideas what could be going wrong? |
@btaba could you share more details on the WIP ? I played around with this a bit over the weekend and it looks like the issue comes from how modules are being loaded in the “third_party” FFI (in Mujoco). I made some progress but was planning to dig deeper into it this week. Is this the issue you’re working on fixing? |
I'm not actively working on it, just wanted to take a stab to try to get to a semi-working state. If you have some cycles, please let us know what you find! There is currently a race condition, and data on one device is corrupt. Other than that, the code at least runs with pmap (occasionally). I'm a bit confused by this PR as |
warp/jax_experimental/ffi.py
Outdated
Args: | ||
kernel: The Warp kernel to wrap. | ||
num_outputs: Number of output arrays produced by the kernel. | ||
static_argnames: Optional iterable of argument names that should be treated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems there is some confusion in this function between static-for-jit arguments and arguments that you don't want the gradients for. I can have a jax array as an input that is not static-for-jit but that I don't want gradients for. For example, when ray tracing, I don't want gradients of rays, but I don't want them to be static for jit (i.e. they are regular JAX Arrays).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I don't think in JAX you can pass jax-arrays as static argnums....? It was the case before and I believe it should the case now as arrays are not hashable. Is this still true @nouiz?
Maybe I don’t fully understand your request. Could you provide an example of something we currently do not support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think my comment was a little wrong, but I do think this code slightly conflates static args for jitting and for custom gradients. I don't really think this code should be concerned with jit-ing at all, static args for this might not be (and usually won't be!) static args for jit.
For specific changes:
Also, you should call these nondiff_argnums
to match JAX's terminology. https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html
I also wouldn't try to be clever and use heuristics to figure what should or should not be nondiff_argnums
, that could cause silent bugs where someone forgets to wrap a float in a jnp.array
.
I think we should also allow for JAX input array's that shouldn't have gradients computed (i.e. as if we had wrapped them in jax.stop_gradient
). I don't know what to call these or if Warp really supports not computing gradients correctly.
warp/jax_experimental/ffi.py
Outdated
as static (non-differentiable) in JAX. If None, scalar (non-array) inputs | ||
are treated as static by default. | ||
vmap_method: How the callback transforms under jax.vmap. | ||
launch_dim_arg_index: Index (in the kernel's argument list) of the input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Sorry was in wrong place before)
I don't this will work for everyone, for example if I want to do something for every pixel in an RGB image, then I want the kernel to be over the height,width but not channel dimension. Also, this is different than the existing jax_kernel interface. I wonder if it's better to force providing these explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we do currently handle this when channels are represented as dtype dims (e.g., wp.vec3) and appear as the trailing axis in the JAX array. I pushed a test to demonstrate it.
Caveat: this works when channels are encoded as dtype dims (e.g., vec3) and are trailing in the JAX tensor. If someone insists on modeling channels as an extra array axis in the Warp type (e.g., wp.array3d(dtype=float)), then the kernel would launch over H×W×C. If we need to support that representation while still launching over H×W, we can extend jax_ad_kernel to accept an explicit launch_dims or launch_dims_axes parameter and update both forward/backward wrappers accordingly.
Please let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I just don't think it's a good idea to hard code that the launch dims must batch one of the input dims, you can definitely think of examples where that wouldn't be true.
e0f325f
to
d2693bc
Compare
Thanks. I did try to fix the issue in Mujoco. I had a bit of success but it seems like there are multiple versions of warp in that repo that needs to be synced based on the same ffi implementation. The use of wp.get_device("cuda") in the FFI path is that inside the FFI callback we bind the Warp stream to the CUDA stream that XLA passes us and rely on the current CUDA context that XLA sets for the shard. That guarantees the stream/device match across pmap replicas and avoids subtle mismatches that can occur when re-deriving the device from a JAX device handle or from a buffer pointer. Note that the stream itself comes from XLA and aot we preload everything to avoid build races. If you try pmap with this Warp branch, it works quite well on multi-GPU devices. All the changes above were made in response to different issues that came up during the implementation. |
@mehdiataei I'm also a bit confused, the graph_mode is no longer respected in warp/warp/jax_experimental/ffi.py Lines 729 to 738 in 6673762
and the JAX graph mode is removed? |
the above worked for me can you try? |
Thanks @mehdiataei , I just tried with the latest branch but still get zeros on the second device. I'll have to try on another machine; do you have any ideas otherwise? |
@btaba @mehdiataei, I have a pmap solution based on GH-976 from @chaserileyroberts. Thanks Chase! It's currently under review by the Warp team, but I expect it to be merged shortly. I had to add some additional thread safety, but the Mujoco example works fine now. That should take care of pmap, so once that's merged, @mehdiataei you can modify this PR to only add the autodiff functionality. Thank you all for your contributions and patience! |
hmm no idea. I did follow thoese steps again and it works fine with me. Is it working correctly with pure jax and pmap? @nvlukasz That's great! Yes. I can merge those changes and add the adjoint stuff. |
Warning Rate limit exceeded@mehdiataei has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 20 minutes and 53 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (3)
📝 WalkthroughWalkthroughAdds a differentiable JAX path for Warp FFI kernels: implements a custom-VJP differentiable wrapper cached separately, updates warp-aware shape/launch-dim inference and runtime handling, extends Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User code
participant JK as jax_kernel / jax_callable
participant REGd as _FFI_DIFF_KERNEL_REGISTRY
participant REG as _FFI_KERNEL_REGISTRY
participant WK as Warp kernel
participant AK as Adjoint kernel
Note over U,JK: create callable (maybe differentiable)
U->>JK: jax_kernel(kernel, differentiable=True, static_argnames)
JK->>REGd: lookup DiffKernelCacheKey
alt cache miss
JK->>JK: inspect signature & static args
JK->>JK: build forward wrapper (launch dims, warp shapes)
JK->>JK: build backward (custom_vjp) that calls adjoint
JK->>REGd: cache diff wrapper
else cache hit
REGd-->>JK: return cached wrapper
end
JK-->>U: return JAX-callable
Note over U,WK: forward call
U->>JK: call(args)
JK->>WK: infer warp dims, launch kernel
WK-->>JK: outputs
JK-->>U: outputs
Note over U,AK: reverse pass
U->>JK: JAX requests grad
JK->>AK: call adjoint with cotangents (warp-aware shapes)
AK-->>JK: input gradients
JK-->>U: return gradients
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
warp/jax_experimental/ffi.py (2)
1465-1472
: Fix warp-shape inference under vmap (use trailing dims, not leading).Current
get_warp_shape()
returns the firstwarp_ndim
dims, which is wrong when leading batch dims exist. Use the lastwarp_ndim
dims before dtype dims. This affects launch dims and output type construction.-def get_warp_shape(arg, dims): - if arg.dtype_ndim > 0: - # vector/matrix array - return dims[: arg.warp_ndim] - else: - # scalar array - return dims +def get_warp_shape(arg, dims): + # dims may include leading batch dims and trailing dtype dims + jax_rank = len(dims) - (arg.dtype_ndim if arg.dtype_ndim > 0 else 0) + if jax_rank < arg.warp_ndim: + raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}") + start = jax_rank - arg.warp_ndim + end = jax_rank + return tuple(dims[start:end])
324-347
: Enforce stream/pointer device-ordinal agreement in FfiKernel callback.
FfiKernel.ffi_callback
derives device from the XLA stream but doesn’t validate array pointer ordinals. Under pmap this risks cross-device launches. MirrorFfiCallable
logic: prefer pointer ordinal; only bind the XLA stream when ordinals match.- # get stream and derive device from stream to be replica-local under pmap - stream_handle = get_stream_from_callframe(call_frame.contents) - try: - ordinal = wp.context.runtime.core.wp_cuda_stream_get_device_ordinal(stream_handle) - device = wp.get_cuda_device(ordinal) - except Exception: - device = wp.device_from_jax(get_jax_device()) - stream = wp.Stream(device, cuda_stream=stream_handle) + # get stream and derive device/stream safely under pmap + stream_handle = get_stream_from_callframe(call_frame.contents) + try: + stream_ord = wp.context.runtime.core.wp_cuda_stream_get_device_ordinal(stream_handle) + except Exception: + stream_ord = -1 + # get first array input pointer ordinal (if any) + ptr_ord = -1 + for i, arg in enumerate(self.input_args): + if arg.is_array: + buf = inputs[i].contents + ptr_ord = wp.context.runtime.core.wp_cuda_pointer_get_device_ordinal(ctypes.c_void_p(buf.data)) + break + ord_to_use = ptr_ord if ptr_ord >= 0 else stream_ord + if ord_to_use < 0: + device = wp.device_from_jax(get_jax_device()) + stream = device.stream # fallback to device default stream + else: + device = wp.get_cuda_device(ord_to_use) + # Bind XLA stream only if ordinals agree; otherwise use device default stream + stream = ( + wp.Stream(device, cuda_stream=stream_handle) + if (ptr_ord >= 0 and ptr_ord == stream_ord) + else device.stream + )warp/native/warp.cu (2)
2398-2401
: Compile error: missing parentheses in if-statementC++ requires parentheses around the condition.
Apply:
- if check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS)) + if (check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS))) return (void*) device_ptr; else return NULL;
2539-2543
: Use check_cu for CUDA Driver API calls (cu)*These lines use check_cuda with cu* functions; switch to check_cu for accurate error handling.
Apply (illustrative diffs):
- check_cuda(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority)); + check_cu(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority));- if (!check_cuda(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL))) + if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))Also applies to: 3062-3066, 3113-3116, 3221-3224, 3264-3267
♻️ Duplicate comments (1)
warp/tests/interop/test_jax.py (1)
352-375
: Addresses prior feedback: added jit(grad(...)) teststest_jax_ad_kernel_jit_of_grad_simple and ..._multi_output cover jit-of-grad usage. Looks good.
🧹 Nitpick comments (8)
warp/context.py (1)
5578-5595
: Tighten exceptions and guard CUDA-only pointer-ordinal lookupGood diagnostics, but avoid blind excepts and skip pointer-ordinal queries on non‑CUDA or null ptr. Also use getattr for ext stream safely.
Apply this diff:
- # check device; provide detailed diagnostics on mismatch - if value.device != device: - try: - ext_stream = _get_external_stream() - ext_dev = ext_stream.device if ext_stream is not None else None - except Exception: - ext_dev = None - try: - ptr_ord = runtime.core.wp_cuda_pointer_get_device_ordinal(ctypes.c_void_p(value.ptr)) if value.ptr is not None else -1 - except Exception: - ptr_ord = -1 + # check device; provide detailed diagnostics on mismatch + if value.device != device: + ext_stream = _get_external_stream() + ext_dev = getattr(ext_stream, "device", None) + ptr_ord = -1 + if getattr(device, "is_cuda", False) and value.ptr is not None: + try: + ptr_ord = runtime.core.wp_cuda_pointer_get_device_ordinal(ctypes.c_void_p(value.ptr)) + except (AttributeError, OSError, ValueError): + ptr_ord = -1 raise RuntimeError( ( f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', " f"but input array for argument '{arg_name}' is on device={value.device}. " f"[debug ext_stream_device={ext_dev}, ptr_device_ordinal={ptr_ord}]" ) )Note: This also resolves Ruff BLE001/S110 warnings. [As per static analysis hints]
warp/examples/optim/example_inverse_kinematics_jax.py (2)
69-69
: Avoid double-wrapping when creating JAX arrays.Construct directly with JAX to skip the intermediate NumPy array.
- self.target = jp.array(np.array((2.0, 1.0, 0.0), dtype=np.float32)) + self.target = jp.array((2.0, 1.0, 0.0), dtype=jp.float32)
192-201
: Narrow the exception and avoid silent pass when setting env vars.Only import errors should be caught here; consider logging unexpected failures.
- os_env = {} - try: - import os as _os - - os_env = _os.environ - os_env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - os_env["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - except Exception: - pass + try: + import os as _os + _os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + _os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" + except ImportError: + # JAX will use defaults; nothing to configure + passwarp/jax_experimental/ffi.py (4)
30-30
: Importget_stream_from_callframe
explicitly.Silences F405 and makes dependency clear.
-from .xla_ffi import * +from .xla_ffi import * +from .xla_ffi import get_stream_from_callframe
732-736
: Remove unusedcall_desc.capture
assignment.This attribute is never read; use
captures
dict only to avoid confusion.- if ffi_stream.is_capturing: - with wp.ScopedCapture(external=True) as capture: - self.func(*arg_list) - call_desc.capture = capture + if ffi_stream.is_capturing: + with wp.ScopedCapture(external=True) as capture: + self.func(*arg_list)
476-487
: Allow “non‑diff” JAX arrays separately from “static” args; avoid heuristic defaults.Current logic conflates nondifferentiable args with jit-static args and uses a heuristic to treat non-arrays as nondiff. This prevents common patterns (e.g., pass JAX arrays that shouldn’t receive grads but aren’t static for jit).
- Add a
nondiff_argnames
parameter distinct fromstatic_argnames
.- Drop the heuristic; require explicit
nondiff_argnames
for clarity.- Keep
static_argnames
purely for jitting.- Continue to pass
nondiff_argnums
tocustom_vjp
.Would you like a follow-up patch adding
nondiff_argnames
to bothjax_callable
andjax_ad_kernel
and updating call sites?Also applies to: 882-907
236-251
: Module preloading across GPUs is good; consider caching to avoid repeated loads.
module.load(dev)
is idempotent but may still cost. Cache per (module, device) to skip redundant calls during tracing.If you want, I can wire a simple per-process cache keyed by
(module.name, device.ordinal)
and guard with_FFI_REGISTRY_LOCK
.Also applies to: 525-556
docs/modules/interoperability.rst (1)
956-1090
: AD section reads well and matches API semanticsExamples and notes are clear (statics, VMAP options, multi-output). Consider adding a short cross-link near “JAX Foreign Function Interface (FFI)” pointing here, plus an anchor (e.g., .. _jax-ad:) for easier navigation.
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
docs/modules/interoperability.rst
(1 hunks)warp/context.py
(6 hunks)warp/examples/optim/example_inverse_kinematics_jax.py
(1 hunks)warp/jax_experimental/__init__.py
(1 hunks)warp/jax_experimental/ffi.py
(18 hunks)warp/native/warp.cu
(1 hunks)warp/native/warp.h
(1 hunks)warp/tests/interop/test_jax.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
warp/jax_experimental/__init__.py (1)
warp/jax_experimental/ffi.py (1)
jax_ad_kernel
(1162-1445)
warp/examples/optim/example_inverse_kinematics_jax.py (6)
warp/jax_experimental/ffi.py (1)
jax_ad_kernel
(1162-1445)warp/sim/model.py (9)
ModelBuilder
(1103-4821)add_articulation
(1398-1399)joint_count
(1354-1355)add_body
(1598-1657)add_joint_revolute
(1800-1888)add_shape_sphere
(2873-2940)add_shape_box
(2942-3011)articulation_count
(1386-1387)color
(4515-4555)warp/sim/render.py (1)
SimRenderer
(69-424)warp/context.py (2)
zeros
(5162-5187)ones
(5212-5233)warp/jax.py (3)
to_jax
(160-172)from_jax
(175-186)device_to_jax
(19-41)warp/utils.py (1)
ScopedDevice
(1194-1235)
warp/native/warp.h (2)
warp/native/warp.cpp (16)
int
(1003-1003)int
(1004-1004)int
(1007-1007)int
(1010-1010)int
(1013-1013)int
(1014-1014)int
(1016-1016)int
(1017-1017)int
(1018-1018)int
(1019-1019)int
(1020-1020)int
(1021-1021)int
(1022-1022)int
(1036-1036)int
(1037-1037)int
(1041-1041)warp/native/warp.cu (4)
wp_cuda_stream_get_device_ordinal
(2545-2549)wp_cuda_stream_get_device_ordinal
(2545-2545)wp_cuda_pointer_get_device_ordinal
(2551-2569)wp_cuda_pointer_get_device_ordinal
(2551-2551)
warp/jax_experimental/ffi.py (4)
warp/jax.py (2)
device_from_jax
(44-58)get_jax_device
(61-71)warp/context.py (13)
load
(2196-2448)context
(3031-3046)get_cuda_device
(4494-4502)stream
(3054-3063)stream
(3066-3067)Stream
(2690-2834)ScopedExternalStream
(2842-2851)is_capturing
(2827-2829)is_capturing
(3019-3028)get_device
(4325-4339)get_device
(4526-4531)launch
(5862-5903)launch
(5906-6164)warp/jax_experimental/xla_ffi.py (1)
get_stream_from_callframe
(554-561)warp/native/warp.cu (4)
wp_cuda_stream_get_device_ordinal
(2545-2549)wp_cuda_stream_get_device_ordinal
(2545-2545)wp_cuda_pointer_get_device_ordinal
(2551-2569)wp_cuda_pointer_get_device_ordinal
(2551-2551)
warp/native/warp.cu (3)
warp/native/cuda_util.h (2)
get_stream_context
(155-162)get_stream_context
(164-167)warp/native/warp.cpp (1)
wp_cuda_context_get_device_ordinal
(1036-1036)warp/native/cuda_util.cpp (2)
cuPointerGetAttribute_f
(597-600)cuPointerGetAttribute_f
(597-597)
warp/tests/interop/test_jax.py (4)
warp/tests/interop/test_dlpack.py (1)
_jax_version
(28-34)warp/jax_experimental/ffi.py (2)
jax_ad_kernel
(1162-1445)jax_callable
(825-1087)warp/jax.py (1)
device_to_jax
(19-41)warp/tests/unittest_utils.py (1)
add_function_test
(284-303)
warp/context.py (2)
warp/native/warp.cu (4)
wp_cuda_stream_get_device_ordinal
(2545-2549)wp_cuda_stream_get_device_ordinal
(2545-2545)wp_cuda_pointer_get_device_ordinal
(2551-2569)wp_cuda_pointer_get_device_ordinal
(2551-2551)warp/types.py (1)
is_array
(2021-2023)
🪛 Ruff (0.13.3)
warp/examples/optim/example_inverse_kinematics_jax.py
148-148: Unused function argument: body_qd
(ARG001)
199-200: try
-except
-pass
detected, consider logging the exception
(S110)
199-199: Do not catch blind exception: Exception
(BLE001)
warp/jax_experimental/ffi.py
179-181: Avoid specifying long messages outside the exception class
(TRY003)
239-239: Do not catch blind exception: Exception
(BLE001)
248-248: Avoid specifying long messages outside the exception class
(TRY003)
325-325: get_stream_from_callframe
may be undefined, or defined from star imports
(F405)
329-329: Do not catch blind exception: Exception
(BLE001)
478-480: Avoid specifying long messages outside the exception class
(TRY003)
528-528: Do not catch blind exception: Exception
(BLE001)
537-538: try
-except
-continue
detected, consider logging the exception
(S112)
537-537: Do not catch blind exception: Exception
(BLE001)
542-543: try
-except
-pass
detected, consider logging the exception
(S110)
542-542: Do not catch blind exception: Exception
(BLE001)
544-545: try
-except
-pass
detected, consider logging the exception
(S110)
544-544: Do not catch blind exception: Exception
(BLE001)
632-632: Abstract raise
to an inner function
(TRY301)
632-632: Avoid specifying long messages outside the exception class
(TRY003)
661-663: Abstract raise
to an inner function
(TRY301)
661-663: Avoid specifying long messages outside the exception class
(TRY003)
677-679: Abstract raise
to an inner function
(TRY301)
677-679: Avoid specifying long messages outside the exception class
(TRY003)
683-685: Abstract raise
to an inner function
(TRY301)
683-685: Avoid specifying long messages outside the exception class
(TRY003)
700-702: Abstract raise
to an inner function
(TRY301)
700-702: Avoid specifying long messages outside the exception class
(TRY003)
715-717: Abstract raise
to an inner function
(TRY301)
715-717: Avoid specifying long messages outside the exception class
(TRY003)
721-723: Abstract raise
to an inner function
(TRY301)
721-723: Avoid specifying long messages outside the exception class
(TRY003)
927-927: Do not catch blind exception: Exception
(BLE001)
940-940: Avoid specifying long messages outside the exception class
(TRY003)
995-996: try
-except
-pass
detected, consider logging the exception
(S110)
995-995: Do not catch blind exception: Exception
(BLE001)
1222-1222: Do not catch blind exception: Exception
(BLE001)
1285-1286: try
-except
-pass
detected, consider logging the exception
(S110)
1285-1285: Do not catch blind exception: Exception
(BLE001)
1405-1405: Do not catch blind exception: Exception
(BLE001)
1414-1415: try
-except
-pass
detected, consider logging the exception
(S110)
1414-1414: Do not catch blind exception: Exception
(BLE001)
warp/tests/interop/test_jax.py
312-312: Unused function argument: test
(ARG001)
348-348: assert_np_equal
may be undefined, or defined from star imports
(F405)
349-349: assert_np_equal
may be undefined, or defined from star imports
(F405)
353-353: Unused function argument: test
(ARG001)
388-388: assert_np_equal
may be undefined, or defined from star imports
(F405)
389-389: assert_np_equal
may be undefined, or defined from star imports
(F405)
393-393: Unused function argument: test
(ARG001)
429-429: assert_np_equal
may be undefined, or defined from star imports
(F405)
430-430: assert_np_equal
may be undefined, or defined from star imports
(F405)
434-434: Unused function argument: test
(ARG001)
455-455: Unused function argument: s
(ARG001)
477-477: assert_np_equal
may be undefined, or defined from star imports
(F405)
478-478: assert_np_equal
may be undefined, or defined from star imports
(F405)
482-482: Unused function argument: test
(ARG001)
511-511: assert_np_equal
may be undefined, or defined from star imports
(F405)
515-515: Unused function argument: test
(ARG001)
540-540: assert_np_equal
may be undefined, or defined from star imports
(F405)
544-544: Unused function argument: test
(ARG001)
572-572: assert_np_equal
may be undefined, or defined from star imports
(F405)
576-576: Unused function argument: test
(ARG001)
608-608: assert_np_equal
may be undefined, or defined from star imports
(F405)
609-609: assert_np_equal
may be undefined, or defined from star imports
(F405)
613-613: Unused function argument: test
(ARG001)
644-644: assert_np_equal
may be undefined, or defined from star imports
(F405)
645-645: assert_np_equal
may be undefined, or defined from star imports
(F405)
649-649: Unused function argument: test
(ARG001)
674-674: assert_np_equal
may be undefined, or defined from star imports
(F405)
678-678: Unused function argument: test
(ARG001)
711-711: assert_np_equal
may be undefined, or defined from star imports
(F405)
712-712: assert_np_equal
may be undefined, or defined from star imports
(F405)
716-716: Unused function argument: test
(ARG001)
748-748: assert_np_equal
may be undefined, or defined from star imports
(F405)
749-749: assert_np_equal
may be undefined, or defined from star imports
(F405)
753-753: Unused function argument: test
(ARG001)
785-785: assert_np_equal
may be undefined, or defined from star imports
(F405)
786-786: assert_np_equal
may be undefined, or defined from star imports
(F405)
790-790: Unused function argument: test
(ARG001)
822-822: assert_np_equal
may be undefined, or defined from star imports
(F405)
823-823: assert_np_equal
may be undefined, or defined from star imports
(F405)
827-827: Unused function argument: test
(ARG001)
860-860: assert_np_equal
may be undefined, or defined from star imports
(F405)
861-861: assert_np_equal
may be undefined, or defined from star imports
(F405)
865-865: Unused function argument: test
(ARG001)
883-883: Unused function argument: c
(ARG001)
934-934: assert_np_equal
may be undefined, or defined from star imports
(F405)
935-935: assert_np_equal
may be undefined, or defined from star imports
(F405)
938-938: Unused function argument: test
(ARG001)
962-962: assert_np_equal
may be undefined, or defined from star imports
(F405)
966-966: Unused function argument: test
(ARG001)
990-990: assert_np_equal
may be undefined, or defined from star imports
(F405)
1034-1034: Unused function argument: test
(ARG001)
1065-1065: assert_np_equal
may be undefined, or defined from star imports
(F405)
1066-1066: assert_np_equal
may be undefined, or defined from star imports
(F405)
1070-1070: Unused function argument: test
(ARG001)
1112-1112: assert_np_equal
may be undefined, or defined from star imports
(F405)
1113-1113: assert_np_equal
may be undefined, or defined from star imports
(F405)
1116-1116: Unused function argument: device
(ARG001)
1153-1153: Unused function argument: device
(ARG001)
1184-1184: Unused function argument: device
(ARG001)
1218-1218: Unused function argument: device
(ARG001)
1263-1263: Unused function argument: device
(ARG001)
1314-1314: Unused function argument: device
(ARG001)
1363-1363: assert_np_equal
may be undefined, or defined from star imports
(F405)
1364-1364: assert_np_equal
may be undefined, or defined from star imports
(F405)
1368-1368: Unused function argument: device
(ARG001)
1401-1401: Unused function argument: device
(ARG001)
1436-1436: assert_np_equal
may be undefined, or defined from star imports
(F405)
1437-1437: assert_np_equal
may be undefined, or defined from star imports
(F405)
1493-1493: add_function_test
may be undefined, or defined from star imports
(F405)
1497-1497: add_function_test
may be undefined, or defined from star imports
(F405)
1503-1503: add_function_test
may be undefined, or defined from star imports
(F405)
1506-1506: add_function_test
may be undefined, or defined from star imports
(F405)
1507-1507: add_function_test
may be undefined, or defined from star imports
(F405)
1510-1510: add_function_test
may be undefined, or defined from star imports
(F405)
1516-1516: add_function_test
may be undefined, or defined from star imports
(F405)
1522-1522: add_function_test
may be undefined, or defined from star imports
(F405)
1528-1528: add_function_test
may be undefined, or defined from star imports
(F405)
1531-1531: add_function_test
may be undefined, or defined from star imports
(F405)
1537-1537: add_function_test
may be undefined, or defined from star imports
(F405)
1543-1543: add_function_test
may be undefined, or defined from star imports
(F405)
1549-1549: add_function_test
may be undefined, or defined from star imports
(F405)
1555-1555: add_function_test
may be undefined, or defined from star imports
(F405)
1561-1561: add_function_test
may be undefined, or defined from star imports
(F405)
1567-1567: add_function_test
may be undefined, or defined from star imports
(F405)
1574-1574: add_function_test
may be undefined, or defined from star imports
(F405)
1580-1580: add_function_test
may be undefined, or defined from star imports
(F405)
1587-1587: add_function_test
may be undefined, or defined from star imports
(F405)
1593-1593: add_function_test
may be undefined, or defined from star imports
(F405)
1599-1599: add_function_test
may be undefined, or defined from star imports
(F405)
1606-1606: add_function_test
may be undefined, or defined from star imports
(F405)
1612-1612: add_function_test
may be undefined, or defined from star imports
(F405)
1619-1619: add_function_test
may be undefined, or defined from star imports
(F405)
1626-1626: add_function_test
may be undefined, or defined from star imports
(F405)
1632-1632: add_function_test
may be undefined, or defined from star imports
(F405)
1639-1639: add_function_test
may be undefined, or defined from star imports
(F405)
1645-1645: add_function_test
may be undefined, or defined from star imports
(F405)
1652-1652: add_function_test
may be undefined, or defined from star imports
(F405)
warp/context.py
5583-5583: Do not catch blind exception: Exception
(BLE001)
5587-5587: Do not catch blind exception: Exception
(BLE001)
5968-5969: try
-except
-pass
detected, consider logging the exception
(S110)
5968-5968: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (5)
warp/context.py (1)
3822-3826
: CTypes bindings for device-ordinal helpers look correctSignatures match native helpers; this unblocks better diagnostics and device mapping.
warp/examples/optim/example_inverse_kinematics_jax.py (1)
87-92
: Good use of jax_ad_kernel with explicit output dims and launch index.Launch dims taken from mask arg and explicit output dims resolve shape inference; aligns with adjoint support.
warp/jax_experimental/ffi.py (1)
628-756
: Stream/pointer ordinal checks and batch-dim collapse logic look solid.This should address pmap multi-GPU races and shape alignment under vmap.
Please run multi-device tests (pmap/shard_map) with mixed leading batch dims and vector dtypes to confirm no regressions.
warp/jax_experimental/__init__.py (1)
17-17
: Public export looks goodRe-exporting jax_ad_kernel at the package level is convenient and aligns with docs/tests.
warp/native/warp.cu (1)
2545-2569
: New device-ordinal helpers: OKImplementations for stream/pointer → device ordinal are correct and defensive (return -1 on failure/non-device memory).
warp/context.py
Outdated
_EXTERNAL_STREAM_LOCAL = threading.local() | ||
|
||
def _get_external_stream(): | ||
return getattr(_EXTERNAL_STREAM_LOCAL, "stream", None) | ||
|
||
class ScopedExternalStream: | ||
def __init__(self, stream: Stream | None): | ||
self.stream = stream | ||
|
||
def __enter__(self): | ||
_EXTERNAL_STREAM_LOCAL.stream = self.stream | ||
return self.stream | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
_EXTERNAL_STREAM_LOCAL.stream = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ScopedExternalStream is not nest-safe; restore prior TLS value on exit
exit unconditionally sets TLS to None, breaking nested usage or re-entrancy. Save previous value on enter and restore it on exit.
Apply this diff:
class ScopedExternalStream:
def __init__(self, stream: Stream | None):
self.stream = stream
+ self._prev = None
def __enter__(self):
- _EXTERNAL_STREAM_LOCAL.stream = self.stream
- return self.stream
+ self._prev = getattr(_EXTERNAL_STREAM_LOCAL, "stream", None)
+ _EXTERNAL_STREAM_LOCAL.stream = self.stream
+ return self.stream
def __exit__(self, exc_type, exc_val, exc_tb):
- _EXTERNAL_STREAM_LOCAL.stream = None
+ _EXTERNAL_STREAM_LOCAL.stream = self._prev
+ self._prev = None
🤖 Prompt for AI Agents
In warp/context.py around lines 2837 to 2852, ScopedExternalStream.__exit__
unconditionally clears the thread-local stream, which breaks nested or
re-entrant usage; modify ScopedExternalStream so __enter__ saves the previous
value (e.g., prev = _get_external_stream() or getattr from
_EXTERNAL_STREAM_LOCAL) onto the instance and __exit__ restores that saved
previous value instead of setting TLS to None, ensuring the original stream (or
absence) is reinstated when exiting the scope.
def loss_from_fk_outputs(body_q, body_qd): | ||
return jp.sum((body_q[self.model.body_count - 1, 0:3] - self.target) ** 2.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused parameter in local loss helper.
body_qd
is unused. Simplify signature to avoid ARG001.
- def loss_from_fk_outputs(body_q, body_qd):
+ def loss_from_fk_outputs(body_q, _):
return jp.sum((body_q[self.model.body_count - 1, 0:3] - self.target) ** 2.0)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def loss_from_fk_outputs(body_q, body_qd): | |
return jp.sum((body_q[self.model.body_count - 1, 0:3] - self.target) ** 2.0) | |
def loss_from_fk_outputs(body_q, _): | |
return jp.sum((body_q[self.model.body_count - 1, 0:3] - self.target) ** 2.0) |
🧰 Tools
🪛 Ruff (0.13.3)
148-148: Unused function argument: body_qd
(ARG001)
🤖 Prompt for AI Agents
In warp/examples/optim/example_inverse_kinematics_jax.py around lines 148-149,
the local helper loss_from_fk_outputs declares an unused parameter body_qd
causing an ARG001 warning; remove body_qd from the function signature and update
any internal references/call sites to call loss_from_fk_outputs(body_q) instead,
ensuring the function definition and all invocations match the new
single-parameter signature.
warp/native/warp.h
Outdated
WP_API int wp_cuda_stream_get_device_ordinal(void* stream); | ||
WP_API int wp_cuda_pointer_get_device_ordinal(void* ptr); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ Verification inconclusive
Add CPU stubs to prevent link errors when CUDA is disabled
These symbols need counterparts in warp/native/warp.cpp (CPU builds). Please add no-op stubs returning -1.
Suggested stubs (in warp/native/warp.cpp):
WP_API int wp_cuda_stream_get_device_ordinal(void* /*stream*/) { return -1; }
WP_API int wp_cuda_pointer_get_device_ordinal(void* /*ptr*/) { return -1; }
🏁 Script executed:
#!/bin/bash
# Verify stubs exist in warp/native/warp.cpp
rg -nP '^\s*WP_API\s+int\s+wp_cuda_(stream|pointer)_get_device_ordinal\b' warp/native/warp.cpp || true
Length of output: 0
Add CPU stubs for device-ordinal helpers in warp/native/warp.cpp
Add no-op implementations to avoid link errors when CUDA is disabled:
WP_API int wp_cuda_stream_get_device_ordinal(void* /*stream*/) { return -1; }
WP_API int wp_cuda_pointer_get_device_ordinal(void* /*ptr*/) { return -1; }
🤖 Prompt for AI Agents
In warp/native/warp.h around lines 301 to 303, the header declares
wp_cuda_stream_get_device_ordinal and wp_cuda_pointer_get_device_ordinal but
there are no CPU-side definitions when CUDA is disabled, causing link errors;
add no-op CPU stubs in warp/native/warp.cpp that match the WP_API signatures and
return -1 (accepting the parameters but not using them) so the symbols are
defined in non-CUDA builds.
warp/tests/interop/test_jax.py
Outdated
def test_jax_ad_kernel_vmap_expand_dims_simple(test, device): | ||
import jax | ||
import jax.numpy as jp | ||
|
||
from warp.jax_experimental.ffi import jax_ad_kernel | ||
|
||
@wp.kernel | ||
def scale_sum_square_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float)): | ||
tid = wp.tid() | ||
c[tid] = (a[tid] * s + b[tid]) ** 2.0 | ||
|
||
jax_func = jax_ad_kernel( | ||
scale_sum_square_kernel, num_outputs=1, static_argnames=("s",), vmap_method="broadcast_all" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mismatch: test names say “expand_dims” but code uses vmap_method="broadcast_all"
To actually exercise expand_dims, set vmap_method="expand_dims".
- jax_func = jax_ad_kernel(
- scale_sum_square_kernel, num_outputs=1, static_argnames=("s",), vmap_method="broadcast_all"
- )
+ jax_func = jax_ad_kernel(
+ scale_sum_square_kernel, num_outputs=1, static_argnames=("s",), vmap_method="expand_dims"
+ )
- jax_func = jax_ad_kernel(add_one_2d, num_outputs=1, vmap_method="broadcast_all")
+ jax_func = jax_ad_kernel(add_one_2d, num_outputs=1, vmap_method="expand_dims")
Also applies to: 966-979
🧰 Tools
🪛 Ruff (0.13.3)
753-753: Unused function argument: test
(ARG001)
🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 753 to 767, the test name
indicates it should exercise the "expand_dims" vmap behavior but the call uses
vmap_method="broadcast_all"; change vmap_method to "expand_dims" to match the
test intent and ensure the kernel is exercised correctly; apply the same change
to the analogous test at lines 966-979.
TestJax, | ||
"test_jax_ad_kernel_pmap_multi_output", | ||
test_jax_ad_kernel_pmap_multi_output, | ||
devices=jax_compatible_cuda_devices, | ||
) | ||
|
||
add_function_test( | ||
TestJax, | ||
"test_jax_ad_kernel_pmap_multi_output", | ||
test_jax_ad_kernel_pmap_multi_output, | ||
devices=jax_compatible_cuda_devices, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test registrations
test_jax_ad_kernel_pmap_multi_output and test_jax_callable_pmap_multi_output_forward are added twice. Remove the duplicates.
- add_function_test(
- TestJax,
- "test_jax_ad_kernel_pmap_multi_output",
- test_jax_ad_kernel_pmap_multi_output,
- devices=jax_compatible_cuda_devices,
- )
...
- add_function_test(
- TestJax,
- "test_jax_callable_pmap_multi_output_forward",
- test_jax_callable_pmap_multi_output_forward,
- devices=jax_compatible_cuda_devices,
- )
Also applies to: 1640-1657
🧰 Tools
🪛 Ruff (0.13.3)
1626-1626: add_function_test
may be undefined, or defined from star imports
(F405)
🤖 Prompt for AI Agents
In warp/tests/interop/test_jax.py around lines 1620-1631 (and similarly around
1640-1657), the same tests are being registered twice (duplicate
add_function_test calls for test_jax_ad_kernel_pmap_multi_output and
test_jax_callable_pmap_multi_output_forward); remove the redundant duplicate
calls so each test is registered only once by deleting the repeated
add_function_test blocks, keeping a single registration per test and verifying
no other duplicate registrations remain in the surrounding lines.
8dc4fa7
to
10f6829
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
warp/jax_experimental/ffi.py (1)
1193-1200
: Fix warp-shape inference under vmap: slice from tail before dtype dims (not from head).Current get_warp_shape() takes the first dims, which are batch dims under vmap. This breaks launch_dims/output_dims inference and can cause OOB or shape mismatches.
Apply:
-def get_warp_shape(arg, dims): - if arg.dtype_ndim > 0: - # vector/matrix array - return dims[: arg.warp_ndim] - else: - # scalar array - return dims +def get_warp_shape(arg, dims): + # dims may include leading batch dims and trailing dtype dims (for vec/mat) + if arg.dtype_ndim > 0: + # vector/matrix array: pick the core warp dims immediately before dtype dims + core_rank = len(dims) - arg.dtype_ndim + start = max(0, core_rank - arg.warp_ndim) + end = core_rank + return dims[start:end] + else: + # scalar array: warp dims are the trailing dims + if arg.warp_ndim == 0: + return () + return dims[-arg.warp_ndim:]This corrects all call sites that rely on get_warp_shape(), e.g., Lines 235 and 545 when computing in-out output types and when inferring default launch/output dims.
♻️ Duplicate comments (2)
warp/jax_experimental/ffi.py (1)
750-752
: Terminology/API: decouple JIT statics from nondifferentiable args; prefer ‘nondiff_argnames’.The parameter static_argnames is used to (1) mark JIT-static args and (2) as nondiff_argnums in custom_vjp. These are conceptually distinct. Consider:
- Rename to nondiff_argnames for the AD aspect, and optionally keep a separate jit_static_argnames (or infer JIT statics via the user’s @jit usage).
- Update docs accordingly.
This mirrors JAX terminology (nondiff_argnums).
Also applies to: 964-966
warp/tests/interop/test_jax.py (1)
1412-1425
: Test intent mismatch: functions named “expand_dims” use vmap_method="sequential".Use vmap_method="expand_dims" to match the test names and actually exercise that path.
- jax_func = jax_kernel( - scale_sum_square_kernel, num_outputs=1, differentiable=True, static_argnames=("s",), vmap_method="sequential" - ) + jax_func = jax_kernel( + scale_sum_square_kernel, num_outputs=1, differentiable=True, static_argnames=("s",), vmap_method="expand_dims" + )- jax_func = jax_kernel(add_one_2d, num_outputs=1, differentiable=True, vmap_method="sequential") + jax_func = jax_kernel(add_one_2d, num_outputs=1, differentiable=True, vmap_method="expand_dims")Also applies to: 1498-1499
🧹 Nitpick comments (1)
warp/jax_experimental/ffi.py (1)
840-846
: Avoid blind except when zeroing adjoint inputs.Catching Exception and pass can hide real issues. Keep the isinstance guard and drop the try/except.
- try: - for gi in grad_in: - if isinstance(gi, wp.array): - gi.zero_() - except Exception: - pass + for gi in grad_in: + if isinstance(gi, wp.array): + gi.zero_()
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
docs/modules/interoperability.rst
(1 hunks)warp/jax_experimental/__init__.py
(1 hunks)warp/jax_experimental/ffi.py
(10 hunks)warp/tests/interop/test_jax.py
(17 hunks)
✅ Files skipped from review due to trivial changes (1)
- docs/modules/interoperability.rst
🚧 Files skipped from review as they are similar to previous changes (1)
- warp/jax_experimental/init.py
🧰 Additional context used
🧬 Code graph analysis (2)
warp/jax_experimental/ffi.py (2)
warp/context.py (4)
kernel
(1069-1141)func
(853-883)launch
(6111-6152)launch
(6155-6387)warp/types.py (5)
array
(2256-3547)array
(4442-4448)zero_
(3161-3168)zero_
(3730-3731)dtype
(4615-4621)
warp/tests/interop/test_jax.py (4)
warp/tests/interop/test_dlpack.py (1)
_jax_version
(28-34)warp/jax_experimental/ffi.py (1)
jax_kernel
(742-991)warp/jax.py (1)
device_to_jax
(19-41)warp/tests/unittest_utils.py (2)
assert_np_equal
(241-247)add_function_test
(284-303)
🪛 Ruff (0.13.3)
warp/jax_experimental/ffi.py
217-219: Avoid specifying long messages outside the exception class
(TRY003)
527-529: Avoid specifying long messages outside the exception class
(TRY003)
844-845: try
-except
-pass
detected, consider logging the exception
(S110)
844-844: Do not catch blind exception: Exception
(BLE001)
940-940: Do not catch blind exception: Exception
(BLE001)
948-949: try
-except
-pass
detected, consider logging the exception
(S110)
948-948: Do not catch blind exception: Exception
(BLE001)
warp/tests/interop/test_jax.py
445-445: Unused function argument: test
(ARG001)
470-470: Unused function argument: test
(ARG001)
499-499: Unused function argument: test
(ARG001)
528-528: Unused function argument: test
(ARG001)
548-548: Unused function argument: test
(ARG001)
571-571: Unused function argument: test
(ARG001)
663-663: Unused function argument: test
(ARG001)
696-696: Unused function argument: test
(ARG001)
729-729: Unused function argument: test
(ARG001)
850-850: Unused function argument: device
(ARG001)
882-882: Unused function argument: device
(ARG001)
929-929: Unused function argument: device
(ARG001)
1037-1037: Unused function argument: test
(ARG001)
1075-1075: assert_np_equal
may be undefined, or defined from star imports
(F405)
1076-1076: assert_np_equal
may be undefined, or defined from star imports
(F405)
1080-1080: Unused function argument: test
(ARG001)
1115-1115: assert_np_equal
may be undefined, or defined from star imports
(F405)
1116-1116: assert_np_equal
may be undefined, or defined from star imports
(F405)
1120-1120: Unused function argument: test
(ARG001)
1156-1156: assert_np_equal
may be undefined, or defined from star imports
(F405)
1157-1157: assert_np_equal
may be undefined, or defined from star imports
(F405)
1161-1161: Unused function argument: test
(ARG001)
1182-1182: Unused function argument: s
(ARG001)
1204-1204: assert_np_equal
may be undefined, or defined from star imports
(F405)
1205-1205: assert_np_equal
may be undefined, or defined from star imports
(F405)
1209-1209: Unused function argument: test
(ARG001)
1238-1238: assert_np_equal
may be undefined, or defined from star imports
(F405)
1242-1242: Unused function argument: test
(ARG001)
1267-1267: assert_np_equal
may be undefined, or defined from star imports
(F405)
1271-1271: Unused function argument: test
(ARG001)
1299-1299: assert_np_equal
may be undefined, or defined from star imports
(F405)
1303-1303: Unused function argument: test
(ARG001)
1336-1336: assert_np_equal
may be undefined, or defined from star imports
(F405)
1337-1337: assert_np_equal
may be undefined, or defined from star imports
(F405)
1341-1341: Unused function argument: test
(ARG001)
1368-1368: assert_np_equal
may be undefined, or defined from star imports
(F405)
1372-1372: Unused function argument: test
(ARG001)
1407-1407: assert_np_equal
may be undefined, or defined from star imports
(F405)
1408-1408: assert_np_equal
may be undefined, or defined from star imports
(F405)
1412-1412: Unused function argument: test
(ARG001)
1444-1444: assert_np_equal
may be undefined, or defined from star imports
(F405)
1445-1445: assert_np_equal
may be undefined, or defined from star imports
(F405)
1449-1449: Unused function argument: test
(ARG001)
1482-1482: assert_np_equal
may be undefined, or defined from star imports
(F405)
1483-1483: assert_np_equal
may be undefined, or defined from star imports
(F405)
1487-1487: Unused function argument: test
(ARG001)
1511-1511: assert_np_equal
may be undefined, or defined from star imports
(F405)
1515-1515: Unused function argument: test
(ARG001)
1546-1546: assert_np_equal
may be undefined, or defined from star imports
(F405)
1547-1547: assert_np_equal
may be undefined, or defined from star imports
(F405)
1551-1551: Unused function argument: device
(ARG001)
1588-1588: Unused function argument: device
(ARG001)
1619-1619: Unused function argument: device
(ARG001)
1668-1668: assert_np_equal
may be undefined, or defined from star imports
(F405)
1669-1669: assert_np_equal
may be undefined, or defined from star imports
(F405)
1673-1673: Unused function argument: device
(ARG001)
1708-1708: assert_np_equal
may be undefined, or defined from star imports
(F405)
1709-1709: assert_np_equal
may be undefined, or defined from star imports
(F405)
1803-1803: add_function_test
may be undefined, or defined from star imports
(F405)
1804-1804: add_function_test
may be undefined, or defined from star imports
(F405)
1807-1807: add_function_test
may be undefined, or defined from star imports
(F405)
1810-1810: add_function_test
may be undefined, or defined from star imports
(F405)
1813-1813: add_function_test
may be undefined, or defined from star imports
(F405)
1819-1819: add_function_test
may be undefined, or defined from star imports
(F405)
1825-1825: add_function_test
may be undefined, or defined from star imports
(F405)
1831-1831: add_function_test
may be undefined, or defined from star imports
(F405)
1839-1839: add_function_test
may be undefined, or defined from star imports
(F405)
1845-1845: add_function_test
may be undefined, or defined from star imports
(F405)
1851-1851: add_function_test
may be undefined, or defined from star imports
(F405)
1854-1854: add_function_test
may be undefined, or defined from star imports
(F405)
1860-1860: add_function_test
may be undefined, or defined from star imports
(F405)
1866-1866: add_function_test
may be undefined, or defined from star imports
(F405)
1872-1872: add_function_test
may be undefined, or defined from star imports
(F405)
1880-1880: add_function_test
may be undefined, or defined from star imports
(F405)
1882-1882: add_function_test
may be undefined, or defined from star imports
(F405)
1885-1885: add_function_test
may be undefined, or defined from star imports
(F405)
1888-1888: add_function_test
may be undefined, or defined from star imports
(F405)
1891-1891: add_function_test
may be undefined, or defined from star imports
(F405)
1897-1897: add_function_test
may be undefined, or defined from star imports
(F405)
1903-1903: add_function_test
may be undefined, or defined from star imports
(F405)
1909-1909: add_function_test
may be undefined, or defined from star imports
(F405)
1917-1917: add_function_test
may be undefined, or defined from star imports
(F405)
1923-1923: add_function_test
may be undefined, or defined from star imports
(F405)
1929-1929: add_function_test
may be undefined, or defined from star imports
(F405)
1932-1932: add_function_test
may be undefined, or defined from star imports
(F405)
1938-1938: add_function_test
may be undefined, or defined from star imports
(F405)
1944-1944: add_function_test
may be undefined, or defined from star imports
(F405)
1952-1952: add_function_test
may be undefined, or defined from star imports
(F405)
1954-1954: add_function_test
may be undefined, or defined from star imports
(F405)
1958-1958: add_function_test
may be undefined, or defined from star imports
(F405)
1964-1964: add_function_test
may be undefined, or defined from star imports
(F405)
1971-1971: add_function_test
may be undefined, or defined from star imports
(F405)
1978-1978: add_function_test
may be undefined, or defined from star imports
(F405)
1985-1985: add_function_test
may be undefined, or defined from star imports
(F405)
1992-1992: add_function_test
may be undefined, or defined from star imports
(F405)
1999-1999: add_function_test
may be undefined, or defined from star imports
(F405)
2005-2005: add_function_test
may be undefined, or defined from star imports
(F405)
2012-2012: add_function_test
may be undefined, or defined from star imports
(F405)
2018-2018: add_function_test
may be undefined, or defined from star imports
(F405)
2025-2025: add_function_test
may be undefined, or defined from star imports
(F405)
2032-2032: add_function_test
may be undefined, or defined from star imports
(F405)
2038-2038: add_function_test
may be undefined, or defined from star imports
(F405)
2041-2041: test_jax_ad_kernel_launch_dim_and_output_dims
may be undefined, or defined from star imports
(F405)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
warp/jax_experimental/ffi.py (2)
839-840
: Fix launch dim inference to handle non-array first arguments.Using
args[0].shape
assumes the first argument is always an array, which will fail if the first parameter is a scalar. This was flagged in a previous review but remains unresolved.Apply this diff to match the non-differentiable path logic:
- def fwd_kernel_wrapper(*args): - wp.launch(kernel, dim=args[0].shape, inputs=args[:num_inputs], outputs=args[num_inputs:]) + def fwd_kernel_wrapper(*args): + # Derive launch dims from the first array-typed input + launch_dim = None + for idx, p in enumerate(parameters[:num_inputs]): + if isinstance(p.annotation, wp.array): + launch_dim = args[idx].shape + break + if launch_dim is None: + raise RuntimeError("Unable to infer launch dims: no array inputs") + wp.launch(kernel, dim=launch_dim, inputs=args[:num_inputs], outputs=args[num_inputs:])
867-867
: Fix launch dim inference in backward wrapper.Same issue as the forward wrapper — using
inputs[0].shape
assumes the first input is an array.Apply this diff:
+ # Derive launch dims from the first array-typed input + launch_dim = None + for idx, p in enumerate(parameters[:num_inputs]): + if isinstance(p.annotation, wp.array): + launch_dim = inputs[idx].shape + break + if launch_dim is None: + raise RuntimeError("Unable to infer launch dims: no array inputs") + wp.launch( kernel, - dim=inputs[0].shape, + dim=launch_dim, inputs=inputs, outputs=outputs, adj_inputs=grad_in, adj_outputs=grad_out, adjoint=True, )
🧹 Nitpick comments (6)
warp/jax_experimental/ffi.py (6)
41-41
: Consider a more specific type hint for the cache key.The key type is
tuple
, but the actual structure is(kernel.func, kernel.sig, num_outputs, vmap_method, tuple[str, ...])
based on usage at lines 994 and 1003. A type alias would improve clarity:+# Type alias for differentiable kernel cache key +DiffKernelCacheKey = tuple[Callable, tuple, int, str, tuple[str, ...]] + -_FFI_DIFF_KERNEL_REGISTRY: dict[tuple, Callable] = {} +_FFI_DIFF_KERNEL_REGISTRY: dict[DiffKernelCacheKey, Callable] = {}
815-818
: LGTM! Clear limitation documented.The explicit check prevents unsupported usage of
in_out_argnames
in differentiable mode. Consider tracking this as a TODO or GitHub issue for future enhancement.Do you want me to open a GitHub issue to track support for
in_out_argnames
in differentiable mode?
858-863
: Improve error handling when zeroing gradient arrays.The blind
try-except-pass
silently swallows all exceptions, which could hide real issues. At minimum, verify the object is awp.array
before callingzero_()
.Apply this diff:
- try: - for gi in grad_in: - if isinstance(gi, wp.array): - gi.zero_() - except Exception: - pass + for gi in grad_in: + if isinstance(gi, wp.array): + try: + gi.zero_() + except Exception as e: + # Log or handle the exception appropriately + wp.utils.warn(f"Failed to zero gradient array: {e}", stacklevel=2)
917-924
: Simplify output handling logic.The nested conditions for single vs. multiple outputs are fragile. JAX's
custom_vjp
expectsfwd
to return(outputs, residuals)
whereoutputs
can be a single value or tuple. Simplify by always returning a tuple:def fwd_function(*args): outputs = jax_fwd_kernel(*args) non_static_inputs = list(args) for i in reversed(static_args): del non_static_inputs[i] - if num_outputs == 1: - if isinstance(outputs, (list, tuple)): - outputs_tuple = (outputs[0],) - else: - outputs_tuple = (outputs,) - else: - outputs_tuple = tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,) + # Normalize to tuple for consistent handling + if num_outputs == 1: + outputs_tuple = (outputs,) if not isinstance(outputs, (list, tuple)) else (outputs[0],) + else: + outputs_tuple = outputs if isinstance(outputs, tuple) else tuple(outputs) return outputs, (tuple(non_static_inputs), outputs_tuple)
937-947
: Simplify gradient output unpacking.Similar to the forward function, this logic for unpacking
grad_out_args
is complex. Consider normalizing to a tuple early:- if num_outputs > 1: - if len(grad_out_args) == 1 and isinstance(grad_out_args[0], (list, tuple)): - grad_out_tuple = tuple(grad_out_args[0]) - else: - grad_out_tuple = tuple(grad_out_args) - else: - go = grad_out_args[0] - if isinstance(go, (list, tuple)): - grad_out_tuple = (go[0],) - else: - grad_out_tuple = (go,) + # Normalize grad_out_args to tuple + if num_outputs == 1: + grad_out_tuple = (grad_out_args[0],) + else: + grad_out_tuple = tuple(grad_out_args) if not isinstance(grad_out_args[0], (list, tuple)) else tuple(grad_out_args[0])
958-967
: Avoid bare exception handling in attribute access.The nested try-except blocks with bare
Exception
catching (lines 958, 966) can hide bugs. Use specific attribute checks or document why broad catching is needed.ann = param_ann.get(name) if ann is None: continue - try: - is_array_ann = isinstance(ann, wp.array) - except Exception: - is_array_ann = False + # Check if annotation is a warp array type + is_array_ann = isinstance(ann, type) and issubclass(ann, wp.array) if not is_array_ann: continue dtype_ndim = 0 - try: - if hasattr(ann.dtype, "_wp_scalar_type_"): - dtype_ndim = len(ann.dtype._shape_) - except Exception: - pass + # Extract dtype ndim if it's a vector/matrix type + if hasattr(ann, 'dtype') and hasattr(ann.dtype, "_wp_scalar_type_"): + dtype_ndim = len(ann.dtype._shape_)
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
warp/jax_experimental/ffi.py
(15 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/jax_experimental/ffi.py (1)
warp/types.py (5)
strides_from_shape
(2174-2185)dtype
(4615-4621)array_t
(1697-1721)array
(2256-3547)array
(4442-4448)
🪛 Ruff (0.13.3)
warp/jax_experimental/ffi.py
217-219: Avoid specifying long messages outside the exception class
(TRY003)
537-539: Avoid specifying long messages outside the exception class
(TRY003)
862-863: try
-except
-pass
detected, consider logging the exception
(S110)
862-862: Do not catch blind exception: Exception
(BLE001)
958-958: Do not catch blind exception: Exception
(BLE001)
966-967: try
-except
-pass
detected, consider logging the exception
(S110)
966-966: Do not catch blind exception: Exception
(BLE001)
1218-1218: Avoid specifying long messages outside the exception class
(TRY003)
1225-1225: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (10)
warp/jax_experimental/ffi.py (10)
215-225
: LGTM! Correctly relaxed for vmap support.The updated validation allows leading batch dimensions under vmap while still validating the required trailing dtype dimensions. This is the correct approach for supporting batched operations.
535-545
: LGTM! Correctly relaxed for vmap support.Same validation logic as in
FfiKernel.__call__
— allows leading batch dimensions while validating trailing dtype dimensions.
1213-1226
: LGTM! Correctly extracts core dimensions.The logic properly handles both vector/matrix arrays (extracting middle dims between batch and dtype) and scalar arrays (extracting trailing dims excluding batch). The implementation aligns with the relaxed validation in
__call__
methods.
1239-1239
: LGTM! Aligns with batch dimension support.Changing to
ndim >= arg.jax_ndim
correctly allows for exact shape match or additional leading batch dimensions, consistent with the relaxed validation.
795-813
: LGTM! Non-differentiable path correctly preserved.The logic preserves the existing behavior when
differentiable=False
, using the cachedFfiKernel
approach with proper locking.
820-836
: LGTM! Signature introspection correctly implemented.The logic properly extracts the kernel signature, identifies static arguments from
static_argnames
, and ensures static arguments are only considered among inputs.
875-901
: LGTM! Backward signature construction is correct.The logic properly constructs the backward wrapper signature with input/output/gradient parameters, correctly excludes static arguments from gradient inputs, and follows JAX naming conventions.
982-1009
: LGTM! VJP registration and caching correctly implemented.The logic properly:
- Registers the custom VJP with correct
nondiff_argnums
- Applies
jax.jit
withstatic_argnames
when needed- Uses deterministic cache keys with sorted static argument names
Note: The cache key structure matches the type hint suggestion in the earlier comment on line 41.
343-350
: LGTM! Buffer shape extraction correctly handles batch and dtype dimensions.The logic properly computes the batch dimension offset and extracts only the core warp dimensions, excluding both leading batch dimensions and trailing dtype dimensions. This is consistent with the
get_warp_shape
utility function.Also applies to: 363-370
755-755
: Clarify vmap_method default inconsistency
The defaultvmap_method
differs betweenjax_kernel
("sequential"
) andjax_callable
("broadcast_all"
). Confirm this is intentional and either document the rationale or align the defaults across the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (1)
warp/jax_experimental/ffi.py (1)
955-1003
: Consider refactoring to reduce backward function complexity.The
bwd_function
handles multiple concerns: unpacking residuals, inserting static args, normalizing gradient tuples, and computing output dimensions. While functionally correct, the cognitive complexity is high.Consider extracting helper functions:
def _unpack_residuals(bwd_args, static_args): nondiff_vals = list(bwd_args[:len(static_args)]) residuals = bwd_args[len(static_args)] grad_out_args = bwd_args[len(static_args) + 1:] return nondiff_vals, residuals, grad_out_args def _normalize_grad_out_tuple(grad_out_args, num_outputs): if num_outputs == 1: go = grad_out_args[0] return tuple(go) if isinstance(go, (list, tuple)) else (go,) if len(grad_out_args) == 1 and isinstance(grad_out_args[0], (list, tuple)): return tuple(grad_out_args[0]) return tuple(grad_out_args) def _compute_output_dims_map(parameters, differentiable_input_names, non_static_inputs, num_inputs): out_dims_map = {} param_ann = {p.name: p.annotation for p in parameters[:num_inputs]} for name, val in zip(differentiable_input_names, non_static_inputs): ann = param_ann.get(name) if ann is None or not isinstance(ann, wp.array): continue # ... rest of logic return out_dims_mapThis would make
bwd_function
more readable by clearly separating its phases.
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
docs/modules/interoperability.rst
(1 hunks)warp/jax_experimental/ffi.py
(16 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/jax_experimental/ffi.py (1)
warp/types.py (5)
strides_from_shape
(2174-2185)dtype
(4615-4621)array_t
(1697-1721)array
(2256-3547)array
(4442-4448)
🪛 Ruff (0.13.3)
warp/jax_experimental/ffi.py
220-222: Avoid specifying long messages outside the exception class
(TRY003)
540-542: Avoid specifying long messages outside the exception class
(TRY003)
867-867: Avoid specifying long messages outside the exception class
(TRY003)
893-893: Do not catch blind exception: Exception
(BLE001)
1241-1241: Avoid specifying long messages outside the exception class
(TRY003)
1248-1248: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
warp/jax_experimental/ffi.py (1)
1236-1249
: Shape extraction logic correctly handles vmap batch dimensions.The
get_warp_shape
function properly extracts core Warp dimensions by excluding both leading batch dimensions (from vmap) and trailing dtype payload dimensions (for vector/matrix types). The implementation correctly handles the different cases for vector/matrix vs. scalar arrays.The logic is intricate but appears correct. Consider adding an inline example in a comment to aid future maintainers:
# Example: For wp.vec3 array under vmap with JAX shape (B, N, 3): # - warp_ndim=1, dtype_ndim=1, dims=(B, N, 3) # - core_end = 3 - 1 = 2, core_begin = 2 - 1 = 1 # - Returns dims[1:2] = (N,) — the core warp dimensiondocs/modules/interoperability.rst (1)
959-1119
: Comprehensive AD documentation with clear examples.The Automatic Differentiation section provides excellent coverage of the new
differentiable=True
functionality, including:
- Basic usage patterns with static arguments
- Multiple output handling
- Vector/matrix array support
- VMAP integration with per-sample gradients
- Launch dimension inference
- Troubleshooting guidance
Minor suggestions for enhancement:
Add performance note: Consider mentioning that
vmap_method="broadcast_all"
is more parallel than"sequential"
(the new default), which could guide users toward better performance when applicable.Clarify scalar limitation: The note "Scalars must be static in JAX" could be expanded to explain why (JAX tracing vs. FFI callback requirements).
Link to examples: Consider adding a reference to example files in the repo that demonstrate the AD functionality in practice.
docs/modules/interoperability.rst
Outdated
There are three options for ``vmap_method``: | ||
|
||
- ``"sequential"``: lowers to a :func:`jax.lax.scan`; most general, least parallel. | ||
- ``"broadcast_all"`` (default): assumes the kernel can handle broadcasted batch axes; best for data-parallel batches. | ||
- ``"expand_dims"``: expands a new leading axis per-batched argument; use only if your kernel semantics require it. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document the change in default vmap_method
for differentiable=True
.
The documentation lists three vmap_method
options but doesn't explicitly state that "sequential"
is now the default (changed from "broadcast_all"
). Since this is a breaking change from the non-differentiable path's default, it should be clearly documented.
Consider adding a note:
.. note::
When ``differentiable=True``, the default ``vmap_method`` is ``"sequential"`` for stability.
For the non-differentiable path, the default is ``"broadcast_all"``. If you need maximum
parallelism with AD, explicitly pass ``vmap_method="broadcast_all"``.
🤖 Prompt for AI Agents
In docs/modules/interoperability.rst around lines 1077 to 1082, the docstring
lists vmap_method options but doesn't state that the default changed to
"sequential" when differentiable=True; update the text to add a short note
explaining that for differentiable=True the default vmap_method is "sequential"
(changed from "broadcast_all" in the non-differentiable path), indicate that the
non-differentiable default remains "broadcast_all", and advise users to
explicitly pass vmap_method="broadcast_all" if they need maximum parallelism
with AD.
warp/jax_experimental/ffi.py
Outdated
kernel, | ||
num_outputs=1, | ||
vmap_method="broadcast_all", | ||
vmap_method="sequential", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Verify the breaking change in default vmap_method
.
The default vmap_method
changed from "broadcast_all"
to "sequential"
. This could break existing code that relies on the previous default behavior. "sequential"
lowers to jax.lax.scan
, which is less parallel than "broadcast_all"
.
Consider either:
- Reverting to
"broadcast_all"
as the default for backward compatibility, or - Documenting this breaking change prominently in release notes and migration guides.
Could you clarify the rationale for this change and confirm that existing users have been notified?
🏁 Script executed:
#!/bin/bash
set -e
rg -n "class FFI" -C3 warp/jax_experimental/ffi.py
rg -n "def __init__" -C5 warp/jax_experimental/ffi.py
rg -n -C2 "vmap_method" warp/jax_experimental/ffi.py
rg -n -C2 "broadcast_all" .
Length of output: 0
🏁 Script executed:
#!/bin/bash
set -e
echo "Locating ffi.py in the repository:"
fd -t f ffi.py
echo
echo "Searching for vmap_method definitions and defaults:"
rg -n "vmap_method" -C3 .
echo
echo "Searching for broadcast_all references:"
rg -n '"broadcast_all"' -C3 .
Length of output: 20460
Align vmap_method
default in code and docs
The default in jax_kernel
(warp/jax_experimental/ffi.py:758) was changed from "broadcast_all"
to "sequential"
, but docs (docs/modules/interoperability.rst) still state "broadcast_all"
as the default. This breaks backward compatibility. Either revert the default to "broadcast_all"
or update code, docs, and migration notes to reflect "sequential"
.
🤖 Prompt for AI Agents
In warp/jax_experimental/ffi.py around line 758, the vmap_method default was
changed to "sequential" which conflicts with docs and breaks backward
compatibility; revert the default back to "broadcast_all" (restore the original
default value at that line) and ensure any unit tests that assume the old
behavior still pass; if the change to "sequential" was intentional instead,
update docs/modules/interoperability.rst to state "sequential" as the new
default and add a short migration note explaining the behavioral change and
rationale.
FYI @mehdiataei, I'm planning to make the release candidate for Warp v1.10 sometime next Friday, October 17th, so the changes in this pull request need to be merged into the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
warp/tests/interop/test_jax.py (2)
1882-1914
: Remove duplicate test registrations.These test registrations duplicate those already added at lines 1803-1836. Running the same tests multiple times adds no value and may cause test collection issues.
Apply this diff to remove the duplicate block:
- add_function_test( - TestJax, "test_jax_kernel_sincos", test_jax_kernel_sincos, devices=jax_compatible_cuda_devices - ) - add_function_test( - TestJax, "test_jax_kernel_diagonal", test_jax_kernel_diagonal, devices=jax_compatible_cuda_devices - ) - add_function_test( - TestJax, "test_jax_kernel_in_out", test_jax_kernel_in_out, devices=jax_compatible_cuda_devices - ) - add_function_test( - TestJax, - "test_jax_kernel_scale_vec_constant", - test_jax_kernel_scale_vec_constant, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, - "test_jax_kernel_scale_vec_static", - test_jax_kernel_scale_vec_static, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, - "test_jax_kernel_launch_dims_default", - test_jax_kernel_launch_dims_default, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, - "test_jax_kernel_launch_dims_custom", - test_jax_kernel_launch_dims_custom, - devices=jax_compatible_cuda_devices, - )
1917-1949
: Remove duplicate test registrations.These test registrations duplicate those already added at lines 1839-1877. Delete this redundant block.
Apply this diff to remove the duplicate block:
- # ffi.jax_callable() tests - add_function_test( - TestJax, - "test_jax_callable_scale_constant", - test_jax_callable_scale_constant, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, - "test_jax_callable_scale_static", - test_jax_callable_scale_static, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, "test_jax_callable_in_out", test_jax_callable_in_out, devices=jax_compatible_cuda_devices - ) - add_function_test( - TestJax, - "test_jax_callable_graph_cache", - test_jax_callable_graph_cache, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, - "test_jax_callable_pmap_multi_output_forward", - test_jax_callable_pmap_multi_output_forward, - devices=jax_compatible_cuda_devices, - ) - add_function_test( - TestJax, - "test_jax_callable_pmap_multi_stage_forward", - test_jax_callable_pmap_multi_stage_forward, - devices=jax_compatible_cuda_devices, - ) - - # ffi callback tests - add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
warp/tests/interop/test_jax.py
(17 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/tests/interop/test_jax.py (4)
warp/tests/interop/test_dlpack.py (1)
_jax_version
(28-34)warp/jax_experimental/ffi.py (1)
jax_kernel
(755-1032)warp/jax.py (1)
device_to_jax
(19-41)warp/tests/unittest_utils.py (2)
assert_np_equal
(241-247)add_function_test
(284-303)
🪛 Ruff (0.13.3)
warp/tests/interop/test_jax.py
445-445: Unused function argument: test
(ARG001)
470-470: Unused function argument: test
(ARG001)
499-499: Unused function argument: test
(ARG001)
528-528: Unused function argument: test
(ARG001)
548-548: Unused function argument: test
(ARG001)
571-571: Unused function argument: test
(ARG001)
663-663: Unused function argument: test
(ARG001)
696-696: Unused function argument: test
(ARG001)
729-729: Unused function argument: test
(ARG001)
850-850: Unused function argument: device
(ARG001)
882-882: Unused function argument: device
(ARG001)
929-929: Unused function argument: device
(ARG001)
1037-1037: Unused function argument: test
(ARG001)
1075-1075: assert_np_equal
may be undefined, or defined from star imports
(F405)
1076-1076: assert_np_equal
may be undefined, or defined from star imports
(F405)
1080-1080: Unused function argument: test
(ARG001)
1115-1115: assert_np_equal
may be undefined, or defined from star imports
(F405)
1116-1116: assert_np_equal
may be undefined, or defined from star imports
(F405)
1120-1120: Unused function argument: test
(ARG001)
1156-1156: assert_np_equal
may be undefined, or defined from star imports
(F405)
1157-1157: assert_np_equal
may be undefined, or defined from star imports
(F405)
1161-1161: Unused function argument: test
(ARG001)
1182-1182: Unused function argument: s
(ARG001)
1204-1204: assert_np_equal
may be undefined, or defined from star imports
(F405)
1205-1205: assert_np_equal
may be undefined, or defined from star imports
(F405)
1209-1209: Unused function argument: test
(ARG001)
1238-1238: assert_np_equal
may be undefined, or defined from star imports
(F405)
1242-1242: Unused function argument: test
(ARG001)
1267-1267: assert_np_equal
may be undefined, or defined from star imports
(F405)
1271-1271: Unused function argument: test
(ARG001)
1299-1299: assert_np_equal
may be undefined, or defined from star imports
(F405)
1303-1303: Unused function argument: test
(ARG001)
1336-1336: assert_np_equal
may be undefined, or defined from star imports
(F405)
1337-1337: assert_np_equal
may be undefined, or defined from star imports
(F405)
1341-1341: Unused function argument: test
(ARG001)
1368-1368: assert_np_equal
may be undefined, or defined from star imports
(F405)
1372-1372: Unused function argument: test
(ARG001)
1407-1407: assert_np_equal
may be undefined, or defined from star imports
(F405)
1408-1408: assert_np_equal
may be undefined, or defined from star imports
(F405)
1412-1412: Unused function argument: test
(ARG001)
1444-1444: assert_np_equal
may be undefined, or defined from star imports
(F405)
1445-1445: assert_np_equal
may be undefined, or defined from star imports
(F405)
1449-1449: Unused function argument: test
(ARG001)
1482-1482: assert_np_equal
may be undefined, or defined from star imports
(F405)
1483-1483: assert_np_equal
may be undefined, or defined from star imports
(F405)
1487-1487: Unused function argument: test
(ARG001)
1511-1511: assert_np_equal
may be undefined, or defined from star imports
(F405)
1515-1515: Unused function argument: test
(ARG001)
1546-1546: assert_np_equal
may be undefined, or defined from star imports
(F405)
1547-1547: assert_np_equal
may be undefined, or defined from star imports
(F405)
1551-1551: Unused function argument: device
(ARG001)
1588-1588: Unused function argument: device
(ARG001)
1619-1619: Unused function argument: device
(ARG001)
1668-1668: assert_np_equal
may be undefined, or defined from star imports
(F405)
1669-1669: assert_np_equal
may be undefined, or defined from star imports
(F405)
1673-1673: Unused function argument: device
(ARG001)
1708-1708: assert_np_equal
may be undefined, or defined from star imports
(F405)
1709-1709: assert_np_equal
may be undefined, or defined from star imports
(F405)
1803-1803: add_function_test
may be undefined, or defined from star imports
(F405)
1804-1804: add_function_test
may be undefined, or defined from star imports
(F405)
1807-1807: add_function_test
may be undefined, or defined from star imports
(F405)
1810-1810: add_function_test
may be undefined, or defined from star imports
(F405)
1813-1813: add_function_test
may be undefined, or defined from star imports
(F405)
1819-1819: add_function_test
may be undefined, or defined from star imports
(F405)
1825-1825: add_function_test
may be undefined, or defined from star imports
(F405)
1831-1831: add_function_test
may be undefined, or defined from star imports
(F405)
1839-1839: add_function_test
may be undefined, or defined from star imports
(F405)
1845-1845: add_function_test
may be undefined, or defined from star imports
(F405)
1851-1851: add_function_test
may be undefined, or defined from star imports
(F405)
1854-1854: add_function_test
may be undefined, or defined from star imports
(F405)
1860-1860: add_function_test
may be undefined, or defined from star imports
(F405)
1866-1866: add_function_test
may be undefined, or defined from star imports
(F405)
1872-1872: add_function_test
may be undefined, or defined from star imports
(F405)
1880-1880: add_function_test
may be undefined, or defined from star imports
(F405)
1882-1882: add_function_test
may be undefined, or defined from star imports
(F405)
1885-1885: add_function_test
may be undefined, or defined from star imports
(F405)
1888-1888: add_function_test
may be undefined, or defined from star imports
(F405)
1891-1891: add_function_test
may be undefined, or defined from star imports
(F405)
1897-1897: add_function_test
may be undefined, or defined from star imports
(F405)
1903-1903: add_function_test
may be undefined, or defined from star imports
(F405)
1909-1909: add_function_test
may be undefined, or defined from star imports
(F405)
1917-1917: add_function_test
may be undefined, or defined from star imports
(F405)
1923-1923: add_function_test
may be undefined, or defined from star imports
(F405)
1929-1929: add_function_test
may be undefined, or defined from star imports
(F405)
1932-1932: add_function_test
may be undefined, or defined from star imports
(F405)
1938-1938: add_function_test
may be undefined, or defined from star imports
(F405)
1944-1944: add_function_test
may be undefined, or defined from star imports
(F405)
1952-1952: add_function_test
may be undefined, or defined from star imports
(F405)
1954-1954: add_function_test
may be undefined, or defined from star imports
(F405)
1958-1958: add_function_test
may be undefined, or defined from star imports
(F405)
1964-1964: add_function_test
may be undefined, or defined from star imports
(F405)
1971-1971: add_function_test
may be undefined, or defined from star imports
(F405)
1978-1978: add_function_test
may be undefined, or defined from star imports
(F405)
1985-1985: add_function_test
may be undefined, or defined from star imports
(F405)
1992-1992: add_function_test
may be undefined, or defined from star imports
(F405)
1999-1999: add_function_test
may be undefined, or defined from star imports
(F405)
2005-2005: add_function_test
may be undefined, or defined from star imports
(F405)
2012-2012: add_function_test
may be undefined, or defined from star imports
(F405)
2018-2018: add_function_test
may be undefined, or defined from star imports
(F405)
2025-2025: add_function_test
may be undefined, or defined from star imports
(F405)
2032-2032: add_function_test
may be undefined, or defined from star imports
(F405)
🔇 Additional comments (1)
warp/tests/interop/test_jax.py (1)
1036-1710
: LGTM: Comprehensive AD test coverage.The new automatic differentiation tests are well-structured with:
- Clear test organization covering single/multi-output, vmap, pmap, shard_map scenarios
- Proper data type coverage (float, vec2, mat22, 2D arrays)
- Correct reference gradient computations for validation
- Appropriate JAX version guards and device availability checks
- Consistent patterns that make the test suite maintainable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
warp/jax_experimental/ffi.py (2)
758-758
: Verify vmap_method default change is intentional.The default changed from
"broadcast_all"
to"sequential"
. A previous review comment (lines 758-758) noted this conflicts with documentation and breaks backward compatibility."sequential"
lowers tojax.lax.scan
, which is less parallel than"broadcast_all"
.Per past review: either revert to
"broadcast_all"
for backward compatibility, or update documentation (docs/modules/interoperability.rst) and add migration notes explaining the rationale for the change.
841-867
: Launch dims inference remains fragile.A previous review comment (lines 841-867) noted that starting with
call_args[0].shape
will fail when the first kernel argument is a scalar. While there's fallback logic, it adds unnecessary complexity.Per past review: remove the initial
args[0].shape
attempt and start directly with the array-finding loop:def _resolve_launch_dims(call_args): - s = getattr(call_args[0], "shape", None) - if s is not None: - return s param_ann = {p.name: p.annotation for p in parameters[:num_inputs]} for i in range(num_inputs):This simplifies the logic and handles scalar-first kernels correctly from the start.
🧹 Nitpick comments (1)
warp/jax_experimental/ffi.py (1)
218-228
: Consider extracting shape validation to a shared helper.The updated validation correctly allows leading batch dimensions under vmap while ensuring trailing dtype dimensions match. However, this logic is duplicated in both
FfiKernel.__call__
andFfiCallable.__call__
.Consider extracting to a helper function:
def validate_array_shape(input_arg: FfiArg, input_value, allow_batch_dims: bool = True): """Validate array shape, optionally allowing leading batch dimensions.""" if input_value.dtype != input_arg.jax_scalar_type: raise TypeError( f"Invalid data type for array argument '{input_arg.name}', " f"expected {input_arg.jax_scalar_type}, got {input_value.dtype}" ) if input_value.ndim < input_arg.jax_ndim: raise TypeError( f"Invalid dimensionality for array argument '{input_arg.name}', " f"expected at least {input_arg.jax_ndim} dimensions, got {input_value.ndim}" ) # Check inner (dtype) dims at the end for d in range(input_arg.dtype_ndim): if input_value.shape[-input_arg.dtype_ndim + d] != input_arg.dtype_shape[d]: raise TypeError( f"Invalid inner dimensions for array argument '{input_arg.name}', " f"expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim:]}" )Then use
validate_array_shape(input_arg, input_value)
in both locations.Also applies to: 538-548
📜 Review details
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
warp/jax_experimental/ffi.py
(16 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
warp/jax_experimental/ffi.py (2)
warp/types.py (5)
strides_from_shape
(2174-2185)dtype
(4615-4621)array_t
(1697-1721)array
(2256-3547)array
(4442-4448)warp/context.py (4)
kernel
(1069-1141)func
(853-883)launch
(6111-6152)launch
(6155-6387)
🪛 Ruff (0.13.3)
warp/jax_experimental/ffi.py
220-222: Avoid specifying long messages outside the exception class
(TRY003)
540-542: Avoid specifying long messages outside the exception class
(TRY003)
867-867: Avoid specifying long messages outside the exception class
(TRY003)
895-895: Use raise
without specifying exception name
Remove exception name
(TRY201)
1242-1242: Avoid specifying long messages outside the exception class
(TRY003)
1249-1249: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (7)
warp/jax_experimental/ffi.py (7)
18-18
: LGTM: Import and registry additions are well-structured.The
inspect
import enables runtime signature introspection for the differentiable path. TheDiffKernelCacheKey
type alias clearly documents the cache key structure, and the separate_FFI_DIFF_KERNEL_REGISTRY
prevents conflicts between differentiable and non-differentiable kernel paths.Also applies to: 33-34, 44-44
238-238
: LGTM: Correct usage ofget_warp_shape
.The calls to
get_warp_shape
correctly extract core Warp dimensions from JAX shapes that may include leading batch dimensions (from vmap) and trailing dtype dimensions (from vector/matrix types).Also applies to: 558-558
346-353
: LGTM: FFI callback correctly handles batch and dtype dimensions.The callback properly excludes leading batch dimensions (from vmap) and trailing dtype dimensions when constructing
array_t
for the kernel launch. The calculationbatch_ndim = max(0, rank - dtype_ndim - warp_ndim)
and the slicebuffer.dims[batch_ndim : batch_ndim + warp_ndim]
correctly extract the core Warp dimensions.Also applies to: 366-373
763-764
: LGTM: New differentiable parameters and constraints are well-documented.The
differentiable
flag andstatic_argnames
parameter enable the AD path. The restriction thatin_out_argnames
is not supported withdifferentiable=True
is clearly documented in both the docstring and enforced at runtime.Also applies to: 782-785, 818-821
907-1033
: LGTM: Custom VJP implementation follows JAX conventions.The differentiable path correctly implements JAX's
custom_vjp
API:
- Forward function returns
(outputs, residuals)
where residuals include non-static inputs and outputs- Backward function reconstructs the full input/output/gradient argument list for the adjoint kernel
- Static arguments are properly excluded from differentiation via
nondiff_argnums
- Differentiable wrappers are cached to avoid redundant construction
- Output dimensions are correctly inferred from input array annotations
The complexity is warranted for supporting full AD semantics with static arguments, vmap, and multi-output kernels.
1070-1072
: LGTM: Documentation and helper function updates are consistent.The
jax_callable
docstring update aligns withjax_kernel
, and theget_warp_shape
refactor correctly handles:
- Vector/matrix arrays: extracts core Warp dims positioned before trailing dtype dims
- Scalar arrays: extracts the last
warp_ndim
dims (excluding leading batch dims)The logic properly supports vmap's leading batch dimensions across both array types.
Also applies to: 1237-1250
1253-1276
: LGTM: Output type inference correctly supports batch dimensions.The updated conditions:
- Line 1263:
ndim >= arg.jax_ndim
allows leading batch dimensions while validating inner dtype dimensions- Line 1274:
ndim < arg.warp_ndim
correctly requires at least the core Warp dimensionsBoth changes consistently support vmap's leading batch dimensions in output arrays.
4fee379
to
e6fa8d7
Compare
Signed-off-by: Mehdi Ataei <[email protected]>
Add automatic differentiation support with jax_kernel (NVIDIAGH-515, NVIDIAGH-912) See merge request omniverse/warp!1682
This is now merged in b66f253. Thank you @mehdiataei for this great contribution! |
Description
Added JAX interop adjoint support using FFI. Also added support for static inputs (with auto-detection), vmap, multi-output with AD. Fixed some bugs related to vmap for the forward pass.
This PR also enables distributed forward and backward multi-GPU runs (see test cases) and an inverse kinematic example.
This is a new area for me, so there may be some oddities or poor decisions, but all tests pass. We may need to discuss any missing features and whether this is the way the user intends to use it.
This work is inspired by and thanks to J. Sevcik
https://gist.github.com/jaro-sevcik/893ffb891564dfc7617c8f12f3c2d1d2
Before your PR is "Ready for review"
__init__.pyi
,functions.rst
)pre-commit run -a
Summary by CodeRabbit
New Features
Documentation
API
Tests