Skip to content

Bump jax from 0.4.35 to 0.4.36 in /mpi4jax/_src (#270) #80

Bump jax from 0.4.35 to 0.4.36 in /mpi4jax/_src (#270)

Bump jax from 0.4.35 to 0.4.36 in /mpi4jax/_src (#270) #80

Workflow file for this run

name: Build CUDA extensions
on:
pull_request:
push:
branches:
- master
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
# 18.04 supports CUDA 10.1+
- os: ubuntu-22.04
cuda: "11.7"
- os: ubuntu-22.04
cuda: "12.0"
- os: ubuntu-22.04
cuda: "12.1"
- os: ubuntu-22.04
cuda: "pypi"
steps:
- uses: actions/checkout@v4
# make sure tags are fetched so we can get a version
- run: |
git fetch --prune --unshallow --tags
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install CUDA
env:
cuda: ${{ matrix.cuda }}
run: |
if [[ "${cuda}" == 'pypi' ]]; then
echo "Installing jax[cuda] from PyPI"
pip install 'nvidia-cublas-cu12>=12.1.3.1'
pip install 'nvidia-cuda-cupti-cu12>=12.1.105'
pip install 'nvidia-cuda-nvcc-cu12>=12.1.105'
pip install 'nvidia-cuda-runtime-cu12>=12.1.105'
else
source ./conf/install-cuda-ubuntu.sh
if [[ $? -eq 0 ]]; then
# Set paths for subsequent steps, using ${CUDA_PATH}
echo "Adding CUDA to CUDA_PATH, PATH and LD_LIBRARY_PATH"
echo "CUDA_PATH=${CUDA_PATH}" >> $GITHUB_ENV
echo "${CUDA_PATH}/bin" >> $GITHUB_PATH
echo "LD_LIBRARY_PATH=${CUDA_PATH}/lib:${LD_LIBRARY_PATH}" >> $GITHUB_ENV
fi
fi
shell: bash
- name: Setup MPI (mpich)
uses: mpi4py/setup-mpi@v1
with:
mpi: mpich
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel mpi4py cython
- name: Build GPU extensions
run: |
python setup.py build_ext --inplace
test -f mpi4jax/_src/xla_bridge/mpi_xla_bridge_cuda*.so