@@ -792,6 +792,137 @@ def warp_func(inputs, outputs, attrs, ctx):
792792 assert_np_equal (d , 2 * np .arange (10 , dtype = np .float32 ).reshape ((5 , 2 )))
793793
794794
795+ @unittest .skipUnless (_jax_version () >= (0 , 4 , 31 ), "Jax version too old for pmap forward test" )
796+ def test_jax_callable_pmap_mul_forward (test , device ):
797+ import jax
798+ import jax .numpy as jp
799+
800+ from warp .jax_experimental .ffi import jax_callable
801+
802+ if jax .local_device_count () < 2 :
803+ test .skipTest ("requires >= 2 local devices" )
804+
805+ @wp .kernel
806+ def mul2 (a : wp .array (dtype = float ), out : wp .array (dtype = float )):
807+ tid = wp .tid ()
808+ out [tid ] = 2.0 * a [tid ]
809+
810+ def mul2_py (a : wp .array (dtype = float ), out : wp .array (dtype = float )):
811+ wp .launch (mul2 , dim = a .shape , inputs = [a ], outputs = [out ])
812+
813+ j = jax_callable (mul2_py , num_outputs = 1 )
814+
815+ per_device = 8
816+ ndev = jax .local_device_count ()
817+ x = jp .arange (ndev * per_device , dtype = jp .float32 ).reshape ((ndev , per_device ))
818+
819+ def per_device_fwd (v ):
820+ (y ,) = j (v )
821+ return y
822+
823+ y = jax .pmap (per_device_fwd )(x )
824+ test .assertTrue (np .allclose (np .asarray (y ), 2.0 * np .asarray (x ), rtol = 1e-5 , atol = 1e-6 ))
825+
826+
827+ @unittest .skipUnless (_jax_version () >= (0 , 4 , 31 ), "Jax version too old for pmap forward test" )
828+ def test_jax_callable_pmap_multi_output_forward (test , device ):
829+ import jax
830+ import jax .numpy as jp
831+
832+ from warp .jax_experimental .ffi import jax_callable
833+
834+ if jax .local_device_count () < 2 :
835+ test .skipTest ("requires >= 2 local devices" )
836+
837+ @wp .kernel
838+ def multi_out (
839+ a : wp .array (dtype = float ), b : wp .array (dtype = float ), s : float , c : wp .array (dtype = float ), d : wp .array (dtype = float )
840+ ):
841+ tid = wp .tid ()
842+ c [tid ] = a [tid ] + b [tid ]
843+ d [tid ] = s * a [tid ]
844+
845+ def multi_out_py (
846+ a : wp .array (dtype = float ),
847+ b : wp .array (dtype = float ),
848+ s : float ,
849+ c : wp .array (dtype = float ),
850+ d : wp .array (dtype = float ),
851+ ):
852+ wp .launch (multi_out , dim = a .shape , inputs = [a , b , s ], outputs = [c , d ])
853+
854+ j = jax_callable (multi_out_py , num_outputs = 2 )
855+
856+ per_device = 7
857+ ndev = jax .local_device_count ()
858+ a = jp .arange (ndev * per_device , dtype = jp .float32 ).reshape ((ndev , per_device ))
859+ b = jp .ones ((ndev , per_device ), dtype = jp .float32 )
860+ s = 3.0
861+
862+ def per_device_fwd (aa , bb ):
863+ c , d = j (aa , bb , s )
864+ return c + d # simple combine to exercise both outputs
865+
866+ out = jax .pmap (per_device_fwd )(a , b )
867+
868+ a_np = np .arange (ndev * per_device , dtype = np .float32 ).reshape ((ndev , per_device ))
869+ b_np = np .ones ((ndev , per_device ), dtype = np .float32 )
870+ ref = (a_np + b_np ) + s * a_np
871+ test .assertTrue (np .allclose (np .asarray (out ), ref , rtol = 1e-5 , atol = 1e-6 ))
872+
873+
874+ @unittest .skipUnless (_jax_version () >= (0 , 4 , 31 ), "Jax version too old for pmap forward test" )
875+ def test_jax_callable_pmap_multi_stage_forward (test , device ):
876+ import jax
877+ import jax .numpy as jp
878+
879+ from warp .jax_experimental .ffi import jax_callable
880+
881+ if jax .local_device_count () < 2 :
882+ test .skipTest ("requires >= 2 local devices" )
883+
884+ @wp .kernel
885+ def add_kernel (a : wp .array (dtype = float ), b : wp .array (dtype = float ), out : wp .array (dtype = float )):
886+ tid = wp .tid ()
887+ out [tid ] = a [tid ] + b [tid ]
888+
889+ @wp .kernel
890+ def axpy_kernel (x : wp .array (dtype = float ), y : wp .array (dtype = float ), alpha : float , out : wp .array (dtype = float )):
891+ tid = wp .tid ()
892+ out [tid ] = alpha * x [tid ] + y [tid ]
893+
894+ def multi_stage_py (
895+ a : wp .array (dtype = float ),
896+ b : wp .array (dtype = float ),
897+ alpha : float ,
898+ tmp : wp .array (dtype = float ),
899+ out : wp .array (dtype = float ),
900+ ):
901+ wp .launch (add_kernel , dim = a .shape , inputs = [a , b ], outputs = [tmp ])
902+ wp .launch (axpy_kernel , dim = a .shape , inputs = [tmp , b , alpha ], outputs = [out ])
903+
904+ j = jax_callable (multi_stage_py , num_outputs = 2 )
905+
906+ per_device = 9
907+ ndev = jax .local_device_count ()
908+ a = jp .arange (ndev * per_device , dtype = jp .float32 ).reshape ((ndev , per_device ))
909+ b = jp .ones ((ndev , per_device ), dtype = jp .float32 )
910+ alpha = 2.5
911+
912+ def per_device_fwd (aa , bb ):
913+ tmp , out = j (aa , bb , alpha )
914+ return tmp + out
915+
916+ combined = jax .pmap (per_device_fwd )(a , b )
917+
918+ a_np = np .arange (ndev * per_device , dtype = np .float32 ).reshape ((ndev , per_device ))
919+ b_np = np .ones ((ndev , per_device ), dtype = np .float32 )
920+ tmp_ref = a_np + b_np
921+ out_ref = alpha * (a_np + b_np ) + b_np
922+ ref = tmp_ref + out_ref
923+ test .assertTrue (np .allclose (np .asarray (combined ), ref , rtol = 1e-5 , atol = 1e-6 ))
924+
925+
795926class TestJax (unittest .TestCase ):
796927 pass
797928
@@ -940,6 +1071,25 @@ class TestJax(unittest.TestCase):
9401071 # ffi callback tests
9411072 add_function_test (TestJax , "test_ffi_callback" , test_ffi_callback , devices = jax_compatible_cuda_devices )
9421073
1074+ add_function_test (
1075+ TestJax ,
1076+ "test_jax_callable_pmap_multi_output_forward" ,
1077+ test_jax_callable_pmap_multi_output_forward ,
1078+ devices = jax_compatible_cuda_devices ,
1079+ )
1080+ add_function_test (
1081+ TestJax ,
1082+ "test_jax_callable_pmap_mul_forward" ,
1083+ test_jax_callable_pmap_mul_forward ,
1084+ devices = jax_compatible_cuda_devices ,
1085+ )
1086+ add_function_test (
1087+ TestJax ,
1088+ "test_jax_callable_pmap_multi_stage_forward" ,
1089+ test_jax_callable_pmap_multi_stage_forward ,
1090+ devices = jax_compatible_cuda_devices ,
1091+ )
1092+
9431093
9441094except Exception as e :
9451095 print (f"Skipping Jax tests due to exception: { e } " )
0 commit comments