Stoa provides a lightweight, JAX-native interface for reinforcement learning environments. It defines a common abstraction layer that enables different environment libraries to work together seamlessly in JAX workflows.
⚠️ Early Development – Core abstractions are in place, but the library is still growing!
- Common Interface: A standardized
Environmentbase class that defines the contract for RL environments in JAX. - JAX-Native Design: Pure-functional
stepandresetoperations compatible with JAX transformations likejitandvmap. - Environment Wrappers: A flexible system for composing and extending environments with additional functionality.
- Space Definitions: Structured representations for observation, action, and state spaces.
- TimeStep Protocol: A standardized
TimeStepstructure to represent environment transitions with clear termination and truncation signals.
You can install the core stoa library via pip:
pip install stoa-envThis minimal installation includes the core API and wrappers but no specific environment adapters.
Adapters for external environment libraries are available as optional extras. You can install them individually or all at once.
Install a specific adapter:
# Example for Gymnax
pip install "stoa-env[gymnax]"
# Example for Brax
pip install "stoa-env[brax]"Install all available adapters:
pip install "stoa-env[all]"Stoa currently supports the following JAX-native environment libraries:
- Brax
- Gymnax
- Jumanji
- Kinetix
- Navix
- PGX (Game environments)
- MuJoCo Playground
- XLand-Minigrid
- Craftax (Gymnax Adapter)
Stoa provides a rich set of wrappers to modify and extend environment behavior:
- Core Wrappers:
AutoResetWrapper,RecordEpisodeMetrics,AddRNGKey,VmapWrapper. - Observation Wrappers:
FlattenObservationWrapper,FrameStackingWrapper,ObservationExtractWrapper,AddActionMaskWrapper,AddStartFlagAndPrevAction,AddStepCountWrapper,MakeChannelLast,ObservationTypeWrapper. - Action Space Wrappers:
MultiDiscreteToDiscreteWrapper,MultiBoundedToBoundedWrapper. - Utility Wrappers:
EpisodeStepLimitWrapper,ConsistentExtrasWrapper.
Here's how to adapt a gymnax environment and compose it with several wrappers:
import jax
import gymnax
from stoa import GymnaxToStoa, FlattenObservationWrapper, AutoResetWrapper, RecordEpisodeMetrics
# 1. Instantiate a base environment from a supported library
gymnax_env, env_params = gymnax.make("CartPole-v1")
# 2. Adapt the environment to the Stoa interface
env = GymnaxToStoa(gymnax_env, env_params)
# 3. Apply standard wrappers
# Note: The order of wrappers matters.
env = AutoResetWrapper(env, next_obs_in_extras=True)
env = RecordEpisodeMetrics(env)
# JIT compile the reset and step functions for performance
env.reset = jax.jit(env.reset)
env.step = jax.jit(env.step)
# 4. Interact with the environment
rng_key = jax.random.PRNGKey(0)
state, timestep = env.reset(rng_key)
total_reward = 0
for _ in range(100):
action = env.action_space().sample(rng_key)
state, timestep = env.step(state, action)
total_reward += timestep.reward
if timestep.last():
# Access metrics recorded by the RecordEpisodeMetrics wrapper
episode_return = timestep.extras['episode_metrics']['episode_return']
print(f"Episode finished. Return: {episode_return}")
# The state has been auto-reset, so we can continue the loop
total_reward = 0For more advanced usage, we refer to Stoix which has fully utilised Stoa for its environment abstractions.
- Documentation: Expand documentation with detailed tutorials and API references.
- More Wrappers: Add more common utility wrappers (e.g., observation normalization, reward clipping).
- Integration Examples: Provide examples of how to integrate
stoawith popular JAX-based RL libraries.
We're building Stoa to provide a common foundation for JAX-based RL research. Contributions are welcome!
- Stoix – Distributed single-agent RL in JAX
- Gymnax – Classic control environments in JAX
- Brax – Physics-based environments in JAX
- Jumanji – Board games and optimization problems in JAX
- Navix – Grid-world environments in JAX
- PGX - Classic board and card game environments in JAX
- Kinetix - Robotics environments in JAX
- Craftax - Crafter in JAX
If you use Stoa, please cite it!
@misc{toledo2025stoa,
author = {Edan Toledo},
title = {Stoa: A JAX-Native Interface for Reinforcement Learning Environments},
year = {2025},
url = {https://github.com/EdanToledo/Stoa}
}