You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/modules/interoperability.rst
+227-1Lines changed: 227 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -418,7 +418,6 @@ Since this is an experimental feature, there are some limitations:
418
418
- Kernel launch dimensions are inferred from the shape of the first argument.
419
419
- Input arguments are followed by output arguments in the Warp kernel definition.
420
420
- There must be at least one input argument and at least one output argument.
421
-
- Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
422
421
- All arrays must be contiguous.
423
422
- Only the CUDA backend is supported.
424
423
@@ -462,6 +461,233 @@ Here is an example of an operation with three inputs and two outputs::
462
461
print(x)
463
462
print(y)
464
463
464
+
Using shardmap for distributed computation
465
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
466
+
467
+
Warp can be used in conjunction with JAX's `shard_map <https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html>`_ to perform distributed multi-GPU computations.
468
+
469
+
To achieve this, the JAX distributed environment must be initialized (see `Distributed Arrays and Automatic Parallelization <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_ for more details):
470
+
471
+
.. code-block:: python
472
+
473
+
import jax
474
+
jax.distributed.initialize()
475
+
476
+
This initialization must be called at the beginning of your program, before any other JAX operations.
477
+
478
+
Here's an example of how to use `shard_map` with a Warp kernel:
479
+
480
+
.. code-block:: python
481
+
482
+
import warp as wp
483
+
import jax
484
+
import jax.numpy as jnp
485
+
from jax.sharding import PartitionSpec as P
486
+
from jax.experimental.multihost_utils import process_allgather as allgather
487
+
from jax.experimental.shard_map import shard_map
488
+
from warp.jax_experimental import jax_kernel
489
+
import numpy as np
490
+
491
+
# Initialize JAX distributed environment
492
+
jax.distributed.initialize()
493
+
num_gpus = jax.device_count()
494
+
495
+
defprint_on_process_0(*args, **kwargs):
496
+
if jax.process_index() ==0:
497
+
print(*args, **kwargs)
498
+
499
+
print_on_process_0(f"Running on {num_gpus} GPU(s)")
In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output.
569
+
570
+
This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously.
571
+
572
+
To run this program on multiple GPUs, you must have OpenMPI installed. You can consult the `OpenMPI installation guide <https://docs.open-mpi.org/en/v5.0.x/installing-open-mpi/quickstart.html>`_ for instructions on how to install it. Once OpenMPI is installed, you can use `mpirun` with the following command:
573
+
574
+
.. code-block:: bash
575
+
576
+
mpirun -np <NUM_OF_GPUS> python <filename>.py
577
+
578
+
579
+
Specifying launch dimensions for matrix operations
In some cases, particularly for matrix operations, it's necessary to specify the launch dimensions for Warp kernels. This is because the default behavior of inferring dimensions from the first argument may not always be suitable for matrix operations. Here's an example of a distributed matrix multiplication using Warp and JAX:
583
+
584
+
.. code-block:: python
585
+
586
+
import warp as wp
587
+
import jax
588
+
import jax.numpy as jnp
589
+
from jax.sharding import PartitionSpec as P
590
+
from jax.experimental.multihost_utils import process_allgather as allgather
591
+
from jax.experimental.shard_map import shard_map
592
+
from warp.jax_experimental import jax_kernel
593
+
import numpy as np
594
+
595
+
jax.distributed.initialize()
596
+
num_gpus = jax.device_count()
597
+
598
+
defprint_on_process_0(*args, **kwargs):
599
+
if jax.process_index() ==0:
600
+
print(*args, **kwargs)
601
+
602
+
print_on_process_0(f"Running on {num_gpus} GPU(s)")
In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the global number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction.
686
+
687
+
Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix.
688
+
689
+
Note that this is a naive implementation of matrix multiplication for the sake of this illustration, and there are many optimizations that can be made to improve performance.
0 commit comments