Brax + Pufferlib + CARBS for gpu-accelerated robotics RL
-
Clone this repository.
git clone https://github.com/kywch/brax-trainer.gitThen, go to the directory.
cd brax-trainer -
Using pixi, setup the virutal environment and install the dependencies. Install pixi, if you haven't already. See pixi documentation for more details. The following command is for linux.
curl -fsSL https://pixi.sh/install.sh | bashThe following command sets up the virtual environment and installs the dependencies.
pixi installThen, activate the virtual environment.
pixi shellTo check if both pytorch and jax are installed correctly using cuda, run the following command.
pixi run test_torch pixi run test_jax -
Train a policy. In the virtual environment, run:
python brax_trainer/train.py -m trainOr, run:
pixi run train -m train -
Evaluate the trained policy.
python brax_trainer/train.py -m eval -p <path_to_model>To make a video of the trained policy, run:
python brax_trainer/train.py -m video -p <path_to_model>Try these with the pre-trained policy for the Ant env:
brax_ant_policy.pt. -
Sweep the hyperparameters using CARBS.
python brax_trainer/train.py -m sweep