-
Notifications
You must be signed in to change notification settings - Fork 100
Description
These are the steps necessary (as of 3/24/2025) to build JAX using The Rock:
Set Environment Variables
# Environment variables;
# THE_ROCK_DIR points to the location of The Rock
# JAX_DIR points to the location of the Jax source
# XLA_DIR points to the location of the XLA source
export ROCM_PATH=$THE_ROCK_DIR
export PATH=$ROCM_PATH/bin:$ROCM_PATH/llvm/bin:$PATH
create and activate a python venv if needed
pip -m venv .venv-jax
source .venv-jax/bin/activate
# jax needs patchelf to build its wheels
pip install patchelf
Build and install roctracer into The Rock
The Rock comes with device libs but is missing a symlink needed for roctracer to build:
ln -s $ROCM_PATH/lib/llvm/amdgcn/ $ROCM_PATH/amdgcn
Build roctracer as follows:
# An active Python venv is assumed
pip install CppHeaderParser
git clone https://github.com/rocm/roctracer
cd roctracer
export HIP_PLATFORM=amd
# this will avoid building some targets that we don't need and make building faster
sed -i '/make mytest/d' build.sh
sed -i '/make package/d' build.sh
sed -i 's/^make/make -j \`nproc\`/' build.sh
# build and install
bash build.sh
cd build
make install
Build Jax
Jax currently uses the Bazel build system, which (as of version 6.5, the version used to build Jax) fails with a versioned libamdhip64.so.6. This can be solved by editing third_party/tsl/third_party/gpus/rocm/BUILD.tpl inside $XLA_PATH and changing
srcs = glob(["%{rocm_root}/lib/libamdhip*.so.6*"])
to
srcs = glob(["%{rocm_root}/lib/libamdhip*.so.6"])
around line 140
After this, wheels for the Jax binaries can be built with:
# we want a release build
export JAXLIB_RELEASE=1
cd $JAX_DIR
python build/build.py build --wheels jaxlib,jax-rocm-plugin,jax-rocm-pjrt --rocm_path $ROCM_PATH --local_xla_path=$XLA_PATH --rocm_amdgpu_targets gfx942
If on NFS, the build process will be much faster if you configure bazel to store its intermediate files on local storage, by adding
--bazel_startup_options='--output_user_root=/tmp/bazel-jax-build'
to the build command above
To build jax 0.4.35-qa, check out the appropriate branches of jax and xla and then run
export JAXLIB_RELEASE=1
python ./build/build.py --noenable_cuda --enable_rocm --build_gpu_kernel_plugin=rocm --build_gpu_plugin=true --build_gpu_pjrt_plugin=true --rocm_amdgpu_targets=gfx942 --bazel_options=--override_repository=xla="$XLA_PATH" --rocm_path="$ROCM_PATH"