Skip to content

Llama Performance Test #644

Llama Performance Test

Llama Performance Test #644

Workflow file for this run

name: CI
on:
push:
branches:
- master
- 'release/*'
pull_request:
branches:
- master
- 'release/*'
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
call-build-wheels:
strategy:
matrix:
rocm-version: ["7.0.0"]
uses: ./.github/workflows/build-wheels.yml
with:
# TODO: Add back Python 3.13 when we're ready to move to a more recent version of XLA. 3.13
# fails with a complaint abou the pipes module.
python-versions: "3.11,3.12"
rocm-version: ${{ matrix.rocm-version }}
call-build-docker:
needs: call-build-wheels
strategy:
matrix:
rocm-version: ["7.0.0"]
uses: ./.github/workflows/build-docker.yml
with:
rocm-version: ${{ matrix.rocm-version }}
run-python-unit-tests:
needs: call-build-docker
runs-on: mi-250
strategy:
matrix:
rocm-version: ["7.0.0"]
steps:
- name: Change owners for cleanup
run: |
docker run --rm -v "./:/rocm-jax" ubuntu \
/bin/bash -c "shopt -s dotglob; chown -R $UID /rocm-jax/* || true"
- name: Checkout plugin repo
uses: actions/checkout@v4
with:
clean: false
- name: Checkout JAX repo
uses: actions/checkout@v4
with:
# TODO: Load the ref from a file that sets the min and max JAX version
# TODO: Change the repo and ref once we figure out how exactly we're going to
# manage tests
clean: false
repository: rocm/jax
ref: rocm-jaxlib-v0.7.1
path: jax
- name: Authenticate to GitHub Container Registry
run: |
echo "${{ secrets.GITHUB_TOKEN }}" \
| docker login ghcr.io -u ${{ github.actor }} --password-stdin
- name: Check ROCm GPUs with rocm-smi
env:
ROCM_VERSION: ${{ matrix.rocm-version }}
run: |
docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \
"ghcr.io/rocm/jax-ubu22.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
rocm-smi -a || true
- name: Run tests
env:
GPU_COUNT: "8"
GFX: "gfx90a"
JAX_SKIP_SLOW_TESTS: "1"
ROCM_VERSION: ${{ matrix.rocm-version }}
# TODO: Add the tests/linalg_test.py test back once we fix the XLAClient thing.
run: |
python3 build/ci_build test \
"ghcr.io/rocm/jax-ubu22.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
--test-cmd "pytest -c ci/pytest_skips.ini jax/tests/core_test.py"