Skip to content

Commit 2b3a7c8

Browse files
committed
Merge branch 'GH-310' into 'main'
GitHub PR (#310) Closes GH-310 See merge request omniverse/warp!730
2 parents e6a4eda + c3dca56 commit 2b3a7c8

File tree

3 files changed

+313
-16
lines changed

3 files changed

+313
-16
lines changed

docs/modules/interoperability.rst

Lines changed: 227 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,6 @@ Since this is an experimental feature, there are some limitations:
418418
- Kernel launch dimensions are inferred from the shape of the first argument.
419419
- Input arguments are followed by output arguments in the Warp kernel definition.
420420
- There must be at least one input argument and at least one output argument.
421-
- Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
422421
- All arrays must be contiguous.
423422
- Only the CUDA backend is supported.
424423

@@ -462,6 +461,233 @@ Here is an example of an operation with three inputs and two outputs::
462461
print(x)
463462
print(y)
464463

464+
Using shardmap for distributed computation
465+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
466+
467+
Warp can be used in conjunction with JAX's `shard_map <https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html>`_ to perform distributed multi-GPU computations.
468+
469+
To achieve this, the JAX distributed environment must be initialized (see `Distributed Arrays and Automatic Parallelization <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_ for more details):
470+
471+
.. code-block:: python
472+
473+
import jax
474+
jax.distributed.initialize()
475+
476+
This initialization must be called at the beginning of your program, before any other JAX operations.
477+
478+
Here's an example of how to use `shard_map` with a Warp kernel:
479+
480+
.. code-block:: python
481+
482+
import warp as wp
483+
import jax
484+
import jax.numpy as jnp
485+
from jax.sharding import PartitionSpec as P
486+
from jax.experimental.multihost_utils import process_allgather as allgather
487+
from jax.experimental.shard_map import shard_map
488+
from warp.jax_experimental import jax_kernel
489+
import numpy as np
490+
491+
# Initialize JAX distributed environment
492+
jax.distributed.initialize()
493+
num_gpus = jax.device_count()
494+
495+
def print_on_process_0(*args, **kwargs):
496+
if jax.process_index() == 0:
497+
print(*args, **kwargs)
498+
499+
print_on_process_0(f"Running on {num_gpus} GPU(s)")
500+
501+
@wp.kernel
502+
def multiply_by_two_kernel(
503+
a_in: wp.array(dtype=wp.float32),
504+
a_out: wp.array(dtype=wp.float32),
505+
):
506+
index = wp.tid()
507+
a_out[index] = a_in[index] * 2.0
508+
509+
jax_warp_multiply = jax_kernel(multiply_by_two_kernel)
510+
511+
def warp_multiply(x):
512+
result = jax_warp_multiply(x)
513+
return result
514+
515+
# a_in here is the full sharded array with shape (M,)
516+
# The output will also be a sharded array with shape (M,)
517+
def warp_distributed_operator(a_in):
518+
def _sharded_operator(a_in):
519+
# Inside the sharded operator, a_in is a local shard on each device
520+
# If we have N devices and input size M, each shard has shape (M/N,)
521+
522+
# warp_multiply applies the Warp kernel to the local shard
523+
result = warp_multiply(a_in)[0]
524+
525+
# result has the same shape as the input shard (M/N,)
526+
return result
527+
528+
# shard_map distributes the computation across devices
529+
return shard_map(
530+
_sharded_operator,
531+
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
532+
in_specs=(P("x"),), # Input is sharded along the 'x' axis
533+
out_specs=P("x"), # Output is also sharded along the 'x' axis
534+
check_rep=False,
535+
)(a_in)
536+
537+
print_on_process_0("Test distributed multiplication using JAX + Warp")
538+
539+
devices = jax.devices()
540+
mesh = jax.sharding.Mesh(np.array(devices), "x")
541+
sharding_spec = jax.sharding.NamedSharding(mesh, P("x"))
542+
543+
input_size = num_gpus * 5 # 5 elements per device
544+
single_device_arrays = jnp.arange(input_size, dtype=jnp.float32)
545+
546+
# Define the shape of the input array based on the total input size
547+
shape = (input_size,)
548+
549+
# Create a list of arrays by distributing the single_device_arrays across the available devices
550+
# Each device will receive a portion of the input data
551+
arrays = [
552+
jax.device_put(single_device_arrays[index], d) # Place each element on the corresponding device
553+
for d, index in sharding_spec.addressable_devices_indices_map(shape).items()
554+
]
555+
556+
# Combine the individual device arrays into a single sharded array
557+
sharded_array = jax.make_array_from_single_device_arrays(shape, sharding_spec, arrays)
558+
559+
# sharded_array has shape (input_size,) but is distributed across devices
560+
print_on_process_0(f"Input array: {allgather(sharded_array)}")
561+
562+
# warp_result has the same shape and sharding as sharded_array
563+
warp_result = warp_distributed_operator(sharded_array)
564+
565+
# allgather collects results from all devices, resulting in a full array of shape (input_size,)
566+
print_on_process_0("Warp Output:", allgather(warp_result))
567+
568+
In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output.
569+
570+
This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously.
571+
572+
To run this program on multiple GPUs, you must have OpenMPI installed. You can consult the `OpenMPI installation guide <https://docs.open-mpi.org/en/v5.0.x/installing-open-mpi/quickstart.html>`_ for instructions on how to install it. Once OpenMPI is installed, you can use `mpirun` with the following command:
573+
574+
.. code-block:: bash
575+
576+
mpirun -np <NUM_OF_GPUS> python <filename>.py
577+
578+
579+
Specifying launch dimensions for matrix operations
580+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
581+
582+
In some cases, particularly for matrix operations, it's necessary to specify the launch dimensions for Warp kernels. This is because the default behavior of inferring dimensions from the first argument may not always be suitable for matrix operations. Here's an example of a distributed matrix multiplication using Warp and JAX:
583+
584+
.. code-block:: python
585+
586+
import warp as wp
587+
import jax
588+
import jax.numpy as jnp
589+
from jax.sharding import PartitionSpec as P
590+
from jax.experimental.multihost_utils import process_allgather as allgather
591+
from jax.experimental.shard_map import shard_map
592+
from warp.jax_experimental import jax_kernel
593+
import numpy as np
594+
595+
jax.distributed.initialize()
596+
num_gpus = jax.device_count()
597+
598+
def print_on_process_0(*args, **kwargs):
599+
if jax.process_index() == 0:
600+
print(*args, **kwargs)
601+
602+
print_on_process_0(f"Running on {num_gpus} GPU(s)")
603+
604+
@wp.kernel
605+
def matmul_kernel(
606+
a: wp.array2d(dtype=wp.float32),
607+
b: wp.array2d(dtype=wp.float32),
608+
c: wp.array2d(dtype=wp.float32),
609+
):
610+
# a: (M/num_gpus, K), b: (K, N), c: (M/num_gpus, N)
611+
i, j = wp.tid()
612+
M = a.shape[0] # M/num_gpus
613+
K = a.shape[1] # K
614+
N = b.shape[1] # N
615+
if i < M and j < N:
616+
s = wp.float32(0.0)
617+
for k in range(K):
618+
s += a[i, k] * b[k, j]
619+
c[i, j] = s
620+
621+
# Specify launch dimensions based on the number of GPUs
622+
def create_jax_warp_matmul(M, N):
623+
# M: total rows, N: total columns
624+
block_size_m = M // num_gpus # Rows per GPU
625+
block_size_n = N # All columns
626+
return jax_kernel(matmul_kernel, launch_dims=(block_size_m, block_size_n))
627+
628+
def warp_distributed_matmul(a, b):
629+
# a: (M, K) sharded across GPUs, b: (K, N) replicated
630+
M, K = a.shape
631+
_, N = b.shape
632+
jax_warp_matmul = create_jax_warp_matmul(M, N)
633+
634+
def _sharded_operator(a_shard, b):
635+
# a_shard: (M/num_gpus, K), b: (K, N)
636+
return jax_warp_matmul(a_shard, b)[0] # Result: (M/num_gpus, N)
637+
638+
return shard_map(
639+
_sharded_operator,
640+
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
641+
in_specs=(P("x", None), P(None, None)), # a sharded in first dim, b replicated
642+
out_specs=P("x", None), # Output sharded in first dim
643+
check_rep=False,
644+
)(a, b)
645+
646+
print_on_process_0("Test distributed matrix multiplication using JAX + Warp")
647+
648+
# Define matrix dimensions
649+
M = 8 * num_gpus # Scale M with the number of devices
650+
K, N = 4, 6
651+
652+
# Create input matrices
653+
a = jnp.arange(M * K, dtype=jnp.float32).reshape(M, K) # Shape: (M, K)
654+
b = jnp.arange(K * N, dtype=jnp.float32).reshape(K, N) # Shape: (K, N)
655+
656+
devices = jax.devices()
657+
mesh = jax.sharding.Mesh(np.array(devices), "x")
658+
sharding_spec_a = jax.sharding.NamedSharding(mesh, P("x", None))
659+
sharding_spec_b = jax.sharding.NamedSharding(mesh, P(None, None))
660+
661+
# Shard matrix A and replicate matrix B
662+
sharded_a = jax.device_put(a, sharding_spec_a) # Sharded shape: (M/num_gpus, K) per device
663+
replicated_b = jax.device_put(b, sharding_spec_b) # Replicated shape: (K, N) on all devices
664+
665+
print_on_process_0(f"Input matrix A:\n{allgather(sharded_a)}") # Shape: (M, K)
666+
print_on_process_0(f"Input matrix B:\n{allgather(replicated_b)}") # Shape: (K, N)
667+
668+
warp_result = warp_distributed_matmul(sharded_a, replicated_b) # Sharded result: (M/num_gpus, N) per device
669+
print_on_process_0("Warp Output:")
670+
# Use allgather to collect results from all devices
671+
print_on_process_0(allgather(warp_result)) # Shape: (M, N)
672+
673+
jax_result = jnp.matmul(a, b) # Shape: (M, N)
674+
print_on_process_0("JAX Output:")
675+
print_on_process_0(jax_result)
676+
677+
expected_shape = (M, N)
678+
print_on_process_0(f"Expected shape: {expected_shape}")
679+
print_on_process_0(f"Warp output shape: {warp_result.shape}") # Should be (M/num_gpus, N) on each device
680+
print_on_process_0(f"JAX output shape: {jax_result.shape}") # Should be (M, N)
681+
682+
allclose = jnp.allclose(allgather(warp_result), jax_result, atol=1e-5)
683+
print_on_process_0(f"Allclose: {allclose}")
684+
685+
In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the global number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction.
686+
687+
Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix.
688+
689+
Note that this is a naive implementation of matrix multiplication for the sake of this illustration, and there are many optimizations that can be made to improve performance.
690+
465691
.. _DLPack:
466692

467693
DLPack

warp/jax_experimental.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,22 @@
2121
_registered_kernel_to_id = {}
2222

2323

24-
def jax_kernel(wp_kernel):
24+
def jax_kernel(wp_kernel, launch_dims=None):
2525
"""Create a Jax primitive from a Warp kernel.
2626
2727
NOTE: This is an experimental feature under development.
2828
29+
Args:
30+
wp_kernel: The Warp kernel to be wrapped.
31+
launch_dims: Optional. Specify the kernel launch dimensions. If None,
32+
dimensions are inferred from the shape of the first argument.
33+
This option when set will specify the output dimensions.
34+
2935
Current limitations:
3036
- All kernel arguments must be arrays.
31-
- Kernel launch dimensions are inferred from the shape of the first argument.
37+
- If launch_dims is not provided, kernel launch dimensions are inferred from the shape of the first argument.
3238
- Input arguments are followed by output arguments in the Warp kernel definition.
3339
- There must be at least one input argument and at least one output argument.
34-
- Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
3540
- All arrays must be contiguous.
3641
- Only the CUDA backend is supported.
3742
"""
@@ -47,7 +52,7 @@ def jax_kernel(wp_kernel):
4752
id = _registered_kernel_to_id[wp_kernel]
4853

4954
def bind(*args):
50-
return _jax_warp_p.bind(*args, kernel=id)
55+
return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims)
5156

5257
return bind
5358

@@ -106,7 +111,7 @@ def _get_jax_device():
106111
device = jax.config.jax_default_device
107112
# if default device is not set, use first device
108113
if device is None:
109-
device = jax.devices()[0]
114+
device = jax.local_devices()[0]
110115
return device
111116

112117

@@ -223,12 +228,17 @@ def base_type_is_compatible(warp_type, jax_ir_type):
223228
raise TypeError(f"Invalid or unsupported data type: {jax_ir_type}")
224229

225230
# Abstract evaluation.
226-
def jax_warp_abstract(*args, kernel=None):
231+
def jax_warp_abstract(*args, kernel=None, launch_dims=None):
227232
wp_kernel = _registered_kernels[kernel]
228233
# All the extra arguments to the warp kernel are outputs.
229234
warp_outputs = [o.type for o in wp_kernel.adj.args[len(args) :]]
230-
# TODO. Let's just use the first input dimension to infer the output's dimensions.
231-
dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))
235+
236+
if launch_dims is None:
237+
# Use the first input dimension to infer the output's dimensions if launch_dims is not provided
238+
dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))
239+
else:
240+
dims = launch_dims
241+
232242
jax_outputs = []
233243
for o in warp_outputs:
234244
shape = list(dims) + list(get_vecmat_shape(o))
@@ -260,7 +270,7 @@ def jax_warp_abstract(*args, kernel=None):
260270
def default_layout(shape):
261271
return range(len(shape) - 1, -1, -1)
262272

263-
def warp_call_lowering(ctx, *args, kernel=None):
273+
def warp_call_lowering(ctx, *args, kernel=None, launch_dims=None):
264274
if not kernel:
265275
raise Exception("Unknown kernel id " + str(kernel))
266276
wp_kernel = _registered_kernels[kernel]
@@ -272,12 +282,15 @@ def warp_call_lowering(ctx, *args, kernel=None):
272282
if not module.load(device):
273283
raise Exception("Could not load kernel on device")
274284

275-
# Infer dimensions from the first input.
276-
warp_arg0 = wp_kernel.adj.args[0]
277-
actual_shape0 = ir.RankedTensorType(args[0].type).shape
278-
dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
279-
warp_dims = collapse_into_leading_dimension(warp_arg0, dims)
280-
285+
if launch_dims is None:
286+
# Infer dimensions from the first input.
287+
warp_arg0 = wp_kernel.adj.args[0]
288+
actual_shape0 = ir.RankedTensorType(args[0].type).shape
289+
dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
290+
warp_dims = collapse_into_leading_dimension(warp_arg0, dims)
291+
else:
292+
dims = launch_dims
293+
warp_dims = launch_dims
281294
# Figure out the types and shapes of the input arrays.
282295
arg_strings = []
283296
operand_layouts = []

0 commit comments

Comments
 (0)