Skip to content

Commit c3dca56

Browse files
committed
Made the examples in doc runnable
1 parent 26e7ae5 commit c3dca56

File tree

1 file changed

+127
-12
lines changed

1 file changed

+127
-12
lines changed

docs/modules/interoperability.rst

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,20 @@ Here's an example of how to use `shard_map` with a Warp kernel:
483483
import jax
484484
import jax.numpy as jnp
485485
from jax.sharding import PartitionSpec as P
486+
from jax.experimental.multihost_utils import process_allgather as allgather
486487
from jax.experimental.shard_map import shard_map
487488
from warp.jax_experimental import jax_kernel
489+
import numpy as np
488490
489491
# Initialize JAX distributed environment
490492
jax.distributed.initialize()
493+
num_gpus = jax.device_count()
494+
495+
def print_on_process_0(*args, **kwargs):
496+
if jax.process_index() == 0:
497+
print(*args, **kwargs)
498+
499+
print_on_process_0(f"Running on {num_gpus} GPU(s)")
491500
492501
@wp.kernel
493502
def multiply_by_two_kernel(
@@ -499,18 +508,63 @@ Here's an example of how to use `shard_map` with a Warp kernel:
499508
500509
jax_warp_multiply = jax_kernel(multiply_by_two_kernel)
501510
511+
def warp_multiply(x):
512+
result = jax_warp_multiply(x)
513+
return result
514+
515+
# a_in here is the full sharded array with shape (M,)
516+
# The output will also be a sharded array with shape (M,)
502517
def warp_distributed_operator(a_in):
503518
def _sharded_operator(a_in):
504-
return jax_warp_multiply(a_in)[0]
505-
519+
# Inside the sharded operator, a_in is a local shard on each device
520+
# If we have N devices and input size M, each shard has shape (M/N,)
521+
522+
# warp_multiply applies the Warp kernel to the local shard
523+
result = warp_multiply(a_in)[0]
524+
525+
# result has the same shape as the input shard (M/N,)
526+
return result
527+
528+
# shard_map distributes the computation across devices
506529
return shard_map(
507530
_sharded_operator,
508531
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
509-
in_specs=(P("x"),),
510-
out_specs=P("x"),
532+
in_specs=(P("x"),), # Input is sharded along the 'x' axis
533+
out_specs=P("x"), # Output is also sharded along the 'x' axis
511534
check_rep=False,
512535
)(a_in)
513536
537+
print_on_process_0("Test distributed multiplication using JAX + Warp")
538+
539+
devices = jax.devices()
540+
mesh = jax.sharding.Mesh(np.array(devices), "x")
541+
sharding_spec = jax.sharding.NamedSharding(mesh, P("x"))
542+
543+
input_size = num_gpus * 5 # 5 elements per device
544+
single_device_arrays = jnp.arange(input_size, dtype=jnp.float32)
545+
546+
# Define the shape of the input array based on the total input size
547+
shape = (input_size,)
548+
549+
# Create a list of arrays by distributing the single_device_arrays across the available devices
550+
# Each device will receive a portion of the input data
551+
arrays = [
552+
jax.device_put(single_device_arrays[index], d) # Place each element on the corresponding device
553+
for d, index in sharding_spec.addressable_devices_indices_map(shape).items()
554+
]
555+
556+
# Combine the individual device arrays into a single sharded array
557+
sharded_array = jax.make_array_from_single_device_arrays(shape, sharding_spec, arrays)
558+
559+
# sharded_array has shape (input_size,) but is distributed across devices
560+
print_on_process_0(f"Input array: {allgather(sharded_array)}")
561+
562+
# warp_result has the same shape and sharding as sharded_array
563+
warp_result = warp_distributed_operator(sharded_array)
564+
565+
# allgather collects results from all devices, resulting in a full array of shape (input_size,)
566+
print_on_process_0("Warp Output:", allgather(warp_result))
567+
514568
In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output.
515569

516570
This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously.
@@ -529,15 +583,35 @@ In some cases, particularly for matrix operations, it's necessary to specify the
529583

530584
.. code-block:: python
531585
586+
import warp as wp
587+
import jax
588+
import jax.numpy as jnp
589+
from jax.sharding import PartitionSpec as P
590+
from jax.experimental.multihost_utils import process_allgather as allgather
591+
from jax.experimental.shard_map import shard_map
592+
from warp.jax_experimental import jax_kernel
593+
import numpy as np
594+
595+
jax.distributed.initialize()
596+
num_gpus = jax.device_count()
597+
598+
def print_on_process_0(*args, **kwargs):
599+
if jax.process_index() == 0:
600+
print(*args, **kwargs)
601+
602+
print_on_process_0(f"Running on {num_gpus} GPU(s)")
603+
532604
@wp.kernel
533605
def matmul_kernel(
534606
a: wp.array2d(dtype=wp.float32),
535607
b: wp.array2d(dtype=wp.float32),
536608
c: wp.array2d(dtype=wp.float32),
537609
):
610+
# a: (M/num_gpus, K), b: (K, N), c: (M/num_gpus, N)
538611
i, j = wp.tid()
539-
M, K = a.shape
540-
N = b.shape[1]
612+
M = a.shape[0] # M/num_gpus
613+
K = a.shape[1] # K
614+
N = b.shape[1] # N
541615
if i < M and j < N:
542616
s = wp.float32(0.0)
543617
for k in range(K):
@@ -546,27 +620,68 @@ In some cases, particularly for matrix operations, it's necessary to specify the
546620
547621
# Specify launch dimensions based on the number of GPUs
548622
def create_jax_warp_matmul(M, N):
549-
num_gpus = jax.device_count()
550-
block_size_m = M // num_gpus
551-
block_size_n = N
623+
# M: total rows, N: total columns
624+
block_size_m = M // num_gpus # Rows per GPU
625+
block_size_n = N # All columns
552626
return jax_kernel(matmul_kernel, launch_dims=(block_size_m, block_size_n))
553627
554628
def warp_distributed_matmul(a, b):
629+
# a: (M, K) sharded across GPUs, b: (K, N) replicated
555630
M, K = a.shape
556631
_, N = b.shape
557632
jax_warp_matmul = create_jax_warp_matmul(M, N)
558633
559634
def _sharded_operator(a_shard, b):
560-
return jax_warp_matmul(a_shard, b)
635+
# a_shard: (M/num_gpus, K), b: (K, N)
636+
return jax_warp_matmul(a_shard, b)[0] # Result: (M/num_gpus, N)
561637
562638
return shard_map(
563639
_sharded_operator,
564640
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
565-
in_specs=(P("x", None), P(None, None)),
566-
out_specs=P("x", None),
641+
in_specs=(P("x", None), P(None, None)), # a sharded in first dim, b replicated
642+
out_specs=P("x", None), # Output sharded in first dim
567643
check_rep=False,
568644
)(a, b)
569645
646+
print_on_process_0("Test distributed matrix multiplication using JAX + Warp")
647+
648+
# Define matrix dimensions
649+
M = 8 * num_gpus # Scale M with the number of devices
650+
K, N = 4, 6
651+
652+
# Create input matrices
653+
a = jnp.arange(M * K, dtype=jnp.float32).reshape(M, K) # Shape: (M, K)
654+
b = jnp.arange(K * N, dtype=jnp.float32).reshape(K, N) # Shape: (K, N)
655+
656+
devices = jax.devices()
657+
mesh = jax.sharding.Mesh(np.array(devices), "x")
658+
sharding_spec_a = jax.sharding.NamedSharding(mesh, P("x", None))
659+
sharding_spec_b = jax.sharding.NamedSharding(mesh, P(None, None))
660+
661+
# Shard matrix A and replicate matrix B
662+
sharded_a = jax.device_put(a, sharding_spec_a) # Sharded shape: (M/num_gpus, K) per device
663+
replicated_b = jax.device_put(b, sharding_spec_b) # Replicated shape: (K, N) on all devices
664+
665+
print_on_process_0(f"Input matrix A:\n{allgather(sharded_a)}") # Shape: (M, K)
666+
print_on_process_0(f"Input matrix B:\n{allgather(replicated_b)}") # Shape: (K, N)
667+
668+
warp_result = warp_distributed_matmul(sharded_a, replicated_b) # Sharded result: (M/num_gpus, N) per device
669+
print_on_process_0("Warp Output:")
670+
# Use allgather to collect results from all devices
671+
print_on_process_0(allgather(warp_result)) # Shape: (M, N)
672+
673+
jax_result = jnp.matmul(a, b) # Shape: (M, N)
674+
print_on_process_0("JAX Output:")
675+
print_on_process_0(jax_result)
676+
677+
expected_shape = (M, N)
678+
print_on_process_0(f"Expected shape: {expected_shape}")
679+
print_on_process_0(f"Warp output shape: {warp_result.shape}") # Should be (M/num_gpus, N) on each device
680+
print_on_process_0(f"JAX output shape: {jax_result.shape}") # Should be (M, N)
681+
682+
allclose = jnp.allclose(allgather(warp_result), jax_result, atol=1e-5)
683+
print_on_process_0(f"Allclose: {allclose}")
684+
570685
In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the global number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction.
571686

572687
Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix.

0 commit comments

Comments
 (0)