Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions build/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,13 @@ def get_clangpp_path(clang_path):
clang_path = pathlib.Path(clang_path)
clang_exec_name = clang_path.name
clangpp_exec_name = clang_exec_name
if "clang++" not in clang_exec_name:
clangpp_exec_name = clang_exec_name.replace("clang", "clang++")
clangpp_path = clang_path.parent / clang_exec_name
# Try and match what the user passed in (either clang-18 or clang)
if "clang++" not in clangpp_exec_name:
clangpp_exec_name = clangpp_exec_name.replace("clang", "clang++")
clangpp_path = clang_path.parent / clangpp_exec_name
if not clangpp_path.exists():
clangpp_exec_name = "clang++"
clangpp_path = clang_path.parent / clangpp_exec_name
if not clangpp_path.exists():
raise FileNotFoundError(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def _check_branch_outputs(
f'to the output of {n}{component(p)}'
for p, aval1, aval2 in zip(paths, out_avals1, out_avals2)
for n, a1, a2 in [(name1, aval2, aval1), (name2, aval1, aval2)]
if not core.typematch(a1, a2) and
if not core.typematch(a1, a2) and
isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray)
and a1.vma != a2.vma and a2.vma - a1.vma]

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'applying `jax.lax.pvary(..., {tuple(out_aval.vma - in_aval.vma)})` '
f'to the initial carry value corresponding to {component(path)}'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval) and
if not core.typematch(in_aval, out_aval) and
isinstance(in_aval, ShapedArray) and isinstance(out_aval, ShapedArray)
and in_aval.vma != out_aval.vma and out_aval.vma - in_aval.vma]

Expand Down
4 changes: 4 additions & 0 deletions tests/debugging_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,8 @@ def f2(x):
self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")

def test_unordered_print_with_pjit(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: tests/debugging_primitives_test.py::DebugPrintParallelTest::test_unordered_print_with_pjit")
def f(x):
debug_print("{}", x, ordered=False)
return x
Expand Down Expand Up @@ -841,6 +843,8 @@ def f(x):
self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n")

def test_unordered_print_of_pjit_of_while(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: tests/debugging_primitives_test.py::DebugPrintParallelTest::test_unordered_print_of_pjit_of_while")
def f(x):
def cond(carry):
i, *_ = carry
Expand Down
2 changes: 1 addition & 1 deletion tests/experimental_rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
hidden_size: int, num_layers: int, bidirectional: bool):
# TODO(ruturaj4): Bidirectional doesn't quite work well with rocm.
if bidirectional and jtu.is_device_rocm():
self.skipTest("Bidirectional mode is not available for ROCm.")
self.skipTest("Skip on ROCm: tests/experimental_rnn_test.py::RnnTest::test_lstm: Bidirectional mode is not available for ROCm.")

num_directions = 2 if bidirectional else 1
seq_length_key, root_key = jax.random.split(jax.random.PRNGKey(0))
Expand Down
4 changes: 3 additions & 1 deletion tests/export_harnesses_multi_platform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from jax._src.internal_test_util import test_harnesses
from jax import random


def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
if not parts:
return re.compile("matches_no_test")
Expand Down Expand Up @@ -192,6 +191,9 @@ def test_all_gather(self, *, dtype):
self.export_and_compare_to_native(f, x)

def test_random_with_threefry_gpu_kernel_lowering(self):
if jtu.is_device_rocm and jtu.get_rocm_version() > (6, 5):
self.skipTest("Skip on ROCm: test_random_with_threefry_gpu_kernel_lowering")

# On GPU we use a custom call for threefry2x32
with config.threefry_gpu_kernel_lowering(True):
def f(x):
Expand Down
2 changes: 2 additions & 0 deletions tests/ffi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def test_invalid_result_type(self):

@jtu.run_on_devices("gpu", "cpu")
def test_shard_map(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: tests/ffi_test.py::FfiTest::test_shard_map")
mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)

Expand Down
6 changes: 6 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,12 @@ def testTrimZerosNotOneDArray(self):
)
@jax.default_matmul_precision("float32")
def testPoly(self, a_shape, dtype, rank):
if jtu.is_device_rocm and a_shape == (12,) and dtype in ( np.int32, np.int8 ) and rank == 2:
self.skipTest(f"Skip on ROCm: testPoly: a_shape == (12,) and dtype == {dtype} and rank == 2")

if jtu.is_device_rocm and a_shape == (6,) and dtype == np.float32 and rank == 2:
self.skipTest("Skip on ROCm: testPoly: a_shape == (6,) and dtype == numpy.float32 and rank == 2")

if dtype in (np.float16, jnp.bfloat16, np.int16):
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]):
Expand Down
2 changes: 2 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@ def testTensordot(self, lhs_shape, rhs_shape, axes, dtype):
@jax.default_matmul_precision("float32")
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorithm):
if algorithm is not None:
if jtu.is_device_rocm and algorithm == lax.linalg.SvdAlgorithm.JACOBI and dtype in {np.float32, np.complex64}:
self.skipTest("Skip on ROCm: testSVD Jacobi tests")
if hermitian:
self.skipTest("Hermitian SVD doesn't support the algorithm parameter.")
if not jtu.test_device_matches(["cpu", "gpu"]):
Expand Down
2 changes: 2 additions & 0 deletions tests/multiprocess_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def test_gpu_distributed_initialize(self):

def test_distributed_jax_visible_devices(self):
"""Test jax_visible_devices works in distributed settings."""
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_distributed_jax_visible_devices")
if not jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Tests only for GPU.')

Expand Down
2 changes: 2 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def fwd(a, b, is_ref=False):
impl=['cudnn', 'xla'],
)
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
if jtu.is_device_rocm and dtype == jnp.float16 and group_num == 4 and impl == 'xla':
self.skipTest("Skip on ROCm: testDotProductAttention[21,23]")
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
Expand Down
6 changes: 6 additions & 0 deletions tests/pallas/gpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def test_fused_attention_fwd(
use_fwd,
use_segment_ids,
):
if jtu.is_device_rocm and batch_size == 2 and seq_len == 384 and num_heads == 8 and head_dim == 64 and block_sizes == (('block_q', 128), ('block_k', 128)) and causal and use_fwd and use_segment_ids:
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4")
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
Expand Down Expand Up @@ -270,6 +272,10 @@ def test_fused_attention_bwd(
causal,
use_segment_ids,
):
test_name = str(self).split()[0]
if jtu.is_device_rocm and test_name in {"test_fused_attention_bwd7", "test_fused_attention_bwd8"}:
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd[7,8]")

k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
Expand Down
4 changes: 4 additions & 0 deletions tests/pallas/gpu_paged_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def test_paged_attention(
k_splits,
attn_logits_soft_cap,
):
test_name = str(self).split()[0]
skip_numbers = {0, 1, 3, 5, 6, 7, 9}
if jtu.is_device_rocm and test_name in {f"test_paged_attention{i}" for i in skip_numbers}:
self.skipTest("Skip on ROCm: tests/pallas/gpu_paged_attention_test.py::PagedAttentionKernelTest::test_paged_attention0")
max_kv_len = 2048
seq_lens = np.asarray([3, 256, 513, 1023, 2048], dtype=jnp.int32)
q, k_pages, v_pages, block_tables = _generate_qkv(
Expand Down
22 changes: 22 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ def kernel(x_ref, o_ref):
@hp.given(select_n_strategy(max_cases=2, min_rank=2, max_rank=4,
min_size_exp=1))
def test_select_n(self, args):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_select_n")
pred, *cases = args
scalar_pred = not pred.shape

Expand Down Expand Up @@ -558,6 +560,8 @@ def kernel(*refs):
)
@hp.given(hps.data())
def test_unary_primitives(self, name, func, shape_dtype_strategy, data):
if jtu.is_device_rocm and name in {"logistic", "reciprocal"}:
self.skipTest("Skip on ROCm: test_unary_primitives_[logistic,reciprocal]")
if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]:
self.skip_if_mosaic_gpu()

Expand Down Expand Up @@ -1481,6 +1485,8 @@ def kernel(x_ref, y_ref, o_ref):
for fn, dtype in itertools.product(*args)
)
def test_binary(self, f, dtype):
if jtu.is_device_rocm and f == jnp.bitwise_right_shift and dtype == "uint32":
self.skipTest("Skip on ROCm: binary_bitwise_right_shift for uint32")
self.skip_if_mosaic_gpu()

if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
Expand Down Expand Up @@ -1553,6 +1559,9 @@ def kernel(o_ref):

@parameterized.parameters("float16", "bfloat16", "float32")
def test_approx_tanh(self, dtype):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_approx_tanh")

self.skip_if_mosaic_gpu()

if jtu.test_device_matches(["tpu"]):
Expand Down Expand Up @@ -1582,6 +1591,9 @@ def kernel(x_ref, o_ref):
)

def test_elementwise_inline_asm(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_elementwise_inline_asm")

self.skip_if_mosaic_gpu()

if jtu.test_device_matches(["tpu"]):
Expand Down Expand Up @@ -1862,12 +1874,17 @@ def f(x_ref, o_ref):
trans_y=[False, True],
)
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
test_name = str(self).split()[0]
skip_numbers = list(range(12, 20))
if jtu.is_device_rocm and jtu.get_rocm_version() == (7, 0) and test_name in {f"test_dot{i}" for i in skip_numbers}:
self.skipTest("Skip on ROCm: test_dot[12-19]")
if (
jtu.is_device_rocm() and
jtu.get_rocm_version() < (6, 5)
):
# TODO(psanal35): Investigate the root cause
self.skipTest("ROCm <6.5 issue: some test cases fail (fixed in ROCm 6.5.0)")
self.skip_if_mosaic_gpu()

# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
Expand Down Expand Up @@ -2152,6 +2169,9 @@ def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref):
("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min),
)
def test_scalar_atomic(self, op, value, numpy_op):
if jtu.is_device_rocm and value.dtype == np.float32 and (op == pl.atomic_min or op == pl.atomic_max):
self.skipTest("Skip on ROCm: test_scalar_atomic_(max/min)_f32")

self.skip_if_mosaic_gpu()

# The Pallas TPU lowering currently supports only blocks of rank >= 1
Expand Down Expand Up @@ -2390,6 +2410,8 @@ def reduce(x_ref, y_ref):
dtype=["float16", "float32", "int32", "uint32"],
)
def test_cumsum(self, dtype, axis):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_cumsum")
self.skip_if_mosaic_gpu()

if jtu.test_device_matches(["tpu"]):
Expand Down
2 changes: 2 additions & 0 deletions tests/pallas/tpu_pallas_interpret_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def setUp(self):
self.skipTest(f'requires 1 device, found {self.num_devices}')

def test_matmul_example(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_matmul_example")
def matmul_kernel(x_ref, y_ref, z_ref):
z_ref[...] = x_ref[...] @ y_ref[...]

Expand Down
2 changes: 2 additions & 0 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ def check_if_cache_hit(event):
@parameterized.parameters([True, False])
@jtu.thread_unsafe_test()
def testAutoPgleWithCommandBuffers(self, enable_compilation_cache):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: tests/pgle_test.py::PgleTest::testAutoPgleWithCommandBuffers")
with (config.pgle_profiling_runs(1),
config.enable_compilation_cache(enable_compilation_cache),
config.enable_pgle(True),
Expand Down
8 changes: 8 additions & 0 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,8 @@ def g(x):
np.testing.assert_allclose(g(x), x)

def test_can_shard_pure_callback_maximally(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_can_shard_pure_callback_maximally")
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))

spec = jax.sharding.PartitionSpec('x')
Expand All @@ -967,6 +969,8 @@ def f(x):
)

def test_can_shard_pure_callback_maximally_with_sharding(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_can_shard_pure_callback_maximally_with_sharding")
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))

spec = jax.sharding.PartitionSpec('x')
Expand Down Expand Up @@ -1240,6 +1244,8 @@ def f(x, y):
def test_can_use_io_callback_in_pjit(
self, *, ordered: bool, with_sharding: bool
):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_can_use_io_callback_in_pjit")
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), ['dev'])

Expand Down Expand Up @@ -1298,6 +1304,8 @@ def f(x):
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)

def test_sequence_pjit_io_callback_ordered(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: test_sequence_pjit_io_callback_ordered")
# A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each
# pair on a different device assignment.
_collected: list[int] = []
Expand Down
6 changes: 6 additions & 0 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3624,6 +3624,12 @@ def tearDown(self):
#one_containing="",
)
def test_harness(self, harness: PolyHarness):

if jtu.is_device_rocm and harness.group_name == "svd" and harness.dtype in {np.float32, np.complex64}:
raise unittest.SkipTest("Skip for ROCm: shape_poly_test svd for float32 and complex64.")
if jtu.is_device_rocm and harness.group_name == "vmap_eigh" and harness.dtype == np.complex64:
raise unittest.SkipTest("Skip for ROCm: shape_poly_test vmap_eigh for complex64.")

# We do not expect the associative scan error on TPUs
if harness.expect_error == expect_error_associative_scan and jtu.test_device_matches(["tpu"]):
harness.expect_error = None
Expand Down
Loading