Caution
This library is still experimental and under development. Using it may lead to experiencing bugs or changing interfaces. If you encounter any bugs or other issues, please let us know via the issue tracker. If you are an RL developer and want to collaborate, feel free to contact us.
The implementation of this project follows the following principles:
- Algorithms are functions!
- Algorithms are implemented in single files.
- Policies and values functions are data containers.
- Our environment interface is Gymnasium.
- We use JAX for everything.
- We use Chex to write reliable code.
- For optimization algorithms we use Optax.
- For probability distributions we use TensorFlow Probability.
- For all neural networks we use Flax NNX.
- To save checkpoints we use Orbax.
The easiest way to install is via PyPI:
pip install rl-blox
Alternatively, e.g. if you want to develop extensions for the library, you can also install rl-blox from source:
git clone [email protected]:mlaux1/rl-blox.git
After cloning the repository, it is recommended to install the library in editable mode.
pip install -e .
To be able to run the provided examples use pip install 'rl-blox[examples]'
.
To install development dependencies, please use pip install 'rl-blox[dev]'
.
To enable logging with aim, please use pip install 'rl_blox[logging]'
You can install all optional dependencies (except logging) using pip install 'rl_blox[all]'
.
We currently provide implementations of the following algorithms (ordered from SotA to classic RL algorithms): MR.Q, TD7, TD3+LAP, PE-TS, SAC, TD3, DDPG, DDQN, DQN, double Q-learning, CMA-ES, Dyna-Q, actor-critic, REINFORCE, Q-learning, MC.
RL-BLOX relies on gymnasium's environment interface. This is an example with the SAC RL algorithm.
import gymnasium as gym
from rl_blox.algorithm.sac import create_sac_state, train_sac
from rl_blox.logging.checkpointer import OrbaxCheckpointer
from rl_blox.logging.logger import AIMLogger, LoggerList
env_name = "Pendulum-v1"
env = gym.make(env_name)
seed = 1
verbose = 1
env = gym.wrappers.RecordEpisodeStatistics(env)
hparams_models = dict(
policy_hidden_nodes=[128, 128],
policy_learning_rate=3e-4,
q_hidden_nodes=[512, 512],
q_learning_rate=1e-3,
seed=seed,
)
hparams_algorithm = dict(
total_timesteps=11_000,
buffer_size=11_000,
gamma=0.99,
learning_starts=5_000,
)
if verbose:
print(
"This example uses the AIM logger. You will not see any output on "
"stdout. Run 'aim up' to analyze the progress."
)
checkpointer = OrbaxCheckpointer("/tmp/rl-blox/sac_example/", verbose=verbose)
logger = LoggerList([
AIMLogger(),
# uncomment to store checkpoints
# checkpointer,
])
logger.define_experiment(
env_name=env_name,
algorithm_name="SAC",
hparams=hparams_models | hparams_algorithm,
)
logger.define_checkpoint_frequency("policy", 1_000)
sac_state = create_sac_state(env, **hparams_models)
sac_result = train_sac(
env,
sac_state.policy,
sac_state.policy_optimizer,
sac_state.q,
sac_state.q_optimizer,
logger=logger,
**hparams_algorithm,
)
env.close()
policy, _, q, _, _, _, _ = sac_result
# Do something with the trained policy...
You can build the sphinx documentation with
pip install -e '.[doc]'
cd doc
make html
The HTML documentation will be available under doc/build/html/index.html
.
If you wish to report bugs, please use the issue tracker. If you would like to contribute to RL-BLOX, just open an issue or a pull request. The target branch for merge requests is the development branch. The development branch will be merged to master for new releases. If you have questions about the software, you should ask them in the discussion section.
The recommended workflow to add a new feature, add documentation, or fix a bug is the following:
- Push your changes to a branch (e.g. feature/x, doc/y, or fix/z) of your fork of the RL-BLOX repository.
- Open a pull request to the main branch.
It is forbidden to directly push to the main branch.
Run the tests with
pip install -e '.[dev]'
pytest
Semantic versioning must be used, that is, the major version number will be incremented when the API changes in a backwards incompatible way, the minor version will be incremented when new functionality is added in a backwards compatible manner, and the patch version is incremented for bugfixes, documentation, etc.
This library is currently developed at the Robotics Group of the University of Bremen together with the Robotics Innovation Center of the German Research Center for Artificial Intelligence (DFKI) in Bremen.