-
Notifications
You must be signed in to change notification settings - Fork 377
Add device ordinal XLA plumbing #976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -792,6 +792,137 @@ def warp_func(inputs, outputs, attrs, ctx): | |||||||||||||||||||||||||||||||||||||
| assert_np_equal(d, 2 * np.arange(10, dtype=np.float32).reshape((5, 2))) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test") | ||||||||||||||||||||||||||||||||||||||
| def test_jax_callable_pmap_mul_forward(test, device): | ||||||||||||||||||||||||||||||||||||||
| import jax | ||||||||||||||||||||||||||||||||||||||
| import jax.numpy as jp | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from warp.jax_experimental.ffi import jax_callable | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if jax.local_device_count() < 2: | ||||||||||||||||||||||||||||||||||||||
| test.skipTest("requires >= 2 local devices") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @wp.kernel | ||||||||||||||||||||||||||||||||||||||
| def mul2(a: wp.array(dtype=float), out: wp.array(dtype=float)): | ||||||||||||||||||||||||||||||||||||||
| tid = wp.tid() | ||||||||||||||||||||||||||||||||||||||
| out[tid] = 2.0 * a[tid] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def mul2_py(a: wp.array(dtype=float), out: wp.array(dtype=float)): | ||||||||||||||||||||||||||||||||||||||
| wp.launch(mul2, dim=a.shape, inputs=[a], outputs=[out]) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| j = jax_callable(mul2_py, num_outputs=1) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| per_device = 8 | ||||||||||||||||||||||||||||||||||||||
| ndev = jax.local_device_count() | ||||||||||||||||||||||||||||||||||||||
| x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def per_device_fwd(v): | ||||||||||||||||||||||||||||||||||||||
| (y,) = j(v) | ||||||||||||||||||||||||||||||||||||||
| return y | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| y = jax.pmap(per_device_fwd)(x) | ||||||||||||||||||||||||||||||||||||||
| test.assertTrue(np.allclose(np.asarray(y), 2.0 * np.asarray(x), rtol=1e-5, atol=1e-6)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test") | ||||||||||||||||||||||||||||||||||||||
| def test_jax_callable_pmap_multi_output_forward(test, device): | ||||||||||||||||||||||||||||||||||||||
| import jax | ||||||||||||||||||||||||||||||||||||||
| import jax.numpy as jp | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from warp.jax_experimental.ffi import jax_callable | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if jax.local_device_count() < 2: | ||||||||||||||||||||||||||||||||||||||
| test.skipTest("requires >= 2 local devices") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+827
to
+836
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Fix version guard and unused arg here as well -@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_multi_output_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_multi_output_forward(test, _device):📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.13.1)828-828: Unused function argument: (ARG001) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
| @wp.kernel | ||||||||||||||||||||||||||||||||||||||
| def multi_out( | ||||||||||||||||||||||||||||||||||||||
| a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float) | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| tid = wp.tid() | ||||||||||||||||||||||||||||||||||||||
| c[tid] = a[tid] + b[tid] | ||||||||||||||||||||||||||||||||||||||
| d[tid] = s * a[tid] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def multi_out_py( | ||||||||||||||||||||||||||||||||||||||
| a: wp.array(dtype=float), | ||||||||||||||||||||||||||||||||||||||
| b: wp.array(dtype=float), | ||||||||||||||||||||||||||||||||||||||
| s: float, | ||||||||||||||||||||||||||||||||||||||
| c: wp.array(dtype=float), | ||||||||||||||||||||||||||||||||||||||
| d: wp.array(dtype=float), | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| wp.launch(multi_out, dim=a.shape, inputs=[a, b, s], outputs=[c, d]) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| j = jax_callable(multi_out_py, num_outputs=2) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| per_device = 7 | ||||||||||||||||||||||||||||||||||||||
| ndev = jax.local_device_count() | ||||||||||||||||||||||||||||||||||||||
| a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device)) | ||||||||||||||||||||||||||||||||||||||
| b = jp.ones((ndev, per_device), dtype=jp.float32) | ||||||||||||||||||||||||||||||||||||||
| s = 3.0 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def per_device_fwd(aa, bb): | ||||||||||||||||||||||||||||||||||||||
| c, d = j(aa, bb, s) | ||||||||||||||||||||||||||||||||||||||
| return c + d # simple combine to exercise both outputs | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| out = jax.pmap(per_device_fwd)(a, b) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device)) | ||||||||||||||||||||||||||||||||||||||
| b_np = np.ones((ndev, per_device), dtype=np.float32) | ||||||||||||||||||||||||||||||||||||||
| ref = (a_np + b_np) + s * a_np | ||||||||||||||||||||||||||||||||||||||
| test.assertTrue(np.allclose(np.asarray(out), ref, rtol=1e-5, atol=1e-6)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test") | ||||||||||||||||||||||||||||||||||||||
| def test_jax_callable_pmap_multi_stage_forward(test, device): | ||||||||||||||||||||||||||||||||||||||
| import jax | ||||||||||||||||||||||||||||||||||||||
| import jax.numpy as jp | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from warp.jax_experimental.ffi import jax_callable | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if jax.local_device_count() < 2: | ||||||||||||||||||||||||||||||||||||||
| test.skipTest("requires >= 2 local devices") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+874
to
+883
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Fix version guard and unused arg here as well -@unittest.skipUnless(_jax_version() >= (0, 4, 31), "Jax version too old for pmap forward test")
-def test_jax_callable_pmap_multi_stage_forward(test, device):
+@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old for pmap forward test")
+def test_jax_callable_pmap_multi_stage_forward(test, _device):📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.13.1)875-875: Unused function argument: (ARG001) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
| @wp.kernel | ||||||||||||||||||||||||||||||||||||||
| def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), out: wp.array(dtype=float)): | ||||||||||||||||||||||||||||||||||||||
| tid = wp.tid() | ||||||||||||||||||||||||||||||||||||||
| out[tid] = a[tid] + b[tid] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @wp.kernel | ||||||||||||||||||||||||||||||||||||||
| def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)): | ||||||||||||||||||||||||||||||||||||||
| tid = wp.tid() | ||||||||||||||||||||||||||||||||||||||
| out[tid] = alpha * x[tid] + y[tid] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def multi_stage_py( | ||||||||||||||||||||||||||||||||||||||
| a: wp.array(dtype=float), | ||||||||||||||||||||||||||||||||||||||
| b: wp.array(dtype=float), | ||||||||||||||||||||||||||||||||||||||
| alpha: float, | ||||||||||||||||||||||||||||||||||||||
| tmp: wp.array(dtype=float), | ||||||||||||||||||||||||||||||||||||||
| out: wp.array(dtype=float), | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp]) | ||||||||||||||||||||||||||||||||||||||
| wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out]) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| j = jax_callable(multi_stage_py, num_outputs=2) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| per_device = 9 | ||||||||||||||||||||||||||||||||||||||
| ndev = jax.local_device_count() | ||||||||||||||||||||||||||||||||||||||
| a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device)) | ||||||||||||||||||||||||||||||||||||||
| b = jp.ones((ndev, per_device), dtype=jp.float32) | ||||||||||||||||||||||||||||||||||||||
| alpha = 2.5 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def per_device_fwd(aa, bb): | ||||||||||||||||||||||||||||||||||||||
| tmp, out = j(aa, bb, alpha) | ||||||||||||||||||||||||||||||||||||||
| return tmp + out | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| combined = jax.pmap(per_device_fwd)(a, b) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device)) | ||||||||||||||||||||||||||||||||||||||
| b_np = np.ones((ndev, per_device), dtype=np.float32) | ||||||||||||||||||||||||||||||||||||||
| tmp_ref = a_np + b_np | ||||||||||||||||||||||||||||||||||||||
| out_ref = alpha * (a_np + b_np) + b_np | ||||||||||||||||||||||||||||||||||||||
| ref = tmp_ref + out_ref | ||||||||||||||||||||||||||||||||||||||
| test.assertTrue(np.allclose(np.asarray(combined), ref, rtol=1e-5, atol=1e-6)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| class TestJax(unittest.TestCase): | ||||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -940,6 +1071,25 @@ class TestJax(unittest.TestCase): | |||||||||||||||||||||||||||||||||||||
| # ffi callback tests | ||||||||||||||||||||||||||||||||||||||
| add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| add_function_test( | ||||||||||||||||||||||||||||||||||||||
| TestJax, | ||||||||||||||||||||||||||||||||||||||
| "test_jax_callable_pmap_multi_output_forward", | ||||||||||||||||||||||||||||||||||||||
| test_jax_callable_pmap_multi_output_forward, | ||||||||||||||||||||||||||||||||||||||
| devices=jax_compatible_cuda_devices, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| add_function_test( | ||||||||||||||||||||||||||||||||||||||
| TestJax, | ||||||||||||||||||||||||||||||||||||||
| "test_jax_callable_pmap_mul_forward", | ||||||||||||||||||||||||||||||||||||||
| test_jax_callable_pmap_mul_forward, | ||||||||||||||||||||||||||||||||||||||
| devices=jax_compatible_cuda_devices, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| add_function_test( | ||||||||||||||||||||||||||||||||||||||
| TestJax, | ||||||||||||||||||||||||||||||||||||||
| "test_jax_callable_pmap_multi_stage_forward", | ||||||||||||||||||||||||||||||||||||||
| test_jax_callable_pmap_multi_stage_forward, | ||||||||||||||||||||||||||||||||||||||
| devices=jax_compatible_cuda_devices, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||
| print(f"Skipping Jax tests due to exception: {e}") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Align version guard with jax_callable requirement; silence unused arg
These tests call warp.jax_experimental.ffi.jax_callable which requires JAX >= 0.5.0 via check_jax_version(). Current guard is 0.4.31, leading to failures on 0.4.31–0.4.99. Also silence the unused device arg.
🧰 Tools
🪛 Ruff (0.13.1)
796-796: Unused function argument:
device(ARG001)
🤖 Prompt for AI Agents