-
Notifications
You must be signed in to change notification settings - Fork 117
Open
Description
Describe the bug
jax.interpreters.xla.backend_specific_translations
is deprecated in jax v0.4.29
https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-29-june-10-2024
This causes the following error when running in xla mode AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.
To Reproduce
import envpool
env = envpool.make("Breakout-v5")
env.xla()
handle, recv, send, step = env.xla()
File "/lib/python3.10/site-packages/envpool/python/lax.py", line 30, in xla
_handle, _recv, _send = make_xla(self)
File "/lib/python3.10/site-packages/envpool/python/xla_template.py", line 124, in make_xla
methods.append(_make_xla_function(obj, handle, name, specs, capsules))
File "/lib/python3.10/site-packages/envpool/python/xla_template.py", line 94, in _make_xla_function
xla.backend_specific_translations["cpu"][prim] = partial(
File "/lib/python3.10/site-packages/jax/_src/deprecations.py", line 52, in getattr
raise AttributeError(message)
AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.
Expected behavior
xla function is created
System info
import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)
> 0.8.4 1.26.4 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux
print(jax.__version__)
> 0.4.29
Reason and Possible fixes
According to the error message then we should use Register custom primitives via jax.interpreters.mlir instead
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
e-zorzi, dillonmsandhu, MehrdadMoghimi and lebrice
Metadata
Metadata
Assignees
Labels
No labels