Bump jax from 0.4.35 to 0.4.36 in /mpi4jax/_src (#270) #80
Workflow file for this run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Build CUDA extensions | |
on: | |
pull_request: | |
push: | |
branches: | |
- master | |
jobs: | |
build: | |
runs-on: ${{ matrix.os }} | |
strategy: | |
fail-fast: false | |
matrix: | |
include: | |
# 18.04 supports CUDA 10.1+ | |
- os: ubuntu-22.04 | |
cuda: "11.7" | |
- os: ubuntu-22.04 | |
cuda: "12.0" | |
- os: ubuntu-22.04 | |
cuda: "12.1" | |
- os: ubuntu-22.04 | |
cuda: "pypi" | |
steps: | |
- uses: actions/checkout@v4 | |
# make sure tags are fetched so we can get a version | |
- run: | | |
git fetch --prune --unshallow --tags | |
- name: Set up Python | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.11' | |
- name: Install CUDA | |
env: | |
cuda: ${{ matrix.cuda }} | |
run: | | |
if [[ "${cuda}" == 'pypi' ]]; then | |
echo "Installing jax[cuda] from PyPI" | |
pip install 'nvidia-cublas-cu12>=12.1.3.1' | |
pip install 'nvidia-cuda-cupti-cu12>=12.1.105' | |
pip install 'nvidia-cuda-nvcc-cu12>=12.1.105' | |
pip install 'nvidia-cuda-runtime-cu12>=12.1.105' | |
else | |
source ./conf/install-cuda-ubuntu.sh | |
if [[ $? -eq 0 ]]; then | |
# Set paths for subsequent steps, using ${CUDA_PATH} | |
echo "Adding CUDA to CUDA_PATH, PATH and LD_LIBRARY_PATH" | |
echo "CUDA_PATH=${CUDA_PATH}" >> $GITHUB_ENV | |
echo "${CUDA_PATH}/bin" >> $GITHUB_PATH | |
echo "LD_LIBRARY_PATH=${CUDA_PATH}/lib:${LD_LIBRARY_PATH}" >> $GITHUB_ENV | |
fi | |
fi | |
shell: bash | |
- name: Setup MPI (mpich) | |
uses: mpi4py/setup-mpi@v1 | |
with: | |
mpi: mpich | |
- name: Install dependencies | |
run: | | |
python -m pip install --upgrade pip | |
pip install setuptools wheel mpi4py cython | |
- name: Build GPU extensions | |
run: | | |
python setup.py build_ext --inplace | |
test -f mpi4jax/_src/xla_bridge/mpi_xla_bridge_cuda*.so |