Skip to content

[Documentation]: Instructions for building Jax binaries using The Rock (as of 3/24/2025) #247

@gabeweisz

Description

@gabeweisz

These are the steps necessary (as of 3/24/2025) to build JAX using The Rock:

Set Environment Variables

# Environment variables;
#  THE_ROCK_DIR points to the location of The Rock
#  JAX_DIR points to the location of the Jax source
#  XLA_DIR points to the location of the XLA source   

export ROCM_PATH=$THE_ROCK_DIR
export PATH=$ROCM_PATH/bin:$ROCM_PATH/llvm/bin:$PATH

create and activate a python venv if needed

pip -m venv .venv-jax
source .venv-jax/bin/activate
# jax needs patchelf to build its wheels
pip install patchelf

Build and install roctracer into The Rock

The Rock comes with device libs but is missing a symlink needed for roctracer to build:

ln -s $ROCM_PATH/lib/llvm/amdgcn/ $ROCM_PATH/amdgcn

Build roctracer as follows:

# An active Python venv is assumed
pip install CppHeaderParser
git clone https://github.com/rocm/roctracer
cd roctracer
export HIP_PLATFORM=amd
# this will avoid building some targets that we don't need and make building faster
sed -i '/make mytest/d' build.sh
sed -i '/make package/d' build.sh
sed -i 's/^make/make -j \`nproc\`/' build.sh

# build and install
bash build.sh
cd build
make install

Build Jax

Jax currently uses the Bazel build system, which (as of version 6.5, the version used to build Jax) fails with a versioned libamdhip64.so.6. This can be solved by editing third_party/tsl/third_party/gpus/rocm/BUILD.tpl inside $XLA_PATH and changing

srcs = glob(["%{rocm_root}/lib/libamdhip*.so.6*"])

to

srcs = glob(["%{rocm_root}/lib/libamdhip*.so.6"])

around line 140

After this, wheels for the Jax binaries can be built with:


# we want a release build
export JAXLIB_RELEASE=1

cd $JAX_DIR
python build/build.py build --wheels jaxlib,jax-rocm-plugin,jax-rocm-pjrt --rocm_path $ROCM_PATH --local_xla_path=$XLA_PATH --rocm_amdgpu_targets gfx942

If on NFS, the build process will be much faster if you configure bazel to store its intermediate files on local storage, by adding

--bazel_startup_options='--output_user_root=/tmp/bazel-jax-build'

to the build command above

To build jax 0.4.35-qa, check out the appropriate branches of jax and xla and then run

export JAXLIB_RELEASE=1
python ./build/build.py --noenable_cuda --enable_rocm --build_gpu_kernel_plugin=rocm --build_gpu_plugin=true --build_gpu_pjrt_plugin=true  --rocm_amdgpu_targets=gfx942 --bazel_options=--override_repository=xla="$XLA_PATH" --rocm_path="$ROCM_PATH"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions