Our goal is to run jaxmarl_testdrive.py. This is a basic script that you should only be able to run if you have everything installed correctly, which should be the case by the end of this section. I've only tested this out on macOS with an M1 chip. I suggest using mamba instead of conda.
- Make a conda environment with
conda create -n siggame python=3.9 - Activate the environment with
conda activate siggame - Install jax. If you're on macOS, you can run this if you are using conda:
Or you can run this if you are using mamba (better):
conda install -c conda-forge jaxlib=0.4.19 conda install -c conda-forge jax
Verify that it is installed correctly by runningmamba install jaxlib=0.4.19 mamba install jax
python -c 'import jax; print(jax.numpy.arange(10))'You may get some warnings but you should end up with a list of 0-9. - Install JaxMARL. This cannot be done with pip, since we need to run the algorithms (See JaxMARL installation instructions for why).
Clone JaxMARL in a separate directory with:
git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL pip install -e .
- If you're on mac, you may still struggle to run
jaxmarl_testdrive.pybecause you'll get a MuJoCo error. However, we don't really care about MuJoCo, so you can just comment out the following lines in the JaxMARL package (I hate how dirty this is):# In /JaxMARL/jaxmarl/environments/__init__.py comment out line 20 ... from .overcooked import Overcooked, overcooked_layouts # from .mabrax import Ant, Humanoid, Hopper, Walker2d, HalfCheetah from .hanabi import HanabiGame ...
This should work for# In /JaxMARL/jaxmarl/registration.py comment out lines 19-23 ... HeuristicEnemySMAX, LearnedPolicyEnemySMAX, SwitchRiddle, # Ant, # Humanoid, # Hopper, # Walker2d, # HalfCheetah, InTheGrid, InTheGrid_2p, HanabiGame, ...
JaxMARL v0.0.2. If you're using a different version, you may need to find the lines yourself. - We also need torch and torchvision, which can be installed with mamba:
The conda analog would be something like (untested):
mamba install torchvision
conda install -c conda-forge torchvision
We also need jax-dataloader, which can be installed with pip:We no longer need jax-dataloader.pip install jax-dataloader
Finally, we should be able to run jaxmarl_testdrive.py with:
python jaxmarl_testdrive.pyFollow the Oscar Jax Instructions, which are copied below.
- Run the following:
module purge unset LD_LIBRARY_PATH module load cuda cudnn - In a separate directory, create a python virtual environment:
python -m venv jax.venv
- Activate it:
source jax.venv/bin/activate - Load jax:
Verify by running
pip install --upgrade pip pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlpython -c 'from jax.lib import xla_bridge;print(xla_bridge.get_backend().platform)'You should seegpu. - Install JaxMARL in a separate directory:
git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL pip install -e .
- Install torchvision (this may take a while):
pip install torchvision
You should be able to run jaxmarl_testdrive.py.
Check the readme.md in the base_experiment/ folder for more information on that experiment.