Skip to content

Commit 1e59cf7

Browse files
charleshoferJehandadKhanjakevdpJake VanderPlasGoogle-ML-Automation
authored
Prep for v0.5.1 (#295)
* Add workflow for nightly pull from upstream * Only run on weekdays * Fix yaml checker * Set runners for ROCM * Allow devs to kick off sync job manually (#119) * Unpin container in CI build and remove libssl-dev install * Rename the CI flow to 'ROCm CI' and only run it on PRs to rocm-main (#126) * Rename the CI flow to 'ROCm CI' and only run it on PRs to the rocm-main branch * Change name to 'ROCm CPU CI' * Fix nightly sync permissions (#124) * Add GHA workflow for opening PRs upstream (#116) * Add file for opening PRs upstream * Add HEAD_REF as environment variable * Fill out code for making a new branch and opening a PR to upstream * Add names for steps * Fix yaml * Fix yaml again * Leave a comment on the old PR linking to the new one * Add proper permissions for creating banches and opening PRs * Fix YAML * Create a new branch when merging upstream main to rocm-main (#128) * Fix upstream sync checkout (#130) * Checkout main before trying to switch to it * Fix the checkout command * Add git fetch (#132) * Fix debug_nans false positive in jnp.quantile * Remove some obsolete deprecation registrations PiperOrigin-RevId: 693793727 * Update XLA dependency to use revision http://github.com/openxla/xla/commit/0f6331b1881ae34c8b1cd59580900d556bc8305c. PiperOrigin-RevId: 693819727 * Adding start index and kv_seq_len to decode kernel * Add workflow for nightly pull from upstream * Only run on weekdays * Fix yaml checker * Set runners for ROCM * Allow devs to kick off sync job manually (#119) * Unpin container in CI build and remove libssl-dev install * Rename the CI flow to 'ROCm CI' and only run it on PRs to rocm-main (#126) * Rename the CI flow to 'ROCm CI' and only run it on PRs to the rocm-main branch * Change name to 'ROCm CPU CI' * Fix nightly sync permissions (#124) * Add GHA workflow for opening PRs upstream (#116) * Add file for opening PRs upstream * Add HEAD_REF as environment variable * Fill out code for making a new branch and opening a PR to upstream * Add names for steps * Fix yaml * Fix yaml again * Leave a comment on the old PR linking to the new one * Add proper permissions for creating banches and opening PRs * Fix YAML * Create a new branch when merging upstream main to rocm-main (#128) * Fix upstream sync checkout (#130) * Checkout main before trying to switch to it * Fix the checkout command * Fix FFI example test in CI * Add commit to see if it triggers CI * Make daily sync permissions at the workflow level and fix merge CI (#143) * Longer timeout for doc render * Fix upstream PR workflow to use origin branches (#151) * Add token for GitHub CLI (#152) * Change the workflow for opening upstream PRs to post links that open PRs (#157) * Add GH auth token to env * Make the job post a comment with a link to open the PR instead of actually opening the PR * Fix rebase command to exclude rocm-main (#158) * Fix user identity for rebase (#159) * Fix the link to the downstream PR (#160) * Use the reference format for links instead of inline (#162) * Update ci-build.yaml to use specific image * Update ci-build.yaml * Don't look for CUDA files when building the ROCm wheel (#173) * GH 9948: Automerge daily sync PRs (#181) * Run CPU CI again * Add upload wheels file for pypi (#184) * Change to trigger CI * Skip failing tests * Skip one more test * Add GPU CI (#137) * Commit to trigger CI * Add option to ci_build to run different tests * Only run core tests for CI * Quote test command in workflow file * Add dev guide (#188) * Use hipfft XLA fix * Skip PallasCallRemoteDMAInterpretTest.test_interpret_remote_dma_ppermute for failing on ROCm * Reduce pytest threads * Remove conflicting param for ci_build * Run GPU CI on PRs destined for QA branches (#228) * Change to make CI run * Use a GitHub app for syncing rocm-main and upstream main (#224) * Add CODEOWNERS file (#236) * Use bazel for PR tests (#216) * Use bazel for running pre-merge CI tests * Don't use HEREDOC * Fix block text * Use bash array * Add bazel install * Put Bazel in the build image * Use Bazelisk * Remove bazel install in Docker * Go back to upstream XLA * Remove bazel test command from workflow * Move test command to build container * Fix string format typos * Change CODEOWNERS (#237) * Install numa library * Fix numa package * Fix numactl-devel name * Fix auditwheel version issue (#288) Auditwheel 6.3.0 changed/removed the lddtree function so cap constraint to 6.2.x --------- Co-authored-by: JD <[email protected]> Co-authored-by: Jake VanderPlas <[email protected]> Co-authored-by: Jake VanderPlas <[email protected]> Co-authored-by: jax authors <[email protected]> Co-authored-by: Robert Dyro <[email protected]> Co-authored-by: GitHub Actions <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Ruturaj Vaidya <[email protected]> Co-authored-by: Mathew Odden <[email protected]>
1 parent 07440f4 commit 1e59cf7

14 files changed

+295
-41
lines changed

.github/CODEOWNERS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Require approvals from someone on the JAX team before PRs are merged
2+
* @ROCm/jax-devs
3+

.github/workflows/ci-build.yaml

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: CI
1+
name: ROCm CPU CI
22

33
# We test all supported Python versions as follows:
44
# - 3.10 : Documentation build
@@ -11,10 +11,10 @@ on:
1111
# but only for the main branch
1212
push:
1313
branches:
14-
- main
14+
- rocm-main
1515
pull_request:
1616
branches:
17-
- main
17+
- rocm-main
1818

1919
permissions:
2020
contents: read # to fetch code
@@ -42,12 +42,8 @@ jobs:
4242
- run: pre-commit run --show-diff-on-failure --color=always --all-files
4343

4444
build:
45-
# Don't execute in fork due to runner type
46-
if: github.repository == 'jax-ml/jax'
4745
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
48-
runs-on: linux-x86-n2-32
49-
container:
50-
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
46+
runs-on: ROCM-Ubuntu
5147
timeout-minutes: 60
5248
strategy:
5349
matrix:
@@ -65,10 +61,6 @@ jobs:
6561
num_generated_cases: 1
6662
steps:
6763
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
68-
- name: Image Setup
69-
run: |
70-
apt update
71-
apt install -y libssl-dev
7264
- name: Set up Python ${{ matrix.python-version }}
7365
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
7466
with:
@@ -95,12 +87,12 @@ jobs:
9587
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
9688
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
9789
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
98-
pytest -n auto --tb=short --maxfail=20 tests examples
90+
pytest -n 4 --tb=short --maxfail=20 tests examples
9991
10092
10193
documentation:
10294
name: Documentation - test code snippets
103-
runs-on: ubuntu-latest
95+
runs-on: ROCM-Ubuntu
10496
timeout-minutes: 10
10597
strategy:
10698
matrix:
@@ -128,19 +120,13 @@ jobs:
128120
129121
documentation_render:
130122
name: Documentation - render documentation
131-
runs-on: linux-x86-n2-16
132-
container:
133-
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
134-
timeout-minutes: 10
123+
runs-on: ubuntu-latest
124+
timeout-minutes: 20
135125
strategy:
136126
matrix:
137127
python-version: ['3.10']
138128
steps:
139129
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
140-
- name: Image Setup
141-
run: |
142-
apt update
143-
apt install -y libssl-dev libsqlite3-dev
144130
- name: Set up Python ${{ matrix.python-version }}
145131
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
146132
with:
@@ -193,9 +179,7 @@ jobs:
193179
194180
ffi:
195181
name: FFI example
196-
runs-on: linux-x86-g2-16-l4-1gpu
197-
container:
198-
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
182+
runs-on: ROCM-Ubuntu
199183
timeout-minutes: 30
200184
steps:
201185
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -206,7 +190,7 @@ jobs:
206190
- name: Install JAX
207191
run: |
208192
pip install uv
209-
uv pip install --system .[cuda12]
193+
uv pip install --system .
210194
- name: Build and install example project
211195
run: uv pip install --system ./examples/ffi[test]
212196
env:
@@ -215,10 +199,11 @@ jobs:
215199
# a different toolchain. GCC is the default compiler on the
216200
# 'ubuntu-latest' runner, but we still set this explicitly just to be
217201
# clear.
218-
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
202+
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
219203
- name: Run CPU tests
220204
run: python -m pytest examples/ffi/tests
221205
env:
222206
JAX_PLATFORM_NAME: cpu
223207
- name: Run GPU tests
224208
run: python -m pytest examples/ffi/tests
209+

.github/workflows/rocm-ci.yml

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
1-
name: ROCm GPU Post-Merge Check
1+
name: ROCm GPU CI
22

33
on:
4-
# Trigger the workflow after a push into the main branch
4+
# Trigger the workflow on push or pull request,
5+
# but only for the rocm-main branch
56
push:
67
branches:
7-
- main
8-
9-
permissions:
10-
contents: read
8+
- rocm-main
9+
- 'rocm-jaxlib-v*'
10+
pull_request:
11+
branches:
12+
- rocm-main
13+
- 'rocm-jaxlib-v*'
1114

1215
concurrency:
1316
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
17+
cancel-in-progress: true
1418

1519
jobs:
16-
build-jax-in-docker:
17-
runs-on: linux-x86_64-cirrascale-64-8gpu-amd-mi250
20+
build-jax-in-docker: # strategy and matrix come here
21+
runs-on: mi-250
1822
env:
1923
BASE_IMAGE: "ubuntu:22.04"
20-
TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
24+
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
2125
PYTHON_VERSION: "3.10"
2226
ROCM_VERSION: "6.2.4"
2327
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
@@ -32,6 +36,9 @@ jobs:
3236
ls
3337
- name: Print system info
3438
run: |
39+
whoami
40+
printenv
41+
df -h
3542
rocm-smi
3643
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
3744
with:
@@ -50,9 +57,11 @@ jobs:
5057
uses: actions/upload-artifact@v4
5158
with:
5259
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
53-
path: ${{ env.WORKSPACE_DIR }}/dist/*.whl
54-
retention-days: 2
60+
path: ./dist/*.whl
5561
- name: Run tests
62+
env:
63+
GPU_COUNT: "8"
64+
GFX: "gfx90a"
5665
run: |
5766
cd $WORKSPACE_DIR
5867
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Pulls the latest changes from upstream into main and opens a PR to merge
2+
# them into rocm-main branch.
3+
4+
name: ROCm Nightly Upstream Sync
5+
on:
6+
workflow_dispatch:
7+
schedule:
8+
- cron: '0 6 * * 1-5'
9+
permissions:
10+
contents: write
11+
pull-requests: write
12+
env:
13+
SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }}
14+
jobs:
15+
sync-main:
16+
runs-on: ubuntu-latest
17+
steps:
18+
- name: Generate an app token
19+
id: generate-token
20+
uses: actions/create-github-app-token@v1
21+
with:
22+
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
23+
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
24+
- name: Sync our main with upstream main
25+
run: |
26+
gh auth status
27+
gh repo sync rocm/jax -b main
28+
env:
29+
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
30+
create-sync-branch:
31+
needs: sync-main
32+
runs-on: ubuntu-latest
33+
env:
34+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
35+
steps:
36+
- name: Checkout code
37+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
38+
- name: Create branch
39+
run: |
40+
git fetch
41+
git checkout origin/main
42+
git checkout -b $SYNC_BRANCH_NAME
43+
# Try and merge rocm-main into this new branch so that we don't run upstream's CI code
44+
git config --global user.email "[email protected]"
45+
git config --global user.name "GitHub Actions"
46+
git merge origin/rocm-main || true
47+
# If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts
48+
git merge --abort || true
49+
git push origin HEAD
50+
open-sync-pr:
51+
needs: create-sync-branch
52+
runs-on: ubuntu-latest
53+
steps:
54+
- name: Generate an app token
55+
id: generate-token
56+
uses: actions/create-github-app-token@v1
57+
with:
58+
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
59+
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
60+
- name: Open a PR to rocm-main
61+
run: |
62+
gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream"
63+
gh pr merge --repo $GITHUB_REPOSITORY --merge --auto $SYNC_BRANCH_NAME
64+
env:
65+
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
66+
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: ROCm Open Upstream PR
2+
on:
3+
pull_request:
4+
types: [ labeled ]
5+
branches: [ rocm-main ]
6+
jobs:
7+
open-upstream:
8+
if: ${{ github.event.label.name == 'open-upstream' }}
9+
permissions:
10+
contents: write
11+
pull-requests: write
12+
runs-on: ubuntu-latest
13+
env:
14+
NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream"
15+
steps:
16+
- name: Checkout code
17+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
18+
- name: Rebase code to main
19+
run: |
20+
git config --global user.email "[email protected]"
21+
git config --global user.name "Github Actions"
22+
git fetch
23+
git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }}
24+
git rebase --onto origin/main origin/rocm-main
25+
# Force push here so that we don't run into conflicts with the origin branch
26+
git push origin HEAD --force
27+
- name: Leave link to create PR
28+
env:
29+
GH_TOKEN: ${{ github.token }}
30+
run: |
31+
# Bash is not friendly with newline characters, so make our own
32+
NL=$'\n'
33+
# Encode the PR title and body for passing as URL get parameters
34+
TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri')
35+
BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri')
36+
# Create a link to the that will open up a new PR form to upstream and autofill the fields
37+
CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC"
38+
# Add a comment with the link to the PR
39+
COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK"
40+
gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY"
41+

.github/workflows/upstream-nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ on:
2222

2323
jobs:
2424
upstream-dev:
25-
runs-on: ubuntu-latest
25+
runs-on: ROCM-Ubuntu
2626
permissions:
2727
contents: read
2828
issues: write # for failed-build-issue

build/rocm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
207207
### Simplified Build Script
208208

209209
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.
210+

build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM
99
# manylinux base image. However, adding this does fix an issue where Bazel isn't able
1010
# to find them.
1111
RUN --mount=type=cache,target=/var/cache/dnf \
12-
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64
12+
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel
1313

1414
RUN --mount=type=cache,target=/var/cache/dnf \
1515
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \

build/rocm/ci_build

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,31 @@ def dist_wheels(
127127
]
128128
)
129129

130+
# Add command for unit tests
131+
cmd.extend(
132+
[
133+
"&&",
134+
"bazel",
135+
"test",
136+
"-k",
137+
"--jobs=4",
138+
"--test_verbose_timeout_warnings=true",
139+
"--test_output=all",
140+
"--test_summary=detailed",
141+
"--local_test_jobs=1",
142+
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
143+
"--test_env=JAX_SKIP_SLOW_TESTS=0",
144+
"--verbose_failures=true",
145+
"--config=rocm",
146+
"--action_env=ROCM_PATH=/opt/rocm",
147+
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
148+
"--test_tag_filters=-multiaccelerator",
149+
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
150+
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
151+
"//tests:gpu_tests",
152+
]
153+
)
154+
130155
LOG.info("Running: %s", cmd)
131156
_ = subprocess.run(cmd, check=True)
132157

@@ -356,3 +381,4 @@ def main():
356381

357382
if __name__ == "__main__":
358383
main()
384+

build/rocm/tools/build_wheels.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ def fix_wheel(path, jax_path):
226226
py_bin = "/opt/python/cp310-cp310/bin"
227227
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
228228

229-
cmd = ["pip", "install", "auditwheel>=6"]
229+
# NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed
230+
# the fuction to ldd and also changed its behavior
231+
# constrain range to 6.0 to 6.2.x
232+
cmd = ["pip", "install", "auditwheel>=6,<6.3"]
230233
subprocess.run(cmd, check=True, env=env)
231234

232235
fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")

0 commit comments

Comments
 (0)