@@ -483,11 +483,20 @@ Here's an example of how to use `shard_map` with a Warp kernel:
483
483
import jax
484
484
import jax.numpy as jnp
485
485
from jax.sharding import PartitionSpec as P
486
+ from jax.experimental.multihost_utils import process_allgather as allgather
486
487
from jax.experimental.shard_map import shard_map
487
488
from warp.jax_experimental import jax_kernel
489
+ import numpy as np
488
490
489
491
# Initialize JAX distributed environment
490
492
jax.distributed.initialize()
493
+ num_gpus = jax.device_count()
494
+
495
+ def print_on_process_0 (* args , ** kwargs ):
496
+ if jax.process_index() == 0 :
497
+ print (* args, ** kwargs)
498
+
499
+ print_on_process_0(f " Running on { num_gpus} GPU(s) " )
491
500
492
501
@wp.kernel
493
502
def multiply_by_two_kernel (
@@ -499,18 +508,63 @@ Here's an example of how to use `shard_map` with a Warp kernel:
499
508
500
509
jax_warp_multiply = jax_kernel(multiply_by_two_kernel)
501
510
511
+ def warp_multiply (x ):
512
+ result = jax_warp_multiply(x)
513
+ return result
514
+
515
+ # a_in here is the full sharded array with shape (M,)
516
+ # The output will also be a sharded array with shape (M,)
502
517
def warp_distributed_operator (a_in ):
503
518
def _sharded_operator (a_in ):
504
- return jax_warp_multiply(a_in)[0 ]
505
-
519
+ # Inside the sharded operator, a_in is a local shard on each device
520
+ # If we have N devices and input size M, each shard has shape (M/N,)
521
+
522
+ # warp_multiply applies the Warp kernel to the local shard
523
+ result = warp_multiply(a_in)[0 ]
524
+
525
+ # result has the same shape as the input shard (M/N,)
526
+ return result
527
+
528
+ # shard_map distributes the computation across devices
506
529
return shard_map(
507
530
_sharded_operator,
508
531
mesh = jax.sharding.Mesh(np.array(jax.devices()), " x" ),
509
- in_specs = (P(" x" ),),
510
- out_specs = P(" x" ),
532
+ in_specs = (P(" x" ),), # Input is sharded along the 'x' axis
533
+ out_specs = P(" x" ), # Output is also sharded along the 'x' axis
511
534
check_rep = False ,
512
535
)(a_in)
513
536
537
+ print_on_process_0(" Test distributed multiplication using JAX + Warp" )
538
+
539
+ devices = jax.devices()
540
+ mesh = jax.sharding.Mesh(np.array(devices), " x" )
541
+ sharding_spec = jax.sharding.NamedSharding(mesh, P(" x" ))
542
+
543
+ input_size = num_gpus * 5 # 5 elements per device
544
+ single_device_arrays = jnp.arange(input_size, dtype = jnp.float32)
545
+
546
+ # Define the shape of the input array based on the total input size
547
+ shape = (input_size,)
548
+
549
+ # Create a list of arrays by distributing the single_device_arrays across the available devices
550
+ # Each device will receive a portion of the input data
551
+ arrays = [
552
+ jax.device_put(single_device_arrays[index], d) # Place each element on the corresponding device
553
+ for d, index in sharding_spec.addressable_devices_indices_map(shape).items()
554
+ ]
555
+
556
+ # Combine the individual device arrays into a single sharded array
557
+ sharded_array = jax.make_array_from_single_device_arrays(shape, sharding_spec, arrays)
558
+
559
+ # sharded_array has shape (input_size,) but is distributed across devices
560
+ print_on_process_0(f " Input array: { allgather(sharded_array)} " )
561
+
562
+ # warp_result has the same shape and sharding as sharded_array
563
+ warp_result = warp_distributed_operator(sharded_array)
564
+
565
+ # allgather collects results from all devices, resulting in a full array of shape (input_size,)
566
+ print_on_process_0(" Warp Output:" , allgather(warp_result))
567
+
514
568
In this example, `shard_map ` is used to distribute the computation across available devices. The input array `a_in ` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel ` is applied to each shard, and the results are combined to form the final output.
515
569
516
570
This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously.
@@ -529,15 +583,35 @@ In some cases, particularly for matrix operations, it's necessary to specify the
529
583
530
584
.. code-block :: python
531
585
586
+ import warp as wp
587
+ import jax
588
+ import jax.numpy as jnp
589
+ from jax.sharding import PartitionSpec as P
590
+ from jax.experimental.multihost_utils import process_allgather as allgather
591
+ from jax.experimental.shard_map import shard_map
592
+ from warp.jax_experimental import jax_kernel
593
+ import numpy as np
594
+
595
+ jax.distributed.initialize()
596
+ num_gpus = jax.device_count()
597
+
598
+ def print_on_process_0 (* args , ** kwargs ):
599
+ if jax.process_index() == 0 :
600
+ print (* args, ** kwargs)
601
+
602
+ print_on_process_0(f " Running on { num_gpus} GPU(s) " )
603
+
532
604
@wp.kernel
533
605
def matmul_kernel (
534
606
a : wp.array2d(dtype = wp.float32),
535
607
b : wp.array2d(dtype = wp.float32),
536
608
c : wp.array2d(dtype = wp.float32),
537
609
):
610
+ # a: (M/num_gpus, K), b: (K, N), c: (M/num_gpus, N)
538
611
i, j = wp.tid()
539
- M, K = a.shape
540
- N = b.shape[1 ]
612
+ M = a.shape[0 ] # M/num_gpus
613
+ K = a.shape[1 ] # K
614
+ N = b.shape[1 ] # N
541
615
if i < M and j < N:
542
616
s = wp.float32(0.0 )
543
617
for k in range (K):
@@ -546,27 +620,68 @@ In some cases, particularly for matrix operations, it's necessary to specify the
546
620
547
621
# Specify launch dimensions based on the number of GPUs
548
622
def create_jax_warp_matmul (M , N ):
549
- num_gpus = jax.device_count()
550
- block_size_m = M // num_gpus
551
- block_size_n = N
623
+ # M: total rows, N: total columns
624
+ block_size_m = M // num_gpus # Rows per GPU
625
+ block_size_n = N # All columns
552
626
return jax_kernel(matmul_kernel, launch_dims = (block_size_m, block_size_n))
553
627
554
628
def warp_distributed_matmul (a , b ):
629
+ # a: (M, K) sharded across GPUs, b: (K, N) replicated
555
630
M, K = a.shape
556
631
_, N = b.shape
557
632
jax_warp_matmul = create_jax_warp_matmul(M, N)
558
633
559
634
def _sharded_operator (a_shard , b ):
560
- return jax_warp_matmul(a_shard, b)
635
+ # a_shard: (M/num_gpus, K), b: (K, N)
636
+ return jax_warp_matmul(a_shard, b)[0 ] # Result: (M/num_gpus, N)
561
637
562
638
return shard_map(
563
639
_sharded_operator,
564
640
mesh = jax.sharding.Mesh(np.array(jax.devices()), " x" ),
565
- in_specs = (P(" x" , None ), P(None , None )),
566
- out_specs = P(" x" , None ),
641
+ in_specs = (P(" x" , None ), P(None , None )), # a sharded in first dim, b replicated
642
+ out_specs = P(" x" , None ), # Output sharded in first dim
567
643
check_rep = False ,
568
644
)(a, b)
569
645
646
+ print_on_process_0(" Test distributed matrix multiplication using JAX + Warp" )
647
+
648
+ # Define matrix dimensions
649
+ M = 8 * num_gpus # Scale M with the number of devices
650
+ K, N = 4 , 6
651
+
652
+ # Create input matrices
653
+ a = jnp.arange(M * K, dtype = jnp.float32).reshape(M, K) # Shape: (M, K)
654
+ b = jnp.arange(K * N, dtype = jnp.float32).reshape(K, N) # Shape: (K, N)
655
+
656
+ devices = jax.devices()
657
+ mesh = jax.sharding.Mesh(np.array(devices), " x" )
658
+ sharding_spec_a = jax.sharding.NamedSharding(mesh, P(" x" , None ))
659
+ sharding_spec_b = jax.sharding.NamedSharding(mesh, P(None , None ))
660
+
661
+ # Shard matrix A and replicate matrix B
662
+ sharded_a = jax.device_put(a, sharding_spec_a) # Sharded shape: (M/num_gpus, K) per device
663
+ replicated_b = jax.device_put(b, sharding_spec_b) # Replicated shape: (K, N) on all devices
664
+
665
+ print_on_process_0(f " Input matrix A: \n { allgather(sharded_a)} " ) # Shape: (M, K)
666
+ print_on_process_0(f " Input matrix B: \n { allgather(replicated_b)} " ) # Shape: (K, N)
667
+
668
+ warp_result = warp_distributed_matmul(sharded_a, replicated_b) # Sharded result: (M/num_gpus, N) per device
669
+ print_on_process_0(" Warp Output:" )
670
+ # Use allgather to collect results from all devices
671
+ print_on_process_0(allgather(warp_result)) # Shape: (M, N)
672
+
673
+ jax_result = jnp.matmul(a, b) # Shape: (M, N)
674
+ print_on_process_0(" JAX Output:" )
675
+ print_on_process_0(jax_result)
676
+
677
+ expected_shape = (M, N)
678
+ print_on_process_0(f " Expected shape: { expected_shape} " )
679
+ print_on_process_0(f " Warp output shape: { warp_result.shape} " ) # Should be (M/num_gpus, N) on each device
680
+ print_on_process_0(f " JAX output shape: { jax_result.shape} " ) # Should be (M, N)
681
+
682
+ allclose = jnp.allclose(allgather(warp_result), jax_result, atol = 1e-5 )
683
+ print_on_process_0(f " Allclose: { allclose} " )
684
+
570
685
In this example, we create a function `create_jax_warp_matmul ` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count() ` to get the global number of GPUs and divide the `M ` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N ` dimension (columns) remains unchanged as we're not sharding in that direction.
571
686
572
687
Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m ` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n ` is set to the full width of the output matrix.
0 commit comments