|
20 | 20 |
|
21 | 21 | import numpy as np |
22 | 22 | import pytest |
| 23 | +from numpy.testing import assert_raises_regex |
23 | 24 |
|
24 | 25 | import dpctl |
25 | 26 | import dpctl.memory as dpm |
@@ -282,34 +283,60 @@ def test_properties(dt): |
282 | 283 | V.mT |
283 | 284 |
|
284 | 285 |
|
285 | | -@pytest.mark.parametrize("func", [bool, float, int, complex]) |
286 | 286 | @pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)]) |
287 | 287 | @pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"]) |
288 | | -def test_copy_scalar_with_func(func, shape, dtype): |
289 | | - try: |
290 | | - X = dpt.usm_ndarray(shape, dtype=dtype) |
291 | | - except dpctl.SyclDeviceCreationError: |
292 | | - pytest.skip("No SYCL devices available") |
293 | | - Y = np.arange(1, X.size + 1, dtype=dtype) |
294 | | - X.usm_data.copy_from_host(Y.view("|u1")) |
295 | | - Y.shape = tuple() |
296 | | - assert func(X) == func(Y) |
| 288 | +class TestCopyScalar: |
| 289 | + def test_copy_bool_scalar_with_func(self, shape, dtype): |
| 290 | + try: |
| 291 | + X = dpt.usm_ndarray(shape, dtype=dtype) |
| 292 | + except dpctl.SyclDeviceCreationError: |
| 293 | + pytest.skip("No SYCL devices available") |
| 294 | + Y = np.arange(1, X.size + 1, dtype=dtype) |
| 295 | + X.usm_data.copy_from_host(Y.view("|u1")) |
| 296 | + Y.shape = tuple() |
| 297 | + assert bool(X) == bool(Y) |
297 | 298 |
|
| 299 | + @pytest.mark.parametrize("func", [float, int, complex]) |
| 300 | + def test_copy_numeric_scalar_with_func(self, func, shape, dtype): |
| 301 | + try: |
| 302 | + X = dpt.usm_ndarray(shape, dtype=dtype) |
| 303 | + except dpctl.SyclDeviceCreationError: |
| 304 | + pytest.skip("No SYCL devices available") |
| 305 | + Y = np.arange(1, X.size + 1, dtype=dtype) |
| 306 | + X.usm_data.copy_from_host(Y.view("|u1")) |
| 307 | + Y.shape = tuple() |
| 308 | + # Non-0D numeric arrays must not be convertible to Python scalars |
| 309 | + if len(shape) != 0: |
| 310 | + assert_raises_regex(TypeError, "only 0-dimensional arrays", func, X) |
| 311 | + else: |
| 312 | + # 0D arrays are allowed to convert |
| 313 | + assert func(X) == func(Y) |
| 314 | + |
| 315 | + def test_copy_bool_scalar_with_method(self, shape, dtype): |
| 316 | + try: |
| 317 | + X = dpt.usm_ndarray(shape, dtype=dtype) |
| 318 | + except dpctl.SyclDeviceCreationError: |
| 319 | + pytest.skip("No SYCL devices available") |
| 320 | + Y = np.arange(1, X.size + 1, dtype=dtype) |
| 321 | + X.usm_data.copy_from_host(Y.view("|u1")) |
| 322 | + Y.shape = tuple() |
| 323 | + assert getattr(X, "__bool__")() == getattr(Y, "__bool__")() |
298 | 324 |
|
299 | | -@pytest.mark.parametrize( |
300 | | - "method", ["__bool__", "__float__", "__int__", "__complex__"] |
301 | | -) |
302 | | -@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)]) |
303 | | -@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"]) |
304 | | -def test_copy_scalar_with_method(method, shape, dtype): |
305 | | - try: |
306 | | - X = dpt.usm_ndarray(shape, dtype=dtype) |
307 | | - except dpctl.SyclDeviceCreationError: |
308 | | - pytest.skip("No SYCL devices available") |
309 | | - Y = np.arange(1, X.size + 1, dtype=dtype) |
310 | | - X.usm_data.copy_from_host(Y.view("|u1")) |
311 | | - Y.shape = tuple() |
312 | | - assert getattr(X, method)() == getattr(Y, method)() |
| 325 | + @pytest.mark.parametrize("method", ["__float__", "__int__", "__complex__"]) |
| 326 | + def test_copy_numeric_scalar_with_method(self, method, shape, dtype): |
| 327 | + try: |
| 328 | + X = dpt.usm_ndarray(shape, dtype=dtype) |
| 329 | + except dpctl.SyclDeviceCreationError: |
| 330 | + pytest.skip("No SYCL devices available") |
| 331 | + Y = np.arange(1, X.size + 1, dtype=dtype) |
| 332 | + X.usm_data.copy_from_host(Y.view("|u1")) |
| 333 | + Y.shape = tuple() |
| 334 | + if len(shape) != 0: |
| 335 | + assert_raises_regex( |
| 336 | + TypeError, "only 0-dimensional arrays", getattr(X, method) |
| 337 | + ) |
| 338 | + else: |
| 339 | + assert getattr(X, method)() == getattr(Y, method)() |
313 | 340 |
|
314 | 341 |
|
315 | 342 | @pytest.mark.parametrize("func", [bool, float, int, complex]) |
|
0 commit comments