|
| 1 | +from numba.core.imputils import lower_builtin |
| 2 | +import numba_dppy.experimental_numpy_lowering_overload as dpnp_lowering |
| 3 | +from numba import types |
| 4 | +from numba.core.typing import signature |
| 5 | +from numba.core.extending import overload, register_jitable |
| 6 | +from . import stubs |
| 7 | +import numpy as np |
| 8 | +from numba_dppy.dpctl_functions import _DPCTL_FUNCTIONS |
| 9 | + |
| 10 | + |
| 11 | +def get_dpnp_fptr(fn_name, type_names): |
| 12 | + from . import dpnp_fptr_interface as dpnp_glue |
| 13 | + |
| 14 | + f_ptr = dpnp_glue.get_dpnp_fn_ptr(fn_name, type_names) |
| 15 | + return f_ptr |
| 16 | + |
| 17 | + |
| 18 | +@register_jitable |
| 19 | +def _check_finite_matrix(a): |
| 20 | + for v in np.nditer(a): |
| 21 | + if not np.isfinite(v.item()): |
| 22 | + raise np.linalg.LinAlgError("Array must not contain infs or NaNs.") |
| 23 | + |
| 24 | + |
| 25 | +@register_jitable |
| 26 | +def _dummy_liveness_func(a): |
| 27 | + """pass a list of variables to be preserved through dead code elimination""" |
| 28 | + return a[0] |
| 29 | + |
| 30 | + |
| 31 | +class RetrieveDpnpFnPtr(types.ExternalFunctionPointer): |
| 32 | + def __init__(self, fn_name, type_names, sig, get_pointer): |
| 33 | + self.fn_name = fn_name |
| 34 | + self.type_names = type_names |
| 35 | + super(RetrieveDpnpFnPtr, self).__init__(sig, get_pointer) |
| 36 | + |
| 37 | + |
| 38 | +class _DPNP_EXTENSION: |
| 39 | + def __init__(self, name): |
| 40 | + dpnp_lowering.ensure_dpnp(name) |
| 41 | + |
| 42 | + @classmethod |
| 43 | + def dpnp_sum(cls, fn_name, type_names): |
| 44 | + ret_type = types.void |
| 45 | + sig = signature(ret_type, types.voidptr, types.voidptr, types.int64) |
| 46 | + f_ptr = get_dpnp_fptr(fn_name, type_names) |
| 47 | + |
| 48 | + def get_pointer(obj): |
| 49 | + return f_ptr |
| 50 | + |
| 51 | + return types.ExternalFunctionPointer(sig, get_pointer=get_pointer) |
| 52 | + |
| 53 | + |
| 54 | +@overload(stubs.dpnp.sum) |
| 55 | +def dpnp_sum_impl(a): |
| 56 | + dpnp_extension = _DPNP_EXTENSION("sum") |
| 57 | + dpctl_functions = _DPCTL_FUNCTIONS() |
| 58 | + |
| 59 | + dpnp_sum = dpnp_extension.dpnp_sum("dpnp_sum", [a.dtype.name, "NONE"]) |
| 60 | + |
| 61 | + get_sycl_queue = dpctl_functions.dpctl_get_current_queue() |
| 62 | + allocate_usm_shared = dpctl_functions.dpctl_malloc_shared() |
| 63 | + copy_usm = dpctl_functions.dpctl_queue_memcpy() |
| 64 | + free_usm = dpctl_functions.dpctl_free_with_queue() |
| 65 | + |
| 66 | + def dpnp_sum_impl(a): |
| 67 | + if a.size == 0: |
| 68 | + raise ValueError("Passed Empty array") |
| 69 | + |
| 70 | + sycl_queue = get_sycl_queue() |
| 71 | + a_usm = allocate_usm_shared(a.size * a.itemsize, sycl_queue) |
| 72 | + copy_usm(sycl_queue, a_usm, a.ctypes, a.size * a.itemsize) |
| 73 | + |
| 74 | + out_usm = allocate_usm_shared(a.itemsize, sycl_queue) |
| 75 | + |
| 76 | + dpnp_sum(a_usm, out_usm, a.size) |
| 77 | + |
| 78 | + out = np.empty(1, dtype=a.dtype) |
| 79 | + copy_usm(sycl_queue, out.ctypes, out_usm, out.size * out.itemsize) |
| 80 | + |
| 81 | + free_usm(a_usm, sycl_queue) |
| 82 | + free_usm(out_usm, sycl_queue) |
| 83 | + |
| 84 | + |
| 85 | + _dummy_liveness_func([out.size]) |
| 86 | + |
| 87 | + return out[0] |
| 88 | + |
| 89 | + return dpnp_sum_impl |
0 commit comments