Skip to content

Commit 33b2c8d

Browse files
Update scalar conversion tests for non-0d arrays
1 parent 96c722c commit 33b2c8d

File tree

1 file changed

+51
-24
lines changed

1 file changed

+51
-24
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
import pytest
23+
from numpy.testing import assert_raises_regex
2324

2425
import dpctl
2526
import dpctl.memory as dpm
@@ -282,34 +283,60 @@ def test_properties(dt):
282283
V.mT
283284

284285

285-
@pytest.mark.parametrize("func", [bool, float, int, complex])
286286
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
287287
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
288-
def test_copy_scalar_with_func(func, shape, dtype):
289-
try:
290-
X = dpt.usm_ndarray(shape, dtype=dtype)
291-
except dpctl.SyclDeviceCreationError:
292-
pytest.skip("No SYCL devices available")
293-
Y = np.arange(1, X.size + 1, dtype=dtype)
294-
X.usm_data.copy_from_host(Y.view("|u1"))
295-
Y.shape = tuple()
296-
assert func(X) == func(Y)
288+
class TestCopyScalar:
289+
def test_copy_bool_scalar_with_func(self, shape, dtype):
290+
try:
291+
X = dpt.usm_ndarray(shape, dtype=dtype)
292+
except dpctl.SyclDeviceCreationError:
293+
pytest.skip("No SYCL devices available")
294+
Y = np.arange(1, X.size + 1, dtype=dtype)
295+
X.usm_data.copy_from_host(Y.view("|u1"))
296+
Y.shape = tuple()
297+
assert bool(X) == bool(Y)
297298

299+
@pytest.mark.parametrize("func", [float, int, complex])
300+
def test_copy_numeric_scalar_with_func(self, func, shape, dtype):
301+
try:
302+
X = dpt.usm_ndarray(shape, dtype=dtype)
303+
except dpctl.SyclDeviceCreationError:
304+
pytest.skip("No SYCL devices available")
305+
Y = np.arange(1, X.size + 1, dtype=dtype)
306+
X.usm_data.copy_from_host(Y.view("|u1"))
307+
Y.shape = tuple()
308+
# Non-0D numeric arrays must not be convertible to Python scalars
309+
if len(shape) != 0:
310+
assert_raises_regex(TypeError, "only 0-dimensional arrays", func, X)
311+
else:
312+
# 0D arrays are allowed to convert
313+
assert func(X) == func(Y)
314+
315+
def test_copy_bool_scalar_with_method(self, shape, dtype):
316+
try:
317+
X = dpt.usm_ndarray(shape, dtype=dtype)
318+
except dpctl.SyclDeviceCreationError:
319+
pytest.skip("No SYCL devices available")
320+
Y = np.arange(1, X.size + 1, dtype=dtype)
321+
X.usm_data.copy_from_host(Y.view("|u1"))
322+
Y.shape = tuple()
323+
assert getattr(X, "__bool__")() == getattr(Y, "__bool__")()
298324

299-
@pytest.mark.parametrize(
300-
"method", ["__bool__", "__float__", "__int__", "__complex__"]
301-
)
302-
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
303-
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
304-
def test_copy_scalar_with_method(method, shape, dtype):
305-
try:
306-
X = dpt.usm_ndarray(shape, dtype=dtype)
307-
except dpctl.SyclDeviceCreationError:
308-
pytest.skip("No SYCL devices available")
309-
Y = np.arange(1, X.size + 1, dtype=dtype)
310-
X.usm_data.copy_from_host(Y.view("|u1"))
311-
Y.shape = tuple()
312-
assert getattr(X, method)() == getattr(Y, method)()
325+
@pytest.mark.parametrize("method", ["__float__", "__int__", "__complex__"])
326+
def test_copy_numeric_scalar_with_method(self, method, shape, dtype):
327+
try:
328+
X = dpt.usm_ndarray(shape, dtype=dtype)
329+
except dpctl.SyclDeviceCreationError:
330+
pytest.skip("No SYCL devices available")
331+
Y = np.arange(1, X.size + 1, dtype=dtype)
332+
X.usm_data.copy_from_host(Y.view("|u1"))
333+
Y.shape = tuple()
334+
if len(shape) != 0:
335+
assert_raises_regex(
336+
TypeError, "only 0-dimensional arrays", getattr(X, method)
337+
)
338+
else:
339+
assert getattr(X, method)() == getattr(Y, method)()
313340

314341

315342
@pytest.mark.parametrize("func", [bool, float, int, complex])

0 commit comments

Comments
 (0)