Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions jaxlib/plugin_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from collections.abc import Sequence
import importlib
import os
import re
from types import ModuleType
import warnings
Expand Down Expand Up @@ -67,13 +68,53 @@ def _parse_version(v: str) -> tuple[int, ...]:
raise ValueError(f"Unable to parse version string '{v}'")
return tuple(int(x) for x in m.group(0).split("."))

if _parse_version(jaxlib_version) != _parse_version(plugin_version):
jaxlib_parsed = _parse_version(jaxlib_version)
plugin_parsed = _parse_version(plugin_version)

# Check if this is a ROCm plugin
is_rocm_plugin = any(rocm_name in plugin_name for rocm_name in _PLUGIN_MODULE_NAMES.get("rocm", []))

# Check if version checking is disabled via environment variable (ROCm plugins only)
if is_rocm_plugin and os.environ.get("JAX_DEBUG_SKIP_ROCM_PLUGIN_VERSION_CHECK", "").lower() in ("1", "true", "yes"):
warnings.warn(
f"JAX plugin {plugin_name} version {plugin_version} is installed, but "
"it is not compatible with the installed jaxlib version "
f"{jaxlib_version}, so it will not be used.",
f"JAX ROCm plugin {plugin_name} version checking has been disabled via "
"JAX_DEBUG_SKIP_ROCM_PLUGIN_VERSION_CHECK environment variable. "
"This may lead to compatibility issues.",
RuntimeWarning,
)
return True

if is_rocm_plugin:
# For ROCm plugins: major versions must match, plugin minor version must be <= jaxlib minor version
if len(jaxlib_parsed) < 2 or len(plugin_parsed) < 2:
# If either version doesn't have major.minor format, fall back to exact match
compatible = jaxlib_parsed == plugin_parsed
else:
# Check major version match and minor version constraint
major_match = jaxlib_parsed[0] == plugin_parsed[0]
minor_compatible = plugin_parsed[1] <= jaxlib_parsed[1]
compatible = major_match and minor_compatible
else:
# For non-ROCm plugins: require exact version match
compatible = jaxlib_parsed == plugin_parsed

if not compatible:
if is_rocm_plugin:
warnings.warn(
f"JAX ROCm plugin {plugin_name} version {plugin_version} is installed, but "
"it is not compatible with the installed jaxlib version "
f"{jaxlib_version}. ROCm plugins require matching major versions and "
"plugin minor version <= jaxlib minor version, so it will not be used."
f"Use JAX_DEBUG_SKIP_ROCM_PLUGIN_VERSION_CHECK=1 to override",
RuntimeWarning,
)
else:
warnings.warn(
f"JAX plugin {plugin_name} version {plugin_version} is installed, but "
"it is not compatible with the installed jaxlib version "
f"{jaxlib_version}, so it will not be used.",
RuntimeWarning,
)
return False
return True

Expand Down
Loading