Skip to content

Conversation

amacati
Copy link
Contributor

@amacati amacati commented Jun 8, 2025

Description

This PR is a first shot at making spaces generic over all Array API compatible frameworks (e.g. numpy, jax, torch, ...).

The benefit of this change would be to make gymnasium compatible with the growing list of environments written in jax or torch. Strictly speaking, the ArrayConversion wrappers (e.g. JaxToTorch) are already outside of what we can currently express with gymnasium.spaces: There is no way to let the action and observation spaces reflect that we are expecting jax.Arrays as observations and torch.Tensors as inputs (safe for custom spaces which are rarely used).

This PR is also meant as a basis for discussion if things require breaking changes, new core dependencies (i.e. array_api_compat would become a core dependency) etc.

One fundamental change is that each box is now also linked with a device. This allows users to express that e.g. the observation is an array on the GPU. The example below shows how this impacts space.contains.

Open Challenges

  • space.sample: Sampling will require several changes. It does not make sense to seed a torch space with a numpy random generator
  • All tests can be rewritten to test against all Array API compatible frameworks. I would like to postpone this until we have a clearer picture of how things should be implemented though

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Example

With the changes in this PR, it is already possible to run the following:

import array_api_strict as xp
import jax.numpy as jp
import numpy as np
import torch

from gymnasium.spaces import Box


box = Box(low=np.zeros(3, dtype=np.float32), high=np.zeros(3, dtype=np.float32))
print(box)  # Box(0.0, 0.0, (3,), <class 'numpy.float32'>, cpu)

strict_box = Box(low=xp.zeros(3, dtype=xp.float32), high=xp.zeros(3, dtype=xp.float32))
print(strict_box)  # Box(Array(0., dtype=array_api_strict.float32), Array(0., dtype=array_api_strict.float32), (3,), array_api_strict.float32, array_api_strict.Device('CPU_DEVICE'))

torch_box = Box(low=torch.zeros(3), high=torch.ones(3))
print(torch_box)  # Box(tensor(0.), tensor(1.), (3,), torch.float32, cpu)

jax_box = Box(low=jp.zeros(3), high=jp.ones(3))
print(jax_box)  # Box(0.0, 1.0, (3,), <class 'jax.numpy.float32'>, cuda:0)

cpu_device = jax.devices("cpu")[0]
print(jax_box.contains(jp.array([0.5, 0.5, 0.5], device=cpu_device)))  # False, devices do not match
gpu_device = jax.devices("gpu")[0]
print(jax_box.contains(jp.array([0.5, 0.5, 0.5], device=gpu_device)))  # True

Checklist:

  • I have run the pre-commit checks with pre-commit run --all-files (see CONTRIBUTING.md instructions to set it up)
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@amacati amacati changed the base branch from main to array-api June 8, 2025 21:31
@pseudo-rnd-thoughts
Copy link
Member

@amacati I've updated the array-api to match main minimising the number of changes between them allow us to see the differences most easily

@amacati
Copy link
Contributor Author

amacati commented Jun 10, 2025

Thanks! Looking at the tests and numpy docs, numpy only supports the Array API standard since version 2.0. This should be the reason why build-necessary fails. We'd have to update that dependency, but bumping the numpy requirement is not a small thing. I remember that we got around that by making all Array API features optional when adding the ArrayConversion wrappers. If the Array API becomes part of the core spaces that would no longer work. What's your stance on this?

@pseudo-rnd-thoughts
Copy link
Member

pseudo-rnd-thoughts commented Jun 11, 2025

I'm starting to believe this is a sort of a Gymnasium 2.0 like feature change.
This would give us more liberties for changing the API where necessary though I wish to make as few as possible.

  • It would allow us to drop NumPy 1.x for 2.2+
  • Also we could simplify the Box space if we wished as currently it looks like a mound of technical debt to me.

In the end, I would want to update the whole project to support array-api, though we can of course start with Spaces only first.

Thoughts? If we doing this, it would be good to get some more voices involved to check their thoughts as well

@amacati
Copy link
Contributor Author

amacati commented Jun 11, 2025

I would definitely be on board with converting everything. The spaces were just a good module to start with. I also agree with the technical dept. Box could be simplified by a lot if low and high were always arrays.

This could very well be something for a gymnasium 2.0 release and it would be great to get some more voices from other people to see where the different interests lie at.

@Jammf
Copy link
Contributor

Jammf commented Jun 11, 2025

I agree that it'd make more sense as a 2.0 change, and a good opportunity to clean up the tech debt with Box with more dramatic changes. I also like the idea of having low and high be Arrays, constrained to having the same backing library, dtype, and device (and perhaps shape?). Then it'd be impossible to pass in values for low/high that are outside the bounds of the Box's dtype, so all that logic could be removed.

It's also make typing Box easier too, since then we could have something like:

Box(Space[T]):
    def __init__(self, low: T, high: T, seed: SeedType): ...
    def contains(x: T) -> bool: ...
    def sample() -> T: ...

with T having the appropriate bounds. Then, if numpy (or other libraries) ever get something shape typing, that should just work without any changes on our end.

@@ -354,6 +334,8 @@ def test_infinite_space(low, high, shape, dtype):


def test_legacy_state_pickling():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a legacy fix from like gym v0.21, I think we can remove this

@pseudo-rnd-thoughts pseudo-rnd-thoughts changed the title Add first draft of Array API Box [Array-API] Add support for Box space Jun 11, 2025
@amacati
Copy link
Contributor Author

amacati commented Jun 11, 2025

This would be more appropriate for a general discussion on what a 2.0 version should look like, but has there ever been a discussion about making Env a Protocol? The main utilities Env brings are numpy seeding and some unwrap functionalities, some of which would have to either yield to a more generalizable random backend or move out of the core Env if we make everything fully Array API generic. One benefit of a Protocol would be that envs written without inheritance of gymnasium.core.Env are still compatible with code that uses a potential gymnasium.api.Env Protocol for type checking (and even for instance checks with @runtime_checkable).

@Jammf
Copy link
Contributor

Jammf commented Jun 14, 2025

If all the implementation is removed from it, I think making Env a Protocol makes sense. That might also be a good time to add a RenderableEnv subprotocol to make rendering support optional, along the lines of what #842 was proposing.

@amacati
Copy link
Contributor Author

amacati commented Jun 16, 2025

If we want this to be more of a 2.0 feature I could also draft a version with significantly reduced complexity which would not be backward compatible. But it does not make much sense to iterate more on this PR before we are sure what direction this should take

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants