Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions jax_rocm_plugin/pjrt/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import logging
import os
import pathlib
import re
import warnings

from jax._src.lib import xla_client
import jax._src.xla_bridge as xb
Expand All @@ -36,6 +38,81 @@
logger = logging.getLogger(__name__)


def _get_jax_version():
"""Get JAX version for version checking."""
try:
import jax
return jax.__version__
except ImportError:
logger.warning("Could not import jax to check version compatibility")
return None


def _get_pjrt_wheel_version():
"""Get the current PJRT wheel version."""
try:
from . import version
return version.__version__
except ImportError:
logger.warning("Could not import PJRT wheel version module")
return None


def _check_pjrt_wheel_version(pjrt_wheel_name: str, jax_version: str, pjrt_version: str) -> bool:
"""Check if PJRT wheel version is compatible with JAX version.

This implements the same logic as jaxlib.plugin_support.check_plugin_version
for ROCm PJRT wheels.
"""
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")

def _parse_version(v: str) -> tuple[int, ...]:
m = version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse version string '{v}'")
return tuple(int(x) for x in m.group(0).split("."))

try:
jax_parsed = _parse_version(jax_version)
pjrt_parsed = _parse_version(pjrt_version)
except ValueError as e:
logger.warning(f"Failed to parse version strings: {e}")
return True # Allow loading if we can't parse versions

# Check if version checking is disabled via environment variable
if os.environ.get("JAX_DEBUG_SKIP_ROCM_PLUGIN_VERSION_CHECK", "").lower() in ("1", "true", "yes"):
warnings.warn(
f"JAX ROCm PJRT wheel {pjrt_wheel_name} version checking has been disabled via "
"JAX_DEBUG_SKIP_ROCM_PLUGIN_VERSION_CHECK environment variable. "
"This may lead to compatibility issues.",
RuntimeWarning,
)
return True

# For ROCm PJRT wheels: major versions must match, PJRT wheel minor version must be <= JAX minor version
if len(jax_parsed) < 2 or len(pjrt_parsed) < 2:
# If either version doesn't have major.minor format, fall back to exact match
compatible = jax_parsed == pjrt_parsed
else:
# Check major version match and minor version constraint
major_match = jax_parsed[0] == pjrt_parsed[0]
minor_compatible = pjrt_parsed[1] <= jax_parsed[1]
compatible = major_match and minor_compatible

if not compatible:
warnings.warn(
f"JAX ROCm PJRT wheel {pjrt_wheel_name} version {pjrt_version} is installed, but "
"it is not compatible with the installed JAX version "
f"{jax_version}. ROCm PJRT wheels require matching major versions and "
"PJRT wheel minor version <= JAX minor version, so it will not be used. "
f"Use JAX_DEBUG_SKIP_ROCM_PLUGIN_VERSION_CHECK=1 to override",
RuntimeWarning,
)
return False
return True


def _get_library_path():
base_path = pathlib.Path(__file__).resolve().parent
installed_path = (
Expand Down Expand Up @@ -132,6 +209,19 @@ def initialize():

set_rocm_paths(path)

# Version compatibility check
jax_version = _get_jax_version()
pjrt_version = _get_pjrt_wheel_version()

if jax_version and pjrt_version:
pjrt_wheel_name = __package__ or "jax_rocm_pjrt"
if not _check_pjrt_wheel_version(pjrt_wheel_name, jax_version, pjrt_version):
logger.warning(f"Skipping ROCm PJRT plugin initialization due to version incompatibility")
return
logger.info(f"ROCm PJRT wheel version {pjrt_version} is compatible with JAX version {jax_version}")
else:
logger.warning("Could not perform PJRT wheel version check, proceeding with plugin initialization")

options = xla_client.generate_pjrt_gpu_plugin_options()
options["platform_name"] = "ROCM"
c_api = xb.register_plugin(
Expand Down
Loading