Skip to content

Conversation

JehandadKhan
Copy link
Collaborator

Motivation

This PR adds a version check to the PJRT wheel similar to the one added to the Plugin Wheel here

Technical Details

The module init file has been updated to parse the version of JAX and the PJRT wheel. If the version of the PJRT wheel is ahead of the JAX wheel, loading would fail. The behavior can be overridden with the JAX_DEBUG_SKIP_ROCM_PLUGIN_VERSION_CHECK env var.

Test Plan

The change was tested by updating the versions of jax/jaxlib and then verifying the behavior of the plugin. Please note that the PJRT plugin is lazily loaded, meaning you have to access the device to trigger the behavior such as issuing jax.devices() in a python shell.

Test Result

JAX/JAXlib Version PJRT Version Result
0.6.2 0.6.0 Plugin loaded
0.5.2 0.6.0 Plugin not loaded
0.5.2 0.6.0 Plugin loaded by using the env var

@JehandadKhan JehandadKhan marked this pull request as draft September 25, 2025 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant