diff --git a/numba_cuda/numba/cuda/core/base.py b/numba_cuda/numba/cuda/core/base.py index 877e2d054..56eed6211 100644 --- a/numba_cuda/numba/cuda/core/base.py +++ b/numba_cuda/numba/cuda/core/base.py @@ -1040,7 +1040,7 @@ def make_data_helper(self, builder, typ, ref=None): return self._make_helper(builder, typ, ref=ref, kind="data") def make_array(self, typ): - from numba.np import arrayobj + from numba.cuda.np import arrayobj return arrayobj.make_array(typ) @@ -1048,7 +1048,7 @@ def populate_array(self, arr, **kwargs): """ Populate array structure. """ - from numba.np import arrayobj + from numba.cuda.np import arrayobj return arrayobj.populate_array(arr, **kwargs) diff --git a/numba_cuda/numba/cuda/core/config.py b/numba_cuda/numba/cuda/core/config.py index b7adebba7..d9c2cfb67 100644 --- a/numba_cuda/numba/cuda/core/config.py +++ b/numba_cuda/numba/cuda/core/config.py @@ -149,7 +149,7 @@ def check_numba_config(self): "numba.config is deprecated for numba-cuda " "and support for configuration values from it " "will be removed in a future release. " - "Please use numba.cuda.config." + "Please use numba.cuda.core.config." ) warnings.warn(msg, category=DeprecationWarning) self.value = config_value @@ -610,7 +610,7 @@ def num_threads_default(): "NUMBA_NUM_THREADS" in globals() and globals()["NUMBA_NUM_THREADS"] != _NUMBA_NUM_THREADS ): - from numba.np.ufunc import parallel + from numba.cuda.np.ufunc import parallel if parallel._is_initialized: raise RuntimeError( diff --git a/numba_cuda/numba/cuda/core/inline_closurecall.py b/numba_cuda/numba/cuda/core/inline_closurecall.py index 34899f9df..3d8b9afa9 100644 --- a/numba_cuda/numba/cuda/core/inline_closurecall.py +++ b/numba_cuda/numba/cuda/core/inline_closurecall.py @@ -38,7 +38,7 @@ from numba.core.typing import signature from numba.cuda.core import postproc, rewrites -from numba.np.unsafe.ndarray import empty_inferred as unsafe_empty_inferred +from numba.cuda.np.unsafe.ndarray import empty_inferred as unsafe_empty_inferred import numpy as np import operator from numba.cuda.misc.special import prange @@ -1119,7 +1119,7 @@ def codegen(context, builder, sig, args): intp_t = context.get_value_type(types.intp) iterobj = context.make_helper(builder, iterty, value=value) arrayty = iterty.array_type - from numba.np.arrayobj import make_array + from numba.cuda.np.arrayobj import make_array ary = make_array(arrayty)(context, builder, value=iterobj.array) shape = cgutils.unpack_tuple(builder, ary.shape) diff --git a/numba_cuda/numba/cuda/cudadrv/devicearray.py b/numba_cuda/numba/cuda/cudadrv/devicearray.py index 0f496b620..4d38806a1 100644 --- a/numba_cuda/numba/cuda/cudadrv/devicearray.py +++ b/numba_cuda/numba/cuda/cudadrv/devicearray.py @@ -21,9 +21,9 @@ from numba.cuda.cudadrv import driver as _driver from numba.core import types from numba.cuda.core import config -from numba.np.unsafe.ndarray import to_fixed_tuple -from numba.np.numpy_support import numpy_version -from numba.np import numpy_support +from numba.cuda.np.unsafe.ndarray import to_fixed_tuple +from numba.cuda.np.numpy_support import numpy_version +from numba.cuda.np import numpy_support from numba.cuda.api_util import prepare_shape_strides_dtype from numba.core.errors import NumbaPerformanceWarning from warnings import warn diff --git a/numba_cuda/numba/cuda/cudaimpl.py b/numba_cuda/numba/cuda/cudaimpl.py index 3a326286c..c78e6c74a 100644 --- a/numba_cuda/numba/cuda/cudaimpl.py +++ b/numba_cuda/numba/cuda/cudaimpl.py @@ -14,8 +14,8 @@ from numba.core.datamodel import models from numba.core import types from numba.cuda import cgutils -from numba.np import ufunc_db -from numba.np.npyimpl import register_ufuncs +from numba.cuda.np import ufunc_db +from numba.cuda.np.npyimpl import register_ufuncs from .cudadrv import nvvm from numba import cuda from numba.cuda import nvvmutils, stubs diff --git a/numba_cuda/numba/cuda/kernels/reduction.py b/numba_cuda/numba/cuda/kernels/reduction.py index 129f525bf..463db8846 100644 --- a/numba_cuda/numba/cuda/kernels/reduction.py +++ b/numba_cuda/numba/cuda/kernels/reduction.py @@ -5,7 +5,7 @@ A library written in CUDA Python for generating reduction kernels """ -from numba.np.numpy_support import from_dtype +from numba.cuda.np.numpy_support import from_dtype _WARPSIZE = 32 diff --git a/numba_cuda/numba/cuda/kernels/transpose.py b/numba_cuda/numba/cuda/kernels/transpose.py index fd031d21d..01e2670b0 100644 --- a/numba_cuda/numba/cuda/kernels/transpose.py +++ b/numba_cuda/numba/cuda/kernels/transpose.py @@ -4,7 +4,7 @@ from numba import cuda from numba.cuda.cudadrv.driver import driver import math -from numba.np import numpy_support as nps +from numba.cuda.np import numpy_support as nps def transpose(a, b=None): diff --git a/numba_cuda/numba/cuda/misc/cffiimpl.py b/numba_cuda/numba/cuda/misc/cffiimpl.py index 4bb0c4b21..e4057e7df 100644 --- a/numba_cuda/numba/cuda/misc/cffiimpl.py +++ b/numba_cuda/numba/cuda/misc/cffiimpl.py @@ -7,7 +7,7 @@ from numba.core.imputils import Registry from numba.core import types -from numba.np import arrayobj +from numba.cuda.np import arrayobj registry = Registry("cffiimpl") diff --git a/numba_cuda/numba/cuda/np/arraymath.py b/numba_cuda/numba/cuda/np/arraymath.py new file mode 100644 index 000000000..57d853d56 --- /dev/null +++ b/numba_cuda/numba/cuda/np/arraymath.py @@ -0,0 +1,5199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implementation of math operations on Array objects. +""" + +import math +from collections import namedtuple +import operator + +import llvmlite.ir +import numpy as np + +from numba.core import types +from numba.cuda import cgutils +from numba.cuda.extending import overload, overload_method, register_jitable +from numba.cuda.np.numpy_support import ( + as_dtype, + type_can_asarray, + type_is_scalar, + numpy_version, + is_nonelike, + check_is_integer, + lt_floats, + lt_complex, +) +from numba.core.imputils import ( + impl_ret_borrowed, + impl_ret_new_ref, + impl_ret_untracked, + Registry, +) +from numba.cuda.np.arrayobj import ( + make_array, + load_item, + store_item, + _empty_nd_impl, +) +from numba.cuda.np.linalg import ensure_blas + +from numba.cuda.extending import intrinsic +from numba.core.errors import ( + RequireLiteralValue, + TypingError, + NumbaValueError, + NumbaNotImplementedError, + NumbaTypeError, +) +from numba.cpython.unsafe.tuple import tuple_setitem + +registry = Registry("np.arraymath") +lower = registry.lower + + +def _check_blas(): + # Checks if a BLAS is available so e.g. dot will work + try: + ensure_blas() + except ImportError: + return False + return True + + +_HAVE_BLAS = _check_blas() + + +@intrinsic +def _create_tuple_result_shape(tyctx, shape_list, shape_tuple): + """ + This routine converts shape list where the axis dimension has already + been popped to a tuple for indexing of the same size. The original shape + tuple is also required because it contains a length field at compile time + whereas the shape list does not. + """ + + # The new tuple's size is one less than the original tuple since axis + # dimension removed. + nd = len(shape_tuple) - 1 + # The return type of this intrinsic is an int tuple of length nd. + tupty = types.UniTuple(types.intp, nd) + # The function signature for this intrinsic. + function_sig = tupty(shape_list, shape_tuple) + + def codegen(cgctx, builder, signature, args): + lltupty = cgctx.get_value_type(tupty) + # Create an empty int tuple. + tup = cgutils.get_null_value(lltupty) + + # Get the shape list from the args and we don't need shape tuple. + [in_shape, _] = args + + def array_indexer(a, i): + return a[i] + + # loop to fill the tuple + for i in range(nd): + dataidx = cgctx.get_constant(types.intp, i) + # compile and call array_indexer + data = cgctx.compile_internal( + builder, + array_indexer, + types.intp(shape_list, types.intp), + [in_shape, dataidx], + ) + tup = builder.insert_value(tup, data, i) + return tup + + return function_sig, codegen + + +@intrinsic +def _gen_index_tuple(tyctx, shape_tuple, value, axis): + """ + Generates a tuple that can be used to index a specific slice from an + array for sum with axis. shape_tuple is the size of the dimensions of + the input array. 'value' is the value to put in the indexing tuple + in the axis dimension and 'axis' is that dimension. For this to work, + axis has to be a const. + """ + if not isinstance(axis, types.Literal): + raise RequireLiteralValue("axis argument must be a constant") + # Get the value of the axis constant. + axis_value = axis.literal_value + # The length of the indexing tuple to be output. + nd = len(shape_tuple) + + # If the axis value is impossible for the given size array then + # just fake it like it was for axis 0. This will stop compile errors + # when it looks like it could be called from array_sum_axis but really + # can't because that routine checks the axis mismatch and raise an + # exception. + if axis_value >= nd: + axis_value = 0 + + # Calculate the type of the indexing tuple. All the non-axis + # dimensions have slice2 type and the axis dimension has int type. + before = axis_value + after = nd - before - 1 + + types_list = [] + types_list += [types.slice2_type] * before + types_list += [types.intp] + types_list += [types.slice2_type] * after + + # Creates the output type of the function. + tupty = types.Tuple(types_list) + # Defines the signature of the intrinsic. + function_sig = tupty(shape_tuple, value, axis) + + def codegen(cgctx, builder, signature, args): + lltupty = cgctx.get_value_type(tupty) + # Create an empty indexing tuple. + tup = cgutils.get_null_value(lltupty) + + # We only need value of the axis dimension here. + # The rest are constants defined above. + [_, value_arg, _] = args + + def create_full_slice(): + return slice(None, None) + + # loop to fill the tuple with slice(None,None) before + # the axis dimension. + + # compile and call create_full_slice + slice_data = cgctx.compile_internal( + builder, create_full_slice, types.slice2_type(), [] + ) + for i in range(0, axis_value): + tup = builder.insert_value(tup, slice_data, i) + + # Add the axis dimension 'value'. + tup = builder.insert_value(tup, value_arg, axis_value) + + # loop to fill the tuple with slice(None,None) after + # the axis dimension. + for i in range(axis_value + 1, nd): + tup = builder.insert_value(tup, slice_data, i) + return tup + + return function_sig, codegen + + +# ---------------------------------------------------------------------------- +# Basic stats and aggregates + + +@lower(np.sum, types.Array) +@lower("array.sum", types.Array) +def array_sum(context, builder, sig, args): + zero = sig.return_type(0) + + def array_sum_impl(arr): + c = zero + for v in np.nditer(arr): + c += v.item() + return c + + res = context.compile_internal( + builder, array_sum_impl, sig, args, locals=dict(c=sig.return_type) + ) + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@register_jitable +def _array_sum_axis_nop(arr, v): + return arr + + +def gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero): + def inner(arr, axis): + """ + function that performs sums over one specific axis + + The third parameter to gen_index_tuple that generates the indexing + tuples has to be a const so we can't just pass "axis" through since + that isn't const. We can check for specific values and have + different instances that do take consts. Supporting axis summation + only up to the fourth dimension for now. + + typing/arraydecl.py:sum_expand defines the return type for sum with + axis. It is one dimension less than the input array. + """ + ndim = arr.ndim + + if not is_axis_const: + # Catch where axis is negative or greater than 3. + if axis < 0 or axis > 3: + raise ValueError( + "Numba does not support sum with axis " + "parameter outside the range 0 to 3." + ) + + # Catch the case where the user misspecifies the axis to be + # more than the number of the array's dimensions. + if axis >= ndim: + raise ValueError("axis is out of bounds for array") + + # Convert the shape of the input array to a list. + ashape = list(arr.shape) + # Get the length of the axis dimension. + axis_len = ashape[axis] + # Remove the axis dimension from the list of dimensional lengths. + ashape.pop(axis) + # Convert this shape list back to a tuple using above intrinsic. + ashape_without_axis = _create_tuple_result_shape(ashape, arr.shape) + # Tuple needed here to create output array with correct size. + result = np.full(ashape_without_axis, zero, type(zero)) + + # Iterate through the axis dimension. + for axis_index in range(axis_len): + if is_axis_const: + # constant specialized version works for any valid axis value + index_tuple_generic = _gen_index_tuple( + arr.shape, axis_index, const_axis_val + ) + result += arr[index_tuple_generic] + else: + # Generate a tuple used to index the input array. + # The tuple is ":" in all dimensions except the axis + # dimension where it is "axis_index". + if axis == 0: + index_tuple1 = _gen_index_tuple(arr.shape, axis_index, 0) + result += arr[index_tuple1] + elif axis == 1: + index_tuple2 = _gen_index_tuple(arr.shape, axis_index, 1) + result += arr[index_tuple2] + elif axis == 2: + index_tuple3 = _gen_index_tuple(arr.shape, axis_index, 2) + result += arr[index_tuple3] + elif axis == 3: + index_tuple4 = _gen_index_tuple(arr.shape, axis_index, 3) + result += arr[index_tuple4] + return op(result, 0) + + return inner + + +@lower(np.sum, types.Array, types.intp, types.DTypeSpec) +@lower(np.sum, types.Array, types.IntegerLiteral, types.DTypeSpec) +@lower("array.sum", types.Array, types.intp, types.DTypeSpec) +@lower("array.sum", types.Array, types.IntegerLiteral, types.DTypeSpec) +def array_sum_axis_dtype(context, builder, sig, args): + retty = sig.return_type + zero = getattr(retty, "dtype", retty)(0) + # if the return is scalar in type then "take" the 0th element of the + # 0d array accumulator as the return value + if getattr(retty, "ndim", None) is None: + op = np.take + else: + op = _array_sum_axis_nop + [ty_array, ty_axis, ty_dtype] = sig.args + is_axis_const = False + const_axis_val = 0 + if isinstance(ty_axis, types.Literal): + # this special-cases for constant axis + const_axis_val = ty_axis.literal_value + # fix negative axis + if const_axis_val < 0: + const_axis_val = ty_array.ndim + const_axis_val + if const_axis_val < 0 or const_axis_val > ty_array.ndim: + raise ValueError("'axis' entry is out of bounds") + + ty_axis = context.typing_context.resolve_value_type(const_axis_val) + axis_val = context.get_constant(ty_axis, const_axis_val) + # rewrite arguments + args = args[0], axis_val, args[2] + # rewrite sig + sig = sig.replace(args=[ty_array, ty_axis, ty_dtype]) + is_axis_const = True + + gen_impl = gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero) + compiled = register_jitable(gen_impl) + + def array_sum_impl_axis(arr, axis, dtype): + return compiled(arr, axis) + + res = context.compile_internal(builder, array_sum_impl_axis, sig, args) + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +@lower(np.sum, types.Array, types.DTypeSpec) +@lower("array.sum", types.Array, types.DTypeSpec) +def array_sum_dtype(context, builder, sig, args): + zero = sig.return_type(0) + + def array_sum_impl(arr, dtype): + c = zero + for v in np.nditer(arr): + c += v.item() + return c + + res = context.compile_internal( + builder, array_sum_impl, sig, args, locals=dict(c=sig.return_type) + ) + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower(np.sum, types.Array, types.intp) +@lower(np.sum, types.Array, types.IntegerLiteral) +@lower("array.sum", types.Array, types.intp) +@lower("array.sum", types.Array, types.IntegerLiteral) +def array_sum_axis(context, builder, sig, args): + retty = sig.return_type + zero = getattr(retty, "dtype", retty)(0) + # if the return is scalar in type then "take" the 0th element of the + # 0d array accumulator as the return value + if getattr(retty, "ndim", None) is None: + op = np.take + else: + op = _array_sum_axis_nop + [ty_array, ty_axis] = sig.args + is_axis_const = False + const_axis_val = 0 + if isinstance(ty_axis, types.Literal): + # this special-cases for constant axis + const_axis_val = ty_axis.literal_value + # fix negative axis + if const_axis_val < 0: + const_axis_val = ty_array.ndim + const_axis_val + if const_axis_val < 0 or const_axis_val > ty_array.ndim: + msg = f"'axis' entry ({const_axis_val}) is out of bounds" + raise NumbaValueError(msg) + + ty_axis = context.typing_context.resolve_value_type(const_axis_val) + axis_val = context.get_constant(ty_axis, const_axis_val) + # rewrite arguments + args = args[0], axis_val + # rewrite sig + sig = sig.replace(args=[ty_array, ty_axis]) + is_axis_const = True + + gen_impl = gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero) + compiled = register_jitable(gen_impl) + + def array_sum_impl_axis(arr, axis): + return compiled(arr, axis) + + res = context.compile_internal(builder, array_sum_impl_axis, sig, args) + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +def get_accumulator(dtype, value): + if dtype.type == np.timedelta64: + acc_init = np.int64(value).view(dtype) + else: + acc_init = dtype.type(value) + return acc_init + + +@overload(np.prod) +@overload_method(types.Array, "prod") +def array_prod(a): + if isinstance(a, types.Array): + dtype = as_dtype(a.dtype) + + acc_init = get_accumulator(dtype, 1) + + def array_prod_impl(a): + c = acc_init + for v in np.nditer(a): + c *= v.item() + return c + + return array_prod_impl + + +@overload(np.cumsum) +@overload_method(types.Array, "cumsum") +def array_cumsum(a): + if isinstance(a, types.Array): + is_integer = a.dtype in types.signed_domain + is_bool = a.dtype == types.bool_ + if (is_integer and a.dtype.bitwidth < types.intp.bitwidth) or is_bool: + dtype = as_dtype(types.intp) + else: + dtype = as_dtype(a.dtype) + + acc_init = get_accumulator(dtype, 0) + + def array_cumsum_impl(a): + out = np.empty(a.size, dtype) + c = acc_init + for idx, v in enumerate(a.flat): + c += v + out[idx] = c + return out + + return array_cumsum_impl + + +@overload(np.cumprod) +@overload_method(types.Array, "cumprod") +def array_cumprod(a): + if isinstance(a, types.Array): + is_integer = a.dtype in types.signed_domain + is_bool = a.dtype == types.bool_ + if (is_integer and a.dtype.bitwidth < types.intp.bitwidth) or is_bool: + dtype = as_dtype(types.intp) + else: + dtype = as_dtype(a.dtype) + + acc_init = get_accumulator(dtype, 1) + + def array_cumprod_impl(a): + out = np.empty(a.size, dtype) + c = acc_init + for idx, v in enumerate(a.flat): + c *= v + out[idx] = c + return out + + return array_cumprod_impl + + +@overload(np.mean) +@overload_method(types.Array, "mean") +def array_mean(a): + if isinstance(a, types.Array): + is_number = a.dtype in types.integer_domain | frozenset([types.bool_]) + if is_number: + dtype = as_dtype(types.float64) + else: + dtype = as_dtype(a.dtype) + + acc_init = get_accumulator(dtype, 0) + + def array_mean_impl(a): + # Can't use the naive `arr.sum() / arr.size`, as it would return + # a wrong result on integer sum overflow. + c = acc_init + for v in np.nditer(a): + c += v.item() + return c / a.size + + return array_mean_impl + + +@overload(np.var) +@overload_method(types.Array, "var") +def array_var(a): + if isinstance(a, types.Array): + + def array_var_impl(a): + # Compute the mean + m = a.mean() + + # Compute the sum of square diffs + ssd = 0 + for v in np.nditer(a): + val = v.item() - m + ssd += np.real(val * np.conj(val)) + return ssd / a.size + + return array_var_impl + + +@overload(np.std) +@overload_method(types.Array, "std") +def array_std(a): + if isinstance(a, types.Array): + + def array_std_impl(a): + return a.var() ** 0.5 + + return array_std_impl + + +@register_jitable +def min_comparator(a, min_val): + return a < min_val + + +@register_jitable +def max_comparator(a, min_val): + return a > min_val + + +@register_jitable +def return_false(a): + return False + + +@overload(np.min) +@overload(np.amin) +@overload_method(types.Array, "min") +def npy_min(a): + if not isinstance(a, types.Array): + return + + if isinstance(a.dtype, (types.NPDatetime, types.NPTimedelta)): + pre_return_func = np.isnat + comparator = min_comparator + elif isinstance(a.dtype, types.Complex): + pre_return_func = return_false + + def comp_func(a, min_val): + if a.real < min_val.real: + return True + elif a.real == min_val.real: + if a.imag < min_val.imag: + return True + return False + + comparator = register_jitable(comp_func) + elif isinstance(a.dtype, types.Float): + pre_return_func = np.isnan + comparator = min_comparator + else: + pre_return_func = return_false + comparator = min_comparator + + def impl_min(a): + if a.size == 0: + raise ValueError( + "zero-size array to reduction operation " + "minimum which has no identity" + ) + + it = np.nditer(a) + min_value = next(it).take(0) + if pre_return_func(min_value): + return min_value + + for view in it: + v = view.item() + if pre_return_func(v): + return v + if comparator(v, min_value): + min_value = v + return min_value + + return impl_min + + +@overload(np.max) +@overload(np.amax) +@overload_method(types.Array, "max") +def npy_max(a): + if not isinstance(a, types.Array): + return + + if isinstance(a.dtype, (types.NPDatetime, types.NPTimedelta)): + pre_return_func = np.isnat + comparator = max_comparator + elif isinstance(a.dtype, types.Complex): + pre_return_func = return_false + + def comp_func(a, max_val): + if a.real > max_val.real: + return True + elif a.real == max_val.real: + if a.imag > max_val.imag: + return True + return False + + comparator = register_jitable(comp_func) + elif isinstance(a.dtype, types.Float): + pre_return_func = np.isnan + comparator = max_comparator + else: + pre_return_func = return_false + comparator = max_comparator + + def impl_max(a): + if a.size == 0: + raise ValueError( + "zero-size array to reduction operation " + "maximum which has no identity" + ) + + it = np.nditer(a) + max_value = next(it).take(0) + if pre_return_func(max_value): + return max_value + + for view in it: + v = view.item() + if pre_return_func(v): + return v + if comparator(v, max_value): + max_value = v + return max_value + + return impl_max + + +@register_jitable +def array_argmin_impl_datetime(arry): + if arry.size == 0: + raise ValueError("attempt to get argmin of an empty sequence") + it = np.nditer(arry) + min_value = next(it).take(0) + min_idx = 0 + if np.isnat(min_value): + return min_idx + + idx = 1 + for view in it: + v = view.item() + if np.isnat(v): + return idx + if v < min_value: + min_value = v + min_idx = idx + idx += 1 + return min_idx + + +@register_jitable +def array_argmin_impl_float(arry): + if arry.size == 0: + raise ValueError("attempt to get argmin of an empty sequence") + for v in arry.flat: + min_value = v + min_idx = 0 + break + if np.isnan(min_value): + return min_idx + + idx = 0 + for v in arry.flat: + if np.isnan(v): + return idx + if v < min_value: + min_value = v + min_idx = idx + idx += 1 + return min_idx + + +@register_jitable +def array_argmin_impl_generic(arry): + if arry.size == 0: + raise ValueError("attempt to get argmin of an empty sequence") + for v in arry.flat: + min_value = v + min_idx = 0 + break + else: + raise RuntimeError("unreachable") + + idx = 0 + for v in arry.flat: + if v < min_value: + min_value = v + min_idx = idx + idx += 1 + return min_idx + + +@overload(np.argmin) +@overload_method(types.Array, "argmin") +def array_argmin(a, axis=None): + if isinstance(a.dtype, (types.NPDatetime, types.NPTimedelta)): + flatten_impl = array_argmin_impl_datetime + elif isinstance(a.dtype, types.Float): + flatten_impl = array_argmin_impl_float + else: + flatten_impl = array_argmin_impl_generic + + if is_nonelike(axis): + + def array_argmin_impl(a, axis=None): + return flatten_impl(a) + else: + array_argmin_impl = build_argmax_or_argmin_with_axis_impl( + a, axis, flatten_impl + ) + return array_argmin_impl + + +@register_jitable +def array_argmax_impl_datetime(arry): + if arry.size == 0: + raise ValueError("attempt to get argmax of an empty sequence") + it = np.nditer(arry) + max_value = next(it).take(0) + max_idx = 0 + if np.isnat(max_value): + return max_idx + + idx = 1 + for view in it: + v = view.item() + if np.isnat(v): + return idx + if v > max_value: + max_value = v + max_idx = idx + idx += 1 + return max_idx + + +@register_jitable +def array_argmax_impl_float(arry): + if arry.size == 0: + raise ValueError("attempt to get argmax of an empty sequence") + for v in arry.flat: + max_value = v + max_idx = 0 + break + if np.isnan(max_value): + return max_idx + + idx = 0 + for v in arry.flat: + if np.isnan(v): + return idx + if v > max_value: + max_value = v + max_idx = idx + idx += 1 + return max_idx + + +@register_jitable +def array_argmax_impl_generic(arry): + if arry.size == 0: + raise ValueError("attempt to get argmax of an empty sequence") + for v in arry.flat: + max_value = v + max_idx = 0 + break + + idx = 0 + for v in arry.flat: + if v > max_value: + max_value = v + max_idx = idx + idx += 1 + return max_idx + + +def build_argmax_or_argmin_with_axis_impl(a, axis, flatten_impl): + """ + Given a function that implements the logic for handling a flattened + array, return the implementation function. + """ + check_is_integer(axis, "axis") + retty = types.intp + + tuple_buffer = tuple(range(a.ndim)) + + def impl(a, axis=None): + if axis < 0: + axis = a.ndim + axis + + if axis < 0 or axis >= a.ndim: + raise ValueError("axis is out of bounds") + + # Short circuit 1-dimensional arrays: + if a.ndim == 1: + return flatten_impl(a) + + # Make chosen axis the last axis: + tmp = tuple_buffer + for i in range(axis, a.ndim - 1): + tmp = tuple_setitem(tmp, i, i + 1) + transpose_index = tuple_setitem(tmp, a.ndim - 1, axis) + transposed_arr = a.transpose(transpose_index) + + # Flatten along that axis; since we've transposed, we can just get + # batches off the overall flattened array. + m = transposed_arr.shape[-1] + raveled = transposed_arr.ravel() + assert raveled.size == a.size + assert transposed_arr.size % m == 0 + out = np.empty(transposed_arr.size // m, retty) + for i in range(out.size): + out[i] = flatten_impl(raveled[i * m : (i + 1) * m]) + + # Reshape based on axis we didn't flatten over: + return out.reshape(transposed_arr.shape[:-1]) + + return impl + + +@overload(np.argmax) +@overload_method(types.Array, "argmax") +def array_argmax(a, axis=None): + if isinstance(a.dtype, (types.NPDatetime, types.NPTimedelta)): + flatten_impl = array_argmax_impl_datetime + elif isinstance(a.dtype, types.Float): + flatten_impl = array_argmax_impl_float + else: + flatten_impl = array_argmax_impl_generic + + if is_nonelike(axis): + + def array_argmax_impl(a, axis=None): + return flatten_impl(a) + else: + array_argmax_impl = build_argmax_or_argmin_with_axis_impl( + a, axis, flatten_impl + ) + return array_argmax_impl + + +@overload(np.all) +@overload_method(types.Array, "all") +def np_all(a): + def flat_all(a): + for v in np.nditer(a): + if not v.item(): + return False + return True + + return flat_all + + +@register_jitable +def _allclose_scalars(a_v, b_v, rtol=1e-05, atol=1e-08, equal_nan=False): + a_v_isnan = np.isnan(a_v) + b_v_isnan = np.isnan(b_v) + + # only one of the values is NaN and the + # other is not. + if (not a_v_isnan and b_v_isnan) or (a_v_isnan and not b_v_isnan): + return False + + # either both of the values are NaN + # or both are numbers + if a_v_isnan and b_v_isnan: + if not equal_nan: + return False + else: + if np.isinf(a_v) or np.isinf(b_v): + return a_v == b_v + + if np.abs(a_v - b_v) > atol + rtol * np.abs(b_v * 1.0): + return False + + return True + + +@overload(np.allclose) +@overload_method(types.Array, "allclose") +def np_allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + if not type_can_asarray(a): + raise TypingError('The first argument "a" must be array-like') + + if not type_can_asarray(b): + raise TypingError('The second argument "b" must be array-like') + + if not isinstance(rtol, (float, types.Float)): + raise TypingError('The third argument "rtol" must be a floating point') + + if not isinstance(atol, (float, types.Float)): + raise TypingError('The fourth argument "atol" must be a floating point') + + if not isinstance(equal_nan, (bool, types.Boolean)): + raise TypingError('The fifth argument "equal_nan" must be a boolean') + + is_a_scalar = isinstance(a, types.Number) + is_b_scalar = isinstance(b, types.Number) + + if is_a_scalar and is_b_scalar: + + def np_allclose_impl_scalar_scalar( + a, b, rtol=1e-05, atol=1e-08, equal_nan=False + ): + return _allclose_scalars( + a, b, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + return np_allclose_impl_scalar_scalar + elif is_a_scalar and not is_b_scalar: + + def np_allclose_impl_scalar_array( + a, b, rtol=1e-05, atol=1e-08, equal_nan=False + ): + b = np.asarray(b) + for bv in np.nditer(b): + if not _allclose_scalars( + a, bv.item(), rtol=rtol, atol=atol, equal_nan=equal_nan + ): + return False + return True + + return np_allclose_impl_scalar_array + elif not is_a_scalar and is_b_scalar: + + def np_allclose_impl_array_scalar( + a, b, rtol=1e-05, atol=1e-08, equal_nan=False + ): + a = np.asarray(a) + for av in np.nditer(a): + if not _allclose_scalars( + av.item(), b, rtol=rtol, atol=atol, equal_nan=equal_nan + ): + return False + return True + + return np_allclose_impl_array_scalar + elif not is_a_scalar and not is_b_scalar: + + def np_allclose_impl_array_array( + a, b, rtol=1e-05, atol=1e-08, equal_nan=False + ): + a = np.asarray(a) + b = np.asarray(b) + a_a, b_b = np.broadcast_arrays(a, b) + + for av, bv in np.nditer((a_a, b_b)): + if not _allclose_scalars( + av.item(), + bv.item(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + ): + return False + + return True + + return np_allclose_impl_array_array + + +@overload(np.any) +@overload_method(types.Array, "any") +def np_any(a): + def flat_any(a): + for v in np.nditer(a): + if v.item(): + return True + return False + + return flat_any + + +@overload(np.average) +def np_average(a, axis=None, weights=None): + if weights is None or isinstance(weights, types.NoneType): + + def np_average_impl(a, axis=None, weights=None): + arr = np.asarray(a) + return np.mean(arr) + else: + if axis is None or isinstance(axis, types.NoneType): + + def np_average_impl(a, axis=None, weights=None): + arr = np.asarray(a) + weights = np.asarray(weights) + + if arr.shape != weights.shape: + if axis is None: + raise TypeError( + "Numba does not support average when shapes of " + "a and weights differ." + ) + if weights.ndim != 1: + raise TypeError( + "1D weights expected when shapes of " + "a and weights differ." + ) + + scl = np.sum(weights) + if scl == 0.0: + raise ZeroDivisionError( + "Weights sum to zero, can't be normalized." + ) + + avg = np.sum(np.multiply(arr, weights)) / scl + return avg + else: + + def np_average_impl(a, axis=None, weights=None): + raise TypeError("Numba does not support average with axis.") + + return np_average_impl + + +def get_isnan(dtype): + """ + A generic isnan() function + """ + if isinstance(dtype, (types.Float, types.Complex)): + return np.isnan + else: + + @register_jitable + def _trivial_isnan(x): + return False + + return _trivial_isnan + + +@overload(np.iscomplex) +def np_iscomplex(x): + if type_can_asarray(x): + # NumPy uses asanyarray here! + return lambda x: np.asarray(x).imag != 0 + return None + + +@overload(np.isreal) +def np_isreal(x): + if type_can_asarray(x): + # NumPy uses asanyarray here! + return lambda x: np.asarray(x).imag == 0 + return None + + +@overload(np.iscomplexobj) +def iscomplexobj(x): + # Implementation based on NumPy + # https://github.com/numpy/numpy/blob/d9b1e32cb8ef90d6b4a47853241db2a28146a57d/numpy/lib/type_check.py#L282-L320 + dt = determine_dtype(x) + if isinstance(x, types.Optional): + dt = determine_dtype(x.type) + iscmplx = np.issubdtype(dt, np.complexfloating) + + if isinstance(x, types.Optional): + + def impl(x): + if x is None: + return False + return iscmplx + else: + + def impl(x): + return iscmplx + + return impl + + +@overload(np.isrealobj) +def isrealobj(x): + # Return True if x is not a complex type. + # Implementation based on NumPy + # https://github.com/numpy/numpy/blob/ccfbcc1cd9a4035a467f2e982a565ab27de25b6b/numpy/lib/type_check.py#L290-L322 + def impl(x): + return not np.iscomplexobj(x) + + return impl + + +@overload(np.isscalar) +def np_isscalar(element): + res = type_is_scalar(element) + + def impl(element): + return res + + return impl + + +def is_np_inf_impl(x, out, fn): + # if/else branch should be unified after PR #5606 is merged + if is_nonelike(out): + + def impl(x, out=None): + return np.logical_and(np.isinf(x), fn(np.signbit(x))) + else: + + def impl(x, out=None): + return np.logical_and(np.isinf(x), fn(np.signbit(x)), out) + + return impl + + +@overload(np.isneginf) +def isneginf(x, out=None): + fn = register_jitable(lambda x: x) + return is_np_inf_impl(x, out, fn) + + +@overload(np.isposinf) +def isposinf(x, out=None): + fn = register_jitable(lambda x: ~x) + return is_np_inf_impl(x, out, fn) + + +@register_jitable +def less_than(a, b): + return a < b + + +@register_jitable +def greater_than(a, b): + return a > b + + +@register_jitable +def check_array(a): + if a.size == 0: + raise ValueError("zero-size array to reduction operation not possible") + + +def nan_min_max_factory(comparison_op, is_complex_dtype): + if is_complex_dtype: + + def impl(a): + arr = np.asarray(a) + check_array(arr) + it = np.nditer(arr) + return_val = next(it).take(0) + for view in it: + v = view.item() + if np.isnan(return_val.real) and not np.isnan(v.real): + return_val = v + else: + if comparison_op(v.real, return_val.real): + return_val = v + elif v.real == return_val.real: + if comparison_op(v.imag, return_val.imag): + return_val = v + return return_val + else: + + def impl(a): + arr = np.asarray(a) + check_array(arr) + it = np.nditer(arr) + return_val = next(it).take(0) + for view in it: + v = view.item() + if not np.isnan(v): + if not comparison_op(return_val, v): + return_val = v + return return_val + + return impl + + +real_nanmin = register_jitable( + nan_min_max_factory(less_than, is_complex_dtype=False) +) +real_nanmax = register_jitable( + nan_min_max_factory(greater_than, is_complex_dtype=False) +) +complex_nanmin = register_jitable( + nan_min_max_factory(less_than, is_complex_dtype=True) +) +complex_nanmax = register_jitable( + nan_min_max_factory(greater_than, is_complex_dtype=True) +) + + +@register_jitable +def _isclose_item(x, y, rtol, atol, equal_nan): + if np.isnan(x) and np.isnan(y): + return equal_nan + elif np.isinf(x) and np.isinf(y): + return (x > 0) == (y > 0) + elif np.isinf(x) or np.isinf(y): + return False + else: + return abs(x - y) <= atol + rtol * abs(y) + + +@overload(np.isclose) +def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + if not type_can_asarray(a): + raise TypingError('The first argument "a" must be array-like') + + if not type_can_asarray(b): + raise TypingError('The second argument "b" must be array-like') + + if not isinstance(rtol, (float, types.Float)): + raise TypingError('The third argument "rtol" must be a floating point') + + if not isinstance(atol, (float, types.Float)): + raise TypingError('The fourth argument "atol" must be a floating point') + + if not isinstance(equal_nan, (bool, types.Boolean)): + raise TypingError('The fifth argument "equal_nan" must be a boolean') + + if isinstance(a, types.Array) and isinstance(b, types.Number): + + def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + x = a.reshape(-1) + y = b + out = np.zeros(len(x), np.bool_) + for i in range(len(out)): + out[i] = _isclose_item(x[i], y, rtol, atol, equal_nan) + return out.reshape(a.shape) + + elif isinstance(a, types.Number) and isinstance(b, types.Array): + + def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + x = a + y = b.reshape(-1) + out = np.zeros(len(y), np.bool_) + for i in range(len(out)): + out[i] = _isclose_item(x, y[i], rtol, atol, equal_nan) + return out.reshape(b.shape) + + elif isinstance(a, types.Array) and isinstance(b, types.Array): + + def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + shape = np.broadcast_shapes(a.shape, b.shape) + a_ = np.broadcast_to(a, shape) + b_ = np.broadcast_to(b, shape) + + out = np.zeros(len(a_), dtype=np.bool_) + for i, (av, bv) in enumerate(np.nditer((a_, b_))): + out[i] = _isclose_item( + av.item(), bv.item(), rtol, atol, equal_nan + ) + return np.broadcast_to(out, shape) + + else: + + def isclose_impl(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return _isclose_item(a, b, rtol, atol, equal_nan) + + return isclose_impl + + +@overload(np.nanmin) +def np_nanmin(a): + dt = determine_dtype(a) + if np.issubdtype(dt, np.complexfloating): + return complex_nanmin + else: + return real_nanmin + + +@overload(np.nanmax) +def np_nanmax(a): + dt = determine_dtype(a) + if np.issubdtype(dt, np.complexfloating): + return complex_nanmax + else: + return real_nanmax + + +@overload(np.nanmean) +def np_nanmean(a): + if not isinstance(a, types.Array): + return + isnan = get_isnan(a.dtype) + + def nanmean_impl(a): + c = 0.0 + count = 0 + for view in np.nditer(a): + v = view.item() + if not isnan(v): + c += v.item() + count += 1 + # np.divide() doesn't raise ZeroDivisionError + return np.divide(c, count) + + return nanmean_impl + + +@overload(np.nanvar) +def np_nanvar(a): + if not isinstance(a, types.Array): + return + isnan = get_isnan(a.dtype) + + def nanvar_impl(a): + # Compute the mean + m = np.nanmean(a) + + # Compute the sum of square diffs + ssd = 0.0 + count = 0 + for view in np.nditer(a): + v = view.item() + if not isnan(v): + val = v.item() - m + ssd += np.real(val * np.conj(val)) + count += 1 + # np.divide() doesn't raise ZeroDivisionError + return np.divide(ssd, count) + + return nanvar_impl + + +@overload(np.nanstd) +def np_nanstd(a): + if not isinstance(a, types.Array): + return + + def nanstd_impl(a): + return np.nanvar(a) ** 0.5 + + return nanstd_impl + + +@overload(np.nansum) +def np_nansum(a): + if not isinstance(a, types.Array): + return + if isinstance(a.dtype, types.Integer): + retty = types.intp + else: + retty = a.dtype + zero = retty(0) + isnan = get_isnan(a.dtype) + + def nansum_impl(a): + c = zero + for view in np.nditer(a): + v = view.item() + if not isnan(v): + c += v + return c + + return nansum_impl + + +@overload(np.nanprod) +def np_nanprod(a): + if not isinstance(a, types.Array): + return + if isinstance(a.dtype, types.Integer): + retty = types.intp + else: + retty = a.dtype + one = retty(1) + isnan = get_isnan(a.dtype) + + def nanprod_impl(a): + c = one + for view in np.nditer(a): + v = view.item() + if not isnan(v): + c *= v + return c + + return nanprod_impl + + +@overload(np.nancumprod) +def np_nancumprod(a): + if not isinstance(a, types.Array): + return + + if isinstance(a.dtype, (types.Boolean, types.Integer)): + # dtype cannot possibly contain NaN + return lambda a: np.cumprod(a) + else: + retty = a.dtype + is_nan = get_isnan(retty) + one = retty(1) + + def nancumprod_impl(a): + out = np.empty(a.size, retty) + c = one + for idx, v in enumerate(a.flat): + if ~is_nan(v): + c *= v + out[idx] = c + return out + + return nancumprod_impl + + +@overload(np.nancumsum) +def np_nancumsum(a): + if not isinstance(a, types.Array): + return + + if isinstance(a.dtype, (types.Boolean, types.Integer)): + # dtype cannot possibly contain NaN + return lambda a: np.cumsum(a) + else: + retty = a.dtype + is_nan = get_isnan(retty) + zero = retty(0) + + def nancumsum_impl(a): + out = np.empty(a.size, retty) + c = zero + for idx, v in enumerate(a.flat): + if ~is_nan(v): + c += v + out[idx] = c + return out + + return nancumsum_impl + + +@register_jitable +def prepare_ptp_input(a): + arr = _asarray(a) + if len(arr) == 0: + raise ValueError("zero-size array reduction not possible") + else: + return arr + + +def _compute_current_val_impl_gen(op, current_val, val): + if isinstance(current_val, types.Complex): + # The sort order for complex numbers is lexicographic. If both the + # real and imaginary parts are non-nan then the order is determined + # by the real parts except when they are equal, in which case the + # order is determined by the imaginary parts. + # https://github.com/numpy/numpy/blob/577a86e/numpy/core/fromnumeric.py#L874-L877 # noqa: E501 + def impl(current_val, val): + if op(val.real, current_val.real): + return val + elif val.real == current_val.real and op( + val.imag, current_val.imag + ): + return val + return current_val + else: + + def impl(current_val, val): + return val if op(val, current_val) else current_val + + return impl + + +def _compute_a_max(current_val, val): + pass + + +def _compute_a_min(current_val, val): + pass + + +@overload(_compute_a_max) +def _compute_a_max_impl(current_val, val): + return _compute_current_val_impl_gen(operator.gt, current_val, val) + + +@overload(_compute_a_min) +def _compute_a_min_impl(current_val, val): + return _compute_current_val_impl_gen(operator.lt, current_val, val) + + +def _early_return(val): + pass + + +@overload(_early_return) +def _early_return_impl(val): + UNUSED = 0 + if isinstance(val, types.Complex): + + def impl(val): + if np.isnan(val.real): + if np.isnan(val.imag): + return True, np.nan + np.nan * 1j + else: + return True, np.nan + 0j + else: + return False, UNUSED + elif isinstance(val, types.Float): + + def impl(val): + if np.isnan(val): + return True, np.nan + else: + return False, UNUSED + else: + + def impl(val): + return False, UNUSED + + return impl + + +@overload(np.ptp) +def np_ptp(a): + if hasattr(a, "dtype"): + if isinstance(a.dtype, types.Boolean): + raise TypingError("Boolean dtype is unsupported (as per NumPy)") + # Numpy raises a TypeError + + def np_ptp_impl(a): + arr = prepare_ptp_input(a) + + a_flat = arr.flat + a_min = a_flat[0] + a_max = a_flat[0] + + for i in range(arr.size): + val = a_flat[i] + take_branch, retval = _early_return(val) + if take_branch: + return retval + a_max = _compute_a_max(a_max, val) + a_min = _compute_a_min(a_min, val) + + return a_max - a_min + + return np_ptp_impl + + +if numpy_version < (2, 0): + overload_method(types.Array, "ptp")(np_ptp) + +# ---------------------------------------------------------------------------- +# Median and partitioning + + +@register_jitable +def nan_aware_less_than(a, b): + if np.isnan(a): + return False + else: + if np.isnan(b): + return True + else: + return a < b + + +def _partition_factory(pivotimpl, argpartition=False): + def _partition(A, low, high, I=None): + mid = (low + high) >> 1 + # NOTE: the pattern of swaps below for the pivot choice and the + # partitioning gives good results (i.e. regular O(n log n)) + # on sorted, reverse-sorted, and uniform arrays. Subtle changes + # risk breaking this property. + + # Use median of three {low, middle, high} as the pivot + if pivotimpl(A[mid], A[low]): + A[low], A[mid] = A[mid], A[low] + if argpartition: + I[low], I[mid] = I[mid], I[low] + if pivotimpl(A[high], A[mid]): + A[high], A[mid] = A[mid], A[high] + if argpartition: + I[high], I[mid] = I[mid], I[high] + if pivotimpl(A[mid], A[low]): + A[low], A[mid] = A[mid], A[low] + if argpartition: + I[low], I[mid] = I[mid], I[low] + pivot = A[mid] + + A[high], A[mid] = A[mid], A[high] + if argpartition: + I[high], I[mid] = I[mid], I[high] + i = low + j = high - 1 + while True: + while i < high and pivotimpl(A[i], pivot): + i += 1 + while j >= low and pivotimpl(pivot, A[j]): + j -= 1 + if i >= j: + break + A[i], A[j] = A[j], A[i] + if argpartition: + I[i], I[j] = I[j], I[i] + i += 1 + j -= 1 + # Put the pivot back in its final place (all items before `i` + # are smaller than the pivot, all items at/after `i` are larger) + A[i], A[high] = A[high], A[i] + if argpartition: + I[i], I[high] = I[high], I[i] + return i + + return _partition + + +_partition = register_jitable(_partition_factory(less_than)) +_partition_w_nan = register_jitable(_partition_factory(nan_aware_less_than)) +_argpartition_w_nan = register_jitable( + _partition_factory(nan_aware_less_than, argpartition=True) +) + + +def _select_factory(partitionimpl): + def _select(arry, k, low, high, idx=None): + """ + Select the k'th smallest element in array[low:high + 1]. + """ + i = partitionimpl(arry, low, high, idx) + while i != k: + if i < k: + low = i + 1 + i = partitionimpl(arry, low, high, idx) + else: + high = i - 1 + i = partitionimpl(arry, low, high, idx) + return arry[k] + + return _select + + +_select = register_jitable(_select_factory(_partition)) +_select_w_nan = register_jitable(_select_factory(_partition_w_nan)) +_arg_select_w_nan = register_jitable(_select_factory(_argpartition_w_nan)) + + +@register_jitable +def _select_two(arry, k, low, high): + """ + Select the k'th and k+1'th smallest elements in array[low:high + 1]. + + This is significantly faster than doing two independent selections + for k and k+1. + """ + while True: + assert high > low # by construction + i = _partition(arry, low, high) + if i < k: + low = i + 1 + elif i > k + 1: + high = i - 1 + elif i == k: + _select(arry, k + 1, i + 1, high) + break + else: # i == k + 1 + _select(arry, k, low, i - 1) + break + + return arry[k], arry[k + 1] + + +@register_jitable +def _median_inner(temp_arry, n): + """ + The main logic of the median() call. *temp_arry* must be disposable, + as this function will mutate it. + """ + low = 0 + high = n - 1 + half = n >> 1 + if n & 1 == 0: + a, b = _select_two(temp_arry, half - 1, low, high) + return (a + b) / 2 + else: + return _select(temp_arry, half, low, high) + + +@overload(np.median) +def np_median(a): + if not isinstance(a, types.Array): + return + + def median_impl(a): + # np.median() works on the flattened array, and we need a temporary + # workspace anyway + temp_arry = a.flatten() + n = temp_arry.shape[0] + return _median_inner(temp_arry, n) + + return median_impl + + +@register_jitable +def _collect_percentiles_inner(a, q): + # TODO: This needs rewriting to be closer to NumPy, particularly the nan/inf + # handling which is generally subject to algorithmic changes. + n = len(a) + + if n == 1: + # single element array; output same for all percentiles + out = np.full(len(q), a[0], dtype=np.float64) + else: + out = np.empty(len(q), dtype=np.float64) + for i in range(len(q)): + percentile = q[i] + + # bypass pivoting where requested percentile is 100 + if percentile == 100: + val = np.max(a) + # heuristics to handle infinite values a la NumPy + if ~np.all(np.isfinite(a)): + if ~np.isfinite(val): + val = np.nan + + # bypass pivoting where requested percentile is 0 + elif percentile == 0: + val = np.min(a) + # convoluted heuristics to handle infinite values a la NumPy + if ~np.all(np.isfinite(a)): + num_pos_inf = np.sum(a == np.inf) + num_neg_inf = np.sum(a == -np.inf) + num_finite = n - (num_neg_inf + num_pos_inf) + if num_finite == 0: + val = np.nan + if num_pos_inf == 1 and n == 2: + val = np.nan + if num_neg_inf > 1: + val = np.nan + if num_finite == 1: + if num_pos_inf > 1: + if num_neg_inf != 1: + val = np.nan + + else: + # linear interp between closest ranks + rank = 1 + (n - 1) * np.true_divide(percentile, 100.0) + f = math.floor(rank) + m = rank - f + lower, upper = _select_two(a, k=int(f - 1), low=0, high=(n - 1)) + val = lower * (1 - m) + upper * m + out[i] = val + + return out + + +@register_jitable +def _can_collect_percentiles(a, nan_mask, skip_nan): + if skip_nan: + a = a[~nan_mask] + if len(a) == 0: + return False # told to skip nan, but no elements remain + else: + if np.any(nan_mask): + return False # told *not* to skip nan, but nan encountered + + if len(a) == 1: # single element array + val = a[0] + return np.isfinite(val) # can collect percentiles if element is finite + else: + return True + + +@register_jitable +def check_valid(q, q_upper_bound): + valid = True + + # avoid expensive reductions where possible + if q.ndim == 1 and q.size < 10: + for i in range(q.size): + if q[i] < 0.0 or q[i] > q_upper_bound or np.isnan(q[i]): + valid = False + break + else: + if np.any(np.isnan(q)) or np.any(q < 0.0) or np.any(q > q_upper_bound): + valid = False + + return valid + + +@register_jitable +def percentile_is_valid(q): + if not check_valid(q, q_upper_bound=100.0): + raise ValueError("Percentiles must be in the range [0, 100]") + + +@register_jitable +def quantile_is_valid(q): + if not check_valid(q, q_upper_bound=1.0): + raise ValueError("Quantiles must be in the range [0, 1]") + + +@register_jitable +def _collect_percentiles(a, q, check_q, factor, skip_nan): + q = np.asarray(q, dtype=np.float64).flatten() + check_q(q) + q = q * factor + + temp_arry = np.asarray(a, dtype=np.float64).flatten() + nan_mask = np.isnan(temp_arry) + + if _can_collect_percentiles(temp_arry, nan_mask, skip_nan): + temp_arry = temp_arry[~nan_mask] + out = _collect_percentiles_inner(temp_arry, q) + else: + out = np.full(len(q), np.nan) + + return out + + +def _percentile_quantile_inner(a, q, skip_nan, factor, check_q): + """ + The underlying algorithm to find percentiles and quantiles + is the same, hence we converge onto the same code paths + in this inner function implementation + """ + dt = determine_dtype(a) + if np.issubdtype(dt, np.complexfloating): + raise TypingError("Not supported for complex dtype") + # this could be supported, but would require a + # lexicographic comparison + + def np_percentile_q_scalar_impl(a, q): + return _collect_percentiles(a, q, check_q, factor, skip_nan)[0] + + def np_percentile_impl(a, q): + return _collect_percentiles(a, q, check_q, factor, skip_nan) + + if isinstance(q, (types.Number, types.Boolean)): + return np_percentile_q_scalar_impl + elif isinstance(q, types.Array) and q.ndim == 0: + return np_percentile_q_scalar_impl + else: + return np_percentile_impl + + +@overload(np.percentile) +def np_percentile(a, q): + return _percentile_quantile_inner( + a, q, skip_nan=False, factor=1.0, check_q=percentile_is_valid + ) + + +@overload(np.nanpercentile) +def np_nanpercentile(a, q): + return _percentile_quantile_inner( + a, q, skip_nan=True, factor=1.0, check_q=percentile_is_valid + ) + + +@overload(np.quantile) +def np_quantile(a, q): + return _percentile_quantile_inner( + a, q, skip_nan=False, factor=100.0, check_q=quantile_is_valid + ) + + +@overload(np.nanquantile) +def np_nanquantile(a, q): + return _percentile_quantile_inner( + a, q, skip_nan=True, factor=100.0, check_q=quantile_is_valid + ) + + +@overload(np.nanmedian) +def np_nanmedian(a): + if not isinstance(a, types.Array): + return + isnan = get_isnan(a.dtype) + + def nanmedian_impl(a): + # Create a temporary workspace with only non-NaN values + temp_arry = np.empty(a.size, a.dtype) + n = 0 + for view in np.nditer(a): + v = view.item() + if not isnan(v): + temp_arry[n] = v + n += 1 + + # all NaNs + if n == 0: + return np.nan + + return _median_inner(temp_arry, n) + + return nanmedian_impl + + +@register_jitable +def np_partition_impl_inner(a, kth_array): + # allocate and fill empty array rather than copy a and mutate in place + # as the latter approach fails to preserve strides + out = np.empty_like(a) + + idx = np.ndindex(a.shape[:-1]) # Numpy default partition axis is -1 + for s in idx: + arry = a[s].copy() + low = 0 + high = len(arry) - 1 + + for kth in kth_array: + _select_w_nan(arry, kth, low, high) + low = kth # narrow span of subsequent partition + + out[s] = arry + return out + + +@register_jitable +def np_argpartition_impl_inner(a, kth_array): + # allocate and fill empty array rather than copy a and mutate in place + # as the latter approach fails to preserve strides + out = np.empty_like(a, dtype=np.intp) + + idx = np.ndindex(a.shape[:-1]) # Numpy default partition axis is -1 + for s in idx: + arry = a[s].copy() + idx_arry = np.arange(len(arry)) + low = 0 + high = len(arry) - 1 + + for kth in kth_array: + _arg_select_w_nan(arry, kth, low, high, idx_arry) + low = kth # narrow span of subsequent partition + + out[s] = idx_arry + return out + + +@register_jitable +def valid_kths(a, kth): + """ + Returns a sorted, unique array of kth values which serve + as indexers for partitioning the input array, a. + + If the absolute value of any of the provided values + is greater than a.shape[-1] an exception is raised since + we are partitioning along the last axis (per Numpy default + behaviour). + + Values less than 0 are transformed to equivalent positive + index values. + """ + # cast boolean to int, where relevant + kth_array = _asarray(kth).astype(np.int64) + + if kth_array.ndim != 1: + raise ValueError("kth must be scalar or 1-D") + # numpy raises ValueError: object too deep for desired array + + if np.any(np.abs(kth_array) >= a.shape[-1]): + raise ValueError("kth out of bounds") + + out = np.empty_like(kth_array) + + for index, val in np.ndenumerate(kth_array): + if val < 0: + out[index] = val + a.shape[-1] # equivalent positive index + else: + out[index] = val + + return np.unique(out) + + +@overload(np.partition) +def np_partition(a, kth): + if not isinstance(a, (types.Array, types.Sequence, types.Tuple)): + raise NumbaTypeError("The first argument must be an array-like") + + if isinstance(a, types.Array) and a.ndim == 0: + msg = "The first argument must be at least 1-D (found 0-D)" + raise NumbaTypeError(msg) + + kthdt = getattr(kth, "dtype", kth) + if not isinstance(kthdt, (types.Boolean, types.Integer)): + # bool gets cast to int subsequently + raise NumbaTypeError("Partition index must be integer") + + def np_partition_impl(a, kth): + a_tmp = _asarray(a) + if a_tmp.size == 0: + return a_tmp.copy() + else: + kth_array = valid_kths(a_tmp, kth) + return np_partition_impl_inner(a_tmp, kth_array) + + return np_partition_impl + + +@overload(np.argpartition) +def np_argpartition(a, kth): + if not isinstance(a, (types.Array, types.Sequence, types.Tuple)): + raise NumbaTypeError("The first argument must be an array-like") + + if isinstance(a, types.Array) and a.ndim == 0: + msg = "The first argument must be at least 1-D (found 0-D)" + raise NumbaTypeError(msg) + + kthdt = getattr(kth, "dtype", kth) + if not isinstance(kthdt, (types.Boolean, types.Integer)): + # bool gets cast to int subsequently + raise NumbaTypeError("Partition index must be integer") + + def np_argpartition_impl(a, kth): + a_tmp = _asarray(a) + if a_tmp.size == 0: + return a_tmp.copy().astype("intp") + else: + kth_array = valid_kths(a_tmp, kth) + return np_argpartition_impl_inner(a_tmp, kth_array) + + return np_argpartition_impl + + +# ---------------------------------------------------------------------------- +# Building matrices + + +@register_jitable +def _tri_impl(N, M, k): + shape = max(0, N), max(0, M) # numpy floors each dimension at 0 + out = np.empty(shape, dtype=np.float64) # numpy default dtype + + for i in range(shape[0]): + m_max = min(max(0, i + k + 1), shape[1]) + out[i, :m_max] = 1 + out[i, m_max:] = 0 + + return out + + +@overload(np.tri) +def np_tri(N, M=None, k=0): + # we require k to be integer, unlike numpy + check_is_integer(k, "k") + + def tri_impl(N, M=None, k=0): + if M is None: + M = N + return _tri_impl(N, M, k) + + return tri_impl + + +@register_jitable +def _make_square(m): + """ + Takes a 1d array and tiles it to form a square matrix + - i.e. a facsimile of np.tile(m, (len(m), 1)) + """ + assert m.ndim == 1 + + len_m = len(m) + out = np.empty((len_m, len_m), dtype=m.dtype) + + for i in range(len_m): + out[i] = m + + return out + + +@register_jitable +def np_tril_impl_2d(m, k=0): + mask = np.tri(m.shape[-2], M=m.shape[-1], k=k).astype(np.uint) + return np.where(mask, m, np.zeros_like(m, dtype=m.dtype)) + + +@overload(np.tril) +def my_tril(m, k=0): + # we require k to be integer, unlike numpy + check_is_integer(k, "k") + + def np_tril_impl_1d(m, k=0): + m_2d = _make_square(m) + return np_tril_impl_2d(m_2d, k) + + def np_tril_impl_multi(m, k=0): + mask = np.tri(m.shape[-2], M=m.shape[-1], k=k).astype(np.uint) + idx = np.ndindex(m.shape[:-2]) + z = np.empty_like(m) + zero_opt = np.zeros_like(mask, dtype=m.dtype) + for sel in idx: + z[sel] = np.where(mask, m[sel], zero_opt) + return z + + if m.ndim == 1: + return np_tril_impl_1d + elif m.ndim == 2: + return np_tril_impl_2d + else: + return np_tril_impl_multi + + +@overload(np.tril_indices) +def np_tril_indices(n, k=0, m=None): + # we require integer arguments, unlike numpy + check_is_integer(n, "n") + check_is_integer(k, "k") + if not is_nonelike(m): + check_is_integer(m, "m") + + def np_tril_indices_impl(n, k=0, m=None): + return np.nonzero(np.tri(n, m, k=k)) + + return np_tril_indices_impl + + +@overload(np.tril_indices_from) +def np_tril_indices_from(arr, k=0): + # we require k to be integer, unlike numpy + check_is_integer(k, "k") + + if arr.ndim != 2: + raise TypingError("input array must be 2-d") + + def np_tril_indices_from_impl(arr, k=0): + return np.tril_indices(arr.shape[0], k=k, m=arr.shape[1]) + + return np_tril_indices_from_impl + + +@register_jitable +def np_triu_impl_2d(m, k=0): + mask = np.tri(m.shape[-2], M=m.shape[-1], k=k - 1).astype(np.uint) + return np.where(mask, np.zeros_like(m, dtype=m.dtype), m) + + +@overload(np.triu) +def my_triu(m, k=0): + # we require k to be integer, unlike numpy + check_is_integer(k, "k") + + def np_triu_impl_1d(m, k=0): + m_2d = _make_square(m) + return np_triu_impl_2d(m_2d, k) + + def np_triu_impl_multi(m, k=0): + mask = np.tri(m.shape[-2], M=m.shape[-1], k=k - 1).astype(np.uint) + idx = np.ndindex(m.shape[:-2]) + z = np.empty_like(m) + zero_opt = np.zeros_like(mask, dtype=m.dtype) + for sel in idx: + z[sel] = np.where(mask, zero_opt, m[sel]) + return z + + if m.ndim == 1: + return np_triu_impl_1d + elif m.ndim == 2: + return np_triu_impl_2d + else: + return np_triu_impl_multi + + +@overload(np.triu_indices) +def np_triu_indices(n, k=0, m=None): + # we require integer arguments, unlike numpy + check_is_integer(n, "n") + check_is_integer(k, "k") + if not is_nonelike(m): + check_is_integer(m, "m") + + def np_triu_indices_impl(n, k=0, m=None): + return np.nonzero(1 - np.tri(n, m, k=k - 1)) + + return np_triu_indices_impl + + +@overload(np.triu_indices_from) +def np_triu_indices_from(arr, k=0): + # we require k to be integer, unlike numpy + check_is_integer(k, "k") + + if arr.ndim != 2: + raise TypingError("input array must be 2-d") + + def np_triu_indices_from_impl(arr, k=0): + return np.triu_indices(arr.shape[0], k=k, m=arr.shape[1]) + + return np_triu_indices_from_impl + + +def _prepare_array(arr): + pass + + +@overload(_prepare_array) +def _prepare_array_impl(arr): + if arr in (None, types.none): + return lambda arr: np.array(()) + else: + return lambda arr: _asarray(arr).ravel() + + +def _dtype_of_compound(inobj): + obj = inobj + while True: + if isinstance(obj, (types.Number, types.Boolean)): + return as_dtype(obj) + l = getattr(obj, "__len__", None) + if l is not None and l() == 0: # empty tuple or similar + return np.float64 + dt = getattr(obj, "dtype", None) + if dt is None: + raise NumbaTypeError("type has no dtype attr") + if isinstance(obj, types.Sequence): + obj = obj.dtype + else: + return as_dtype(dt) + + +@overload(np.ediff1d) +def np_ediff1d(ary, to_end=None, to_begin=None): + if isinstance(ary, types.Array): + if isinstance(ary.dtype, types.Boolean): + raise NumbaTypeError("Boolean dtype is unsupported (as per NumPy)") + # Numpy tries to do this: return ary[1:] - ary[:-1] which + # results in a TypeError exception being raised + + # Check that to_end and to_begin are compatible with ary + ary_dt = _dtype_of_compound(ary) + to_begin_dt = None + if not (is_nonelike(to_begin)): + to_begin_dt = _dtype_of_compound(to_begin) + to_end_dt = None + if not (is_nonelike(to_end)): + to_end_dt = _dtype_of_compound(to_end) + + if to_begin_dt is not None and not np.can_cast(to_begin_dt, ary_dt): + msg = "dtype of to_begin must be compatible with input ary" + raise NumbaTypeError(msg) + + if to_end_dt is not None and not np.can_cast(to_end_dt, ary_dt): + msg = "dtype of to_end must be compatible with input ary" + raise NumbaTypeError(msg) + + def np_ediff1d_impl(ary, to_end=None, to_begin=None): + # transform each input into an equivalent 1d array + start = _prepare_array(to_begin) + mid = _prepare_array(ary) + end = _prepare_array(to_end) + + out_dtype = mid.dtype + # output array dtype determined by ary dtype, per NumPy + # (for the most part); an exception to the rule is a zero length + # array-like, where NumPy falls back to np.float64; this behaviour + # is *not* replicated + + if len(mid) > 0: + out = np.empty( + (len(start) + len(mid) + len(end) - 1), dtype=out_dtype + ) + start_idx = len(start) + mid_idx = len(start) + len(mid) - 1 + out[:start_idx] = start + out[start_idx:mid_idx] = np.diff(mid) + out[mid_idx:] = end + else: + out = np.empty((len(start) + len(end)), dtype=out_dtype) + start_idx = len(start) + out[:start_idx] = start + out[start_idx:] = end + return out + + return np_ediff1d_impl + + +def _select_element(arr): + pass + + +@overload(_select_element) +def _select_element_impl(arr): + zerod = getattr(arr, "ndim", None) == 0 + if zerod: + + def impl(arr): + x = np.array((1,), dtype=arr.dtype) + x[:] = arr + return x[0] + + return impl + else: + + def impl(arr): + return arr + + return impl + + +def _get_d(dx, x): + pass + + +@overload(_get_d) +def get_d_impl(x, dx): + if is_nonelike(x): + + def impl(x, dx): + return np.asarray(dx) + else: + + def impl(x, dx): + return np.diff(np.asarray(x)) + + return impl + + +@overload(np.trapz) +def np_trapz(y, x=None, dx=1.0): + if isinstance(y, (types.Number, types.Boolean)): + raise TypingError("y cannot be a scalar") + elif isinstance(y, types.Array) and y.ndim == 0: + raise TypingError("y cannot be 0D") + # NumPy raises IndexError: list assignment index out of range + + # inspired by: + # https://github.com/numpy/numpy/blob/7ee52003/numpy/lib/function_base.py#L4040-L4065 # noqa: E501 + def impl(y, x=None, dx=1.0): + yarr = np.asarray(y) + d = _get_d(x, dx) + y_ave = (yarr[..., slice(1, None)] + yarr[..., slice(None, -1)]) / 2.0 + ret = np.sum(d * y_ave, -1) + processed = _select_element(ret) + return processed + + return impl + + +# numpy 2.0 rename np.trapz to np.trapezoid +if numpy_version >= (2, 0): + overload(np.trapezoid)(np_trapz) + + +@register_jitable +def _np_vander(x, N, increasing, out): + """ + Generate an N-column Vandermonde matrix from a supplied 1-dimensional + array, x. Store results in an output matrix, out, which is assumed to + be of the required dtype. + + Values are accumulated using np.multiply to match the floating point + precision behaviour of numpy.vander. + """ + m, n = out.shape + assert m == len(x) + assert n == N + + if increasing: + for i in range(N): + if i == 0: + out[:, i] = 1 + else: + out[:, i] = np.multiply(x, out[:, (i - 1)]) + else: + for i in range(N - 1, -1, -1): + if i == N - 1: + out[:, i] = 1 + else: + out[:, i] = np.multiply(x, out[:, (i + 1)]) + + +@register_jitable +def _check_vander_params(x, N): + if x.ndim > 1: + raise ValueError("x must be a one-dimensional array or sequence.") + if N < 0: + raise ValueError("Negative dimensions are not allowed") + + +@overload(np.vander) +def np_vander(x, N=None, increasing=False): + if N not in (None, types.none): + if not isinstance(N, types.Integer): + raise TypingError("Second argument N must be None or an integer") + + def np_vander_impl(x, N=None, increasing=False): + if N is None: + N = len(x) + + _check_vander_params(x, N) + + # allocate output matrix using dtype determined in closure + out = np.empty((len(x), int(N)), dtype=dtype) + + _np_vander(x, N, increasing, out) + return out + + def np_vander_seq_impl(x, N=None, increasing=False): + if N is None: + N = len(x) + + x_arr = np.array(x) + _check_vander_params(x_arr, N) + + # allocate output matrix using dtype inferred when x_arr was created + out = np.empty((len(x), int(N)), dtype=x_arr.dtype) + + _np_vander(x_arr, N, increasing, out) + return out + + if isinstance(x, types.Array): + x_dt = as_dtype(x.dtype) + # replicate numpy behaviour w.r.t.type promotion + dtype = np.promote_types(x_dt, int) + return np_vander_impl + elif isinstance(x, (types.Tuple, types.Sequence)): + return np_vander_seq_impl + + +@overload(np.roll) +def np_roll(a, shift): + if not isinstance(shift, (types.Integer, types.Boolean)): + raise TypingError("shift must be an integer") + + def np_roll_impl(a, shift): + arr = np.asarray(a) + out = np.empty(arr.shape, dtype=arr.dtype) + # empty_like might result in different contiguity vs NumPy + + arr_flat = arr.flat + for i in range(arr.size): + idx = (i + shift) % arr.size + out.flat[idx] = arr_flat[i] + + return out + + if isinstance(a, (types.Number, types.Boolean)): + return lambda a, shift: np.asarray(a) + else: + return np_roll_impl + + +# ---------------------------------------------------------------------------- +# Mathematical functions + +LIKELY_IN_CACHE_SIZE = 8 + + +@register_jitable +def binary_search_with_guess(key, arr, length, guess): + # NOTE: Do not refactor... see note in np_interp function impl below + # this is a facsimile of binary_search_with_guess prior to 1.15: + # https://github.com/numpy/numpy/blob/maintenance/1.15.x/numpy/core/src/multiarray/compiled_base.c # noqa: E501 + # Permanent reference: + # https://github.com/numpy/numpy/blob/3430d78c01a3b9a19adad75f1acb5ae18286da73/numpy/core/src/multiarray/compiled_base.c#L447 # noqa: E501 + imin = 0 + imax = length + + # Handle keys outside of the arr range first + if key > arr[length - 1]: + return length + elif key < arr[0]: + return -1 + + # If len <= 4 use linear search. + # From above we know key >= arr[0] when we start. + if length <= 4: + i = 1 + while i < length and key >= arr[i]: + i += 1 + return i - 1 + + if guess > length - 3: + guess = length - 3 + + if guess < 1: + guess = 1 + + # check most likely values: guess - 1, guess, guess + 1 + if key < arr[guess]: + if key < arr[guess - 1]: + imax = guess - 1 + + # last attempt to restrict search to items in cache + if ( + guess > LIKELY_IN_CACHE_SIZE + and key >= arr[guess - LIKELY_IN_CACHE_SIZE] + ): + imin = guess - LIKELY_IN_CACHE_SIZE + else: + # key >= arr[guess - 1] + return guess - 1 + else: + # key >= arr[guess] + if key < arr[guess + 1]: + return guess + else: + # key >= arr[guess + 1] + if key < arr[guess + 2]: + return guess + 1 + else: + # key >= arr[guess + 2] + imin = guess + 2 + # last attempt to restrict search to items in cache + if (guess < (length - LIKELY_IN_CACHE_SIZE - 1)) and ( + key < arr[guess + LIKELY_IN_CACHE_SIZE] + ): + imax = guess + LIKELY_IN_CACHE_SIZE + + # finally, find index by bisection + while imin < imax: + imid = imin + ((imax - imin) >> 1) + if key >= arr[imid]: + imin = imid + 1 + else: + imax = imid + + return imin - 1 + + +@register_jitable +def np_interp_impl_complex_inner(x, xp, fp, dtype): + # NOTE: Do not refactor... see note in np_interp function impl below + # this is a facsimile of arr_interp_complex post 1.16 with added + # branching to support np1.17 style NaN handling. + # https://github.com/numpy/numpy/blob/maintenance/1.16.x/numpy/core/src/multiarray/compiled_base.c # noqa: E501 + # Permanent reference: + # https://github.com/numpy/numpy/blob/971e2e89d08deeae0139d3011d15646fdac13c92/numpy/core/src/multiarray/compiled_base.c#L628 # noqa: E501 + dz = np.asarray(x) + dx = np.asarray(xp) + dy = np.asarray(fp) + + if len(dx) == 0: + raise ValueError("array of sample points is empty") + + if len(dx) != len(dy): + raise ValueError("fp and xp are not of the same size.") + + if dx.size == 1: + return np.full(dz.shape, fill_value=dy[0], dtype=dtype) + + dres = np.empty(dz.shape, dtype=dtype) + + lenx = dz.size + lenxp = len(dx) + lval = dy[0] + rval = dy[lenxp - 1] + + if lenxp == 1: + xp_val = dx[0] + fp_val = dy[0] + + for i in range(lenx): + x_val = dz.flat[i] + if x_val < xp_val: + dres.flat[i] = lval + elif x_val > xp_val: + dres.flat[i] = rval + else: + dres.flat[i] = fp_val + + else: + j = 0 + + # only pre-calculate slopes if there are relatively few of them. + if lenxp <= lenx: + slopes = np.empty((lenxp - 1), dtype=dtype) + else: + slopes = np.empty(0, dtype=dtype) + + if slopes.size: + for i in range(lenxp - 1): + inv_dx = 1 / (dx[i + 1] - dx[i]) + real = (dy[i + 1].real - dy[i].real) * inv_dx + imag = (dy[i + 1].imag - dy[i].imag) * inv_dx + slopes[i] = real + 1j * imag + + for i in range(lenx): + x_val = dz.flat[i] + + if np.isnan(x_val): + real = x_val + imag = 0.0 + dres.flat[i] = real + 1j * imag + continue + + j = binary_search_with_guess(x_val, dx, lenxp, j) + + if j == -1: + dres.flat[i] = lval + elif j == lenxp: + dres.flat[i] = rval + elif j == lenxp - 1: + dres.flat[i] = dy[j] + elif dx[j] == x_val: + # Avoid potential non-finite interpolation + dres.flat[i] = dy[j] + else: + if slopes.size: + slope = slopes[j] + else: + inv_dx = 1 / (dx[j + 1] - dx[j]) + real = (dy[j + 1].real - dy[j].real) * inv_dx + imag = (dy[j + 1].imag - dy[j].imag) * inv_dx + slope = real + 1j * imag + + # NumPy 1.17 handles NaN correctly - this is a copy of + # innermost part of arr_interp_complex post 1.17: + # https://github.com/numpy/numpy/blob/maintenance/1.17.x/numpy/core/src/multiarray/compiled_base.c # noqa: E501 + # Permanent reference: + # https://github.com/numpy/numpy/blob/91fbe4dde246559fa5b085ebf4bc268e2b89eea8/numpy/core/src/multiarray/compiled_base.c#L798-L812 # noqa: E501 + + # If we get NaN in one direction, try the other + real = slope.real * (x_val - dx[j]) + dy[j].real + if np.isnan(real): + real = slope.real * (x_val - dx[j + 1]) + dy[j + 1].real + if np.isnan(real) and dy[j].real == dy[j + 1].real: + real = dy[j].real + + imag = slope.imag * (x_val - dx[j]) + dy[j].imag + if np.isnan(imag): + imag = slope.imag * (x_val - dx[j + 1]) + dy[j + 1].imag + if np.isnan(imag) and dy[j].imag == dy[j + 1].imag: + imag = dy[j].imag + + dres.flat[i] = real + 1j * imag + + return dres + + +@register_jitable +def np_interp_impl_inner(x, xp, fp, dtype): + # NOTE: Do not refactor... see note in np_interp function impl below + # this is a facsimile of arr_interp post 1.16: + # https://github.com/numpy/numpy/blob/maintenance/1.16.x/numpy/core/src/multiarray/compiled_base.c # noqa: E501 + # Permanent reference: + # https://github.com/numpy/numpy/blob/971e2e89d08deeae0139d3011d15646fdac13c92/numpy/core/src/multiarray/compiled_base.c#L473 # noqa: E501 + dz = np.asarray(x, dtype=np.float64) + dx = np.asarray(xp, dtype=np.float64) + dy = np.asarray(fp, dtype=np.float64) + + if len(dx) == 0: + raise ValueError("array of sample points is empty") + + if len(dx) != len(dy): + raise ValueError("fp and xp are not of the same size.") + + if dx.size == 1: + return np.full(dz.shape, fill_value=dy[0], dtype=dtype) + + dres = np.empty(dz.shape, dtype=dtype) + + lenx = dz.size + lenxp = len(dx) + lval = dy[0] + rval = dy[lenxp - 1] + + if lenxp == 1: + xp_val = dx[0] + fp_val = dy[0] + + for i in range(lenx): + x_val = dz.flat[i] + if x_val < xp_val: + dres.flat[i] = lval + elif x_val > xp_val: + dres.flat[i] = rval + else: + dres.flat[i] = fp_val + + else: + j = 0 + + # only pre-calculate slopes if there are relatively few of them. + if lenxp <= lenx: + slopes = (dy[1:] - dy[:-1]) / (dx[1:] - dx[:-1]) + else: + slopes = np.empty(0, dtype=dtype) + + for i in range(lenx): + x_val = dz.flat[i] + + if np.isnan(x_val): + dres.flat[i] = x_val + continue + + j = binary_search_with_guess(x_val, dx, lenxp, j) + + if j == -1: + dres.flat[i] = lval + elif j == lenxp: + dres.flat[i] = rval + elif j == lenxp - 1: + dres.flat[i] = dy[j] + elif dx[j] == x_val: + # Avoid potential non-finite interpolation + dres.flat[i] = dy[j] + else: + if slopes.size: + slope = slopes[j] + else: + slope = (dy[j + 1] - dy[j]) / (dx[j + 1] - dx[j]) + + dres.flat[i] = slope * (x_val - dx[j]) + dy[j] + + # NOTE: this is in np1.17 + # https://github.com/numpy/numpy/blob/maintenance/1.17.x/numpy/core/src/multiarray/compiled_base.c # noqa: E501 + # Permanent reference: + # https://github.com/numpy/numpy/blob/91fbe4dde246559fa5b085ebf4bc268e2b89eea8/numpy/core/src/multiarray/compiled_base.c#L610-L616 # noqa: E501 + # + # If we get nan in one direction, try the other + if np.isnan(dres.flat[i]): + dres.flat[i] = slope * (x_val - dx[j + 1]) + dy[j + 1] # noqa: E501 + if np.isnan(dres.flat[i]) and dy[j] == dy[j + 1]: + dres.flat[i] = dy[j] + + return dres + + +@overload(np.interp) +def np_interp(x, xp, fp): + # Replicating basic interp is relatively simple, but matching the behaviour + # of NumPy for edge cases is really quite hard. After a couple of attempts + # to avoid translation of the C source it was deemed necessary. + + if hasattr(xp, "ndim") and xp.ndim > 1: + raise TypingError("xp must be 1D") + if hasattr(fp, "ndim") and fp.ndim > 1: + raise TypingError("fp must be 1D") + + complex_dtype_msg = ( + "Cannot cast array data from complex dtype to float64 dtype" + ) + + xp_dt = determine_dtype(xp) + if np.issubdtype(xp_dt, np.complexfloating): + raise TypingError(complex_dtype_msg) + + fp_dt = determine_dtype(fp) + dtype = np.result_type(fp_dt, np.float64) + + if np.issubdtype(dtype, np.complexfloating): + inner = np_interp_impl_complex_inner + else: + inner = np_interp_impl_inner + + def np_interp_impl(x, xp, fp): + return inner(x, xp, fp, dtype) + + def np_interp_scalar_impl(x, xp, fp): + return inner(x, xp, fp, dtype).flat[0] + + if isinstance(x, types.Number): + if isinstance(x, types.Complex): + raise TypingError(complex_dtype_msg) + return np_interp_scalar_impl + + return np_interp_impl + + +# ---------------------------------------------------------------------------- +# Statistics + + +@register_jitable +def row_wise_average(a): + assert a.ndim == 2 + + m, n = a.shape + out = np.empty((m, 1), dtype=a.dtype) + + for i in range(m): + out[i, 0] = np.sum(a[i, :]) / n + + return out + + +@register_jitable +def np_cov_impl_inner(X, bias, ddof): + # determine degrees of freedom + if ddof is None: + if bias: + ddof = 0 + else: + ddof = 1 + + # determine the normalization factor + fact = X.shape[1] - ddof + + # numpy warns if less than 0 and floors at 0 + fact = max(fact, 0.0) + + # de-mean + X -= row_wise_average(X) + + # calculate result - requires blas + c = np.dot(X, np.conj(X.T)) + c *= np.true_divide(1, fact) + return c + + +def _prepare_cov_input_inner(): + pass + + +@overload(_prepare_cov_input_inner) +def _prepare_cov_input_impl(m, y, rowvar, dtype): + if y in (None, types.none): + + def _prepare_cov_input_inner(m, y, rowvar, dtype): + m_arr = np.atleast_2d(_asarray(m)) + + if not rowvar: + m_arr = m_arr.T + + return m_arr + else: + + def _prepare_cov_input_inner(m, y, rowvar, dtype): + m_arr = np.atleast_2d(_asarray(m)) + y_arr = np.atleast_2d(_asarray(y)) + + # transpose if asked to and not a (1, n) vector - this looks + # wrong as you might end up transposing one and not the other, + # but it's what numpy does + if not rowvar: + if m_arr.shape[0] != 1: + m_arr = m_arr.T + if y_arr.shape[0] != 1: + y_arr = y_arr.T + + m_rows, m_cols = m_arr.shape + y_rows, y_cols = y_arr.shape + + if m_cols != y_cols: + raise ValueError("m and y have incompatible dimensions") + + # allocate and fill output array + out = np.empty((m_rows + y_rows, m_cols), dtype=dtype) + out[:m_rows, :] = m_arr + out[-y_rows:, :] = y_arr + + return out + + return _prepare_cov_input_inner + + +@register_jitable +def _handle_m_dim_change(m): + if m.ndim == 2 and m.shape[0] == 1: + msg = ( + "2D array containing a single row is unsupported due to " + "ambiguity in type inference. To use numpy.cov in this case " + "simply pass the row as a 1D array, i.e. m[0]." + ) + raise RuntimeError(msg) + + +_handle_m_dim_nop = register_jitable(lambda x: x) + + +def determine_dtype(array_like): + array_like_dt = np.float64 + if isinstance(array_like, types.Array): + array_like_dt = as_dtype(array_like.dtype) + elif isinstance(array_like, (types.Number, types.Boolean)): + array_like_dt = as_dtype(array_like) + elif isinstance(array_like, (types.UniTuple, types.Tuple)): + coltypes = set() + for val in array_like: + if hasattr(val, "count"): + [coltypes.add(v) for v in val] + else: + coltypes.add(val) + if len(coltypes) > 1: + array_like_dt = np.promote_types(*[as_dtype(ty) for ty in coltypes]) + elif len(coltypes) == 1: + array_like_dt = as_dtype(coltypes.pop()) + + return array_like_dt + + +def check_dimensions(array_like, name): + if isinstance(array_like, types.Array): + if array_like.ndim > 2: + raise NumbaTypeError("{0} has more than 2 dimensions".format(name)) + elif isinstance(array_like, types.Sequence): + if isinstance(array_like.key[0], types.Sequence): + if isinstance(array_like.key[0].key[0], types.Sequence): + msg = "{0} has more than 2 dimensions".format(name) + raise NumbaTypeError(msg) + + +@register_jitable +def _handle_ddof(ddof): + if not np.isfinite(ddof): + raise ValueError("Cannot convert non-finite ddof to integer") + if ddof - int(ddof) != 0: + raise ValueError("ddof must be integral value") + + +_handle_ddof_nop = register_jitable(lambda x: x) + + +@register_jitable +def _prepare_cov_input( + m, y, rowvar, dtype, ddof, _DDOF_HANDLER, _M_DIM_HANDLER +): + _M_DIM_HANDLER(m) + _DDOF_HANDLER(ddof) + return _prepare_cov_input_inner(m, y, rowvar, dtype) + + +def scalar_result_expected(mandatory_input, optional_input): + opt_is_none = optional_input in (None, types.none) + + if isinstance(mandatory_input, types.Array) and mandatory_input.ndim == 1: + return opt_is_none + + if isinstance(mandatory_input, types.BaseTuple): + if all( + isinstance(x, (types.Number, types.Boolean)) + for x in mandatory_input.types + ): + return opt_is_none + else: + if len(mandatory_input.types) == 1 and isinstance( + mandatory_input.types[0], types.BaseTuple + ): + return opt_is_none + + if isinstance(mandatory_input, (types.Number, types.Boolean)): + return opt_is_none + + if isinstance(mandatory_input, types.Sequence): + if ( + not isinstance(mandatory_input.key[0], types.Sequence) + and opt_is_none + ): + return True + + return False + + +@register_jitable +def _clip_corr(x): + return np.where(np.fabs(x) > 1, np.sign(x), x) + + +@register_jitable +def _clip_complex(x): + real = _clip_corr(x.real) + imag = _clip_corr(x.imag) + return real + 1j * imag + + +@overload(np.cov) +def np_cov(m, y=None, rowvar=True, bias=False, ddof=None): + # reject problem if m and / or y are more than 2D + check_dimensions(m, "m") + check_dimensions(y, "y") + + # reject problem if ddof invalid (either upfront if type is + # obviously invalid, or later if value found to be non-integral) + if ddof in (None, types.none): + _DDOF_HANDLER = _handle_ddof_nop + else: + if isinstance(ddof, (types.Integer, types.Boolean)): + _DDOF_HANDLER = _handle_ddof_nop + elif isinstance(ddof, types.Float): + _DDOF_HANDLER = _handle_ddof + else: + raise TypingError("ddof must be a real numerical scalar type") + + # special case for 2D array input with 1 row of data - select + # handler function which we'll call later when we have access + # to the shape of the input array + _M_DIM_HANDLER = _handle_m_dim_nop + if isinstance(m, types.Array): + _M_DIM_HANDLER = _handle_m_dim_change + + # infer result dtype + m_dt = determine_dtype(m) + y_dt = determine_dtype(y) + dtype = np.result_type(m_dt, y_dt, np.float64) + + def np_cov_impl(m, y=None, rowvar=True, bias=False, ddof=None): + X = _prepare_cov_input( + m, y, rowvar, dtype, ddof, _DDOF_HANDLER, _M_DIM_HANDLER + ).astype(dtype) + + if np.any(np.array(X.shape) == 0): + return np.full( + (X.shape[0], X.shape[0]), fill_value=np.nan, dtype=dtype + ) + else: + return np_cov_impl_inner(X, bias, ddof) + + def np_cov_impl_single_variable( + m, y=None, rowvar=True, bias=False, ddof=None + ): + X = _prepare_cov_input( + m, y, rowvar, ddof, dtype, _DDOF_HANDLER, _M_DIM_HANDLER + ).astype(dtype) + + if np.any(np.array(X.shape) == 0): + variance = np.nan + else: + variance = np_cov_impl_inner(X, bias, ddof).flat[0] + + return np.array(variance) + + if scalar_result_expected(m, y): + return np_cov_impl_single_variable + else: + return np_cov_impl + + +@overload(np.corrcoef) +def np_corrcoef(x, y=None, rowvar=True): + x_dt = determine_dtype(x) + y_dt = determine_dtype(y) + dtype = np.result_type(x_dt, y_dt, np.float64) + + if dtype == np.complex128: + clip_fn = _clip_complex + else: + clip_fn = _clip_corr + + def np_corrcoef_impl(x, y=None, rowvar=True): + c = np.cov(x, y, rowvar) + d = np.diag(c) + stddev = np.sqrt(d.real) + + for i in range(c.shape[0]): + c[i, :] /= stddev + c[:, i] /= stddev + + return clip_fn(c) + + def np_corrcoef_impl_single_variable(x, y=None, rowvar=True): + c = np.cov(x, y, rowvar) + return c / c + + if scalar_result_expected(x, y): + return np_corrcoef_impl_single_variable + else: + return np_corrcoef_impl + + +# ---------------------------------------------------------------------------- +# Element-wise computations + + +@overload(np.argwhere) +def np_argwhere(a): + # needs to be much more array-like for the array impl to work, Numba bug + # in one of the underlying function calls? + + use_scalar = isinstance(a, (types.Number, types.Boolean)) + if type_can_asarray(a) and not use_scalar: + + def impl(a): + arr = np.asarray(a) + if arr.shape == (): + return np.zeros((0, 1), dtype=types.intp) + return np.transpose(np.vstack(np.nonzero(arr))) + else: + falseish = (0, 0) + trueish = (1, 0) + + def impl(a): + if a is not None and bool(a): + return np.zeros(trueish, dtype=types.intp) + else: + return np.zeros(falseish, dtype=types.intp) + + return impl + + +@overload(np.flatnonzero) +def np_flatnonzero(a): + if type_can_asarray(a): + + def impl(a): + arr = np.asarray(a) + return np.nonzero(np.ravel(arr))[0] + else: + + def impl(a): + if a is not None and bool(a): + data = [0] + else: + data = [x for x in range(0)] + return np.array(data, dtype=types.intp) + + return impl + + +@register_jitable +def _fill_diagonal_params(a, wrap): + if a.ndim == 2: + m = a.shape[0] + n = a.shape[1] + step = 1 + n + if wrap: + end = n * m + else: + end = n * min(m, n) + else: + shape = np.array(a.shape) + + if not np.all(np.diff(shape) == 0): + raise ValueError("All dimensions of input must be of equal length") + + step = 1 + (np.cumprod(shape[:-1])).sum() + end = shape.prod() + + return end, step + + +@register_jitable +def _fill_diagonal_scalar(a, val, wrap): + end, step = _fill_diagonal_params(a, wrap) + + for i in range(0, end, step): + a.flat[i] = val + + +@register_jitable +def _fill_diagonal(a, val, wrap): + end, step = _fill_diagonal_params(a, wrap) + ctr = 0 + v_len = len(val) + + for i in range(0, end, step): + a.flat[i] = val[ctr] + ctr += 1 + ctr = ctr % v_len + + +@register_jitable +def _check_val_int(a, val): + iinfo = np.iinfo(a.dtype) + v_min = iinfo.min + v_max = iinfo.max + + # check finite values are within bounds + if np.any(~np.isfinite(val)) or np.any(val < v_min) or np.any(val > v_max): + raise ValueError("Unable to safely conform val to a.dtype") + + +@register_jitable +def _check_val_float(a, val): + finfo = np.finfo(a.dtype) + v_min = finfo.min + v_max = finfo.max + + # check finite values are within bounds + finite_vals = val[np.isfinite(val)] + if np.any(finite_vals < v_min) or np.any(finite_vals > v_max): + raise ValueError("Unable to safely conform val to a.dtype") + + +# no check performed, needed for pathway where no check is required +_check_nop = register_jitable(lambda x, y: x) + + +def _asarray(x): + pass + + +@overload(_asarray) +def _asarray_impl(x): + if isinstance(x, types.Array): + return lambda x: x + elif isinstance(x, (types.Sequence, types.Tuple)): + return lambda x: np.array(x) + elif isinstance(x, (types.Number, types.Boolean)): + ty = as_dtype(x) + return lambda x: np.array([x], dtype=ty) + + +@overload(np.fill_diagonal) +def np_fill_diagonal(a, val, wrap=False): + if a.ndim > 1: + # the following can be simplified after #3088; until then, employ + # a basic mechanism for catching cases where val is of a type/value + # which cannot safely be cast to a.dtype + if isinstance(a.dtype, types.Integer): + checker = _check_val_int + elif isinstance(a.dtype, types.Float): + checker = _check_val_float + else: + checker = _check_nop + + def scalar_impl(a, val, wrap=False): + tmpval = _asarray(val).flatten() + checker(a, tmpval) + _fill_diagonal_scalar(a, val, wrap) + + def non_scalar_impl(a, val, wrap=False): + tmpval = _asarray(val).flatten() + checker(a, tmpval) + _fill_diagonal(a, tmpval, wrap) + + if isinstance(val, (types.Float, types.Integer, types.Boolean)): + return scalar_impl + elif isinstance(val, (types.Tuple, types.Sequence, types.Array)): + return non_scalar_impl + else: + msg = "The first argument must be at least 2-D (found %s-D)" % a.ndim + raise TypingError(msg) + + +def _np_round_intrinsic(tp): + # np.round() always rounds half to even + return "llvm.rint.f%d" % (tp.bitwidth,) + + +@intrinsic +def _np_round_float(typingctx, val): + sig = val(val) + + def codegen(context, builder, sig, args): + [val] = args + tp = sig.args[0] + llty = context.get_value_type(tp) + module = builder.module + fnty = llvmlite.ir.FunctionType(llty, [llty]) + fn = cgutils.get_or_insert_function( + module, fnty, _np_round_intrinsic(tp) + ) + res = builder.call(fn, (val,)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + return sig, codegen + + +@register_jitable +def round_ndigits(x, ndigits): + if math.isinf(x) or math.isnan(x): + return x + + # NOTE: this is CPython's algorithm, but perhaps this is overkill + # when emulating Numpy's behaviour. + if ndigits >= 0: + if ndigits > 22: + # pow1 and pow2 are each safe from overflow, but + # pow1*pow2 ~= pow(10.0, ndigits) might overflow. + pow1 = 10.0 ** (ndigits - 22) + pow2 = 1e22 + else: + pow1 = 10.0**ndigits + pow2 = 1.0 + y = (x * pow1) * pow2 + if math.isinf(y): + return x + return (_np_round_float(y) / pow2) / pow1 + + else: + pow1 = 10.0 ** (-ndigits) + y = x / pow1 + return _np_round_float(y) * pow1 + + +@overload(np.around) +@overload(np.round) +def impl_np_round(a, decimals=0, out=None): + if not type_can_asarray(a): + raise TypingError('The argument "a" must be array-like') + + if not (isinstance(out, types.Array) or is_nonelike(out)): + msg = 'The argument "out" must be an array if it is provided' + raise TypingError(msg) + + if isinstance(a, (types.Float, types.Integer, types.Complex)): + if is_nonelike(out): + if isinstance(a, types.Float): + + def impl(a, decimals=0, out=None): + if decimals == 0: + return _np_round_float(a) + else: + return round_ndigits(a, decimals) + + return impl + elif isinstance(a, types.Integer): + + def impl(a, decimals=0, out=None): + if decimals == 0: + return a + else: + return int(round_ndigits(a, decimals)) + + return impl + elif isinstance(a, types.Complex): + + def impl(a, decimals=0, out=None): + if decimals == 0: + real = _np_round_float(a.real) + imag = _np_round_float(a.imag) + else: + real = round_ndigits(a.real, decimals) + imag = round_ndigits(a.imag, decimals) + return complex(real, imag) + + return impl + else: + + def impl(a, decimals=0, out=None): + out[0] = np.round(a, decimals) + return out + + return impl + elif isinstance(a, types.Array): + if is_nonelike(out): + + def impl(a, decimals=0, out=None): + out = np.empty_like(a) + return np.round(a, decimals, out) + + return impl + else: + + def impl(a, decimals=0, out=None): + if a.shape != out.shape: + raise ValueError("invalid output shape") + for index, val in np.ndenumerate(a): + out[index] = np.round(val, decimals) + return out + + return impl + + +if numpy_version < (2, 0): + overload(np.round_)(impl_np_round) + + +@overload(np.sinc) +def impl_np_sinc(x): + if isinstance(x, types.Number): + + def impl(x): + if x == 0.0e0: # to match np impl + x = 1e-20 + x *= np.pi # np sinc is the normalised variant + return np.sin(x) / x + + return impl + elif isinstance(x, types.Array): + + def impl(x): + out = np.zeros_like(x) + for index, val in np.ndenumerate(x): + out[index] = np.sinc(val) + return out + + return impl + else: + raise NumbaTypeError('Argument "x" must be a Number or array-like.') + + +@overload(np.angle) +def ov_np_angle(z, deg=False): + deg_mult = float(180 / np.pi) + + # non-complex scalar values are accepted as well + if isinstance(z, types.Number): + + def impl(z, deg=False): + if deg: + return np.arctan2(z.imag, z.real) * deg_mult + else: + return np.arctan2(z.imag, z.real) + + return impl + elif isinstance(z, types.Array): + dtype = z.dtype + + if isinstance(dtype, types.Complex): + ret_dtype = dtype.underlying_float + elif isinstance(dtype, types.Float): + ret_dtype = dtype + else: + return + + def impl(z, deg=False): + out = np.zeros_like(z, dtype=ret_dtype) + for index, val in np.ndenumerate(z): + out[index] = np.angle(val, deg) + return out + + return impl + else: + raise NumbaTypeError( + f'Argument "z" must be a complex or Array[complex]. Got {z}' + ) + + +@lower(np.nonzero, types.Array) +@lower("array.nonzero", types.Array) +def array_nonzero(context, builder, sig, args): + aryty = sig.args[0] + # Return type is a N-tuple of 1D C-contiguous arrays + retty = sig.return_type + outaryty = retty.dtype + nouts = retty.count + + ary = make_array(aryty)(context, builder, args[0]) + shape = cgutils.unpack_tuple(builder, ary.shape) + strides = cgutils.unpack_tuple(builder, ary.strides) + data = ary.data + layout = aryty.layout + + # First count the number of non-zero elements + zero = context.get_constant(types.intp, 0) + one = context.get_constant(types.intp, 1) + count = cgutils.alloca_once_value(builder, zero) + with cgutils.loop_nest(builder, shape, zero.type) as indices: + ptr = cgutils.get_item_pointer2( + context, builder, data, shape, strides, layout, indices + ) + val = load_item(context, builder, aryty, ptr) + nz = context.is_true(builder, aryty.dtype, val) + with builder.if_then(nz): + builder.store(builder.add(builder.load(count), one), count) + + # Then allocate output arrays of the right size + out_shape = (builder.load(count),) + outs = [ + _empty_nd_impl(context, builder, outaryty, out_shape)._getvalue() + for i in range(nouts) + ] + outarys = [make_array(outaryty)(context, builder, out) for out in outs] + out_datas = [out.data for out in outarys] + + # And fill them up + index = cgutils.alloca_once_value(builder, zero) + with cgutils.loop_nest(builder, shape, zero.type) as indices: + ptr = cgutils.get_item_pointer2( + context, builder, data, shape, strides, layout, indices + ) + val = load_item(context, builder, aryty, ptr) + nz = context.is_true(builder, aryty.dtype, val) + with builder.if_then(nz): + # Store element indices in output arrays + if not indices: + # For a 0-d array, store 0 in the unique output array + indices = (zero,) + cur = builder.load(index) + for i in range(nouts): + ptr = cgutils.get_item_pointer2( + context, builder, out_datas[i], out_shape, (), "C", [cur] + ) + store_item(context, builder, outaryty, indices[i], ptr) + builder.store(builder.add(cur, one), index) + + tup = context.make_tuple(builder, sig.return_type, outs) + return impl_ret_new_ref(context, builder, sig.return_type, tup) + + +def _where_zero_size_array_impl(dtype): + def impl(condition, x, y): + x_ = np.asarray(x).astype(dtype) + y_ = np.asarray(y).astype(dtype) + return x_ if condition else y_ + + return impl + + +@register_jitable +def _where_generic_inner_impl(cond, x, y, res): + for idx, c in np.ndenumerate(cond): + res[idx] = x[idx] if c else y[idx] + return res + + +@register_jitable +def _where_fast_inner_impl(cond, x, y, res): + cf = cond.flat + xf = x.flat + yf = y.flat + rf = res.flat + for i in range(cond.size): + rf[i] = xf[i] if cf[i] else yf[i] + return res + + +def _where_generic_impl(dtype, layout): + use_faster_impl = layout in [{"C"}, {"F"}] + + def impl(condition, x, y): + cond1, x1, y1 = np.asarray(condition), np.asarray(x), np.asarray(y) + shape = np.broadcast_shapes(cond1.shape, x1.shape, y1.shape) + cond_ = np.broadcast_to(cond1, shape) + x_ = np.broadcast_to(x1, shape) + y_ = np.broadcast_to(y1, shape) + + if layout == "F": + res = np.empty(shape[::-1], dtype=dtype).T + else: + res = np.empty(shape, dtype=dtype) + + if use_faster_impl: + return _where_fast_inner_impl(cond_, x_, y_, res) + else: + return _where_generic_inner_impl(cond_, x_, y_, res) + + return impl + + +@overload(np.where) +def ov_np_where(condition): + if not type_can_asarray(condition): + msg = 'The argument "condition" must be array-like' + raise NumbaTypeError(msg) + + def where_cond_none_none(condition): + return np.asarray(condition).nonzero() + + return where_cond_none_none + + +@overload(np.where) +def ov_np_where_x_y(condition, x, y): + if not type_can_asarray(condition): + msg = 'The argument "condition" must be array-like' + raise NumbaTypeError(msg) + + # corner case: None is a valid value for np.where: + # >>> np.where([0, 1], None, 2) + # array([None, 2]) + # + # >>> np.where([0, 1], 2, None) + # array([2, None]) + # + # >>> np.where([0, 1], None, None) + # array([None, None]) + if is_nonelike(x) or is_nonelike(y): + # skip it for now as np.asarray(None) is not supported + raise NumbaTypeError('Argument "x" or "y" cannot be None') + + for arg, name in zip((x, y), ("x", "y")): + if not type_can_asarray(arg): + msg = 'The argument "{}" must be array-like if provided' + raise NumbaTypeError(msg.format(name)) + + cond_arr = isinstance(condition, types.Array) + x_arr = isinstance(x, types.Array) + y_arr = isinstance(y, types.Array) + + if cond_arr: + x_dt = determine_dtype(x) + y_dt = determine_dtype(y) + dtype = np.promote_types(x_dt, y_dt) + + # corner case - 0 dim values + def check_0_dim(arg): + return isinstance(arg, types.Number) or ( + isinstance(arg, types.Array) and arg.ndim == 0 + ) + + special_0_case = all([check_0_dim(a) for a in (condition, x, y)]) + if special_0_case: + return _where_zero_size_array_impl(dtype) + + layout = condition.layout + if x_arr and y_arr: + if x.layout == y.layout == condition.layout: + layout = x.layout + else: + layout = "A" + return _where_generic_impl(dtype, layout) + else: + + def impl(condition, x, y): + return np.where(np.asarray(condition), np.asarray(x), np.asarray(y)) + + return impl + + +@overload(np.real) +def np_real(val): + def np_real_impl(val): + return val.real + + return np_real_impl + + +@overload(np.imag) +def np_imag(val): + def np_imag_impl(val): + return val.imag + + return np_imag_impl + + +# ---------------------------------------------------------------------------- +# Misc functions + + +@overload(operator.contains) +def np_contains(arr, key): + if not isinstance(arr, types.Array): + return + + def np_contains_impl(arr, key): + for x in np.nditer(arr): + if x == key: + return True + return False + + return np_contains_impl + + +@overload(np.count_nonzero) +def np_count_nonzero(a, axis=None): + if not type_can_asarray(a): + raise TypingError("The argument to np.count_nonzero must be array-like") + + if is_nonelike(axis): + + def impl(a, axis=None): + arr2 = np.ravel(a) + return np.sum(arr2 != 0) + + return impl + else: + + def impl(a, axis=None): + arr2 = a.astype(np.bool_) + return np.sum(arr2, axis=axis) + + return impl + + +np_delete_handler_isslice = register_jitable(lambda x: x) +np_delete_handler_isarray = register_jitable(lambda x: np.asarray(x)) + + +@overload(np.delete) +def np_delete(arr, obj): + # Implementation based on numpy + # https://github.com/numpy/numpy/blob/af66e487a57bfd4850f4306e3b85d1dac3c70412/numpy/lib/function_base.py#L4065-L4267 # noqa: E501 + + if not isinstance(arr, (types.Array, types.Sequence)): + raise TypingError("arr must be either an Array or a Sequence") + + if isinstance(obj, (types.Array, types.Sequence, types.SliceType)): + if isinstance(obj, (types.SliceType)): + handler = np_delete_handler_isslice + else: + if not isinstance(obj.dtype, types.Integer): + raise TypingError("obj should be of Integer dtype") + handler = np_delete_handler_isarray + + def np_delete_impl(arr, obj): + arr = np.ravel(np.asarray(arr)) + N = arr.size + + keep = np.ones(N, dtype=np.bool_) + obj = handler(obj) + keep[obj] = False + return arr[keep] + + return np_delete_impl + + else: # scalar value + if not isinstance(obj, types.Integer): + raise TypingError("obj should be of Integer dtype") + + def np_delete_scalar_impl(arr, obj): + arr = np.ravel(np.asarray(arr)) + N = arr.size + pos = obj + + if pos < -N or pos >= N: + raise IndexError("obj must be less than the len(arr)") + # NumPy raises IndexError: index 'i' is out of + # bounds for axis 'x' with size 'n' + + if pos < 0: + pos += N + + return np.concatenate((arr[:pos], arr[pos + 1 :])) + + return np_delete_scalar_impl + + +@overload(np.diff) +def np_diff_impl(a, n=1): + if not isinstance(a, types.Array) or a.ndim == 0: + return + + def diff_impl(a, n=1): + if n == 0: + return a.copy() + if n < 0: + raise ValueError("diff(): order must be non-negative") + size = a.shape[-1] + out_shape = a.shape[:-1] + (max(size - n, 0),) + out = np.empty(out_shape, a.dtype) + if out.size == 0: + return out + + # np.diff() works on each last dimension subarray independently. + # To make things easier, normalize input and output into 2d arrays + a2 = a.reshape((-1, size)) + out2 = out.reshape((-1, out.shape[-1])) + # A scratchpad for subarrays + work = np.empty(size, a.dtype) + + for major in range(a2.shape[0]): + # First iteration: diff a2 into work + for i in range(size - 1): + work[i] = a2[major, i + 1] - a2[major, i] + # Other iterations: diff work into itself + for niter in range(1, n): + for i in range(size - niter - 1): + work[i] = work[i + 1] - work[i] + # Copy final diff into out2 + out2[major] = work[: size - n] + + return out + + return diff_impl + + +@overload(np.array_equal) +def np_array_equal(a1, a2): + if not (type_can_asarray(a1) and type_can_asarray(a2)): + raise TypingError('Both arguments to "array_equals" must be array-like') + + accepted = (types.Boolean, types.Number) + if isinstance(a1, accepted) and isinstance(a2, accepted): + # special case + def impl(a1, a2): + return a1 == a2 + else: + + def impl(a1, a2): + a = np.asarray(a1) + b = np.asarray(a2) + if a.shape == b.shape: + return np.all(a == b) + return False + + return impl + + +@overload(np.intersect1d) +def jit_np_intersect1d(ar1, ar2, assume_unique=False): + # Not implemented to support return_indices + # https://github.com/numpy/numpy/blob/v1.19.0/numpy/lib + # /arraysetops.py#L347-L441 + if not (type_can_asarray(ar1) or type_can_asarray(ar2)): + raise TypingError("intersect1d: first two args must be array-like") + if not isinstance(assume_unique, (types.Boolean, bool)): + raise TypingError( + 'intersect1d: argument "assume_unique" must be boolean' + ) + + def np_intersects1d_impl(ar1, ar2, assume_unique=False): + ar1 = np.asarray(ar1) + ar2 = np.asarray(ar2) + + if not assume_unique: + ar1 = np.unique(ar1) + ar2 = np.unique(ar2) + else: + ar1 = ar1.ravel() + ar2 = ar2.ravel() + + aux = np.concatenate((ar1, ar2)) + aux.sort() + mask = aux[1:] == aux[:-1] + int1d = aux[:-1][mask] + return int1d + + return np_intersects1d_impl + + +def validate_1d_array_like(func_name, seq): + if isinstance(seq, types.Array): + if seq.ndim != 1: + raise NumbaTypeError( + "{0}(): input should have dimension 1".format(func_name) + ) + elif not isinstance(seq, types.Sequence): + raise NumbaTypeError( + "{0}(): input should be an array or sequence".format(func_name) + ) + + +@overload(np.bincount) +def np_bincount(a, weights=None, minlength=0): + validate_1d_array_like("bincount", a) + + if not isinstance(a.dtype, types.Integer): + return + + check_is_integer(minlength, "minlength") + + if weights not in (None, types.none): + validate_1d_array_like("bincount", weights) + # weights is promoted to double in C impl + # https://github.com/numpy/numpy/blob/maintenance/1.16.x/numpy/core/src/multiarray/compiled_base.c#L93-L95 # noqa: E501 + out_dtype = np.float64 + + @register_jitable + def validate_inputs(a, weights, minlength): + if len(a) != len(weights): + raise ValueError( + "bincount(): weights and list don't have the same length" + ) + + @register_jitable + def count_item(out, idx, val, weights): + out[val] += weights[idx] + + else: + out_dtype = types.intp + + @register_jitable + def validate_inputs(a, weights, minlength): + pass + + @register_jitable + def count_item(out, idx, val, weights): + out[val] += 1 + + def bincount_impl(a, weights=None, minlength=0): + validate_inputs(a, weights, minlength) + if minlength < 0: + raise ValueError("'minlength' must not be negative") + + n = len(a) + a_max = a[0] if n > 0 else -1 + for i in range(1, n): + if a[i] < 0: + raise ValueError( + "bincount(): first argument must be non-negative" + ) + a_max = max(a_max, a[i]) + + out_length = max(a_max + 1, minlength) + out = np.zeros(out_length, out_dtype) + for i in range(n): + count_item(out, i, a[i], weights) + return out + + return bincount_impl + + +less_than_float = register_jitable(lt_floats) +less_than_complex = register_jitable(lt_complex) + + +@register_jitable +def less_than_or_equal_complex(a, b): + if np.isnan(a.real): + if np.isnan(b.real): + if np.isnan(a.imag): + return np.isnan(b.imag) + else: + if np.isnan(b.imag): + return True + else: + return a.imag <= b.imag + else: + return False + + else: + if np.isnan(b.real): + return True + else: + if np.isnan(a.imag): + if np.isnan(b.imag): + return a.real <= b.real + else: + return False + else: + if np.isnan(b.imag): + return True + else: + if a.real < b.real: + return True + elif a.real == b.real: + return a.imag <= b.imag + return False + + +@register_jitable +def _less_than_or_equal(a, b): + if isinstance(a, complex) or isinstance(b, complex): + return less_than_or_equal_complex(a, b) + + elif isinstance(b, float): + if np.isnan(b): + return True + + return a <= b + + +@register_jitable +def _less_than(a, b): + if isinstance(a, complex) or isinstance(b, complex): + return less_than_complex(a, b) + + elif isinstance(b, float): + return less_than_float(a, b) + + return a < b + + +@register_jitable +def _less_then_datetime64(a, b): + # Original numpy code is at: + # https://github.com/numpy/numpy/blob/3dad50936a8dc534a81a545365f69ee9ab162ffe/numpy/_core/src/npysort/npysort_common.h#L334-L346 + if np.isnat(a): + return 0 + + if np.isnat(b): + return 1 + + return a < b + + +@register_jitable +def _less_then_or_equal_datetime64(a, b): + return not _less_then_datetime64(b, a) + + +def _searchsorted(cmp): + # a facsimile of: + # https://github.com/numpy/numpy/blob/4f84d719657eb455a35fcdf9e75b83eb1f97024a/numpy/core/src/npysort/binsearch.cpp#L61 # noqa: E501 + + def impl(a, key_val, min_idx, max_idx): + while min_idx < max_idx: + # to avoid overflow + mid_idx = min_idx + ((max_idx - min_idx) >> 1) + mid_val = a[mid_idx] + if cmp(mid_val, key_val): + min_idx = mid_idx + 1 + else: + max_idx = mid_idx + return min_idx, max_idx + + return impl + + +VALID_SEARCHSORTED_SIDES = frozenset({"left", "right"}) + + +def make_searchsorted_implementation(np_dtype, side): + assert side in VALID_SEARCHSORTED_SIDES + + if np_dtype.char in "mM": + # is datetime + lt = _less_then_datetime64 + le = _less_then_or_equal_datetime64 + else: + lt = _less_than + le = _less_than_or_equal + + if side == "left": + _impl = _searchsorted(lt) + _cmp = lt + else: + _impl = _searchsorted(le) + _cmp = le + + return register_jitable(_impl), register_jitable(_cmp) + + +@overload(np.searchsorted) +def searchsorted(a, v, side="left"): + side_val = getattr(side, "literal_value", side) + + if side_val not in VALID_SEARCHSORTED_SIDES: + # could change this so that side doesn't need to be + # a compile-time constant + raise NumbaValueError(f"Invalid value given for 'side': {side_val}") + + if isinstance(v, (types.Array, types.Sequence)): + v_dt = as_dtype(v.dtype) + else: + v_dt = as_dtype(v) + + np_dt = np.promote_types(as_dtype(a.dtype), v_dt) + _impl, _cmp = make_searchsorted_implementation(np_dt, side_val) + + if isinstance(v, types.Array): + + def impl(a, v, side="left"): + out = np.empty(v.size, dtype=np.intp) + last_key_val = v.flat[0] + min_idx = 0 + max_idx = len(a) + + for i in range(v.size): + key_val = v.flat[i] + + if _cmp(last_key_val, key_val): + max_idx = len(a) + else: + min_idx = 0 + if max_idx < len(a): + max_idx += 1 + else: + max_idx = len(a) + + last_key_val = key_val + min_idx, max_idx = _impl(a, key_val, min_idx, max_idx) + out[i] = min_idx + + return out.reshape(v.shape) + elif isinstance(v, types.Sequence): + + def impl(a, v, side="left"): + v = np.asarray(v) + return np.searchsorted(a, v, side=side) + else: # presumably `v` is scalar + + def impl(a, v, side="left"): + r, _ = _impl(a, v, 0, len(a)) + return r + + return impl + + +@overload(np.digitize) +def np_digitize(x, bins, right=False): + if isinstance(x, types.Array) and x.dtype in types.complex_domain: + raise TypingError("x may not be complex") + + @register_jitable + def _monotonicity(bins): + # all bin edges hold the same value + if len(bins) == 0: + return 1 + + # Skip repeated values at the beginning of the array + last_value = bins[0] + i = 1 + while i < len(bins) and bins[i] == last_value: + i += 1 + + # all bin edges hold the same value + if i == len(bins): + return 1 + + next_value = bins[i] + + if last_value < next_value: + # Possibly monotonic increasing + for i in range(i + 1, len(bins)): + last_value = next_value + next_value = bins[i] + if last_value > next_value: + return 0 + return 1 + + else: + # last > next, possibly monotonic decreasing + for i in range(i + 1, len(bins)): + last_value = next_value + next_value = bins[i] + if last_value < next_value: + return 0 + return -1 + + def digitize_impl(x, bins, right=False): + mono = _monotonicity(bins) + + if mono == 0: + raise ValueError( + "bins must be monotonically increasing or decreasing" + ) + + # this is backwards because the arguments below are swapped + if right: + if mono == -1: + # reverse the bins, and invert the results + return len(bins) - np.searchsorted(bins[::-1], x, side="left") + else: + return np.searchsorted(bins, x, side="left") + else: + if mono == -1: + # reverse the bins, and invert the results + return len(bins) - np.searchsorted(bins[::-1], x, side="right") + else: + return np.searchsorted(bins, x, side="right") + + return digitize_impl + + +_range = range + + +@overload(np.histogram) +def np_histogram(a, bins=10, range=None): + if isinstance(bins, (int, types.Integer)): + # With a uniform distribution of bins, use a fast algorithm + # independent of the number of bins + + if range in (None, types.none): + inf = float("inf") + + def histogram_impl(a, bins=10, range=None): + bin_min = inf + bin_max = -inf + for view in np.nditer(a): + v = view.item() + if bin_min > v: + bin_min = v + if bin_max < v: + bin_max = v + return np.histogram(a, bins, (bin_min, bin_max)) + + else: + + def histogram_impl(a, bins=10, range=None): + if bins <= 0: + raise ValueError( + "histogram(): `bins` should be a positive integer" + ) + bin_min, bin_max = range + if not bin_min <= bin_max: + raise ValueError( + "histogram(): max must be larger than " + "min in range parameter" + ) + + hist = np.zeros(bins, np.intp) + if bin_max > bin_min: + bin_ratio = bins / (bin_max - bin_min) + for view in np.nditer(a): + v = view.item() + b = math.floor((v - bin_min) * bin_ratio) + if 0 <= b < bins: + hist[int(b)] += 1 + elif v == bin_max: + hist[bins - 1] += 1 + + bins_array = np.linspace(bin_min, bin_max, bins + 1) + return hist, bins_array + + else: + # With a custom bins array, use a bisection search + + def histogram_impl(a, bins=10, range=None): + nbins = len(bins) - 1 + for i in _range(nbins): + # Note this also catches NaNs + if not bins[i] <= bins[i + 1]: + raise ValueError( + "histogram(): bins must increase monotonically" + ) + + bin_min = bins[0] + bin_max = bins[nbins] + hist = np.zeros(nbins, np.intp) + + if nbins > 0: + for view in np.nditer(a): + v = view.item() + if not bin_min <= v <= bin_max: + # Value is out of bounds, ignore (also catches NaNs) + continue + # Bisect in bins[:-1] + lo = 0 + hi = nbins - 1 + while lo < hi: + # Note the `+ 1` is necessary to avoid an infinite + # loop where mid = lo => lo = mid + mid = (lo + hi + 1) >> 1 + if v < bins[mid]: + hi = mid - 1 + else: + lo = mid + hist[lo] += 1 + + return hist, bins + + return histogram_impl + + +# Create np.finfo, np.iinfo and np.MachAr +# machar +_mach_ar_supported = ( + "ibeta", + "it", + "machep", + "eps", + "negep", + "epsneg", + "iexp", + "minexp", + "xmin", + "maxexp", + "xmax", + "irnd", + "ngrd", + "epsilon", + "tiny", + "huge", + "precision", + "resolution", +) +MachAr = namedtuple("MachAr", _mach_ar_supported) + +# Do not support MachAr field +# finfo +_finfo_supported = ( + "eps", + "epsneg", + "iexp", + "machep", + "max", + "maxexp", + "min", + "minexp", + "negep", + "nexp", + "nmant", + "precision", + "resolution", + "tiny", + "bits", +) + + +finfo = namedtuple("finfo", _finfo_supported) + +# iinfo +_iinfo_supported = ( + "min", + "max", + "bits", +) + +iinfo = namedtuple("iinfo", _iinfo_supported) + + +def generate_xinfo_body(arg, np_func, container, attr): + nbty = getattr(arg, "dtype", arg) + np_dtype = as_dtype(nbty) + try: + f = np_func(np_dtype) + except ValueError: # This exception instance comes from NumPy + # The np function might not support the dtype + return None + data = tuple([getattr(f, x) for x in attr]) + + @register_jitable + def impl(arg): + return container(*data) + + return impl + + +@overload(np.finfo) +def ol_np_finfo(dtype): + fn = generate_xinfo_body(dtype, np.finfo, finfo, _finfo_supported) + + def impl(dtype): + return fn(dtype) + + return impl + + +@overload(np.iinfo) +def ol_np_iinfo(int_type): + fn = generate_xinfo_body(int_type, np.iinfo, iinfo, _iinfo_supported) + + def impl(int_type): + return fn(int_type) + + return impl + + +def _get_inner_prod(dta, dtb): + # gets an inner product implementation, if both types are float then + # BLAS is used else a local function + + @register_jitable + def _innerprod(a, b): + acc = 0 + for i in range(len(a)): + acc = acc + a[i] * b[i] + return acc + + # no BLAS... use local function regardless + if not _HAVE_BLAS: + return _innerprod + + flty = types.real_domain | types.complex_domain + floats = dta in flty and dtb in flty + if not floats: + return _innerprod + else: + a_dt = as_dtype(dta) + b_dt = as_dtype(dtb) + dt = np.promote_types(a_dt, b_dt) + + @register_jitable + def _dot_wrap(a, b): + return np.dot(a.astype(dt), b.astype(dt)) + + return _dot_wrap + + +def _assert_1d(a, func_name): + if isinstance(a, types.Array): + if not a.ndim <= 1: + raise TypingError("%s() only supported on 1D arrays " % func_name) + + +def _np_correlate_core(ap1, ap2, mode, direction): + pass + + +@overload(_np_correlate_core) +def _np_correlate_core_impl(ap1, ap2, mode, direction): + a_dt = as_dtype(ap1.dtype) + b_dt = as_dtype(ap2.dtype) + dt = np.promote_types(a_dt, b_dt) + innerprod = _get_inner_prod(ap1.dtype, ap2.dtype) + + def impl(ap1, ap2, mode, direction): + # Implementation loosely based on `_pyarray_correlate` from + # https://github.com/numpy/numpy/blob/3bce2be74f228684ca2895ad02b63953f37e2a9d/numpy/core/src/multiarray/multiarraymodule.c#L1191 # noqa: E501 + # For "mode": + # Convolve uses 'full' by default. + # Correlate uses 'valid' by default. + # For "direction", +1 to write the return values out in order 0->N + # -1 to write them out N->0. + + n1 = len(ap1) + n2 = len(ap2) + + if n1 < n2: + # This should never occur when called by np.convolve because + # _np_correlate.impl swaps arguments based on length. + # The same applies for np.correlate. + raise ValueError("'len(ap1)' must greater than 'len(ap2)'") + + length = n1 + n = n2 + if mode == "valid": + length = length - n + 1 + n_left = 0 + n_right = 0 + elif mode == "full": + n_right = n - 1 + n_left = n - 1 + length = length + n - 1 + elif mode == "same": + n_left = n // 2 + n_right = n - n_left - 1 + else: + raise ValueError( + "Invalid 'mode', valid are 'full', 'same', 'valid'" + ) + + ret = np.zeros(length, dt) + + if direction == 1: + idx = 0 + inc = 1 + elif direction == -1: + idx = length - 1 + inc = -1 + else: + raise ValueError("Invalid direction") + + for i in range(n_left): + k = i + n - n_left + ret[idx] = innerprod(ap1[:k], ap2[-k:]) + idx = idx + inc + + for i in range(n1 - n2 + 1): + ret[idx] = innerprod(ap1[i : i + n2], ap2) + idx = idx + inc + + for i in range(n_right): + k = n - i - 1 + ret[idx] = innerprod(ap1[-k:], ap2[:k]) + idx = idx + inc + + return ret + + return impl + + +@overload(np.correlate) +def _np_correlate(a, v, mode="valid"): + _assert_1d(a, "np.correlate") + _assert_1d(v, "np.correlate") + + @register_jitable + def op_conj(x): + return np.conj(x) + + @register_jitable + def op_nop(x): + return x + + if a.dtype in types.complex_domain: + if v.dtype in types.complex_domain: + a_op = op_nop + b_op = op_conj + else: + a_op = op_nop + b_op = op_nop + else: + if v.dtype in types.complex_domain: + a_op = op_nop + b_op = op_conj + else: + a_op = op_conj + b_op = op_nop + + def impl(a, v, mode="valid"): + la = len(a) + lv = len(v) + + if la == 0: + raise ValueError("'a' cannot be empty") + if lv == 0: + raise ValueError("'v' cannot be empty") + + if la < lv: + return _np_correlate_core(b_op(v), a_op(a), mode, -1) + else: + return _np_correlate_core(a_op(a), b_op(v), mode, 1) + + return impl + + +@overload(np.convolve) +def np_convolve(a, v, mode="full"): + _assert_1d(a, "np.convolve") + _assert_1d(v, "np.convolve") + + def impl(a, v, mode="full"): + la = len(a) + lv = len(v) + + if la == 0: + raise ValueError("'a' cannot be empty") + if lv == 0: + raise ValueError("'v' cannot be empty") + + if la < lv: + return _np_correlate_core(v, a[::-1], mode, 1) + else: + return _np_correlate_core(a, v[::-1], mode, 1) + + return impl + + +@overload(np.asarray) +def np_asarray(a, dtype=None): + # developer note... keep this function (type_can_asarray) in sync with the + # accepted types implementations below! + if not type_can_asarray(a): + return None + + if isinstance(a, types.Array): + if is_nonelike(dtype) or a.dtype == dtype.dtype: + + def impl(a, dtype=None): + return a + else: + + def impl(a, dtype=None): + return a.astype(dtype) + elif isinstance(a, (types.Sequence, types.Tuple)): + # Nested lists cannot be unpacked, therefore only single lists are + # permitted and these conform to Sequence and can be unpacked along on + # the same path as Tuple. + if is_nonelike(dtype): + + def impl(a, dtype=None): + return np.array(a) + else: + + def impl(a, dtype=None): + return np.array(a, dtype) + elif isinstance(a, (types.Number, types.Boolean)): + dt_conv = a if is_nonelike(dtype) else dtype + ty = as_dtype(dt_conv) + + def impl(a, dtype=None): + return np.array(a, ty) + elif isinstance(a, types.containers.ListType): + if not isinstance(a.dtype, (types.Number, types.Boolean)): + raise TypingError( + "asarray support for List is limited " + "to Boolean and Number types" + ) + + target_dtype = a.dtype if is_nonelike(dtype) else dtype + + def impl(a, dtype=None): + l = len(a) + ret = np.empty(l, dtype=target_dtype) + for i, v in enumerate(a): + ret[i] = v + return ret + elif isinstance(a, types.StringLiteral): + arr = np.asarray(a.literal_value) + + def impl(a, dtype=None): + return arr.copy() + else: + impl = None + + return impl + + +if numpy_version < (2, 0): + + @overload(np.asfarray) + def np_asfarray(a, dtype=np.float64): + # convert numba dtype types into NumPy dtype + if isinstance(dtype, types.Type): + dtype = as_dtype(dtype) + if not np.issubdtype(dtype, np.inexact): + dx = types.float64 + else: + dx = dtype + + def impl(a, dtype=np.float64): + return np.asarray(a, dx) + + return impl + + +@overload(np.extract) +def np_extract(condition, arr): + def np_extract_impl(condition, arr): + cond = np.asarray(condition).flatten() + a = np.asarray(arr) + + if a.size == 0: + raise ValueError("Cannot extract from an empty array") + + # the following looks odd but replicates NumPy... + # https://github.com/numpy/numpy/issues/12859 + if np.any(cond[a.size :]) and cond.size > a.size: + msg = "condition shape inconsistent with arr shape" + raise ValueError(msg) + # NumPy raises IndexError: index 'm' is out of + # bounds for size 'n' + + max_len = min(a.size, cond.size) + out = [a.flat[idx] for idx in range(max_len) if cond[idx]] + + return np.array(out) + + return np_extract_impl + + +@overload(np.select) +def np_select(condlist, choicelist, default=0): + def np_select_arr_impl(condlist, choicelist, default=0): + if len(condlist) != len(choicelist): + raise ValueError( + "list of cases must be same length as list of conditions" + ) + out = default * np.ones(choicelist[0].shape, choicelist[0].dtype) + # should use reversed+zip, but reversed is not available + for i in range(len(condlist) - 1, -1, -1): + cond = condlist[i] + choice = choicelist[i] + out = np.where(cond, choice, out) + return out + + # first we check the types of the input parameters + if not isinstance(condlist, (types.List, types.UniTuple)): + raise NumbaTypeError("condlist must be a List or a Tuple") + if not isinstance(choicelist, (types.List, types.UniTuple)): + raise NumbaTypeError("choicelist must be a List or a Tuple") + if not isinstance(default, (int, types.Number, types.Boolean)): + raise NumbaTypeError("default must be a scalar (number or boolean)") + # the types of the parameters have been checked, now we test the types + # of the content of the parameters + # implementation note: if in the future numba's np.where accepts tuples + # as elements of condlist, then the check below should be extended to + # accept tuples + if not isinstance(condlist[0], types.Array): + raise NumbaTypeError("items of condlist must be arrays") + if not isinstance(choicelist[0], types.Array): + raise NumbaTypeError("items of choicelist must be arrays") + # the types of the parameters and their contents have been checked, + # now we test the dtypes of the content of parameters + if isinstance(condlist[0], types.Array): + if not isinstance(condlist[0].dtype, types.Boolean): + raise NumbaTypeError("condlist arrays must contain booleans") + if isinstance(condlist[0], types.UniTuple): + if not ( + isinstance(condlist[0], types.UniTuple) + and isinstance(condlist[0][0], types.Boolean) + ): + raise NumbaTypeError("condlist tuples must only contain booleans") + # the input types are correct, now we perform checks on the dimensions + if ( + isinstance(condlist[0], types.Array) + and condlist[0].ndim != choicelist[0].ndim + ): + raise NumbaTypeError( + "condlist and choicelist elements must have the " + "same number of dimensions" + ) + if isinstance(condlist[0], types.Array) and condlist[0].ndim < 1: + raise NumbaTypeError("condlist arrays must be of at least dimension 1") + + return np_select_arr_impl + + +@overload(np.union1d) +def np_union1d(ar1, ar2): + if not type_can_asarray(ar1) or not type_can_asarray(ar2): + raise TypingError("The arguments to np.union1d must be array-like") + if ( + "unichr" in ar1.dtype.name or "unichr" in ar2.dtype.name + ) and ar1.dtype.name != ar2.dtype.name: + raise TypingError("For Unicode arrays, arrays must have same dtype") + + def union_impl(ar1, ar2): + a = np.ravel(np.asarray(ar1)) + b = np.ravel(np.asarray(ar2)) + return np.unique(np.concatenate((a, b))) + + return union_impl + + +@overload(np.asarray_chkfinite) +def np_asarray_chkfinite(a, dtype=None): + msg = "The argument to np.asarray_chkfinite must be array-like" + if not isinstance(a, (types.Array, types.Sequence, types.Tuple)): + raise TypingError(msg) + + if is_nonelike(dtype): + dt = a.dtype + else: + try: + dt = as_dtype(dtype) + except NumbaNotImplementedError: + raise TypingError("dtype must be a valid Numpy dtype") + + def impl(a, dtype=None): + a = np.asarray(a, dtype=dt) + for i in np.nditer(a): + if not np.isfinite(i): + raise ValueError("array must not contain infs or NaNs") + return a + + return impl + + +@overload(np.unwrap) +def numpy_unwrap(p, discont=None, axis=-1, period=6.283185307179586): + if not isinstance(axis, (int, types.Integer)): + msg = 'The argument "axis" must be an integer' + raise TypingError(msg) + + if not type_can_asarray(p): + msg = 'The argument "p" must be array-like' + raise TypingError(msg) + + if not isinstance( + discont, (types.Integer, types.Float) + ) and not cgutils.is_nonelike(discont): + msg = 'The argument "discont" must be a scalar' + raise TypingError(msg) + + if not isinstance(period, (float, types.Number)): + msg = 'The argument "period" must be a scalar' + raise TypingError(msg) + + slice1 = (slice(1, None, None),) + if isinstance(period, types.Number): + dtype = np.result_type(as_dtype(p.dtype), as_dtype(period)) + else: + dtype = np.result_type(as_dtype(p.dtype), np.float64) + + integer_input = np.issubdtype(dtype, np.integer) + + def impl(p, discont=None, axis=-1, period=6.283185307179586): + if axis != -1: + msg = 'Value for argument "axis" is not supported' + raise ValueError(msg) + # Flatten to a 2D array, keeping axis -1 + p_init = np.asarray(p).astype(dtype) + init_shape = p_init.shape + last_axis = init_shape[-1] + p_new = p_init.reshape((p_init.size // last_axis, last_axis)) + # Manipulate discont and period + if discont is None: + discont = period / 2 + if integer_input: + interval_high, rem = divmod(period, 2) + boundary_ambiguous = rem == 0 + else: + interval_high = period / 2 + boundary_ambiguous = True + interval_low = -interval_high + + # Work on each row separately + for i in range(p_init.size // last_axis): + row = p_new[i] + dd = np.diff(row) + ddmod = np.mod(dd - interval_low, period) + interval_low + if boundary_ambiguous: + ddmod = np.where( + (ddmod == interval_low) & (dd > 0), interval_high, ddmod + ) + ph_correct = ddmod - dd + + ph_correct = np.where( + np.array([abs(x) for x in dd]) < discont, 0, ph_correct + ) + ph_ravel = np.where( + np.array([abs(x) for x in dd]) < discont, 0, ph_correct + ) + ph_correct = np.reshape(ph_ravel, ph_correct.shape) + up = np.copy(row) + up[slice1] = row[slice1] + ph_correct.cumsum() + p_new[i] = up + + return p_new.reshape(init_shape) + + return impl + + +# ---------------------------------------------------------------------------- +# Windowing functions +# - translated from the numpy implementations found in: +# https://github.com/numpy/numpy/blob/v1.16.1/numpy/lib/function_base.py#L2543-L3233 # noqa: E501 +# at commit: f1c4c758e1c24881560dd8ab1e64ae750 +# - and also, for NumPy >= 1.20, translated from implementations in +# https://github.com/numpy/numpy/blob/156cd054e007b05d4ac4829e10a369d19dd2b0b1/numpy/lib/function_base.py#L2655-L3065 # noqa: E501 + + +@register_jitable +def np_bartlett_impl(M): + n = np.arange(1.0 - M, M, 2) + return np.where(np.less_equal(n, 0), 1 + n / (M - 1), 1 - n / (M - 1)) + + +@register_jitable +def np_blackman_impl(M): + n = np.arange(1.0 - M, M, 2) + return ( + 0.42 + + 0.5 * np.cos(np.pi * n / (M - 1)) + + 0.08 * np.cos(2.0 * np.pi * n / (M - 1)) + ) + + +@register_jitable +def np_hamming_impl(M): + n = np.arange(1 - M, M, 2) + return 0.54 + 0.46 * np.cos(np.pi * n / (M - 1)) + + +@register_jitable +def np_hanning_impl(M): + n = np.arange(1 - M, M, 2) + return 0.5 + 0.5 * np.cos(np.pi * n / (M - 1)) + + +def window_generator(func): + def window_overload(M): + if not isinstance(M, types.Integer): + raise TypingError("M must be an integer") + + def window_impl(M): + if M < 1: + return np.array((), dtype=np.float64) + if M == 1: + return np.ones(1, dtype=np.float64) + return func(M) + + return window_impl + + return window_overload + + +overload(np.bartlett)(window_generator(np_bartlett_impl)) +overload(np.blackman)(window_generator(np_blackman_impl)) +overload(np.hamming)(window_generator(np_hamming_impl)) +overload(np.hanning)(window_generator(np_hanning_impl)) + + +_i0A = np.array( + [ + -4.41534164647933937950e-18, + 3.33079451882223809783e-17, + -2.43127984654795469359e-16, + 1.71539128555513303061e-15, + -1.16853328779934516808e-14, + 7.67618549860493561688e-14, + -4.85644678311192946090e-13, + 2.95505266312963983461e-12, + -1.72682629144155570723e-11, + 9.67580903537323691224e-11, + -5.18979560163526290666e-10, + 2.65982372468238665035e-9, + -1.30002500998624804212e-8, + 6.04699502254191894932e-8, + -2.67079385394061173391e-7, + 1.11738753912010371815e-6, + -4.41673835845875056359e-6, + 1.64484480707288970893e-5, + -5.75419501008210370398e-5, + 1.88502885095841655729e-4, + -5.76375574538582365885e-4, + 1.63947561694133579842e-3, + -4.32430999505057594430e-3, + 1.05464603945949983183e-2, + -2.37374148058994688156e-2, + 4.93052842396707084878e-2, + -9.49010970480476444210e-2, + 1.71620901522208775349e-1, + -3.04682672343198398683e-1, + 6.76795274409476084995e-1, + ] +) + +_i0B = np.array( + [ + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + 4.46562142029675999901e-17, + 3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + 1.77256013305652638360e-15, + 3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + 1.54008621752140982691e-14, + 3.85277838274214270114e-13, + 7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + 1.18891471078464383424e-11, + 4.94060238822496958910e-10, + 3.39623202570838634515e-9, + 2.26666899049817806459e-8, + 2.04891858946906374183e-7, + 2.89137052083475648297e-6, + 6.88975834691682398426e-5, + 3.36911647825569408990e-3, + 8.04490411014108831608e-1, + ] +) + + +@register_jitable +def _chbevl(x, vals): + b0 = vals[0] + b1 = 0.0 + + for i in range(1, len(vals)): + b2 = b1 + b1 = b0 + b0 = x * b1 - b2 + vals[i] + + return 0.5 * (b0 - b2) + + +@register_jitable +def _i0(x): + if x < 0: + x = -x + if x <= 8.0: + y = (0.5 * x) - 2.0 + return np.exp(x) * _chbevl(y, _i0A) + + return np.exp(x) * _chbevl(32.0 / x - 2.0, _i0B) / np.sqrt(x) + + +@register_jitable +def _i0n(n, alpha, beta): + y = np.empty_like(n, dtype=np.float64) + t = _i0(np.float64(beta)) + for i in range(len(y)): + y[i] = _i0(beta * np.sqrt(1 - ((n[i] - alpha) / alpha) ** 2.0)) / t + + return y + + +@overload(np.kaiser) +def np_kaiser(M, beta): + if not isinstance(M, types.Integer): + raise TypingError("M must be an integer") + + if not isinstance(beta, (types.Integer, types.Float)): + raise TypingError("beta must be an integer or float") + + def np_kaiser_impl(M, beta): + if M < 1: + return np.array((), dtype=np.float64) + if M == 1: + return np.ones(1, dtype=np.float64) + + n = np.arange(0, M) + alpha = (M - 1) / 2.0 + + return _i0n(n, alpha, beta) + + return np_kaiser_impl + + +@register_jitable +def _cross_operation(a, b, out): + def _cross_preprocessing(x): + x0 = x[..., 0] + x1 = x[..., 1] + if x.shape[-1] == 3: + x2 = x[..., 2] + else: + x2 = np.multiply(x.dtype.type(0), x0) + return x0, x1, x2 + + a0, a1, a2 = _cross_preprocessing(a) + b0, b1, b2 = _cross_preprocessing(b) + + cp0 = np.multiply(a1, b2) - np.multiply(a2, b1) + cp1 = np.multiply(a2, b0) - np.multiply(a0, b2) + cp2 = np.multiply(a0, b1) - np.multiply(a1, b0) + + out[..., 0] = cp0 + out[..., 1] = cp1 + out[..., 2] = cp2 + + +def _cross(a, b): + pass + + +@overload(_cross) +def _cross_impl(a, b): + dtype = np.promote_types(as_dtype(a.dtype), as_dtype(b.dtype)) + if a.ndim == 1 and b.ndim == 1: + + def impl(a, b): + cp = np.empty((3,), dtype) + _cross_operation(a, b, cp) + return cp + else: + + def impl(a, b): + shape = np.add(a[..., 0], b[..., 0]).shape + cp = np.empty(shape + (3,), dtype) + _cross_operation(a, b, cp) + return cp + + return impl + + +@overload(np.cross) +def np_cross(a, b): + if not type_can_asarray(a) or not type_can_asarray(b): + raise TypingError("Inputs must be array-like.") + + def impl(a, b): + a_ = np.asarray(a) + b_ = np.asarray(b) + if a_.shape[-1] not in (2, 3) or b_.shape[-1] not in (2, 3): + raise ValueError( + ( + "Incompatible dimensions for cross product\n" + "(dimension must be 2 or 3)" + ) + ) + + if a_.shape[-1] == 3 or b_.shape[-1] == 3: + return _cross(a_, b_) + else: + raise ValueError( + ( + "Dimensions for both inputs is 2.\n" + "Please replace your numpy.cross(a, b) call with " + "a call to `cross2d(a, b)` from `numba.cuda.np.extensions`." + ) + ) + + return impl + + +@register_jitable +def _cross2d_operation(a, b): + def _cross_preprocessing(x): + x0 = x[..., 0] + x1 = x[..., 1] + return x0, x1 + + a0, a1 = _cross_preprocessing(a) + b0, b1 = _cross_preprocessing(b) + + cp = np.multiply(a0, b1) - np.multiply(a1, b0) + # If ndim of a and b is 1, cp is a scalar. + # In this case np.cross returns a 0-D array, containing the scalar. + # np.asarray is used to reconcile this case, without introducing + # overhead in the case where cp is an actual N-D array. + # (recall that np.asarray does not copy existing arrays) + return np.asarray(cp) + + +def cross2d(a, b): + pass + + +@overload(cross2d) +def cross2d_impl(a, b): + if not type_can_asarray(a) or not type_can_asarray(b): + raise TypingError("Inputs must be array-like.") + + def impl(a, b): + a_ = np.asarray(a) + b_ = np.asarray(b) + if a_.shape[-1] != 2 or b_.shape[-1] != 2: + raise ValueError( + ( + "Incompatible dimensions for 2D cross product\n" + "(dimension must be 2 for both inputs)" + ) + ) + return _cross2d_operation(a_, b_) + + return impl + + +@overload(np.trim_zeros) +def np_trim_zeros(filt, trim="fb"): + if not isinstance(filt, types.Array): + raise NumbaTypeError("The first argument must be an array") + + if filt.ndim > 1: + raise NumbaTypeError("array must be 1D") + + if not isinstance(trim, (str, types.UnicodeType)): + raise NumbaTypeError("The second argument must be a string") + + trim_escapes = numpy_version >= (2, 2) + + def impl(filt, trim="fb"): + a_ = np.asarray(filt) + first = 0 + trim = trim.lower() + if "f" in trim: + for i in a_: + if i == 0 or (trim_escapes and i == ""): + first = first + 1 + else: + break + last = len(filt) + if "b" in trim: + for i in a_[::-1]: + if i == 0 or (trim_escapes and i == ""): + last = last - 1 + else: + break + return a_[first:last] + + return impl + + +@overload(np.setxor1d) +def jit_np_setxor1d(ar1, ar2, assume_unique=False): + if not (type_can_asarray(ar1) or type_can_asarray(ar2)): + raise TypingError("setxor1d: first two args must be array-like") + if not (isinstance(assume_unique, (types.Boolean, bool))): + raise TypingError('setxor1d: Argument "assume_unique" must be boolean') + + # https://github.com/numpy/numpy/blob/03b62604eead0f7d279a5a4c094743eb29647368/numpy/lib/arraysetops.py#L477 # noqa: E501 + def np_setxor1d_impl(ar1, ar2, assume_unique=False): + a = np.asarray(ar1) + b = np.asarray(ar2) + + if not assume_unique: + a = np.unique(a) + b = np.unique(b) + else: + a = a.ravel() + b = b.ravel() + + # Implementation very similar to np_intersect1d_impl: + # We want union minus the intersect + aux = np.concatenate((a, b)) + aux.sort() + + flag = np.empty(aux.shape[0] + 1, dtype=np.bool_) + flag[0] = True + flag[-1] = True + flag[1:-1] = aux[1:] != aux[:-1] + return aux[flag[1:] & flag[:-1]] + + return np_setxor1d_impl + + +@overload(np.setdiff1d) +def jit_np_setdiff1d(ar1, ar2, assume_unique=False): + if not (type_can_asarray(ar1) or type_can_asarray(ar2)): + raise TypingError("setdiff1d: first two args must be array-like") + if not (isinstance(assume_unique, (types.Boolean, bool))): + raise TypingError('setdiff1d: Argument "assume_unique" must be boolean') + + # https://github.com/numpy/numpy/blob/03b62604eead0f7d279a5a4c094743eb29647368/numpy/lib/arraysetops.py#L940 # noqa: E501 + def np_setdiff1d_impl(ar1, ar2, assume_unique=False): + ar1 = np.asarray(ar1) + ar2 = np.asarray(ar2) + if assume_unique: + ar1 = ar1.ravel() + ar2 = ar2.ravel() + else: + ar1 = np.unique(ar1) + ar2 = np.unique(ar2) + return ar1[np.in1d(ar1, ar2, assume_unique=True, invert=True)] + + return np_setdiff1d_impl + + +@overload(np.in1d) +def jit_np_in1d(ar1, ar2, assume_unique=False, invert=False): + if not (type_can_asarray(ar1) or type_can_asarray(ar2)): + raise TypingError("in1d: first two args must be array-like") + if not isinstance(assume_unique, (types.Boolean, bool)): + raise TypingError('in1d: Argument "assume_unique" must be boolean') + if not isinstance(invert, (types.Boolean, bool)): + raise TypingError('in1d: Argument "invert" must be boolean') + + def np_in1d_impl(ar1, ar2, assume_unique=False, invert=False): + # https://github.com/numpy/numpy/blob/03b62604eead0f7d279a5a4c094743eb29647368/numpy/lib/arraysetops.py#L525 # noqa: E501 + + # Ravel both arrays, behavior for the first array could be different + ar1 = np.asarray(ar1).ravel() + ar2 = np.asarray(ar2).ravel() + + # This code is run when it would make the code significantly faster + # Sorting is also not guaranteed to work on objects but numba does + # not support object arrays. + if len(ar2) < 10 * len(ar1) ** 0.145: + if invert: + mask = np.ones(len(ar1), dtype=np.bool_) + for a in ar2: + mask &= ar1 != a + else: + mask = np.zeros(len(ar1), dtype=np.bool_) + for a in ar2: + mask |= ar1 == a + return mask + + # Otherwise use sorting + if not assume_unique: + # Equivalent to ar1, inv_idx = np.unique(ar1, return_inverse=True) + # https://github.com/numpy/numpy/blob/03b62604eead0f7d279a5a4c094743eb29647368/numpy/lib/arraysetops.py#L358C8-L358C8 # noqa: E501 + order1 = np.argsort(ar1) + aux = ar1[order1] + mask = np.empty(aux.shape, dtype=np.bool_) + mask[:1] = True + mask[1:] = aux[1:] != aux[:-1] + ar1 = aux[mask] + imask = np.cumsum(mask) - 1 + inv_idx = np.empty(mask.shape, dtype=np.intp) + inv_idx[order1] = imask + ar2 = np.unique(ar2) + + ar = np.concatenate((ar1, ar2)) + # We need this to be a stable sort, so always use 'mergesort' + # here. The values from the first array should always come before + # the values from the second array. + order = ar.argsort(kind="mergesort") + sar = ar[order] + flag = np.empty(sar.size, np.bool_) + if invert: + flag[:-1] = sar[1:] != sar[:-1] + else: + flag[:-1] = sar[1:] == sar[:-1] + flag[-1:] = invert + ret = np.empty(ar.shape, dtype=np.bool_) + ret[order] = flag + + # return ret[:len(ar1)] + if assume_unique: + return ret[: len(ar1)] + else: + return ret[inv_idx] + + return np_in1d_impl + + +@overload(np.isin) +def jit_np_isin(element, test_elements, assume_unique=False, invert=False): + if not (type_can_asarray(element) or type_can_asarray(test_elements)): + raise TypingError("isin: first two args must be array-like") + if not (isinstance(assume_unique, (types.Boolean, bool))): + raise TypingError('isin: Argument "assume_unique" must be boolean') + if not (isinstance(invert, (types.Boolean, bool))): + raise TypingError('isin: Argument "invert" must be boolean') + + # https://github.com/numpy/numpy/blob/03b62604eead0f7d279a5a4c094743eb29647368/numpy/lib/arraysetops.py#L889 # noqa: E501 + def np_isin_impl(element, test_elements, assume_unique=False, invert=False): + element = np.asarray(element) + return np.in1d( + element, test_elements, assume_unique=assume_unique, invert=invert + ).reshape(element.shape) + + return np_isin_impl diff --git a/numba_cuda/numba/cuda/np/arrayobj.py b/numba_cuda/numba/cuda/np/arrayobj.py new file mode 100644 index 000000000..f474780b5 --- /dev/null +++ b/numba_cuda/numba/cuda/np/arrayobj.py @@ -0,0 +1,7663 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implementation of operations on Array objects and objects supporting +the buffer protocol. +""" + +import functools +import math +import operator +import textwrap + +from llvmlite import ir +from llvmlite.ir import Constant + +import numpy as np + +from numba import pndindex, literal_unroll +from numba.core import types, typing, errors +from numba.cuda import cgutils, extending +from numba.cuda.np.numpy_support import ( + as_dtype, + from_dtype, + carray, + farray, + is_contiguous, + is_fortran, + check_is_integer, + type_is_scalar, + lt_complex, + lt_floats, +) +from numba.cuda.np.numpy_support import ( + type_can_asarray, + numpy_version, + is_nonelike, +) +from numba.core.imputils import ( + iternext_impl, + impl_ret_borrowed, + impl_ret_new_ref, + impl_ret_untracked, + RefType, + Registry, +) +from numba.cuda.typing import signature +from numba.core.types import StringLiteral +from numba.cuda.extending import ( + register_jitable, + overload, + overload_method, + intrinsic, + overload_attribute, +) +from numba.misc import quicksort, mergesort +from numba.cuda.cpython import slicing +from numba.cpython.unsafe.tuple import ( + tuple_setitem, + build_full_slice_tuple, +) +from numba.cuda.extending import overload_classmethod +from numba.core.typing.npydecl import ( + parse_dtype as ty_parse_dtype, + parse_shape as ty_parse_shape, + _parse_nested_sequence, + _sequence_of_arrays, + _choose_concatenation_layout, +) + +registry = Registry("arrayobj") +lower = registry.lower +lower_cast = registry.lower_cast +lower_constant = registry.lower_constant +lower_getattr = registry.lower_getattr +lower_getattr_generic = registry.lower_getattr_generic +lower_setattr = registry.lower_setattr +lower_setattr_generic = registry.lower_setattr_generic + + +def set_range_metadata(builder, load, lower_bound, upper_bound): + """ + Set the "range" metadata on a load instruction. + Note the interval is in the form [lower_bound, upper_bound). + """ + range_operands = [ + Constant(load.type, lower_bound), + Constant(load.type, upper_bound), + ] + md = builder.module.add_metadata(range_operands) + load.set_metadata("range", md) + + +def mark_positive(builder, load): + """ + Mark the result of a load instruction as positive (or zero). + """ + upper_bound = (1 << (load.type.width - 1)) - 1 + set_range_metadata(builder, load, 0, upper_bound) + + +def make_array(array_type): + """ + Return the Structure representation of the given *array_type* + (an instance of types.ArrayCompatible). + + Note this does not call __array_wrap__ in case a new array structure + is being created (rather than populated). + """ + real_array_type = array_type.as_array + base = cgutils.create_struct_proxy(real_array_type) + ndim = real_array_type.ndim + + class ArrayStruct(base): + def _make_refs(self, ref): + sig = signature(real_array_type, array_type) + try: + array_impl = self._context.get_function("__array__", sig) + except NotImplementedError: + return super(ArrayStruct, self)._make_refs(ref) + + # Return a wrapped structure and its unwrapped reference + datamodel = self._context.data_model_manager[array_type] + be_type = self._get_be_type(datamodel) + if ref is None: + outer_ref = cgutils.alloca_once( + self._builder, be_type, zfill=True + ) + else: + outer_ref = ref + # NOTE: __array__ is called with a pointer and expects a pointer + # in return! + ref = array_impl(self._builder, (outer_ref,)) + return outer_ref, ref + + @property + def shape(self): + """ + Override .shape to inform LLVM that its elements are all positive. + """ + builder = self._builder + if ndim == 0: + return base.__getattr__(self, "shape") + + # Unfortunately, we can't use llvm.assume as its presence can + # seriously pessimize performance, + # *and* the range metadata currently isn't improving anything here, + # see https://llvm.org/bugs/show_bug.cgi?id=23848 ! + ptr = self._get_ptr_by_name("shape") + dims = [] + for i in range(ndim): + dimptr = cgutils.gep_inbounds(builder, ptr, 0, i) + load = builder.load(dimptr) + dims.append(load) + mark_positive(builder, load) + + return cgutils.pack_array(builder, dims) + + return ArrayStruct + + +def get_itemsize(context, array_type): + """ + Return the item size for the given array or buffer type. + """ + llty = context.get_data_type(array_type.dtype) + return context.get_abi_sizeof(llty) + + +def load_item(context, builder, arrayty, ptr): + """ + Load the item at the given array pointer. + """ + align = None if arrayty.aligned else 1 + return context.unpack_value(builder, arrayty.dtype, ptr, align=align) + + +def store_item(context, builder, arrayty, val, ptr): + """ + Store the item at the given array pointer. + """ + align = None if arrayty.aligned else 1 + return context.pack_value(builder, arrayty.dtype, val, ptr, align=align) + + +def fix_integer_index(context, builder, idxty, idx, size): + """ + Fix the integer index' type and value for the given dimension size. + """ + if idxty.signed: + ind = context.cast(builder, idx, idxty, types.intp) + ind = slicing.fix_index(builder, ind, size) + else: + ind = context.cast(builder, idx, idxty, types.uintp) + return ind + + +def normalize_index(context, builder, idxty, idx): + """ + Normalize the index type and value. 0-d arrays are converted to scalars. + """ + if isinstance(idxty, types.Array) and idxty.ndim == 0: + assert isinstance(idxty.dtype, types.Integer) + idxary = make_array(idxty)(context, builder, idx) + idxval = load_item(context, builder, idxty, idxary.data) + return idxty.dtype, idxval + else: + return idxty, idx + + +def normalize_indices(context, builder, index_types, indices): + """ + Same as normalize_index(), but operating on sequences of + index types and values. + """ + if len(indices): + index_types, indices = zip( + *[ + normalize_index(context, builder, idxty, idx) + for idxty, idx in zip(index_types, indices) + ] + ) + return index_types, indices + + +def populate_array(array, data, shape, strides, itemsize, meminfo, parent=None): + """ + Helper function for populating array structures. + This avoids forgetting to set fields. + + *shape* and *strides* can be Python tuples or LLVM arrays. + """ + context = array._context + builder = array._builder + datamodel = array._datamodel + # doesn't matter what this array type instance is, it's just to get the + # fields for the datamodel of the standard array type in this context + standard_array = types.Array(types.float64, 1, "C") + standard_array_type_datamodel = context.data_model_manager[standard_array] + required_fields = set(standard_array_type_datamodel._fields) + datamodel_fields = set(datamodel._fields) + # Make sure that the presented array object has a data model that is close + # enough to an array for this function to proceed. + if (required_fields & datamodel_fields) != required_fields: + missing = required_fields - datamodel_fields + msg = ( + f"The datamodel for type {array._fe_type} is missing " + f"field{'s' if len(missing) > 1 else ''} {missing}." + ) + raise ValueError(msg) + + if meminfo is None: + meminfo = Constant( + context.get_value_type(datamodel.get_type("meminfo")), None + ) + + intp_t = context.get_value_type(types.intp) + if isinstance(shape, (tuple, list)): + shape = cgutils.pack_array(builder, shape, intp_t) + if isinstance(strides, (tuple, list)): + strides = cgutils.pack_array(builder, strides, intp_t) + if isinstance(itemsize, int): + itemsize = intp_t(itemsize) + + attrs = dict( + shape=shape, + strides=strides, + data=data, + itemsize=itemsize, + meminfo=meminfo, + ) + + # Set `parent` attribute + if parent is None: + attrs["parent"] = Constant( + context.get_value_type(datamodel.get_type("parent")), None + ) + else: + attrs["parent"] = parent + # Calc num of items from shape + nitems = context.get_constant(types.intp, 1) + unpacked_shape = cgutils.unpack_tuple(builder, shape, shape.type.count) + # (note empty shape => 0d array therefore nitems = 1) + for axlen in unpacked_shape: + nitems = builder.mul(nitems, axlen, flags=["nsw"]) + attrs["nitems"] = nitems + + # Make sure that we have all the fields + got_fields = set(attrs.keys()) + if got_fields != required_fields: + raise ValueError("missing {0}".format(required_fields - got_fields)) + + # Set field value + for k, v in attrs.items(): + setattr(array, k, v) + + return array + + +def update_array_info(aryty, array): + """ + Update some auxiliary information in *array* after some of its fields + were changed. `itemsize` and `nitems` are updated. + """ + context = array._context + builder = array._builder + + # Calc num of items from shape + nitems = context.get_constant(types.intp, 1) + unpacked_shape = cgutils.unpack_tuple(builder, array.shape, aryty.ndim) + for axlen in unpacked_shape: + nitems = builder.mul(nitems, axlen, flags=["nsw"]) + array.nitems = nitems + + array.itemsize = context.get_constant( + types.intp, get_itemsize(context, aryty) + ) + + +def normalize_axis(func_name, arg_name, ndim, axis): + """Constrain axis values to valid positive values.""" + raise NotImplementedError() + + +@overload(normalize_axis) +def normalize_axis_overloads(func_name, arg_name, ndim, axis): + if not isinstance(func_name, StringLiteral): + raise errors.TypingError("func_name must be a str literal.") + if not isinstance(arg_name, StringLiteral): + raise errors.TypingError("arg_name must be a str literal.") + + msg = ( + f"{func_name.literal_value}: Argument {arg_name.literal_value} " + "out of bounds for dimensions of the array" + ) + + def impl(func_name, arg_name, ndim, axis): + if axis < 0: + axis += ndim + if axis < 0 or axis >= ndim: + raise ValueError(msg) + + return axis + + return impl + + +@lower("getiter", types.Buffer) +def getiter_array(context, builder, sig, args): + [arrayty] = sig.args + [array] = args + + iterobj = context.make_helper(builder, sig.return_type) + + zero = context.get_constant(types.intp, 0) + indexptr = cgutils.alloca_once_value(builder, zero) + + iterobj.index = indexptr + iterobj.array = array + + # Incref array + if context.enable_nrt: + context.nrt.incref(builder, arrayty, array) + + res = iterobj._getvalue() + + # Note: a decref on the iterator will dereference all internal MemInfo* + out = impl_ret_new_ref(context, builder, sig.return_type, res) + return out + + +def _getitem_array_single_int(context, builder, return_type, aryty, ary, idx): + """Evaluate `ary[idx]`, where idx is a single int.""" + # optimized form of _getitem_array_generic + shapes = cgutils.unpack_tuple(builder, ary.shape, count=aryty.ndim) + strides = cgutils.unpack_tuple(builder, ary.strides, count=aryty.ndim) + offset = builder.mul(strides[0], idx) + dataptr = cgutils.pointer_add(builder, ary.data, offset) + view_shapes = shapes[1:] + view_strides = strides[1:] + + if isinstance(return_type, types.Buffer): + # Build array view + retary = make_view( + context, + builder, + aryty, + ary, + return_type, + dataptr, + view_shapes, + view_strides, + ) + return retary._getvalue() + else: + # Load scalar from 0-d result + assert not view_shapes + return load_item(context, builder, aryty, dataptr) + + +@lower("iternext", types.ArrayIterator) +@iternext_impl(RefType.BORROWED) +def iternext_array(context, builder, sig, args, result): + [iterty] = sig.args + [iter] = args + arrayty = iterty.array_type + + iterobj = context.make_helper(builder, iterty, value=iter) + ary = make_array(arrayty)(context, builder, value=iterobj.array) + + (nitems,) = cgutils.unpack_tuple(builder, ary.shape, count=1) + + index = builder.load(iterobj.index) + is_valid = builder.icmp_signed("<", index, nitems) + result.set_valid(is_valid) + + with builder.if_then(is_valid): + value = _getitem_array_single_int( + context, builder, iterty.yield_type, arrayty, ary, index + ) + result.yield_(value) + nindex = cgutils.increment_index(builder, index) + builder.store(nindex, iterobj.index) + + +# ------------------------------------------------------------------------------ +# Basic indexing (with integers and slices only) + + +def basic_indexing( + context, builder, aryty, ary, index_types, indices, boundscheck=None +): + """ + Perform basic indexing on the given array. + A (data pointer, shapes, strides) tuple is returned describing + the corresponding view. + """ + zero = context.get_constant(types.intp, 0) + one = context.get_constant(types.intp, 1) + + shapes = cgutils.unpack_tuple(builder, ary.shape, aryty.ndim) + strides = cgutils.unpack_tuple(builder, ary.strides, aryty.ndim) + + output_indices = [] + output_shapes = [] + output_strides = [] + + num_newaxes = len([idx for idx in index_types if is_nonelike(idx)]) + ax = 0 + for indexval, idxty in zip(indices, index_types): + if idxty is types.ellipsis: + # Fill up missing dimensions at the middle + n_missing = aryty.ndim - len(indices) + 1 + num_newaxes + for i in range(n_missing): + output_indices.append(zero) + output_shapes.append(shapes[ax]) + output_strides.append(strides[ax]) + ax += 1 + continue + # Regular index value + if isinstance(idxty, types.SliceType): + slice = context.make_helper(builder, idxty, value=indexval) + slicing.guard_invalid_slice(context, builder, idxty, slice) + slicing.fix_slice(builder, slice, shapes[ax]) + output_indices.append(slice.start) + sh = slicing.get_slice_length(builder, slice) + st = slicing.fix_stride(builder, slice, strides[ax]) + output_shapes.append(sh) + output_strides.append(st) + elif isinstance(idxty, types.Integer): + ind = fix_integer_index( + context, builder, idxty, indexval, shapes[ax] + ) + if boundscheck: + cgutils.do_boundscheck(context, builder, ind, shapes[ax], ax) + output_indices.append(ind) + elif is_nonelike(idxty): + output_shapes.append(one) + output_strides.append(zero) + ax -= 1 + else: + raise NotImplementedError("unexpected index type: %s" % (idxty,)) + ax += 1 + + # Fill up missing dimensions at the end + assert ax <= aryty.ndim + while ax < aryty.ndim: + output_shapes.append(shapes[ax]) + output_strides.append(strides[ax]) + ax += 1 + + # No need to check wraparound, as negative indices were already + # fixed in the loop above. + dataptr = cgutils.get_item_pointer( + context, + builder, + aryty, + ary, + output_indices, + wraparound=False, + boundscheck=False, + ) + return (dataptr, output_shapes, output_strides) + + +def make_view(context, builder, aryty, ary, return_type, data, shapes, strides): + """ + Build a view over the given array with the given parameters. + """ + retary = make_array(return_type)(context, builder) + populate_array( + retary, + data=data, + shape=shapes, + strides=strides, + itemsize=ary.itemsize, + meminfo=ary.meminfo, + parent=ary.parent, + ) + return retary + + +def _getitem_array_generic( + context, builder, return_type, aryty, ary, index_types, indices +): + """ + Return the result of indexing *ary* with the given *indices*, + returning either a scalar or a view. + """ + dataptr, view_shapes, view_strides = basic_indexing( + context, + builder, + aryty, + ary, + index_types, + indices, + boundscheck=context.enable_boundscheck, + ) + + if isinstance(return_type, types.Buffer): + # Build array view + retary = make_view( + context, + builder, + aryty, + ary, + return_type, + dataptr, + view_shapes, + view_strides, + ) + return retary._getvalue() + else: + # Load scalar from 0-d result + assert not view_shapes + return load_item(context, builder, aryty, dataptr) + + +@lower(operator.getitem, types.Buffer, types.Integer) +@lower(operator.getitem, types.Buffer, types.SliceType) +def getitem_arraynd_intp(context, builder, sig, args): + """ + Basic indexing with an integer or a slice. + """ + aryty, idxty = sig.args + ary, idx = args + + assert aryty.ndim >= 1 + ary = make_array(aryty)(context, builder, ary) + + res = _getitem_array_generic( + context, builder, sig.return_type, aryty, ary, (idxty,), (idx,) + ) + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower(operator.getitem, types.Buffer, types.BaseTuple) +def getitem_array_tuple(context, builder, sig, args): + """ + Basic or advanced indexing with a tuple. + """ + aryty, tupty = sig.args + ary, tup = args + ary = make_array(aryty)(context, builder, ary) + + index_types = tupty.types + indices = cgutils.unpack_tuple(builder, tup, count=len(tupty)) + + index_types, indices = normalize_indices( + context, builder, index_types, indices + ) + + if any(isinstance(ty, types.Array) for ty in index_types): + # Advanced indexing + return fancy_getitem( + context, builder, sig, args, aryty, ary, index_types, indices + ) + + res = _getitem_array_generic( + context, builder, sig.return_type, aryty, ary, index_types, indices + ) + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower(operator.setitem, types.Buffer, types.Any, types.Any) +def setitem_array(context, builder, sig, args): + """ + array[a] = scalar_or_array + array[a,..,b] = scalar_or_array + """ + aryty, idxty, valty = sig.args + ary, idx, val = args + + if isinstance(idxty, types.BaseTuple): + index_types = idxty.types + indices = cgutils.unpack_tuple(builder, idx, count=len(idxty)) + else: + index_types = (idxty,) + indices = (idx,) + + ary = make_array(aryty)(context, builder, ary) + + # First try basic indexing to see if a single array location is denoted. + index_types, indices = normalize_indices( + context, builder, index_types, indices + ) + try: + dataptr, shapes, strides = basic_indexing( + context, + builder, + aryty, + ary, + index_types, + indices, + boundscheck=context.enable_boundscheck, + ) + except NotImplementedError: + use_fancy_indexing = True + else: + use_fancy_indexing = bool(shapes) + + if use_fancy_indexing: + # Index describes a non-trivial view => use generic slice assignment + # (NOTE: this also handles scalar broadcasting) + return fancy_setslice(context, builder, sig, args, index_types, indices) + + # Store source value the given location + val = context.cast(builder, val, valty, aryty.dtype) + store_item(context, builder, aryty, val, dataptr) + + +@lower(len, types.Buffer) +def array_len(context, builder, sig, args): + (aryty,) = sig.args + (ary,) = args + arystty = make_array(aryty) + ary = arystty(context, builder, ary) + shapeary = ary.shape + res = builder.extract_value(shapeary, 0) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower("array.item", types.Array) +def array_item(context, builder, sig, args): + (aryty,) = sig.args + (ary,) = args + ary = make_array(aryty)(context, builder, ary) + + nitems = ary.nitems + with builder.if_then( + builder.icmp_signed("!=", nitems, nitems.type(1)), likely=False + ): + msg = "item(): can only convert an array of size 1 to a Python scalar" + context.call_conv.return_user_exc(builder, ValueError, (msg,)) + + return load_item(context, builder, aryty, ary.data) + + +if numpy_version < (2, 0): + + @lower("array.itemset", types.Array, types.Any) + def array_itemset(context, builder, sig, args): + aryty, valty = sig.args + ary, val = args + assert valty == aryty.dtype + ary = make_array(aryty)(context, builder, ary) + + nitems = ary.nitems + with builder.if_then( + builder.icmp_signed("!=", nitems, nitems.type(1)), likely=False + ): + msg = "itemset(): can only write to an array of size 1" + context.call_conv.return_user_exc(builder, ValueError, (msg,)) + + store_item(context, builder, aryty, val, ary.data) + return context.get_dummy_value() + + +# ------------------------------------------------------------------------------ +# Advanced / fancy indexing + + +class Indexer(object): + """ + Generic indexer interface, for generating indices over a fancy indexed + array on a single dimension. + """ + + def prepare(self): + """ + Prepare the indexer by initializing any required variables, basic + blocks... + """ + raise NotImplementedError + + def get_size(self): + """ + Return this dimension's size as an integer. + """ + raise NotImplementedError + + def get_shape(self): + """ + Return this dimension's shape as a tuple. + """ + raise NotImplementedError + + def get_index_bounds(self): + """ + Return a half-open [lower, upper) range of indices this dimension + is guaranteed not to step out of. + """ + raise NotImplementedError + + def loop_head(self): + """ + Start indexation loop. Return a (index, count) tuple. + *index* is an integer LLVM value representing the index over this + dimension. + *count* is either an integer LLVM value representing the current + iteration count, or None if this dimension should be omitted from + the indexation result. + """ + raise NotImplementedError + + def loop_tail(self): + """ + Finish indexation loop. + """ + raise NotImplementedError + + +class EntireIndexer(Indexer): + """ + Compute indices along an entire array dimension. + """ + + def __init__(self, context, builder, aryty, ary, dim): + self.context = context + self.builder = builder + self.aryty = aryty + self.ary = ary + self.dim = dim + self.ll_intp = self.context.get_value_type(types.intp) + + def prepare(self): + builder = self.builder + self.size = builder.extract_value(self.ary.shape, self.dim) + self.index = cgutils.alloca_once(builder, self.ll_intp) + self.bb_start = builder.append_basic_block() + self.bb_end = builder.append_basic_block() + + def get_size(self): + return self.size + + def get_shape(self): + return (self.size,) + + def get_index_bounds(self): + # [0, size) + return (self.ll_intp(0), self.size) + + def loop_head(self): + builder = self.builder + # Initialize loop variable + self.builder.store(Constant(self.ll_intp, 0), self.index) + builder.branch(self.bb_start) + builder.position_at_end(self.bb_start) + cur_index = builder.load(self.index) + with builder.if_then( + builder.icmp_signed(">=", cur_index, self.size), likely=False + ): + builder.branch(self.bb_end) + return cur_index, cur_index + + def loop_tail(self): + builder = self.builder + next_index = cgutils.increment_index(builder, builder.load(self.index)) + builder.store(next_index, self.index) + builder.branch(self.bb_start) + builder.position_at_end(self.bb_end) + + +class IntegerIndexer(Indexer): + """ + Compute indices from a single integer. + """ + + def __init__(self, context, builder, idx): + self.context = context + self.builder = builder + self.idx = idx + self.ll_intp = self.context.get_value_type(types.intp) + + def prepare(self): + pass + + def get_size(self): + return Constant(self.ll_intp, 1) + + def get_shape(self): + return () + + def get_index_bounds(self): + # [idx, idx+1) + return (self.idx, self.builder.add(self.idx, self.get_size())) + + def loop_head(self): + return self.idx, None + + def loop_tail(self): + pass + + +class IntegerArrayIndexer(Indexer): + """ + Compute indices from an array of integer indices. + """ + + def __init__(self, context, builder, idxty, idxary, size): + self.context = context + self.builder = builder + self.idxty = idxty + self.idxary = idxary + self.size = size + assert idxty.ndim == 1 + self.ll_intp = self.context.get_value_type(types.intp) + + def prepare(self): + builder = self.builder + self.idx_size = cgutils.unpack_tuple(builder, self.idxary.shape)[0] + self.idx_index = cgutils.alloca_once(builder, self.ll_intp) + self.bb_start = builder.append_basic_block() + self.bb_end = builder.append_basic_block() + + def get_size(self): + return self.idx_size + + def get_shape(self): + return (self.idx_size,) + + def get_index_bounds(self): + # Pessimal heuristic, as we don't want to scan for the min and max + return (self.ll_intp(0), self.size) + + def loop_head(self): + builder = self.builder + # Initialize loop variable + self.builder.store(Constant(self.ll_intp, 0), self.idx_index) + builder.branch(self.bb_start) + builder.position_at_end(self.bb_start) + cur_index = builder.load(self.idx_index) + with builder.if_then( + builder.icmp_signed(">=", cur_index, self.idx_size), likely=False + ): + builder.branch(self.bb_end) + # Load the actual index from the array of indices + index = _getitem_array_single_int( + self.context, + builder, + self.idxty.dtype, + self.idxty, + self.idxary, + cur_index, + ) + index = fix_integer_index( + self.context, builder, self.idxty.dtype, index, self.size + ) + return index, cur_index + + def loop_tail(self): + builder = self.builder + next_index = cgutils.increment_index( + builder, builder.load(self.idx_index) + ) + builder.store(next_index, self.idx_index) + builder.branch(self.bb_start) + builder.position_at_end(self.bb_end) + + +class BooleanArrayIndexer(Indexer): + """ + Compute indices from an array of boolean predicates. + """ + + def __init__(self, context, builder, idxty, idxary): + self.context = context + self.builder = builder + self.idxty = idxty + self.idxary = idxary + assert idxty.ndim == 1 + self.ll_intp = self.context.get_value_type(types.intp) + self.zero = Constant(self.ll_intp, 0) + + def prepare(self): + builder = self.builder + self.size = cgutils.unpack_tuple(builder, self.idxary.shape)[0] + self.idx_index = cgutils.alloca_once(builder, self.ll_intp) + self.count = cgutils.alloca_once(builder, self.ll_intp) + self.bb_start = builder.append_basic_block() + self.bb_tail = builder.append_basic_block() + self.bb_end = builder.append_basic_block() + + def get_size(self): + builder = self.builder + count = cgutils.alloca_once_value(builder, self.zero) + # Sum all true values + with cgutils.for_range(builder, self.size) as loop: + c = builder.load(count) + pred = _getitem_array_single_int( + self.context, + builder, + self.idxty.dtype, + self.idxty, + self.idxary, + loop.index, + ) + c = builder.add(c, builder.zext(pred, c.type)) + builder.store(c, count) + + return builder.load(count) + + def get_shape(self): + return (self.get_size(),) + + def get_index_bounds(self): + # Pessimal heuristic, as we don't want to scan for the + # first and last true items + return (self.ll_intp(0), self.size) + + def loop_head(self): + builder = self.builder + # Initialize loop variable + self.builder.store(self.zero, self.idx_index) + self.builder.store(self.zero, self.count) + builder.branch(self.bb_start) + builder.position_at_end(self.bb_start) + cur_index = builder.load(self.idx_index) + cur_count = builder.load(self.count) + with builder.if_then( + builder.icmp_signed(">=", cur_index, self.size), likely=False + ): + builder.branch(self.bb_end) + # Load the predicate and branch if false + pred = _getitem_array_single_int( + self.context, + builder, + self.idxty.dtype, + self.idxty, + self.idxary, + cur_index, + ) + with builder.if_then(builder.not_(pred)): + builder.branch(self.bb_tail) + # Increment the count for next iteration + next_count = cgutils.increment_index(builder, cur_count) + builder.store(next_count, self.count) + return cur_index, cur_count + + def loop_tail(self): + builder = self.builder + builder.branch(self.bb_tail) + builder.position_at_end(self.bb_tail) + next_index = cgutils.increment_index( + builder, builder.load(self.idx_index) + ) + builder.store(next_index, self.idx_index) + builder.branch(self.bb_start) + builder.position_at_end(self.bb_end) + + +class SliceIndexer(Indexer): + """ + Compute indices along a slice. + """ + + def __init__(self, context, builder, aryty, ary, dim, idxty, slice): + self.context = context + self.builder = builder + self.aryty = aryty + self.ary = ary + self.dim = dim + self.idxty = idxty + self.slice = slice + self.ll_intp = self.context.get_value_type(types.intp) + self.zero = Constant(self.ll_intp, 0) + + def prepare(self): + builder = self.builder + # Fix slice for the dimension's size + self.dim_size = builder.extract_value(self.ary.shape, self.dim) + slicing.guard_invalid_slice( + self.context, builder, self.idxty, self.slice + ) + slicing.fix_slice(builder, self.slice, self.dim_size) + self.is_step_negative = cgutils.is_neg_int(builder, self.slice.step) + # Create loop entities + self.index = cgutils.alloca_once(builder, self.ll_intp) + self.count = cgutils.alloca_once(builder, self.ll_intp) + self.bb_start = builder.append_basic_block() + self.bb_end = builder.append_basic_block() + + def get_size(self): + return slicing.get_slice_length(self.builder, self.slice) + + def get_shape(self): + return (self.get_size(),) + + def get_index_bounds(self): + lower, upper = slicing.get_slice_bounds(self.builder, self.slice) + return lower, upper + + def loop_head(self): + builder = self.builder + # Initialize loop variable + self.builder.store(self.slice.start, self.index) + self.builder.store(self.zero, self.count) + builder.branch(self.bb_start) + builder.position_at_end(self.bb_start) + cur_index = builder.load(self.index) + cur_count = builder.load(self.count) + is_finished = builder.select( + self.is_step_negative, + builder.icmp_signed("<=", cur_index, self.slice.stop), + builder.icmp_signed(">=", cur_index, self.slice.stop), + ) + with builder.if_then(is_finished, likely=False): + builder.branch(self.bb_end) + return cur_index, cur_count + + def loop_tail(self): + builder = self.builder + next_index = builder.add( + builder.load(self.index), self.slice.step, flags=["nsw"] + ) + builder.store(next_index, self.index) + next_count = cgutils.increment_index(builder, builder.load(self.count)) + builder.store(next_count, self.count) + builder.branch(self.bb_start) + builder.position_at_end(self.bb_end) + + +class FancyIndexer(object): + """ + Perform fancy indexing on the given array. + """ + + def __init__(self, context, builder, aryty, ary, index_types, indices): + self.context = context + self.builder = builder + self.aryty = aryty + self.shapes = cgutils.unpack_tuple(builder, ary.shape, aryty.ndim) + self.strides = cgutils.unpack_tuple(builder, ary.strides, aryty.ndim) + self.ll_intp = self.context.get_value_type(types.intp) + self.newaxes = [] + + indexers = [] + num_newaxes = len([idx for idx in index_types if is_nonelike(idx)]) + + ax = 0 # keeps track of position of original axes + new_ax = 0 # keeps track of position for inserting new axes + for indexval, idxty in zip(indices, index_types): + if idxty is types.ellipsis: + # Fill up missing dimensions at the middle + n_missing = aryty.ndim - len(indices) + 1 + num_newaxes + for i in range(n_missing): + indexer = EntireIndexer(context, builder, aryty, ary, ax) + indexers.append(indexer) + ax += 1 + new_ax += 1 + continue + + # Regular index value + if isinstance(idxty, types.SliceType): + slice = context.make_helper(builder, idxty, indexval) + indexer = SliceIndexer( + context, builder, aryty, ary, ax, idxty, slice + ) + indexers.append(indexer) + elif isinstance(idxty, types.Integer): + ind = fix_integer_index( + context, builder, idxty, indexval, self.shapes[ax] + ) + indexer = IntegerIndexer(context, builder, ind) + indexers.append(indexer) + elif isinstance(idxty, types.Array): + idxary = make_array(idxty)(context, builder, indexval) + if isinstance(idxty.dtype, types.Integer): + indexer = IntegerArrayIndexer( + context, builder, idxty, idxary, self.shapes[ax] + ) + elif isinstance(idxty.dtype, types.Boolean): + indexer = BooleanArrayIndexer( + context, builder, idxty, idxary + ) + else: + assert 0 + indexers.append(indexer) + elif is_nonelike(idxty): + self.newaxes.append(new_ax) + ax -= 1 + else: + raise AssertionError("unexpected index type: %s" % (idxty,)) + ax += 1 + new_ax += 1 + + # Fill up missing dimensions at the end + assert ax <= aryty.ndim, (ax, aryty.ndim) + while ax < aryty.ndim: + indexer = EntireIndexer(context, builder, aryty, ary, ax) + indexers.append(indexer) + ax += 1 + + assert len(indexers) == aryty.ndim, (len(indexers), aryty.ndim) + self.indexers = indexers + + def prepare(self): + for i in self.indexers: + i.prepare() + + one = self.context.get_constant(types.intp, 1) + + # Compute the resulting shape given by the indices + res_shape = [i.get_shape() for i in self.indexers] + + # At every position where newaxis/None is present insert + # one as a constant shape in the resulting list of shapes. + for i in self.newaxes: + res_shape.insert(i, (one,)) + + # Store the shape as a tuple, we can't do a simple + # tuple(res_shape) here since res_shape is a list + # of tuples which may be differently sized. + self.indexers_shape = sum(res_shape, ()) + + def get_shape(self): + """ + Get the resulting data shape as Python tuple. + """ + return self.indexers_shape + + def get_offset_bounds(self, strides, itemsize): + """ + Get a half-open [lower, upper) range of byte offsets spanned by + the indexer with the given strides and itemsize. The indexer is + guaranteed to not go past those bounds. + """ + assert len(strides) == self.aryty.ndim + builder = self.builder + is_empty = cgutils.false_bit + zero = self.ll_intp(0) + one = self.ll_intp(1) + lower = zero + upper = zero + for indexer, shape, stride in zip( + self.indexers, self.indexers_shape, strides + ): + is_empty = builder.or_( + is_empty, builder.icmp_unsigned("==", shape, zero) + ) + # Compute [lower, upper) indices on this dimension + lower_index, upper_index = indexer.get_index_bounds() + lower_offset = builder.mul(stride, lower_index) + upper_offset = builder.mul(stride, builder.sub(upper_index, one)) + # Adjust total interval + is_downwards = builder.icmp_signed("<", stride, zero) + lower = builder.add( + lower, builder.select(is_downwards, upper_offset, lower_offset) + ) + upper = builder.add( + upper, builder.select(is_downwards, lower_offset, upper_offset) + ) + # Make interval half-open + upper = builder.add(upper, itemsize) + # Adjust for empty shape + lower = builder.select(is_empty, zero, lower) + upper = builder.select(is_empty, zero, upper) + return lower, upper + + def begin_loops(self): + indices, counts = zip(*(i.loop_head() for i in self.indexers)) + return indices, counts + + def end_loops(self): + for i in reversed(self.indexers): + i.loop_tail() + + +def fancy_getitem( + context, builder, sig, args, aryty, ary, index_types, indices +): + shapes = cgutils.unpack_tuple(builder, ary.shape) + strides = cgutils.unpack_tuple(builder, ary.strides) + data = ary.data + + indexer = FancyIndexer(context, builder, aryty, ary, index_types, indices) + indexer.prepare() + + # Construct output array + out_ty = sig.return_type + out_shapes = indexer.get_shape() + + out = _empty_nd_impl(context, builder, out_ty, out_shapes) + out_data = out.data + out_idx = cgutils.alloca_once_value( + builder, context.get_constant(types.intp, 0) + ) + + # Loop on source and copy to destination + indices, _ = indexer.begin_loops() + + # No need to check for wraparound, as the indexers all ensure + # a positive index is returned. + ptr = cgutils.get_item_pointer2( + context, + builder, + data, + shapes, + strides, + aryty.layout, + indices, + wraparound=False, + boundscheck=context.enable_boundscheck, + ) + val = load_item(context, builder, aryty, ptr) + + # Since the destination is C-contiguous, no need for multi-dimensional + # indexing. + cur = builder.load(out_idx) + ptr = builder.gep(out_data, [cur]) + store_item(context, builder, out_ty, val, ptr) + next_idx = cgutils.increment_index(builder, cur) + builder.store(next_idx, out_idx) + + indexer.end_loops() + + return impl_ret_new_ref(context, builder, out_ty, out._getvalue()) + + +@lower(operator.getitem, types.Buffer, types.Array) +def fancy_getitem_array(context, builder, sig, args): + """ + Advanced or basic indexing with an array. + """ + aryty, idxty = sig.args + ary, idx = args + ary = make_array(aryty)(context, builder, ary) + if idxty.ndim == 0: + # 0-d array index acts as a basic integer index + idxty, idx = normalize_index(context, builder, idxty, idx) + res = _getitem_array_generic( + context, builder, sig.return_type, aryty, ary, (idxty,), (idx,) + ) + return impl_ret_borrowed(context, builder, sig.return_type, res) + else: + # Advanced indexing + return fancy_getitem( + context, builder, sig, args, aryty, ary, (idxty,), (idx,) + ) + + +def offset_bounds_from_strides(context, builder, arrty, arr, shapes, strides): + """ + Compute a half-open range [lower, upper) of byte offsets from the + array's data pointer, that bound the in-memory extent of the array. + + This mimics offset_bounds_from_strides() from + numpy/core/src/private/mem_overlap.c + """ + itemsize = arr.itemsize + zero = itemsize.type(0) + one = zero.type(1) + if arrty.layout in "CF": + # Array is contiguous: contents are laid out sequentially + # starting from arr.data and upwards + lower = zero + upper = builder.mul(itemsize, arr.nitems) + else: + # Non-contiguous array: need to examine strides + lower = zero + upper = zero + for i in range(arrty.ndim): + # Compute the largest byte offset on this dimension + # max_axis_offset = strides[i] * (shapes[i] - 1) + # (shapes[i] == 0 is catered for by the empty array case below) + max_axis_offset = builder.mul( + strides[i], builder.sub(shapes[i], one) + ) + is_upwards = builder.icmp_signed(">=", max_axis_offset, zero) + # Expand either upwards or downwards depending on stride + upper = builder.select( + is_upwards, builder.add(upper, max_axis_offset), upper + ) + lower = builder.select( + is_upwards, lower, builder.add(lower, max_axis_offset) + ) + # Return a half-open range + upper = builder.add(upper, itemsize) + # Adjust for empty arrays + is_empty = builder.icmp_signed("==", arr.nitems, zero) + upper = builder.select(is_empty, zero, upper) + lower = builder.select(is_empty, zero, lower) + + return lower, upper + + +def compute_memory_extents(context, builder, lower, upper, data): + """ + Given [lower, upper) byte offsets and a base data pointer, + compute the memory pointer bounds as pointer-sized integers. + """ + data_ptr_as_int = builder.ptrtoint(data, lower.type) + start = builder.add(data_ptr_as_int, lower) + end = builder.add(data_ptr_as_int, upper) + return start, end + + +def get_array_memory_extents( + context, builder, arrty, arr, shapes, strides, data +): + """ + Compute a half-open range [start, end) of pointer-sized integers + which fully contain the array data. + """ + lower, upper = offset_bounds_from_strides( + context, builder, arrty, arr, shapes, strides + ) + return compute_memory_extents(context, builder, lower, upper, data) + + +def extents_may_overlap(context, builder, a_start, a_end, b_start, b_end): + """ + Whether two memory extents [a_start, a_end) and [b_start, b_end) + may overlap. + """ + # Comparisons are unsigned, since we are really comparing pointers + may_overlap = builder.and_( + builder.icmp_unsigned("<", a_start, b_end), + builder.icmp_unsigned("<", b_start, a_end), + ) + return may_overlap + + +def maybe_copy_source( + context, builder, use_copy, srcty, src, src_shapes, src_strides, src_data +): + ptrty = src_data.type + + copy_layout = "C" + copy_data = cgutils.alloca_once_value(builder, src_data) + copy_shapes = src_shapes + copy_strides = None # unneeded for contiguous arrays + + with builder.if_then(use_copy, likely=False): + # Allocate temporary scratchpad + # XXX: should we use a stack-allocated array for very small + # data sizes? + allocsize = builder.mul(src.itemsize, src.nitems) + data = context.nrt.allocate(builder, allocsize) + voidptrty = data.type + data = builder.bitcast(data, ptrty) + builder.store(data, copy_data) + + # Copy source data into scratchpad + intp_t = context.get_value_type(types.intp) + + with cgutils.loop_nest(builder, src_shapes, intp_t) as indices: + src_ptr = cgutils.get_item_pointer2( + context, + builder, + src_data, + src_shapes, + src_strides, + srcty.layout, + indices, + ) + dest_ptr = cgutils.get_item_pointer2( + context, + builder, + data, + copy_shapes, + copy_strides, + copy_layout, + indices, + ) + builder.store(builder.load(src_ptr), dest_ptr) + + def src_getitem(source_indices): + src_ptr = cgutils.alloca_once(builder, ptrty) + with builder.if_else(use_copy, likely=False) as (if_copy, otherwise): + with if_copy: + builder.store( + cgutils.get_item_pointer2( + context, + builder, + builder.load(copy_data), + copy_shapes, + copy_strides, + copy_layout, + source_indices, + wraparound=False, + ), + src_ptr, + ) + with otherwise: + builder.store( + cgutils.get_item_pointer2( + context, + builder, + src_data, + src_shapes, + src_strides, + srcty.layout, + source_indices, + wraparound=False, + ), + src_ptr, + ) + return load_item(context, builder, srcty, builder.load(src_ptr)) + + def src_cleanup(): + # Deallocate memory + with builder.if_then(use_copy, likely=False): + data = builder.load(copy_data) + data = builder.bitcast(data, voidptrty) + context.nrt.free(builder, data) + + return src_getitem, src_cleanup + + +def _bc_adjust_dimension(context, builder, shapes, strides, target_shape): + """ + Preprocess dimension for broadcasting. + Returns (shapes, strides) such that the ndim match *target_shape*. + When expanding to higher ndim, the returning shapes and strides are + prepended with ones and zeros, respectively. + When truncating to lower ndim, the shapes are checked (in runtime). + All extra dimension must have size of 1. + """ + zero = context.get_constant(types.uintp, 0) + one = context.get_constant(types.uintp, 1) + + # Adjust for broadcasting to higher dimension + if len(target_shape) > len(shapes): + nd_diff = len(target_shape) - len(shapes) + # Fill missing shapes with one, strides with zeros + shapes = [one] * nd_diff + shapes + strides = [zero] * nd_diff + strides + # Adjust for broadcasting to lower dimension + elif len(target_shape) < len(shapes): + # Accepted if all extra dims has shape 1 + nd_diff = len(shapes) - len(target_shape) + dim_is_one = [ + builder.icmp_unsigned("==", sh, one) for sh in shapes[:nd_diff] + ] + accepted = functools.reduce(builder.and_, dim_is_one, cgutils.true_bit) + # Check error + with builder.if_then(builder.not_(accepted), likely=False): + msg = "cannot broadcast source array for assignment" + context.call_conv.return_user_exc(builder, ValueError, (msg,)) + # Truncate extra shapes, strides + shapes = shapes[nd_diff:] + strides = strides[nd_diff:] + + return shapes, strides + + +def _bc_adjust_shape_strides(context, builder, shapes, strides, target_shape): + """ + Broadcast shapes and strides to target_shape given that their ndim already + matches. For each location where the shape is 1 and does not match the + dim for target, it is set to the value at the target and the stride is + set to zero. + """ + bc_shapes = [] + bc_strides = [] + zero = context.get_constant(types.uintp, 0) + one = context.get_constant(types.uintp, 1) + # Adjust all mismatching ones in shape + mismatch = [ + builder.icmp_signed("!=", tar, old) + for tar, old in zip(target_shape, shapes) + ] + src_is_one = [builder.icmp_signed("==", old, one) for old in shapes] + preds = [builder.and_(x, y) for x, y in zip(mismatch, src_is_one)] + bc_shapes = [ + builder.select(p, tar, old) + for p, tar, old in zip(preds, target_shape, shapes) + ] + bc_strides = [ + builder.select(p, zero, old) for p, old in zip(preds, strides) + ] + return bc_shapes, bc_strides + + +def _broadcast_to_shape(context, builder, arrtype, arr, target_shape): + """ + Broadcast the given array to the target_shape. + Returns (array_type, array) + """ + # Compute broadcasted shape and strides + shapes = cgutils.unpack_tuple(builder, arr.shape) + strides = cgutils.unpack_tuple(builder, arr.strides) + + shapes, strides = _bc_adjust_dimension( + context, builder, shapes, strides, target_shape + ) + shapes, strides = _bc_adjust_shape_strides( + context, builder, shapes, strides, target_shape + ) + new_arrtype = arrtype.copy(ndim=len(target_shape), layout="A") + # Create new view + new_arr = make_array(new_arrtype)(context, builder) + populate_array( + new_arr, + data=arr.data, + shape=cgutils.pack_array(builder, shapes), + strides=cgutils.pack_array(builder, strides), + itemsize=arr.itemsize, + meminfo=arr.meminfo, + parent=arr.parent, + ) + return new_arrtype, new_arr + + +@intrinsic +def _numpy_broadcast_to(typingctx, array, shape): + ret = array.copy(ndim=shape.count, layout="A", readonly=True) + sig = ret(array, shape) + + def codegen(context, builder, sig, args): + src, shape_ = args + srcty = sig.args[0] + + src = make_array(srcty)(context, builder, src) + shape_ = cgutils.unpack_tuple(builder, shape_) + _, dest = _broadcast_to_shape( + context, + builder, + srcty, + src, + shape_, + ) + + # Hack to get np.broadcast_to to return a read-only array + setattr( + dest, + "parent", + Constant( + context.get_value_type(dest._datamodel.get_type("parent")), None + ), + ) + + res = dest._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + return sig, codegen + + +@intrinsic +def get_readonly_array(typingctx, arr): + # returns a copy of arr which is readonly + ret = arr.copy(readonly=True) + sig = ret(arr) + + def codegen(context, builder, sig, args): + [src] = args + srcty = sig.args[0] + + dest = make_array(srcty)(context, builder, src) + # Hack to return a read-only array + dest.parent = cgutils.get_null_value(dest.parent.type) + res = dest._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + return sig, codegen + + +@register_jitable +def _can_broadcast(array, dest_shape): + src_shape = array.shape + src_ndim = len(src_shape) + dest_ndim = len(dest_shape) + if src_ndim > dest_ndim: + raise ValueError( + "input operand has more dimensions than allowed " + "by the axis remapping" + ) + for size in dest_shape: + if size < 0: + raise ValueError( + "all elements of broadcast shape must be non-negative" + ) + + # based on _broadcast_onto function in numba/np/npyimpl.py + src_index = 0 + dest_index = dest_ndim - src_ndim + while src_index < src_ndim: + src_dim = src_shape[src_index] + dest_dim = dest_shape[dest_index] + # possible cases for (src_dim, dest_dim): + # * (1, 1) -> Ok + # * (>1, 1) -> Error! + # * (>1, >1) -> src_dim == dest_dim else error! + # * (1, >1) -> Ok + if src_dim == dest_dim or src_dim == 1: + src_index += 1 + dest_index += 1 + else: + raise ValueError( + "operands could not be broadcast together with remapped shapes" + ) + + +def _default_broadcast_to_impl(array, shape): + array = np.asarray(array) + _can_broadcast(array, shape) + return _numpy_broadcast_to(array, shape) + + +@overload(np.broadcast_to) +def numpy_broadcast_to(array, shape): + if not type_can_asarray(array): + raise errors.TypingError( + 'The first argument "array" must be array-like' + ) + + if isinstance(shape, types.Integer): + + def impl(array, shape): + return np.broadcast_to(array, (shape,)) + + return impl + + elif isinstance(shape, types.UniTuple): + if not isinstance(shape.dtype, types.Integer): + msg = 'The second argument "shape" must be a tuple of integers' + raise errors.TypingError(msg) + return _default_broadcast_to_impl + + elif isinstance(shape, types.Tuple) and shape.count > 0: + # check if all types are integers + if not all([isinstance(typ, types.IntegerLiteral) for typ in shape]): + msg = f'"{shape}" object cannot be interpreted as an integer' + raise errors.TypingError(msg) + return _default_broadcast_to_impl + elif isinstance(shape, types.Tuple) and shape.count == 0: + is_scalar_array = isinstance(array, types.Array) and array.ndim == 0 + if type_is_scalar(array) or is_scalar_array: + + def impl(array, shape): # broadcast_to(array, ()) + # Array type must be supported by "type_can_asarray" + # Quick note that unicode types are not supported! + array = np.asarray(array) + return get_readonly_array(array) + + return impl + + else: + msg = "Cannot broadcast a non-scalar to a scalar array" + raise errors.TypingError(msg) + else: + msg = ( + 'The argument "shape" must be a tuple or an integer. Got %s' % shape + ) + raise errors.TypingError(msg) + + +@register_jitable +def numpy_broadcast_shapes_list(r, m, shape): + for i in range(len(shape)): + k = m - len(shape) + i + tmp = shape[i] + if tmp < 0: + raise ValueError("negative dimensions are not allowed") + if tmp == 1: + continue + if r[k] == 1: + r[k] = tmp + elif r[k] != tmp: + raise ValueError( + "shape mismatch: objects cannot be broadcast to a single shape" + ) + + +@overload(np.broadcast_shapes) +def ol_numpy_broadcast_shapes(*args): + # Based on https://github.com/numpy/numpy/blob/f702b26fff3271ba6a6ba29a021fc19051d1f007/numpy/core/src/multiarray/iterators.c#L1129-L1212 # noqa + for idx, arg in enumerate(args): + is_int = isinstance(arg, types.Integer) + is_int_tuple = isinstance(arg, types.UniTuple) and isinstance( + arg.dtype, types.Integer + ) + is_empty_tuple = isinstance(arg, types.Tuple) and len(arg.types) == 0 + if not (is_int or is_int_tuple or is_empty_tuple): + msg = ( + f"Argument {idx} must be either an int or tuple[int]. Got {arg}" + ) + raise errors.TypingError(msg) + + # discover the number of dimensions + m = 0 + for arg in args: + if isinstance(arg, types.Integer): + m = max(m, 1) + elif isinstance(arg, types.BaseTuple): + m = max(m, len(arg)) + + if m == 0: + return lambda *args: () + else: + tup_init = (1,) * m + + def impl(*args): + # propagate args + r = [1] * m + tup = tup_init + for arg in literal_unroll(args): + if isinstance(arg, tuple) and len(arg) > 0: + numpy_broadcast_shapes_list(r, m, arg) + elif isinstance(arg, int): + numpy_broadcast_shapes_list(r, m, (arg,)) + for idx, elem in enumerate(r): + tup = tuple_setitem(tup, idx, elem) + return tup + + return impl + + +@overload(np.broadcast_arrays) +def numpy_broadcast_arrays(*args): + for idx, arg in enumerate(args): + if not type_can_asarray(arg): + raise errors.TypingError(f'Argument "{idx}" must be array-like') + + unified_dtype = None + dt = None + for arg in args: + if isinstance(arg, (types.Array, types.BaseTuple)): + dt = arg.dtype + else: + dt = arg + + if unified_dtype is None: + unified_dtype = dt + elif unified_dtype != dt: + raise errors.TypingError( + "Mismatch of argument types. Numba cannot " + "broadcast arrays with different types. " + f"Got {args}" + ) + + # number of dimensions + m = 0 + for idx, arg in enumerate(args): + if isinstance(arg, types.ArrayCompatible): + m = max(m, arg.ndim) + elif isinstance(arg, (types.Number, types.Boolean, types.BaseTuple)): + m = max(m, 1) + else: + raise errors.TypingError(f"Unhandled type {arg}") + + tup_init = (0,) * m + + def impl(*args): + # find out the output shape + # we can't call np.broadcast_shapes here since args may have arrays + # with different shapes and it is not possible to create a list + # with those shapes dynamically + shape = [1] * m + for array in literal_unroll(args): + numpy_broadcast_shapes_list(shape, m, np.asarray(array).shape) + + tup = tup_init + + for i in range(m): + tup = tuple_setitem(tup, i, shape[i]) + + # numpy checks if the input arrays have the same shape as `shape` + outs = [] + for array in literal_unroll(args): + outs.append(np.broadcast_to(np.asarray(array), tup)) + return outs + + return impl + + +def raise_with_shape_context(src_shapes, index_shape): + """Targets should implement this if they wish to specialize the error + handling/messages. The overload implementation takes two tuples as arguments + and should raise a ValueError.""" + raise NotImplementedError + + +@overload(raise_with_shape_context) +def ol_raise_with_shape_context_generic(src_shapes, index_shape): + # This overload is for a "generic" target, which makes no assumption about + # the NRT or string support, but does assume exceptions can be raised. + if ( + isinstance(src_shapes, types.UniTuple) + and isinstance(index_shape, types.UniTuple) + and src_shapes.dtype == index_shape.dtype + and isinstance(src_shapes.dtype, types.Integer) + ): + + def impl(src_shapes, index_shape): + raise ValueError("cannot assign slice from input of different size") + + return impl + + +def fancy_setslice(context, builder, sig, args, index_types, indices): + """ + Implement slice assignment for arrays. This implementation works for + basic as well as fancy indexing, since there's no functional difference + between the two for indexed assignment. + """ + aryty, _, srcty = sig.args + ary, _, src = args + + ary = make_array(aryty)(context, builder, ary) + dest_shapes = cgutils.unpack_tuple(builder, ary.shape) + dest_strides = cgutils.unpack_tuple(builder, ary.strides) + dest_data = ary.data + + indexer = FancyIndexer(context, builder, aryty, ary, index_types, indices) + indexer.prepare() + + def raise_shape_mismatch_error(context, builder, src_shapes, index_shape): + # This acts as the "trampoline" to raise a ValueError in the case + # of the source and destination shapes mismatch at runtime. It resolves + # the public overload stub `raise_with_shape_context` + fnty = context.typing_context.resolve_value_type( + raise_with_shape_context + ) + argtys = ( + types.UniTuple(types.int64, len(src_shapes)), + types.UniTuple(types.int64, len(index_shape)), + ) + raise_sig = fnty.get_call_type(context.typing_context, argtys, {}) + func = context.get_function(fnty, raise_sig) + func( + builder, + ( + context.make_tuple(builder, raise_sig.args[0], src_shapes), + context.make_tuple(builder, raise_sig.args[1], index_shape), + ), + ) + + if isinstance(srcty, types.Buffer): + # Source is an array + src_dtype = srcty.dtype + index_shape = indexer.get_shape() + src = make_array(srcty)(context, builder, src) + # Broadcast source array to shape + srcty, src = _broadcast_to_shape( + context, builder, srcty, src, index_shape + ) + src_shapes = cgutils.unpack_tuple(builder, src.shape) + src_strides = cgutils.unpack_tuple(builder, src.strides) + src_data = src.data + + # Check shapes are equal + shape_error = cgutils.false_bit + assert len(index_shape) == len(src_shapes) + + for u, v in zip(src_shapes, index_shape): + shape_error = builder.or_( + shape_error, builder.icmp_signed("!=", u, v) + ) + + with builder.if_then(shape_error, likely=False): + raise_shape_mismatch_error( + context, builder, src_shapes, index_shape + ) + + # Check for array overlap + src_start, src_end = get_array_memory_extents( + context, builder, srcty, src, src_shapes, src_strides, src_data + ) + + dest_lower, dest_upper = indexer.get_offset_bounds( + dest_strides, ary.itemsize + ) + dest_start, dest_end = compute_memory_extents( + context, builder, dest_lower, dest_upper, dest_data + ) + + use_copy = extents_may_overlap( + context, builder, src_start, src_end, dest_start, dest_end + ) + + src_getitem, src_cleanup = maybe_copy_source( + context, + builder, + use_copy, + srcty, + src, + src_shapes, + src_strides, + src_data, + ) + + elif isinstance(srcty, types.Sequence): + src_dtype = srcty.dtype + + # Check shape is equal to sequence length + index_shape = indexer.get_shape() + assert len(index_shape) == 1 + len_impl = context.get_function(len, signature(types.intp, srcty)) + seq_len = len_impl(builder, (src,)) + + shape_error = builder.icmp_signed("!=", index_shape[0], seq_len) + + with builder.if_then(shape_error, likely=False): + raise_shape_mismatch_error( + context, builder, (seq_len,), (index_shape[0],) + ) + + def src_getitem(source_indices): + (idx,) = source_indices + getitem_impl = context.get_function( + operator.getitem, + signature(src_dtype, srcty, types.intp), + ) + return getitem_impl(builder, (src, idx)) + + def src_cleanup(): + pass + + else: + # Source is a scalar (broadcast or not, depending on destination + # shape). + src_dtype = srcty + + def src_getitem(source_indices): + return src + + def src_cleanup(): + pass + + zero = context.get_constant(types.uintp, 0) + # Loop on destination and copy from source to destination + dest_indices, counts = indexer.begin_loops() + + # Source is iterated in natural order + + # Counts represent a counter for the number of times a specified axis + # is being accessed, during setitem they are used as source + # indices + counts = list(counts) + + # We need to artifically introduce the index zero wherever a + # newaxis is present within the indexer. These always remain + # zero. + for i in indexer.newaxes: + counts.insert(i, zero) + + source_indices = [c for c in counts if c is not None] + + val = src_getitem(source_indices) + + # Cast to the destination dtype (cross-dtype slice assignment is allowed) + val = context.cast(builder, val, src_dtype, aryty.dtype) + + # No need to check for wraparound, as the indexers all ensure + # a positive index is returned. + dest_ptr = cgutils.get_item_pointer2( + context, + builder, + dest_data, + dest_shapes, + dest_strides, + aryty.layout, + dest_indices, + wraparound=False, + boundscheck=context.enable_boundscheck, + ) + store_item(context, builder, aryty, val, dest_ptr) + + indexer.end_loops() + + src_cleanup() + + return context.get_dummy_value() + + +# ------------------------------------------------------------------------------ +# Shape / layout altering + + +def vararg_to_tuple(context, builder, sig, args): + aryty = sig.args[0] + dimtys = sig.args[1:] + # values + ary = args[0] + dims = args[1:] + # coerce all types to intp + dims = [ + context.cast(builder, val, ty, types.intp) + for ty, val in zip(dimtys, dims) + ] + # make a tuple + shape = cgutils.pack_array(builder, dims, dims[0].type) + + shapety = types.UniTuple(dtype=types.intp, count=len(dims)) + new_sig = typing.signature(sig.return_type, aryty, shapety) + new_args = ary, shape + + return new_sig, new_args + + +@lower("array.transpose", types.Array) +def array_transpose(context, builder, sig, args): + return array_T(context, builder, sig.args[0], args[0]) + + +def permute_arrays(axis, shape, strides): + if len(axis) != len(set(axis)): + raise ValueError("repeated axis in transpose") + dim = len(shape) + for x in axis: + if x >= dim or abs(x) > dim: + raise ValueError( + "axis is out of bounds for array of given dimension" + ) + + shape[:] = shape[axis] + strides[:] = strides[axis] + + +# Transposing an array involves permuting the shape and strides of the array +# based on the given axes. +@lower("array.transpose", types.Array, types.BaseTuple) +def array_transpose_tuple(context, builder, sig, args): + aryty = sig.args[0] + ary = make_array(aryty)(context, builder, args[0]) + + axisty, axis = sig.args[1], args[1] + num_axis, dtype = axisty.count, axisty.dtype + + ll_intp = context.get_value_type(types.intp) + ll_ary_size = ir.ArrayType(ll_intp, num_axis) + + # Allocate memory for axes, shapes, and strides arrays. + arys = [axis, ary.shape, ary.strides] + ll_arys = [cgutils.alloca_once(builder, ll_ary_size) for _ in arys] + + # Store axes, shapes, and strides arrays to the allocated memory. + for src, dst in zip(arys, ll_arys): + builder.store(src, dst) + + np_ary_ty = types.Array(dtype=dtype, ndim=1, layout="C") + np_itemsize = context.get_constant( + types.intp, context.get_abi_sizeof(ll_intp) + ) + + # Form NumPy arrays for axes, shapes, and strides arrays. + np_arys = [make_array(np_ary_ty)(context, builder) for _ in arys] + + # Roughly, `np_ary = np.array(ll_ary)` for each of axes, shapes, and strides + for np_ary, ll_ary in zip(np_arys, ll_arys): + populate_array( + np_ary, + data=builder.bitcast(ll_ary, ll_intp.as_pointer()), + shape=[context.get_constant(types.intp, num_axis)], + strides=[np_itemsize], + itemsize=np_itemsize, + meminfo=None, + ) + + # Pass NumPy arrays formed above to permute_arrays function that permutes + # shapes and strides based on axis contents. + context.compile_internal( + builder, + permute_arrays, + typing.signature(types.void, np_ary_ty, np_ary_ty, np_ary_ty), + [a._getvalue() for a in np_arys], + ) + + # Make a new array based on permuted shape and strides and return it. + ret = make_array(sig.return_type)(context, builder) + populate_array( + ret, + data=ary.data, + shape=builder.load(ll_arys[1]), + strides=builder.load(ll_arys[2]), + itemsize=ary.itemsize, + meminfo=ary.meminfo, + parent=ary.parent, + ) + res = ret._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower("array.transpose", types.Array, types.VarArg(types.Any)) +def array_transpose_vararg(context, builder, sig, args): + new_sig, new_args = vararg_to_tuple(context, builder, sig, args) + return array_transpose_tuple(context, builder, new_sig, new_args) + + +@overload(np.transpose) +def numpy_transpose(a, axes=None): + if isinstance(a, types.BaseTuple): + raise errors.TypingError("np.transpose does not accept tuples") + + if axes is None: + + def np_transpose_impl(a, axes=None): + return a.transpose() + else: + + def np_transpose_impl(a, axes=None): + return a.transpose(axes) + + return np_transpose_impl + + +@lower_getattr(types.Array, "T") +def array_T(context, builder, typ, value): + if typ.ndim <= 1: + res = value + else: + ary = make_array(typ)(context, builder, value) + ret = make_array(typ)(context, builder) + shapes = cgutils.unpack_tuple(builder, ary.shape, typ.ndim) + strides = cgutils.unpack_tuple(builder, ary.strides, typ.ndim) + populate_array( + ret, + data=ary.data, + shape=cgutils.pack_array(builder, shapes[::-1]), + strides=cgutils.pack_array(builder, strides[::-1]), + itemsize=ary.itemsize, + meminfo=ary.meminfo, + parent=ary.parent, + ) + res = ret._getvalue() + return impl_ret_borrowed(context, builder, typ, res) + + +@overload(np.logspace) +def numpy_logspace(start, stop, num=50): + if not isinstance(start, types.Number): + raise errors.TypingError('The first argument "start" must be a number') + if not isinstance(stop, types.Number): + raise errors.TypingError('The second argument "stop" must be a number') + if not isinstance(num, (int, types.Integer)): + raise errors.TypingError('The third argument "num" must be an integer') + + def impl(start, stop, num=50): + y = np.linspace(start, stop, num) + return np.power(10.0, y) + + return impl + + +@overload(np.geomspace) +def numpy_geomspace(start, stop, num=50): + if not isinstance(start, types.Number): + msg = 'The argument "start" must be a number' + raise errors.TypingError(msg) + + if not isinstance(stop, types.Number): + msg = 'The argument "stop" must be a number' + raise errors.TypingError(msg) + + if not isinstance(num, (int, types.Integer)): + msg = 'The argument "num" must be an integer' + raise errors.TypingError(msg) + + if any(isinstance(arg, types.Complex) for arg in [start, stop]): + result_dtype = from_dtype( + np.result_type(as_dtype(start), as_dtype(stop), None) + ) + + def impl(start, stop, num=50): + if start == 0 or stop == 0: + raise ValueError("Geometric sequence cannot include zero") + start = result_dtype(start) + stop = result_dtype(stop) + if numpy_version < (2, 0): + both_imaginary = (start.real == 0) & (stop.real == 0) + both_negative = (np.sign(start) == -1) & (np.sign(stop) == -1) + out_sign = 1 + if both_imaginary: + start = start.imag + stop = stop.imag + out_sign = 1j + if both_negative: + start = -start + stop = -stop + out_sign = -out_sign + else: + out_sign = np.sign(start) + start /= out_sign + stop /= out_sign + + logstart = np.log10(start) + logstop = np.log10(stop) + result = np.logspace(logstart, logstop, num) + # Make sure the endpoints match the start and stop arguments. + # This is necessary because np.exp(np.log(x)) is not necessarily + # equal to x. + if num > 0: + result[0] = start + if num > 1: + result[-1] = stop + return out_sign * result + + else: + + def impl(start, stop, num=50): + if start == 0 or stop == 0: + raise ValueError("Geometric sequence cannot include zero") + both_negative = (np.sign(start) == -1) & (np.sign(stop) == -1) + out_sign = 1 + if both_negative: + start = -start + stop = -stop + out_sign = -out_sign + logstart = np.log10(start) + logstop = np.log10(stop) + result = np.logspace(logstart, logstop, num) + # Make sure the endpoints match the start and stop arguments. + # This is necessary because np.exp(np.log(x)) is not necessarily + # equal to x. + if num > 0: + result[0] = start + if num > 1: + result[-1] = stop + return out_sign * result + + return impl + + +@overload(np.rot90) +def numpy_rot90(m, k=1): + # supporting axes argument it needs to be included in np.flip + if not isinstance(k, (int, types.Integer)): + raise errors.TypingError('The second argument "k" must be an integer') + if not isinstance(m, types.Array): + raise errors.TypingError('The first argument "m" must be an array') + + if m.ndim < 2: + raise errors.NumbaValueError("Input must be >= 2-d.") + + def impl(m, k=1): + k = k % 4 + if k == 0: + return m[:] + elif k == 1: + return np.swapaxes(np.fliplr(m), 0, 1) + elif k == 2: + return np.flipud(np.fliplr(m)) + elif k == 3: + return np.fliplr(np.swapaxes(m, 0, 1)) + else: + raise AssertionError # unreachable + + return impl + + +def _attempt_nocopy_reshape( + context, builder, aryty, ary, newnd, newshape, newstrides +): + """ + Call into Numba_attempt_nocopy_reshape() for the given array type + and instance, and the specified new shape. + + Return value is non-zero if successful, and the array pointed to + by *newstrides* will be filled up with the computed results. + """ + ll_intp = context.get_value_type(types.intp) + ll_intp_star = ll_intp.as_pointer() + ll_intc = context.get_value_type(types.intc) + fnty = ir.FunctionType( + ll_intc, + [ + # nd, *dims, *strides + ll_intp, + ll_intp_star, + ll_intp_star, + # newnd, *newdims, *newstrides + ll_intp, + ll_intp_star, + ll_intp_star, + # itemsize, is_f_order + ll_intp, + ll_intc, + ], + ) + fn = cgutils.get_or_insert_function( + builder.module, fnty, "numba_attempt_nocopy_reshape" + ) + + nd = ll_intp(aryty.ndim) + shape = cgutils.gep_inbounds(builder, ary._get_ptr_by_name("shape"), 0, 0) + strides = cgutils.gep_inbounds( + builder, ary._get_ptr_by_name("strides"), 0, 0 + ) + newnd = ll_intp(newnd) + newshape = cgutils.gep_inbounds(builder, newshape, 0, 0) + newstrides = cgutils.gep_inbounds(builder, newstrides, 0, 0) + is_f_order = ll_intc(0) + res = builder.call( + fn, + [ + nd, + shape, + strides, + newnd, + newshape, + newstrides, + ary.itemsize, + is_f_order, + ], + ) + return res + + +def normalize_reshape_value(origsize, shape): + num_neg_value = 0 + known_size = 1 + for ax, s in enumerate(shape): + if s < 0: + num_neg_value += 1 + neg_ax = ax + else: + known_size *= s + + if num_neg_value == 0: + if origsize != known_size: + raise ValueError("total size of new array must be unchanged") + + elif num_neg_value == 1: + # Infer negative dimension + if known_size == 0: + inferred = 0 + ok = origsize == 0 + else: + inferred = origsize // known_size + ok = origsize % known_size == 0 + if not ok: + raise ValueError("total size of new array must be unchanged") + shape[neg_ax] = inferred + + else: + raise ValueError("multiple negative shape values") + + +@lower("array.reshape", types.Array, types.BaseTuple) +def array_reshape(context, builder, sig, args): + aryty = sig.args[0] + retty = sig.return_type + + shapety = sig.args[1] + shape = args[1] + + ll_intp = context.get_value_type(types.intp) + ll_shape = ir.ArrayType(ll_intp, shapety.count) + + ary = make_array(aryty)(context, builder, args[0]) + + # We will change the target shape in this slot + # (see normalize_reshape_value() below) + newshape = cgutils.alloca_once(builder, ll_shape) + builder.store(shape, newshape) + + # Create a shape array pointing to the value of newshape. + # (roughly, `shape_ary = np.array(ary.shape)`) + shape_ary_ty = types.Array(dtype=shapety.dtype, ndim=1, layout="C") + shape_ary = make_array(shape_ary_ty)(context, builder) + shape_itemsize = context.get_constant( + types.intp, context.get_abi_sizeof(ll_intp) + ) + populate_array( + shape_ary, + data=builder.bitcast(newshape, ll_intp.as_pointer()), + shape=[context.get_constant(types.intp, shapety.count)], + strides=[shape_itemsize], + itemsize=shape_itemsize, + meminfo=None, + ) + + # Compute the original array size + size = ary.nitems + + # Call our normalizer which will fix the shape array in case of negative + # shape value + context.compile_internal( + builder, + normalize_reshape_value, + typing.signature(types.void, types.uintp, shape_ary_ty), + [size, shape_ary._getvalue()], + ) + + # Perform reshape (nocopy) + newnd = shapety.count + newstrides = cgutils.alloca_once(builder, ll_shape) + + ok = _attempt_nocopy_reshape( + context, builder, aryty, ary, newnd, newshape, newstrides + ) + fail = builder.icmp_unsigned("==", ok, ok.type(0)) + + with builder.if_then(fail): + msg = "incompatible shape for array" + context.call_conv.return_user_exc(builder, NotImplementedError, (msg,)) + + ret = make_array(retty)(context, builder) + populate_array( + ret, + data=ary.data, + shape=builder.load(newshape), + strides=builder.load(newstrides), + itemsize=ary.itemsize, + meminfo=ary.meminfo, + parent=ary.parent, + ) + res = ret._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower("array.reshape", types.Array, types.VarArg(types.Any)) +def array_reshape_vararg(context, builder, sig, args): + new_sig, new_args = vararg_to_tuple(context, builder, sig, args) + return array_reshape(context, builder, new_sig, new_args) + + +if numpy_version < (2, 1): + + @overload(np.reshape) + def np_reshape(a, newshape): + def np_reshape_impl(a, newshape): + return a.reshape(newshape) + + return np_reshape_impl +else: + + @overload(np.reshape) + def np_reshape(a, shape): + def np_reshape_impl(a, shape): + return a.reshape(shape) + + return np_reshape_impl + + +@overload(np.resize) +def numpy_resize(a, new_shape): + if not type_can_asarray(a): + msg = 'The argument "a" must be array-like' + raise errors.TypingError(msg) + + if not ( + ( + isinstance(new_shape, types.UniTuple) + and isinstance(new_shape.dtype, types.Integer) + ) + or isinstance(new_shape, types.Integer) + ): + msg = ( + 'The argument "new_shape" must be an integer or a tuple of integers' + ) + raise errors.TypingError(msg) + + def impl(a, new_shape): + a = np.asarray(a) + a = np.ravel(a) + + if isinstance(new_shape, tuple): + new_size = 1 + for dim_length in np.asarray(new_shape): + new_size *= dim_length + if dim_length < 0: + msg = "All elements of `new_shape` must be non-negative" + raise ValueError(msg) + else: + if new_shape < 0: + msg2 = "All elements of `new_shape` must be non-negative" + raise ValueError(msg2) + new_size = new_shape + + if a.size == 0: + return np.zeros(new_shape).astype(a.dtype) + + repeats = -(-new_size // a.size) # ceil division + res = a + for i in range(repeats - 1): + res = np.concatenate((res, a)) + res = res[:new_size] + + return np.reshape(res, new_shape) + + return impl + + +@overload(np.append) +def np_append(arr, values, axis=None): + if not type_can_asarray(arr): + raise errors.TypingError('The first argument "arr" must be array-like') + + if not type_can_asarray(values): + raise errors.TypingError( + 'The second argument "values" must be array-like' + ) + + if is_nonelike(axis): + + def impl(arr, values, axis=None): + arr = np.ravel(np.asarray(arr)) + values = np.ravel(np.asarray(values)) + return np.concatenate((arr, values)) + else: + if not isinstance(axis, types.Integer): + raise errors.TypingError( + 'The third argument "axis" must be an integer' + ) + + def impl(arr, values, axis=None): + return np.concatenate((arr, values), axis=axis) + + return impl + + +@lower("array.ravel", types.Array) +def array_ravel(context, builder, sig, args): + # Only support no argument version (default order='C') + def imp_nocopy(ary): + """No copy version""" + return ary.reshape(ary.size) + + def imp_copy(ary): + """Copy version""" + return ary.flatten() + + # If the input array is C layout already, use the nocopy version + if sig.args[0].layout == "C": + imp = imp_nocopy + # otherwise, use flatten under-the-hood + else: + imp = imp_copy + + res = context.compile_internal(builder, imp, sig, args) + res = impl_ret_new_ref(context, builder, sig.return_type, res) + return res + + +@lower(np.ravel, types.Array) +def np_ravel(context, builder, sig, args): + def np_ravel_impl(a): + return a.ravel() + + return context.compile_internal(builder, np_ravel_impl, sig, args) + + +@lower("array.flatten", types.Array) +def array_flatten(context, builder, sig, args): + # Only support flattening to C layout currently. + def imp(ary): + return ary.copy().reshape(ary.size) + + res = context.compile_internal(builder, imp, sig, args) + res = impl_ret_new_ref(context, builder, sig.return_type, res) + return res + + +@register_jitable +def _np_clip_impl(a, a_min, a_max, out): + # Both a_min and a_max are numpy arrays + ret = np.empty_like(a) if out is None else out + a_b, a_min_b, a_max_b = np.broadcast_arrays(a, a_min, a_max) + for index in np.ndindex(a_b.shape): + val_a = a_b[index] + val_a_min = a_min_b[index] + val_a_max = a_max_b[index] + ret[index] = min(max(val_a, val_a_min), val_a_max) + + return ret + + +@register_jitable +def _np_clip_impl_none(a, b, use_min, out): + for index in np.ndindex(a.shape): + val_a = a[index] + val_b = b[index] + if use_min: + out[index] = min(val_a, val_b) + else: + out[index] = max(val_a, val_b) + return out + + +@overload(np.clip) +def np_clip(a, a_min, a_max, out=None): + if not type_can_asarray(a): + raise errors.TypingError('The argument "a" must be array-like') + + if not isinstance(a_min, types.NoneType) and not type_can_asarray(a_min): + raise errors.TypingError( + ('The argument "a_min" must be a number or an array-like') + ) + + if not isinstance(a_max, types.NoneType) and not type_can_asarray(a_max): + raise errors.TypingError( + 'The argument "a_max" must be a number or an array-like' + ) + + if not (isinstance(out, types.Array) or is_nonelike(out)): + msg = 'The argument "out" must be an array if it is provided' + raise errors.TypingError(msg) + + # TODO: support scalar a (issue #3469) + a_min_is_none = a_min is None or isinstance(a_min, types.NoneType) + a_max_is_none = a_max is None or isinstance(a_max, types.NoneType) + + if a_min_is_none and a_max_is_none: + # Raises value error when both a_min and a_max are None + def np_clip_nn(a, a_min, a_max, out=None): + raise ValueError("array_clip: must set either max or min") + + return np_clip_nn + + a_min_is_scalar = isinstance(a_min, types.Number) + a_max_is_scalar = isinstance(a_max, types.Number) + + if a_min_is_scalar and a_max_is_scalar: + + def np_clip_ss(a, a_min, a_max, out=None): + # a_min and a_max are scalars + # since their shape will be empty + # so broadcasting is not needed at all + ret = np.empty_like(a) if out is None else out + for index in np.ndindex(a.shape): + val_a = a[index] + ret[index] = min(max(val_a, a_min), a_max) + + return ret + + return np_clip_ss + elif a_min_is_scalar and not a_max_is_scalar: + if a_max_is_none: + + def np_clip_sn(a, a_min, a_max, out=None): + # a_min is a scalar + # since its shape will be empty + # so broadcasting is not needed at all + ret = np.empty_like(a) if out is None else out + for index in np.ndindex(a.shape): + val_a = a[index] + ret[index] = max(val_a, a_min) + + return ret + + return np_clip_sn + else: + + def np_clip_sa(a, a_min, a_max, out=None): + # a_min is a scalar + # since its shape will be empty + # broadcast it to shape of a + # by using np.full_like + a_min_full = np.full_like(a, a_min) + return _np_clip_impl(a, a_min_full, a_max, out) + + return np_clip_sa + elif not a_min_is_scalar and a_max_is_scalar: + if a_min_is_none: + + def np_clip_ns(a, a_min, a_max, out=None): + # a_max is a scalar + # since its shape will be empty + # so broadcasting is not needed at all + ret = np.empty_like(a) if out is None else out + for index in np.ndindex(a.shape): + val_a = a[index] + ret[index] = min(val_a, a_max) + + return ret + + return np_clip_ns + else: + + def np_clip_as(a, a_min, a_max, out=None): + # a_max is a scalar + # since its shape will be empty + # broadcast it to shape of a + # by using np.full_like + a_max_full = np.full_like(a, a_max) + return _np_clip_impl(a, a_min, a_max_full, out) + + return np_clip_as + else: + # Case where exactly one of a_min or a_max is None + if a_min_is_none: + + def np_clip_na(a, a_min, a_max, out=None): + # a_max is a numpy array but a_min is None + ret = np.empty_like(a) if out is None else out + a_b, a_max_b = np.broadcast_arrays(a, a_max) + return _np_clip_impl_none(a_b, a_max_b, True, ret) + + return np_clip_na + elif a_max_is_none: + + def np_clip_an(a, a_min, a_max, out=None): + # a_min is a numpy array but a_max is None + ret = np.empty_like(a) if out is None else out + a_b, a_min_b = np.broadcast_arrays(a, a_min) + return _np_clip_impl_none(a_b, a_min_b, False, ret) + + return np_clip_an + else: + + def np_clip_aa(a, a_min, a_max, out=None): + # Both a_min and a_max are clearly arrays + # because none of the above branches + # returned + return _np_clip_impl(a, a_min, a_max, out) + + return np_clip_aa + + +@overload_method(types.Array, "clip") +def array_clip(a, a_min=None, a_max=None, out=None): + def impl(a, a_min=None, a_max=None, out=None): + return np.clip(a, a_min, a_max, out) + + return impl + + +def _change_dtype(context, builder, oldty, newty, ary): + """ + Attempt to fix up *ary* for switching from *oldty* to *newty*. + + See Numpy's array_descr_set() + (np/core/src/multiarray/getset.c). + Attempt to fix the array's shape and strides for a new dtype. + False is returned on failure, True on success. + """ + assert oldty.ndim == newty.ndim + assert oldty.layout == newty.layout + + new_layout = ord(newty.layout) + any_layout = ord("A") + c_layout = ord("C") + f_layout = ord("F") + + int8 = types.int8 + + def imp(nd, dims, strides, old_itemsize, new_itemsize, layout): + # Attempt to update the layout due to limitation of the numba + # type system. + if layout == any_layout: + # Test rightmost stride to be contiguous + if strides[-1] == old_itemsize: + # Process this as if it is C contiguous + layout = int8(c_layout) + # Test leftmost stride to be F contiguous + elif strides[0] == old_itemsize: + # Process this as if it is F contiguous + layout = int8(f_layout) + + if old_itemsize != new_itemsize and (layout == any_layout or nd == 0): + return False + + if layout == c_layout: + i = nd - 1 + else: + i = 0 + + if new_itemsize < old_itemsize: + # If it is compatible, increase the size of the dimension + # at the end (or at the front if F-contiguous) + if (old_itemsize % new_itemsize) != 0: + return False + + newdim = old_itemsize // new_itemsize + dims[i] *= newdim + strides[i] = new_itemsize + + elif new_itemsize > old_itemsize: + # Determine if last (or first if F-contiguous) dimension + # is compatible + bytelength = dims[i] * old_itemsize + if (bytelength % new_itemsize) != 0: + return False + + dims[i] = bytelength // new_itemsize + strides[i] = new_itemsize + + else: + # Same item size: nothing to do (this also works for + # non-contiguous arrays). + pass + + return True + + old_itemsize = context.get_constant( + types.intp, get_itemsize(context, oldty) + ) + new_itemsize = context.get_constant( + types.intp, get_itemsize(context, newty) + ) + + nd = context.get_constant(types.intp, newty.ndim) + shape_data = cgutils.gep_inbounds( + builder, ary._get_ptr_by_name("shape"), 0, 0 + ) + strides_data = cgutils.gep_inbounds( + builder, ary._get_ptr_by_name("strides"), 0, 0 + ) + + shape_strides_array_type = types.Array(dtype=types.intp, ndim=1, layout="C") + arycls = context.make_array(shape_strides_array_type) + + shape_constant = cgutils.pack_array( + builder, [context.get_constant(types.intp, newty.ndim)] + ) + + sizeof_intp = context.get_abi_sizeof(context.get_data_type(types.intp)) + sizeof_intp = context.get_constant(types.intp, sizeof_intp) + strides_constant = cgutils.pack_array(builder, [sizeof_intp]) + + shape_ary = arycls(context, builder) + + populate_array( + shape_ary, + data=shape_data, + shape=shape_constant, + strides=strides_constant, + itemsize=sizeof_intp, + meminfo=None, + ) + + strides_ary = arycls(context, builder) + populate_array( + strides_ary, + data=strides_data, + shape=shape_constant, + strides=strides_constant, + itemsize=sizeof_intp, + meminfo=None, + ) + + shape = shape_ary._getvalue() + strides = strides_ary._getvalue() + args = [ + nd, + shape, + strides, + old_itemsize, + new_itemsize, + context.get_constant(types.int8, new_layout), + ] + + sig = signature( + types.boolean, + types.intp, # nd + shape_strides_array_type, # dims + shape_strides_array_type, # strides + types.intp, # old_itemsize + types.intp, # new_itemsize + types.int8, # layout + ) + + res = context.compile_internal(builder, imp, sig, args) + update_array_info(newty, ary) + res = impl_ret_borrowed(context, builder, sig.return_type, res) + return res + + +@overload(np.shape) +def np_shape(a): + if not type_can_asarray(a): + raise errors.TypingError("The argument to np.shape must be array-like") + + def impl(a): + return np.asarray(a).shape + + return impl + + +@overload(np.size) +def np_size(a): + if not type_can_asarray(a): + raise errors.TypingError("The argument to np.size must be array-like") + + def impl(a): + return np.asarray(a).size + + return impl + + +# ------------------------------------------------------------------------------ + + +@overload(np.unique) +def np_unique(ar): + def np_unique_impl(ar): + b = np.sort(ar.ravel()) + head = list(b[:1]) + tail = [x for i, x in enumerate(b[1:]) if b[i] != x] + return np.array(head + tail) + + return np_unique_impl + + +@overload(np.repeat) +def np_repeat(a, repeats): + # Implementation for repeats being a scalar is a module global function + # (see below) because it might be called from the implementation below. + + def np_repeat_impl_repeats_array_like(a, repeats): + # implementation if repeats is an array like + repeats_array = np.asarray(repeats, dtype=np.int64) + # if it is a singleton array, invoke the scalar implementation + if repeats_array.shape[0] == 1: + return np_repeat_impl_repeats_scaler(a, repeats_array[0]) + if np.any(repeats_array < 0): + raise ValueError("negative dimensions are not allowed") + asa = np.asarray(a) + aravel = asa.ravel() + n = aravel.shape[0] + if aravel.shape != repeats_array.shape: + raise ValueError("operands could not be broadcast together") + to_return = np.empty(np.sum(repeats_array), dtype=asa.dtype) + pos = 0 + for i in range(n): + to_return[pos : pos + repeats_array[i]] = aravel[i] + pos += repeats_array[i] + return to_return + + # type checking + if isinstance( + a, + ( + types.Array, + types.List, + types.BaseTuple, + types.Number, + types.Boolean, + ), + ): + if isinstance(repeats, types.Integer): + return np_repeat_impl_repeats_scaler + elif isinstance(repeats, (types.Array, types.List)): + if isinstance(repeats.dtype, types.Integer): + return np_repeat_impl_repeats_array_like + + raise errors.TypingError( + "The repeats argument must be an integer " + "or an array-like of integer dtype" + ) + + +@register_jitable +def np_repeat_impl_repeats_scaler(a, repeats): + if repeats < 0: + raise ValueError("negative dimensions are not allowed") + asa = np.asarray(a) + aravel = asa.ravel() + n = aravel.shape[0] + if repeats == 0: + return np.empty(0, dtype=asa.dtype) + elif repeats == 1: + return np.copy(aravel) + else: + to_return = np.empty(n * repeats, dtype=asa.dtype) + for i in range(n): + to_return[i * repeats : (i + 1) * repeats] = aravel[i] + return to_return + + +@extending.overload_method(types.Array, "repeat") +def array_repeat(a, repeats): + def array_repeat_impl(a, repeats): + return np.repeat(a, repeats) + + return array_repeat_impl + + +@intrinsic +def _intrin_get_itemsize(tyctx, dtype): + """Computes the itemsize of the dtype""" + sig = types.intp(dtype) + + def codegen(cgctx, builder, sig, llargs): + llty = cgctx.get_data_type(sig.args[0].dtype) + llintp = cgctx.get_data_type(sig.return_type) + return llintp(cgctx.get_abi_sizeof(llty)) + + return sig, codegen + + +def _compatible_view(a, dtype): + pass + + +@overload(_compatible_view) +def ol_compatible_view(a, dtype): + """Determines if the array and dtype are compatible for forming a view.""" + + # NOTE: NumPy 1.23+ uses this check. + # Code based on: + # https://github.com/numpy/numpy/blob/750ad21258cfc00663586d5a466e24f91b48edc7/numpy/core/src/multiarray/getset.c#L500-L555 # noqa: E501 + def impl(a, dtype): + dtype_size = _intrin_get_itemsize(dtype) + if dtype_size != a.itemsize: + # catch forbidden cases + if a.ndim == 0: + msg1 = ( + "Changing the dtype of a 0d array is only supported " + "if the itemsize is unchanged" + ) + raise ValueError(msg1) + else: + # NumPy has a check here for subarray type conversion which + # Numba doesn't support + pass + + # Resize on last axis only + axis = a.ndim - 1 + p1 = a.shape[axis] != 1 + p2 = a.size != 0 + p3 = a.strides[axis] != a.itemsize + if p1 and p2 and p3: + msg2 = ( + "To change to a dtype of a different size, the last " + "axis must be contiguous" + ) + raise ValueError(msg2) + + if dtype_size < a.itemsize: + if dtype_size == 0 or a.itemsize % dtype_size != 0: + msg3 = ( + "When changing to a smaller dtype, its size must " + "be a divisor of the size of original dtype" + ) + raise ValueError(msg3) + else: + newdim = a.shape[axis] * a.itemsize + if newdim % dtype_size != 0: + msg4 = ( + "When changing to a larger dtype, its size must be " + "a divisor of the total size in bytes of the last " + "axis of the array." + ) + raise ValueError(msg4) + + return impl + + +@lower("array.view", types.Array, types.DTypeSpec) +def array_view(context, builder, sig, args): + aryty = sig.args[0] + retty = sig.return_type + + ary = make_array(aryty)(context, builder, args[0]) + ret = make_array(retty)(context, builder) + # Copy all fields, casting the "data" pointer appropriately + fields = set(ret._datamodel._fields) + for k in sorted(fields): + val = getattr(ary, k) + if k == "data": + ptrty = ret.data.type + ret.data = builder.bitcast(val, ptrty) + else: + setattr(ret, k, val) + + tyctx = context.typing_context + fnty = tyctx.resolve_value_type(_compatible_view) + _compatible_view_sig = fnty.get_call_type(tyctx, (*sig.args,), {}) + impl = context.get_function(fnty, _compatible_view_sig) + impl(builder, args) + + ok = _change_dtype(context, builder, aryty, retty, ret) + fail = builder.icmp_unsigned("==", ok, Constant(ok.type, 0)) + + with builder.if_then(fail): + msg = "new type not compatible with array" + context.call_conv.return_user_exc(builder, ValueError, (msg,)) + + res = ret._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +# ------------------------------------------------------------------------------ +# Array attributes + + +@lower_getattr(types.Array, "dtype") +def array_dtype(context, builder, typ, value): + res = context.get_dummy_value() + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.Array, "shape") +@lower_getattr(types.MemoryView, "shape") +def array_shape(context, builder, typ, value): + arrayty = make_array(typ) + array = arrayty(context, builder, value) + res = array.shape + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.Array, "strides") +@lower_getattr(types.MemoryView, "strides") +def array_strides(context, builder, typ, value): + arrayty = make_array(typ) + array = arrayty(context, builder, value) + res = array.strides + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.Array, "ndim") +@lower_getattr(types.MemoryView, "ndim") +def array_ndim(context, builder, typ, value): + res = context.get_constant(types.intp, typ.ndim) + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.Array, "size") +def array_size(context, builder, typ, value): + arrayty = make_array(typ) + array = arrayty(context, builder, value) + res = array.nitems + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.Array, "itemsize") +@lower_getattr(types.MemoryView, "itemsize") +def array_itemsize(context, builder, typ, value): + arrayty = make_array(typ) + array = arrayty(context, builder, value) + res = array.itemsize + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.Array, "nbytes") +@lower_getattr(types.MemoryView, "nbytes") +def array_nbytes(context, builder, typ, value): + """ + nbytes = size * itemsize + """ + arrayty = make_array(typ) + array = arrayty(context, builder, value) + res = builder.mul(array.nitems, array.itemsize) + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.MemoryView, "contiguous") +def array_contiguous(context, builder, typ, value): + res = context.get_constant(types.boolean, typ.is_contig) + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.MemoryView, "c_contiguous") +def array_c_contiguous(context, builder, typ, value): + res = context.get_constant(types.boolean, typ.is_c_contig) + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.MemoryView, "f_contiguous") +def array_f_contiguous(context, builder, typ, value): + res = context.get_constant(types.boolean, typ.is_f_contig) + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.MemoryView, "readonly") +def array_readonly(context, builder, typ, value): + res = context.get_constant(types.boolean, not typ.mutable) + return impl_ret_untracked(context, builder, typ, res) + + +# array.ctypes + + +@lower_getattr(types.Array, "ctypes") +def array_ctypes(context, builder, typ, value): + arrayty = make_array(typ) + array = arrayty(context, builder, value) + # Create new ArrayCType structure + act = types.ArrayCTypes(typ) + ctinfo = context.make_helper(builder, act) + ctinfo.data = array.data + ctinfo.meminfo = array.meminfo + res = ctinfo._getvalue() + return impl_ret_borrowed(context, builder, act, res) + + +@lower_getattr(types.ArrayCTypes, "data") +def array_ctypes_data(context, builder, typ, value): + ctinfo = context.make_helper(builder, typ, value=value) + res = ctinfo.data + # Convert it to an integer + res = builder.ptrtoint(res, context.get_value_type(types.intp)) + return impl_ret_untracked(context, builder, typ, res) + + +@lower_cast(types.ArrayCTypes, types.CPointer) +@lower_cast(types.ArrayCTypes, types.voidptr) +def array_ctypes_to_pointer(context, builder, fromty, toty, val): + ctinfo = context.make_helper(builder, fromty, value=val) + res = ctinfo.data + res = builder.bitcast(res, context.get_value_type(toty)) + return impl_ret_untracked(context, builder, toty, res) + + +def _call_contiguous_check(checker, context, builder, aryty, ary): + """Helper to invoke the contiguous checker function on an array + + Args + ---- + checker : + ``numba.cuda.np.numpy_supports.is_contiguous``, or + ``numba.cuda.np.numpy_supports.is_fortran``. + context : target context + builder : llvm ir builder + aryty : numba type + ary : llvm value + """ + ary = make_array(aryty)(context, builder, value=ary) + tup_intp = types.UniTuple(types.intp, aryty.ndim) + itemsize = context.get_abi_sizeof(context.get_value_type(aryty.dtype)) + check_sig = signature(types.bool_, tup_intp, tup_intp, types.intp) + check_args = [ + ary.shape, + ary.strides, + context.get_constant(types.intp, itemsize), + ] + is_contig = context.compile_internal( + builder, checker, check_sig, check_args + ) + return is_contig + + +# array.flags + + +@lower_getattr(types.Array, "flags") +def array_flags(context, builder, typ, value): + flagsobj = context.make_helper(builder, types.ArrayFlags(typ)) + flagsobj.parent = value + res = flagsobj._getvalue() + context.nrt.incref(builder, typ, value) + return impl_ret_new_ref(context, builder, typ, res) + + +@lower_getattr(types.ArrayFlags, "contiguous") +@lower_getattr(types.ArrayFlags, "c_contiguous") +def array_flags_c_contiguous(context, builder, typ, value): + if typ.array_type.layout != "C": + # any layout can still be contiguous + flagsobj = context.make_helper(builder, typ, value=value) + res = _call_contiguous_check( + is_contiguous, context, builder, typ.array_type, flagsobj.parent + ) + else: + val = typ.array_type.layout == "C" + res = context.get_constant(types.boolean, val) + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.ArrayFlags, "f_contiguous") +def array_flags_f_contiguous(context, builder, typ, value): + if typ.array_type.layout != "F": + # any layout can still be contiguous + flagsobj = context.make_helper(builder, typ, value=value) + res = _call_contiguous_check( + is_fortran, context, builder, typ.array_type, flagsobj.parent + ) + else: + layout = typ.array_type.layout + val = layout == "F" if typ.array_type.ndim > 1 else layout in "CF" + res = context.get_constant(types.boolean, val) + return impl_ret_untracked(context, builder, typ, res) + + +# ------------------------------------------------------------------------------ +# .real / .imag + + +@lower_getattr(types.Array, "real") +def array_real_part(context, builder, typ, value): + if typ.dtype in types.complex_domain: + return array_complex_attr(context, builder, typ, value, attr="real") + elif typ.dtype in types.number_domain: + # as an identity function + return impl_ret_borrowed(context, builder, typ, value) + else: + raise NotImplementedError("unsupported .real for {}".format(type.dtype)) + + +@lower_getattr(types.Array, "imag") +def array_imag_part(context, builder, typ, value): + if typ.dtype in types.complex_domain: + return array_complex_attr(context, builder, typ, value, attr="imag") + elif typ.dtype in types.number_domain: + # return a readonly zero array + sig = signature(typ.copy(readonly=True), typ) + arrtype, shapes = _parse_empty_like_args(context, builder, sig, [value]) + ary = _empty_nd_impl(context, builder, arrtype, shapes) + cgutils.memset( + builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0 + ) + return impl_ret_new_ref( + context, builder, sig.return_type, ary._getvalue() + ) + else: + raise NotImplementedError("unsupported .imag for {}".format(type.dtype)) + + +def array_complex_attr(context, builder, typ, value, attr): + """ + Given a complex array, it's memory layout is: + + R C R C R C + ^ ^ ^ + + (`R` indicates a float for the real part; + `C` indicates a float for the imaginary part; + the `^` indicates the start of each element) + + To get the real part, we can simply change the dtype and itemsize to that + of the underlying float type. The new layout is: + + R x R x R x + ^ ^ ^ + + (`x` indicates unused) + + A load operation will use the dtype to determine the number of bytes to + load. + + To get the imaginary part, we shift the pointer by 1 float offset and + change the dtype and itemsize. The new layout is: + + x C x C x C + ^ ^ ^ + """ + if attr not in ["real", "imag"] or typ.dtype not in types.complex_domain: + raise NotImplementedError("cannot get attribute `{}`".format(attr)) + + arrayty = make_array(typ) + array = arrayty(context, builder, value) + + # sizeof underlying float type + flty = typ.dtype.underlying_float + sizeof_flty = context.get_abi_sizeof(context.get_data_type(flty)) + itemsize = array.itemsize.type(sizeof_flty) + + # cast data pointer to float type + llfltptrty = context.get_value_type(flty).as_pointer() + dataptr = builder.bitcast(array.data, llfltptrty) + + # add offset + if attr == "imag": + dataptr = builder.gep(dataptr, [ir.IntType(32)(1)]) + + # make result + resultty = typ.copy(dtype=flty, layout="A") + result = make_array(resultty)(context, builder) + repl = dict(data=dataptr, itemsize=itemsize) + cgutils.copy_struct(result, array, repl) + return impl_ret_borrowed(context, builder, resultty, result._getvalue()) + + +@overload_method(types.Array, "conj") +@overload_method(types.Array, "conjugate") +def array_conj(arr): + def impl(arr): + return np.conj(arr) + + return impl + + +# ------------------------------------------------------------------------------ +# DType attribute + + +def dtype_type(context, builder, dtypety, dtypeval): + # Just return a dummy opaque value + return context.get_dummy_value() + + +lower_getattr(types.DType, "type")(dtype_type) +lower_getattr(types.DType, "kind")(dtype_type) + + +# ------------------------------------------------------------------------------ +# static_getitem on Numba numerical types to create "array" types + + +@lower("static_getitem", types.NumberClass, types.Any) +def static_getitem_number_clazz(context, builder, sig, args): + """This handles the "static_getitem" when a Numba type is subscripted e.g: + var = typed.List.empty_list(float64[::1, :]) + It only allows this on simple numerical types. Compound types, like + records, are not supported. + """ + retty = sig.return_type + if isinstance(retty, types.Array): + # This isn't used or practically accessible, but has to exist, so just + # put in a NULL of the right type. + res = context.get_value_type(retty)(None) + return impl_ret_untracked(context, builder, retty, res) + else: + # This should be unreachable unless the implementation on the Type + # metaclass is changed. + msg = ( + "Unreachable; the definition of __getitem__ on the " + "numba.types.abstract.Type metaclass should prevent access." + ) + raise errors.LoweringError(msg) + + +# ------------------------------------------------------------------------------ +# Structured / record lookup + + +@lower_getattr_generic(types.Array) +def array_record_getattr(context, builder, typ, value, attr): + """ + Generic getattr() implementation for record arrays: fetch the given + record member, i.e. a subarray. + """ + arrayty = make_array(typ) + array = arrayty(context, builder, value) + + rectype = typ.dtype + if not isinstance(rectype, types.Record): + raise NotImplementedError( + "attribute %r of %s not defined" % (attr, typ) + ) + dtype = rectype.typeof(attr) + offset = rectype.offset(attr) + + if isinstance(dtype, types.NestedArray): + resty = typ.copy( + dtype=dtype.dtype, ndim=typ.ndim + dtype.ndim, layout="A" + ) + else: + resty = typ.copy(dtype=dtype, layout="A") + + raryty = make_array(resty) + + rary = raryty(context, builder) + + constoffset = context.get_constant(types.intp, offset) + + newdataptr = cgutils.pointer_add( + builder, + array.data, + constoffset, + return_type=rary.data.type, + ) + if isinstance(dtype, types.NestedArray): + # new shape = recarray shape + inner dimension from nestedarray + shape = cgutils.unpack_tuple(builder, array.shape, typ.ndim) + shape += [context.get_constant(types.intp, i) for i in dtype.shape] + # new strides = recarray strides + strides of the inner nestedarray + strides = cgutils.unpack_tuple(builder, array.strides, typ.ndim) + strides += [context.get_constant(types.intp, i) for i in dtype.strides] + # New datasize = size of elements of the nestedarray + datasize = context.get_abi_sizeof(context.get_data_type(dtype.dtype)) + else: + # New shape, strides, and datasize match the underlying array + shape = array.shape + strides = array.strides + datasize = context.get_abi_sizeof(context.get_data_type(dtype)) + populate_array( + rary, + data=newdataptr, + shape=shape, + strides=strides, + itemsize=context.get_constant(types.intp, datasize), + meminfo=array.meminfo, + parent=array.parent, + ) + res = rary._getvalue() + return impl_ret_borrowed(context, builder, resty, res) + + +@lower("static_getitem", types.Array, types.StringLiteral) +def array_record_getitem(context, builder, sig, args): + index = args[1] + if not isinstance(index, str): + # This will fallback to normal getitem + raise NotImplementedError + return array_record_getattr(context, builder, sig.args[0], args[0], index) + + +@lower_getattr_generic(types.Record) +def record_getattr(context, builder, typ, value, attr): + """ + Generic getattr() implementation for records: get the given record member. + """ + context.sentry_record_alignment(typ, attr) + offset = typ.offset(attr) + elemty = typ.typeof(attr) + + if isinstance(elemty, types.NestedArray): + # Only a nested array's *data* is stored in a structured array, + # so we create an array structure to point to that data. + aryty = make_array(elemty) + ary = aryty(context, builder) + dtype = elemty.dtype + newshape = [context.get_constant(types.intp, s) for s in elemty.shape] + newstrides = [ + context.get_constant(types.intp, s) for s in elemty.strides + ] + newdata = cgutils.get_record_member( + builder, value, offset, context.get_data_type(dtype) + ) + populate_array( + ary, + data=newdata, + shape=cgutils.pack_array(builder, newshape), + strides=cgutils.pack_array(builder, newstrides), + itemsize=context.get_constant(types.intp, elemty.size), + meminfo=None, + parent=None, + ) + res = ary._getvalue() + return impl_ret_borrowed(context, builder, typ, res) + else: + dptr = cgutils.get_record_member( + builder, value, offset, context.get_data_type(elemty) + ) + align = None if typ.aligned else 1 + res = context.unpack_value(builder, elemty, dptr, align) + return impl_ret_borrowed(context, builder, typ, res) + + +@lower_setattr_generic(types.Record) +def record_setattr(context, builder, sig, args, attr): + """ + Generic setattr() implementation for records: set the given record member. + """ + typ, valty = sig.args + target, val = args + + context.sentry_record_alignment(typ, attr) + offset = typ.offset(attr) + elemty = typ.typeof(attr) + + if isinstance(elemty, types.NestedArray): + # Copy the data from the RHS into the nested array + val_struct = cgutils.create_struct_proxy(valty)( + context, builder, value=args[1] + ) + src = val_struct.data + dest = cgutils.get_record_member( + builder, target, offset, src.type.pointee + ) + cgutils.memcpy( + builder, dest, src, context.get_constant(types.intp, elemty.nitems) + ) + else: + # Set the given scalar record member + dptr = cgutils.get_record_member( + builder, target, offset, context.get_data_type(elemty) + ) + val = context.cast(builder, val, valty, elemty) + align = None if typ.aligned else 1 + context.pack_value(builder, elemty, val, dptr, align=align) + + +@lower("static_getitem", types.Record, types.StringLiteral) +def record_static_getitem_str(context, builder, sig, args): + """ + Record.__getitem__ redirects to getattr() + """ + impl = context.get_getattr(sig.args[0], args[1]) + return impl(context, builder, sig.args[0], args[0], args[1]) + + +@lower("static_getitem", types.Record, types.IntegerLiteral) +def record_static_getitem_int(context, builder, sig, args): + """ + Record.__getitem__ redirects to getattr() + """ + idx = sig.args[1].literal_value + fields = list(sig.args[0].fields) + ll_field = context.insert_const_string(builder.module, fields[idx]) + impl = context.get_getattr(sig.args[0], ll_field) + return impl(context, builder, sig.args[0], args[0], fields[idx]) + + +@lower("static_setitem", types.Record, types.StringLiteral, types.Any) +def record_static_setitem_str(context, builder, sig, args): + """ + Record.__setitem__ redirects to setattr() + """ + recty, _, valty = sig.args + rec, idx, val = args + getattr_sig = signature(sig.return_type, recty, valty) + impl = context.get_setattr(idx, getattr_sig) + assert impl is not None + return impl(builder, (rec, val)) + + +@lower("static_setitem", types.Record, types.IntegerLiteral, types.Any) +def record_static_setitem_int(context, builder, sig, args): + """ + Record.__setitem__ redirects to setattr() + """ + recty, _, valty = sig.args + rec, idx, val = args + getattr_sig = signature(sig.return_type, recty, valty) + fields = list(sig.args[0].fields) + impl = context.get_setattr(fields[idx], getattr_sig) + assert impl is not None + return impl(builder, (rec, val)) + + +# ------------------------------------------------------------------------------ +# Constant arrays and records + + +@lower_constant(types.Array) +def constant_array(context, builder, ty, pyval): + """ + Create a constant array (mechanism is target-dependent). + """ + return context.make_constant_array(builder, ty, pyval) + + +@lower_constant(types.Record) +def constant_record(context, builder, ty, pyval): + """ + Create a record constant as a stack-allocated array of bytes. + """ + lty = ir.ArrayType(ir.IntType(8), pyval.nbytes) + val = lty(bytearray(pyval.tostring())) + return cgutils.alloca_once_value(builder, val) + + +@lower_constant(types.Bytes) +def constant_bytes(context, builder, ty, pyval): + """ + Create a constant array from bytes (mechanism is target-dependent). + """ + buf = np.array(bytearray(pyval), dtype=np.uint8) + return context.make_constant_array(builder, ty, buf) + + +# ------------------------------------------------------------------------------ +# Comparisons + + +@lower(operator.is_, types.Array, types.Array) +def array_is(context, builder, sig, args): + aty, bty = sig.args + if aty != bty: + return cgutils.false_bit + + def array_is_impl(a, b): + return ( + a.shape == b.shape + and a.strides == b.strides + and a.ctypes.data == b.ctypes.data + ) + + return context.compile_internal(builder, array_is_impl, sig, args) + + +# ------------------------------------------------------------------------------ +# Hash + + +@overload_attribute(types.Array, "__hash__") +def ol_array_hash(arr): + return lambda arr: None + + +# ------------------------------------------------------------------------------ +# builtin `np.flat` implementation + + +def make_array_flat_cls(flatiterty): + """ + Return the Structure representation of the given *flatiterty* (an + instance of types.NumpyFlatType). + """ + return _make_flattening_iter_cls(flatiterty, "flat") + + +def make_array_ndenumerate_cls(nditerty): + """ + Return the Structure representation of the given *nditerty* (an + instance of types.NumpyNdEnumerateType). + """ + return _make_flattening_iter_cls(nditerty, "ndenumerate") + + +def _increment_indices( + context, + builder, + ndim, + shape, + indices, + end_flag=None, + loop_continue=None, + loop_break=None, +): + zero = context.get_constant(types.intp, 0) + + bbend = builder.append_basic_block("end_increment") + + if end_flag is not None: + builder.store(cgutils.false_byte, end_flag) + + for dim in reversed(range(ndim)): + idxptr = cgutils.gep_inbounds(builder, indices, dim) + idx = cgutils.increment_index(builder, builder.load(idxptr)) + + count = shape[dim] + in_bounds = builder.icmp_signed("<", idx, count) + with cgutils.if_likely(builder, in_bounds): + # New index is still in bounds + builder.store(idx, idxptr) + if loop_continue is not None: + loop_continue(dim) + builder.branch(bbend) + # Index out of bounds => reset it and proceed it to outer index + builder.store(zero, idxptr) + if loop_break is not None: + loop_break(dim) + + if end_flag is not None: + builder.store(cgutils.true_byte, end_flag) + builder.branch(bbend) + + builder.position_at_end(bbend) + + +def _increment_indices_array( + context, builder, arrty, arr, indices, end_flag=None +): + shape = cgutils.unpack_tuple(builder, arr.shape, arrty.ndim) + _increment_indices(context, builder, arrty.ndim, shape, indices, end_flag) + + +def make_nditer_cls(nditerty): + """ + Return the Structure representation of the given *nditerty* (an + instance of types.NumpyNdIterType). + """ + ndim = nditerty.ndim + layout = nditerty.layout + narrays = len(nditerty.arrays) + nshapes = ndim if nditerty.need_shaped_indexing else 1 + + class BaseSubIter(object): + """ + Base class for sub-iterators of a nditer() instance. + """ + + def __init__(self, nditer, member_name, start_dim, end_dim): + self.nditer = nditer + self.member_name = member_name + self.start_dim = start_dim + self.end_dim = end_dim + self.ndim = end_dim - start_dim + + def set_member_ptr(self, ptr): + setattr(self.nditer, self.member_name, ptr) + + @functools.cached_property + def member_ptr(self): + return getattr(self.nditer, self.member_name) + + def init_specific(self, context, builder): + pass + + def loop_continue(self, context, builder, logical_dim): + pass + + def loop_break(self, context, builder, logical_dim): + pass + + class FlatSubIter(BaseSubIter): + """ + Sub-iterator walking a contiguous array in physical order, with + support for broadcasting (the index is reset on the outer dimension). + """ + + def init_specific(self, context, builder): + zero = context.get_constant(types.intp, 0) + self.set_member_ptr(cgutils.alloca_once_value(builder, zero)) + + def compute_pointer(self, context, builder, indices, arrty, arr): + index = builder.load(self.member_ptr) + return builder.gep(arr.data, [index]) + + def loop_continue(self, context, builder, logical_dim): + if logical_dim == self.ndim - 1: + # Only increment index inside innermost logical dimension + index = builder.load(self.member_ptr) + index = cgutils.increment_index(builder, index) + builder.store(index, self.member_ptr) + + def loop_break(self, context, builder, logical_dim): + if logical_dim == 0: + # At the exit of outermost logical dimension, reset index + zero = context.get_constant(types.intp, 0) + builder.store(zero, self.member_ptr) + elif logical_dim == self.ndim - 1: + # Inside innermost logical dimension, increment index + index = builder.load(self.member_ptr) + index = cgutils.increment_index(builder, index) + builder.store(index, self.member_ptr) + + class TrivialFlatSubIter(BaseSubIter): + """ + Sub-iterator walking a contiguous array in physical order, + *without* support for broadcasting. + """ + + def init_specific(self, context, builder): + assert not nditerty.need_shaped_indexing + + def compute_pointer(self, context, builder, indices, arrty, arr): + assert len(indices) <= 1, len(indices) + return builder.gep(arr.data, indices) + + class IndexedSubIter(BaseSubIter): + """ + Sub-iterator walking an array in logical order. + """ + + def compute_pointer(self, context, builder, indices, arrty, arr): + assert len(indices) == self.ndim + return cgutils.get_item_pointer( + context, builder, arrty, arr, indices, wraparound=False + ) + + class ZeroDimSubIter(BaseSubIter): + """ + Sub-iterator "walking" a 0-d array. + """ + + def compute_pointer(self, context, builder, indices, arrty, arr): + return arr.data + + class ScalarSubIter(BaseSubIter): + """ + Sub-iterator "walking" a scalar value. + """ + + def compute_pointer(self, context, builder, indices, arrty, arr): + return arr + + class NdIter(cgutils.create_struct_proxy(nditerty)): + """ + .nditer() implementation. + + Note: 'F' layout means the shape is iterated in reverse logical order, + so indices and shapes arrays have to be reversed as well. + """ + + @functools.cached_property + def subiters(self): + l = [] + factories = { + "flat": FlatSubIter + if nditerty.need_shaped_indexing + else TrivialFlatSubIter, + "indexed": IndexedSubIter, + "0d": ZeroDimSubIter, + "scalar": ScalarSubIter, + } + for i, sub in enumerate(nditerty.indexers): + kind, start_dim, end_dim, _ = sub + member_name = "index%d" % i + factory = factories[kind] + l.append(factory(self, member_name, start_dim, end_dim)) + return l + + def init_specific(self, context, builder, arrtys, arrays): + """ + Initialize the nditer() instance for the specific array inputs. + """ + zero = context.get_constant(types.intp, 0) + + # Store inputs + self.arrays = context.make_tuple( + builder, types.Tuple(arrtys), arrays + ) + # Create slots for scalars + for i, ty in enumerate(arrtys): + if not isinstance(ty, types.Array): + member_name = "scalar%d" % i + # XXX as_data()? + slot = cgutils.alloca_once_value(builder, arrays[i]) + setattr(self, member_name, slot) + + arrays = self._arrays_or_scalars(context, builder, arrtys, arrays) + + # Extract iterator shape (the shape of the most-dimensional input) + main_shape_ty = types.UniTuple(types.intp, ndim) + main_shape = None + main_nitems = None + for i, arrty in enumerate(arrtys): + if isinstance(arrty, types.Array) and arrty.ndim == ndim: + main_shape = arrays[i].shape + main_nitems = arrays[i].nitems + break + else: + # Only scalar inputs => synthesize a dummy shape + assert ndim == 0 + main_shape = context.make_tuple(builder, main_shape_ty, ()) + main_nitems = context.get_constant(types.intp, 1) + + # Validate shapes of array inputs + def check_shape(shape, main_shape): + n = len(shape) + for i in range(n): + if shape[i] != main_shape[len(main_shape) - n + i]: + raise ValueError( + "nditer(): operands could not be broadcast together" + ) + + for arrty, arr in zip(arrtys, arrays): + if isinstance(arrty, types.Array) and arrty.ndim > 0: + sig = signature( + types.none, + types.UniTuple(types.intp, arrty.ndim), + main_shape_ty, + ) + context.compile_internal( + builder, check_shape, sig, (arr.shape, main_shape) + ) + + # Compute shape and size + shapes = cgutils.unpack_tuple(builder, main_shape) + if layout == "F": + shapes = shapes[::-1] + + # If shape is empty, mark iterator exhausted + shape_is_empty = builder.icmp_signed("==", main_nitems, zero) + exhausted = builder.select( + shape_is_empty, cgutils.true_byte, cgutils.false_byte + ) + + if not nditerty.need_shaped_indexing: + # Flatten shape to make iteration faster on small innermost + # dimensions (e.g. a (100000, 3) shape) + shapes = (main_nitems,) + assert len(shapes) == nshapes + + indices = cgutils.alloca_once(builder, zero.type, size=nshapes) + for dim in range(nshapes): + idxptr = cgutils.gep_inbounds(builder, indices, dim) + builder.store(zero, idxptr) + + self.indices = indices + self.shape = cgutils.pack_array(builder, shapes, zero.type) + self.exhausted = cgutils.alloca_once_value(builder, exhausted) + + # Initialize subiterators + for subiter in self.subiters: + subiter.init_specific(context, builder) + + def iternext_specific(self, context, builder, result): + """ + Compute next iteration of the nditer() instance. + """ + bbend = builder.append_basic_block("end") + + # Branch early if exhausted + exhausted = cgutils.as_bool_bit( + builder, builder.load(self.exhausted) + ) + with cgutils.if_unlikely(builder, exhausted): + result.set_valid(False) + builder.branch(bbend) + + arrtys = nditerty.arrays + arrays = cgutils.unpack_tuple(builder, self.arrays) + arrays = self._arrays_or_scalars(context, builder, arrtys, arrays) + indices = self.indices + + # Compute iterated results + result.set_valid(True) + views = self._make_views(context, builder, indices, arrtys, arrays) + views = [v._getvalue() for v in views] + if len(views) == 1: + result.yield_(views[0]) + else: + result.yield_( + context.make_tuple(builder, nditerty.yield_type, views) + ) + + shape = cgutils.unpack_tuple(builder, self.shape) + _increment_indices( + context, + builder, + len(shape), + shape, + indices, + self.exhausted, + functools.partial(self._loop_continue, context, builder), + functools.partial(self._loop_break, context, builder), + ) + + builder.branch(bbend) + builder.position_at_end(bbend) + + def _loop_continue(self, context, builder, dim): + for sub in self.subiters: + if sub.start_dim <= dim < sub.end_dim: + sub.loop_continue(context, builder, dim - sub.start_dim) + + def _loop_break(self, context, builder, dim): + for sub in self.subiters: + if sub.start_dim <= dim < sub.end_dim: + sub.loop_break(context, builder, dim - sub.start_dim) + + def _make_views(self, context, builder, indices, arrtys, arrays): + """ + Compute the views to be yielded. + """ + views = [None] * narrays + indexers = nditerty.indexers + subiters = self.subiters + rettys = nditerty.yield_type + if isinstance(rettys, types.BaseTuple): + rettys = list(rettys) + else: + rettys = [rettys] + indices = [ + builder.load(cgutils.gep_inbounds(builder, indices, i)) + for i in range(nshapes) + ] + + for sub, subiter in zip(indexers, subiters): + _, _, _, array_indices = sub + sub_indices = indices[subiter.start_dim : subiter.end_dim] + if layout == "F": + sub_indices = sub_indices[::-1] + for i in array_indices: + assert views[i] is None + views[i] = self._make_view( + context, + builder, + sub_indices, + rettys[i], + arrtys[i], + arrays[i], + subiter, + ) + assert all(v for v in views) + return views + + def _make_view( + self, context, builder, indices, retty, arrty, arr, subiter + ): + """ + Compute a 0d view for a given input array. + """ + assert isinstance(retty, types.Array) and retty.ndim == 0 + + ptr = subiter.compute_pointer(context, builder, indices, arrty, arr) + view = context.make_array(retty)(context, builder) + + itemsize = get_itemsize(context, retty) + shape = context.make_tuple( + builder, types.UniTuple(types.intp, 0), () + ) + strides = context.make_tuple( + builder, types.UniTuple(types.intp, 0), () + ) + # HACK: meminfo=None avoids expensive refcounting operations + # on ephemeral views + populate_array(view, ptr, shape, strides, itemsize, meminfo=None) + return view + + def _arrays_or_scalars(self, context, builder, arrtys, arrays): + # Return a list of either array structures or pointers to + # scalar slots + l = [] + for i, (arrty, arr) in enumerate(zip(arrtys, arrays)): + if isinstance(arrty, types.Array): + l.append( + context.make_array(arrty)(context, builder, value=arr) + ) + else: + l.append(getattr(self, "scalar%d" % i)) + return l + + return NdIter + + +def make_ndindex_cls(nditerty): + """ + Return the Structure representation of the given *nditerty* (an + instance of types.NumpyNdIndexType). + """ + ndim = nditerty.ndim + + class NdIndexIter(cgutils.create_struct_proxy(nditerty)): + """ + .ndindex() implementation. + """ + + def init_specific(self, context, builder, shapes): + zero = context.get_constant(types.intp, 0) + indices = cgutils.alloca_once( + builder, zero.type, size=context.get_constant(types.intp, ndim) + ) + exhausted = cgutils.alloca_once_value(builder, cgutils.false_byte) + + for dim in range(ndim): + idxptr = cgutils.gep_inbounds(builder, indices, dim) + builder.store(zero, idxptr) + # 0-sized dimensions really indicate an empty array, + # but we have to catch that condition early to avoid + # a bug inside the iteration logic. + dim_size = shapes[dim] + dim_is_empty = builder.icmp_unsigned("==", dim_size, zero) + with cgutils.if_unlikely(builder, dim_is_empty): + builder.store(cgutils.true_byte, exhausted) + + self.indices = indices + self.exhausted = exhausted + self.shape = cgutils.pack_array(builder, shapes, zero.type) + + def iternext_specific(self, context, builder, result): + zero = context.get_constant(types.intp, 0) + + bbend = builder.append_basic_block("end") + + exhausted = cgutils.as_bool_bit( + builder, builder.load(self.exhausted) + ) + with cgutils.if_unlikely(builder, exhausted): + result.set_valid(False) + builder.branch(bbend) + + indices = [ + builder.load(cgutils.gep_inbounds(builder, self.indices, dim)) + for dim in range(ndim) + ] + for load in indices: + mark_positive(builder, load) + + result.yield_(cgutils.pack_array(builder, indices, zero.type)) + result.set_valid(True) + + shape = cgutils.unpack_tuple(builder, self.shape, ndim) + _increment_indices( + context, builder, ndim, shape, self.indices, self.exhausted + ) + + builder.branch(bbend) + builder.position_at_end(bbend) + + return NdIndexIter + + +def _make_flattening_iter_cls(flatiterty, kind): + assert kind in ("flat", "ndenumerate") + + array_type = flatiterty.array_type + + if array_type.layout == "C": + + class CContiguousFlatIter(cgutils.create_struct_proxy(flatiterty)): + """ + .flat() / .ndenumerate() implementation for C-contiguous arrays. + """ + + def init_specific(self, context, builder, arrty, arr): + zero = context.get_constant(types.intp, 0) + self.index = cgutils.alloca_once_value(builder, zero) + # We can't trust strides[-1] to always contain the right + # step value, see + # http://docs.scipy.org/doc/numpy-dev/release.html#npy-relaxed-strides-checking # noqa: E501 + self.stride = arr.itemsize + + if kind == "ndenumerate": + # Zero-initialize the indices array. + indices = cgutils.alloca_once( + builder, + zero.type, + size=context.get_constant(types.intp, arrty.ndim), + ) + + for dim in range(arrty.ndim): + idxptr = cgutils.gep_inbounds(builder, indices, dim) + builder.store(zero, idxptr) + + self.indices = indices + + # NOTE: Using gep() instead of explicit pointer addition helps + # LLVM vectorize the loop (since the stride is known and + # constant). This is not possible in the non-contiguous case, + # where the strides are unknown at compile-time. + + def iternext_specific(self, context, builder, arrty, arr, result): + ndim = arrty.ndim + nitems = arr.nitems + + index = builder.load(self.index) + is_valid = builder.icmp_signed("<", index, nitems) + result.set_valid(is_valid) + + with cgutils.if_likely(builder, is_valid): + ptr = builder.gep(arr.data, [index]) + value = load_item(context, builder, arrty, ptr) + if kind == "flat": + result.yield_(value) + else: + # ndenumerate(): fetch and increment indices + indices = self.indices + idxvals = [ + builder.load( + cgutils.gep_inbounds(builder, indices, dim) + ) + for dim in range(ndim) + ] + idxtuple = cgutils.pack_array(builder, idxvals) + result.yield_( + cgutils.make_anonymous_struct( + builder, [idxtuple, value] + ) + ) + _increment_indices_array( + context, builder, arrty, arr, indices + ) + + index = cgutils.increment_index(builder, index) + builder.store(index, self.index) + + def getitem(self, context, builder, arrty, arr, index): + ptr = builder.gep(arr.data, [index]) + return load_item(context, builder, arrty, ptr) + + def setitem(self, context, builder, arrty, arr, index, value): + ptr = builder.gep(arr.data, [index]) + store_item(context, builder, arrty, value, ptr) + + return CContiguousFlatIter + + else: + + class FlatIter(cgutils.create_struct_proxy(flatiterty)): + """ + Generic .flat() / .ndenumerate() implementation for + non-contiguous arrays. + It keeps track of pointers along each dimension in order to + minimize computations. + """ + + def init_specific(self, context, builder, arrty, arr): + zero = context.get_constant(types.intp, 0) + data = arr.data + ndim = arrty.ndim + shapes = cgutils.unpack_tuple(builder, arr.shape, ndim) + + indices = cgutils.alloca_once( + builder, + zero.type, + size=context.get_constant(types.intp, arrty.ndim), + ) + pointers = cgutils.alloca_once( + builder, + data.type, + size=context.get_constant(types.intp, arrty.ndim), + ) + exhausted = cgutils.alloca_once_value( + builder, cgutils.false_byte + ) + + # Initialize indices and pointers with their start values. + for dim in range(ndim): + idxptr = cgutils.gep_inbounds(builder, indices, dim) + ptrptr = cgutils.gep_inbounds(builder, pointers, dim) + builder.store(data, ptrptr) + builder.store(zero, idxptr) + # 0-sized dimensions really indicate an empty array, + # but we have to catch that condition early to avoid + # a bug inside the iteration logic (see issue #846). + dim_size = shapes[dim] + dim_is_empty = builder.icmp_unsigned("==", dim_size, zero) + with cgutils.if_unlikely(builder, dim_is_empty): + builder.store(cgutils.true_byte, exhausted) + + self.indices = indices + self.pointers = pointers + self.exhausted = exhausted + + def iternext_specific(self, context, builder, arrty, arr, result): + ndim = arrty.ndim + shapes = cgutils.unpack_tuple(builder, arr.shape, ndim) + strides = cgutils.unpack_tuple(builder, arr.strides, ndim) + indices = self.indices + pointers = self.pointers + + zero = context.get_constant(types.intp, 0) + + bbend = builder.append_basic_block("end") + + # Catch already computed iterator exhaustion + is_exhausted = cgutils.as_bool_bit( + builder, builder.load(self.exhausted) + ) + with cgutils.if_unlikely(builder, is_exhausted): + result.set_valid(False) + builder.branch(bbend) + result.set_valid(True) + + # Current pointer inside last dimension + last_ptr = cgutils.gep_inbounds(builder, pointers, ndim - 1) + ptr = builder.load(last_ptr) + value = load_item(context, builder, arrty, ptr) + if kind == "flat": + result.yield_(value) + else: + # ndenumerate() => yield (indices, value) + idxvals = [ + builder.load( + cgutils.gep_inbounds(builder, indices, dim) + ) + for dim in range(ndim) + ] + idxtuple = cgutils.pack_array(builder, idxvals) + result.yield_( + cgutils.make_anonymous_struct( + builder, [idxtuple, value] + ) + ) + + # Update indices and pointers by walking from inner + # dimension to outer. + for dim in reversed(range(ndim)): + idxptr = cgutils.gep_inbounds(builder, indices, dim) + idx = cgutils.increment_index(builder, builder.load(idxptr)) + + count = shapes[dim] + stride = strides[dim] + in_bounds = builder.icmp_signed("<", idx, count) + with cgutils.if_likely(builder, in_bounds): + # Index is valid => pointer can simply be incremented. + builder.store(idx, idxptr) + ptrptr = cgutils.gep_inbounds(builder, pointers, dim) + ptr = builder.load(ptrptr) + ptr = cgutils.pointer_add(builder, ptr, stride) + builder.store(ptr, ptrptr) + # Reset pointers in inner dimensions + for inner_dim in range(dim + 1, ndim): + ptrptr = cgutils.gep_inbounds( + builder, pointers, inner_dim + ) + builder.store(ptr, ptrptr) + builder.branch(bbend) + # Reset index and continue with next dimension + builder.store(zero, idxptr) + + # End of array + builder.store(cgutils.true_byte, self.exhausted) + builder.branch(bbend) + + builder.position_at_end(bbend) + + def _ptr_for_index(self, context, builder, arrty, arr, index): + ndim = arrty.ndim + shapes = cgutils.unpack_tuple(builder, arr.shape, count=ndim) + strides = cgutils.unpack_tuple(builder, arr.strides, count=ndim) + + # First convert the flattened index into a regular n-dim index + indices = [] + for dim in reversed(range(ndim)): + indices.append(builder.urem(index, shapes[dim])) + index = builder.udiv(index, shapes[dim]) + indices.reverse() + + ptr = cgutils.get_item_pointer2( + context, + builder, + arr.data, + shapes, + strides, + arrty.layout, + indices, + ) + return ptr + + def getitem(self, context, builder, arrty, arr, index): + ptr = self._ptr_for_index(context, builder, arrty, arr, index) + return load_item(context, builder, arrty, ptr) + + def setitem(self, context, builder, arrty, arr, index, value): + ptr = self._ptr_for_index(context, builder, arrty, arr, index) + store_item(context, builder, arrty, value, ptr) + + return FlatIter + + +@lower_getattr(types.Array, "flat") +def make_array_flatiter(context, builder, arrty, arr): + flatitercls = make_array_flat_cls(types.NumpyFlatType(arrty)) + flatiter = flatitercls(context, builder) + + flatiter.array = arr + + arrcls = context.make_array(arrty) + arr = arrcls(context, builder, ref=flatiter._get_ptr_by_name("array")) + + flatiter.init_specific(context, builder, arrty, arr) + + res = flatiter._getvalue() + return impl_ret_borrowed(context, builder, types.NumpyFlatType(arrty), res) + + +@lower("iternext", types.NumpyFlatType) +@iternext_impl(RefType.BORROWED) +def iternext_numpy_flatiter(context, builder, sig, args, result): + [flatiterty] = sig.args + [flatiter] = args + + flatitercls = make_array_flat_cls(flatiterty) + flatiter = flatitercls(context, builder, value=flatiter) + + arrty = flatiterty.array_type + arrcls = context.make_array(arrty) + arr = arrcls(context, builder, value=flatiter.array) + + flatiter.iternext_specific(context, builder, arrty, arr, result) + + +@lower(operator.getitem, types.NumpyFlatType, types.Integer) +def iternext_numpy_getitem(context, builder, sig, args): + flatiterty = sig.args[0] + flatiter, index = args + + flatitercls = make_array_flat_cls(flatiterty) + flatiter = flatitercls(context, builder, value=flatiter) + + arrty = flatiterty.array_type + arrcls = context.make_array(arrty) + arr = arrcls(context, builder, value=flatiter.array) + + res = flatiter.getitem(context, builder, arrty, arr, index) + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower(operator.setitem, types.NumpyFlatType, types.Integer, types.Any) +def iternext_numpy_getitem_any(context, builder, sig, args): + flatiterty = sig.args[0] + flatiter, index, value = args + + flatitercls = make_array_flat_cls(flatiterty) + flatiter = flatitercls(context, builder, value=flatiter) + + arrty = flatiterty.array_type + arrcls = context.make_array(arrty) + arr = arrcls(context, builder, value=flatiter.array) + + flatiter.setitem(context, builder, arrty, arr, index, value) + return context.get_dummy_value() + + +@lower(len, types.NumpyFlatType) +def iternext_numpy_getitem_flat(context, builder, sig, args): + flatiterty = sig.args[0] + flatitercls = make_array_flat_cls(flatiterty) + flatiter = flatitercls(context, builder, value=args[0]) + + arrcls = context.make_array(flatiterty.array_type) + arr = arrcls(context, builder, value=flatiter.array) + return arr.nitems + + +@lower(np.ndenumerate, types.Array) +def make_array_ndenumerate(context, builder, sig, args): + (arrty,) = sig.args + (arr,) = args + nditercls = make_array_ndenumerate_cls(types.NumpyNdEnumerateType(arrty)) + nditer = nditercls(context, builder) + + nditer.array = arr + + arrcls = context.make_array(arrty) + arr = arrcls(context, builder, ref=nditer._get_ptr_by_name("array")) + + nditer.init_specific(context, builder, arrty, arr) + + res = nditer._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower("iternext", types.NumpyNdEnumerateType) +@iternext_impl(RefType.BORROWED) +def iternext_numpy_nditer(context, builder, sig, args, result): + [nditerty] = sig.args + [nditer] = args + + nditercls = make_array_ndenumerate_cls(nditerty) + nditer = nditercls(context, builder, value=nditer) + + arrty = nditerty.array_type + arrcls = context.make_array(arrty) + arr = arrcls(context, builder, value=nditer.array) + + nditer.iternext_specific(context, builder, arrty, arr, result) + + +@lower(pndindex, types.VarArg(types.Integer)) +@lower(np.ndindex, types.VarArg(types.Integer)) +def make_array_ndindex(context, builder, sig, args): + """ndindex(*shape)""" + shape = [ + context.cast(builder, arg, argty, types.intp) + for argty, arg in zip(sig.args, args) + ] + + nditercls = make_ndindex_cls(types.NumpyNdIndexType(len(shape))) + nditer = nditercls(context, builder) + nditer.init_specific(context, builder, shape) + + res = nditer._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower(pndindex, types.BaseTuple) +@lower(np.ndindex, types.BaseTuple) +def make_array_ndindex_tuple(context, builder, sig, args): + """ndindex(shape)""" + ndim = sig.return_type.ndim + if ndim > 0: + idxty = sig.args[0].dtype + tup = args[0] + + shape = cgutils.unpack_tuple(builder, tup, ndim) + shape = [context.cast(builder, idx, idxty, types.intp) for idx in shape] + else: + shape = [] + + nditercls = make_ndindex_cls(types.NumpyNdIndexType(len(shape))) + nditer = nditercls(context, builder) + nditer.init_specific(context, builder, shape) + + res = nditer._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +@lower("iternext", types.NumpyNdIndexType) +@iternext_impl(RefType.BORROWED) +def iternext_numpy_ndindex(context, builder, sig, args, result): + [nditerty] = sig.args + [nditer] = args + + nditercls = make_ndindex_cls(nditerty) + nditer = nditercls(context, builder, value=nditer) + + nditer.iternext_specific(context, builder, result) + + +@lower(np.nditer, types.Any) +def make_array_nditer(context, builder, sig, args): + """ + nditer(...) + """ + nditerty = sig.return_type + arrtys = nditerty.arrays + + if isinstance(sig.args[0], types.BaseTuple): + arrays = cgutils.unpack_tuple(builder, args[0]) + else: + arrays = [args[0]] + + nditer = make_nditer_cls(nditerty)(context, builder) + nditer.init_specific(context, builder, arrtys, arrays) + + res = nditer._getvalue() + return impl_ret_borrowed(context, builder, nditerty, res) + + +@lower("iternext", types.NumpyNdIterType) +@iternext_impl(RefType.BORROWED) +def iternext_numpy_nditer2(context, builder, sig, args, result): + [nditerty] = sig.args + [nditer] = args + + nditer = make_nditer_cls(nditerty)(context, builder, value=nditer) + nditer.iternext_specific(context, builder, result) + + +@lower(operator.eq, types.DType, types.DType) +def dtype_eq_impl(context, builder, sig, args): + arg1, arg2 = sig.args + res = ir.Constant(ir.IntType(1), int(arg1 == arg2)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ------------------------------------------------------------------------------ +# Numpy array constructors + + +def _empty_nd_impl(context, builder, arrtype, shapes): + """Utility function used for allocating a new array during LLVM code + generation (lowering). Given a target context, builder, array + type, and a tuple or list of lowered dimension sizes, returns a + LLVM value pointing at a Numba runtime allocated array. + """ + arycls = make_array(arrtype) + ary = arycls(context, builder) + + datatype = context.get_data_type(arrtype.dtype) + itemsize = context.get_constant(types.intp, get_itemsize(context, arrtype)) + + # compute array length + arrlen = context.get_constant(types.intp, 1) + overflow = Constant(ir.IntType(1), 0) + for s in shapes: + arrlen_mult = builder.smul_with_overflow(arrlen, s) + arrlen = builder.extract_value(arrlen_mult, 0) + overflow = builder.or_(overflow, builder.extract_value(arrlen_mult, 1)) + + if arrtype.ndim == 0: + strides = () + elif arrtype.layout == "C": + strides = [itemsize] + for dimension_size in reversed(shapes[1:]): + strides.append(builder.mul(strides[-1], dimension_size)) + strides = tuple(reversed(strides)) + elif arrtype.layout == "F": + strides = [itemsize] + for dimension_size in shapes[:-1]: + strides.append(builder.mul(strides[-1], dimension_size)) + strides = tuple(strides) + else: + raise NotImplementedError( + "Don't know how to allocate array with layout '{0}'.".format( + arrtype.layout + ) + ) + + # Check overflow, numpy also does this after checking order + allocsize_mult = builder.smul_with_overflow(arrlen, itemsize) + allocsize = builder.extract_value(allocsize_mult, 0) + overflow = builder.or_(overflow, builder.extract_value(allocsize_mult, 1)) + + with builder.if_then(overflow, likely=False): + # Raise same error as numpy, see: + # https://github.com/numpy/numpy/blob/2a488fe76a0f732dc418d03b452caace161673da/numpy/core/src/multiarray/ctors.c#L1095-L1101 # noqa: E501 + context.call_conv.return_user_exc( + builder, + ValueError, + ( + "array is too big; `arr.size * arr.dtype.itemsize` is larger than" + " the maximum possible size.", + ), + ) + + dtype = arrtype.dtype + align_val = context.get_preferred_array_alignment(dtype) + align = context.get_constant(types.uint32, align_val) + args = (context.get_dummy_value(), allocsize, align) + + mip = types.MemInfoPointer(types.voidptr) + arytypeclass = types.TypeRef(type(arrtype)) + argtypes = signature(mip, arytypeclass, types.intp, types.uint32) + + meminfo = context.compile_internal(builder, _call_allocator, argtypes, args) + data = context.nrt.meminfo_data(builder, meminfo) + + intp_t = context.get_value_type(types.intp) + shape_array = cgutils.pack_array(builder, shapes, ty=intp_t) + strides_array = cgutils.pack_array(builder, strides, ty=intp_t) + + populate_array( + ary, + data=builder.bitcast(data, datatype.as_pointer()), + shape=shape_array, + strides=strides_array, + itemsize=itemsize, + meminfo=meminfo, + ) + + return ary + + +@overload_classmethod(types.Array, "_allocate") +def _ol_array_allocate(cls, allocsize, align): + """Implements a Numba-only default target (cpu) classmethod on the array + type. + """ + + def impl(cls, allocsize, align): + return intrin_alloc(allocsize, align) + + return impl + + +def _call_allocator(arrtype, size, align): + """Trampoline to call the intrinsic used for allocation""" + return arrtype._allocate(size, align) + + +@intrinsic +def intrin_alloc(typingctx, allocsize, align): + """Intrinsic to call into the allocator for Array""" + + def codegen(context, builder, signature, args): + [allocsize, align] = args + meminfo = context.nrt.meminfo_alloc_aligned(builder, allocsize, align) + return meminfo + + mip = types.MemInfoPointer(types.voidptr) # return untyped pointer + sig = signature(mip, allocsize, align) + return sig, codegen + + +def _parse_shape(context, builder, ty, val): + """ + Parse the shape argument to an array constructor. + """ + + def safecast_intp(context, builder, src_t, src): + """Cast src to intp only if value can be maintained""" + intp_t = context.get_value_type(types.intp) + intp_width = intp_t.width + intp_ir = ir.IntType(intp_width) + maxval = Constant(intp_ir, ((1 << intp_width - 1) - 1)) + if src_t.width < intp_width: + res = builder.sext(src, intp_ir) + elif src_t.width >= intp_width: + is_larger = builder.icmp_signed(">", src, maxval) + with builder.if_then(is_larger, likely=False): + context.call_conv.return_user_exc( + builder, + ValueError, + ("Cannot safely convert value to intp",), + ) + if src_t.width > intp_width: + res = builder.trunc(src, intp_ir) + else: + res = src + return res + + if isinstance(ty, types.Integer): + ndim = 1 + passed_shapes = [context.cast(builder, val, ty, types.intp)] + else: + assert isinstance(ty, types.BaseTuple) + ndim = ty.count + passed_shapes = cgutils.unpack_tuple(builder, val, count=ndim) + + shapes = [] + for s in passed_shapes: + shapes.append(safecast_intp(context, builder, s.type, s)) + + zero = context.get_constant_generic(builder, types.intp, 0) + for dim in range(ndim): + is_neg = builder.icmp_signed("<", shapes[dim], zero) + with cgutils.if_unlikely(builder, is_neg): + context.call_conv.return_user_exc( + builder, ValueError, ("negative dimensions not allowed",) + ) + + return shapes + + +def _parse_empty_args(context, builder, sig, args): + """ + Parse the arguments of a np.empty(), np.zeros() or np.ones() call. + """ + arrshapetype = sig.args[0] + arrshape = args[0] + arrtype = sig.return_type + return arrtype, _parse_shape(context, builder, arrshapetype, arrshape) + + +def _parse_empty_like_args(context, builder, sig, args): + """ + Parse the arguments of a np.empty_like(), np.zeros_like() or + np.ones_like() call. + """ + arytype = sig.args[0] + if isinstance(arytype, types.Array): + ary = make_array(arytype)(context, builder, value=args[0]) + shapes = cgutils.unpack_tuple(builder, ary.shape, count=arytype.ndim) + return sig.return_type, shapes + else: + return sig.return_type, () + + +def _check_const_str_dtype(fname, dtype): + if isinstance(dtype, types.UnicodeType): + msg = f"If np.{fname} dtype is a string it must be a string constant." + raise errors.TypingError(msg) + + +@intrinsic +def numpy_empty_nd(tyctx, ty_shape, ty_dtype, ty_retty_ref): + ty_retty = ty_retty_ref.instance_type + sig = ty_retty(ty_shape, ty_dtype, ty_retty_ref) + + def codegen(cgctx, builder, sig, llargs): + arrtype, shapes = _parse_empty_args(cgctx, builder, sig, llargs) + ary = _empty_nd_impl(cgctx, builder, arrtype, shapes) + return ary._getvalue() + + return sig, codegen + + +@overload(np.empty) +def ol_np_empty(shape, dtype=float): + _check_const_str_dtype("empty", dtype) + if ( + dtype is float + or (isinstance(dtype, types.Function) and dtype.typing_key is float) + or is_nonelike(dtype) + ): # default + nb_dtype = types.double + else: + nb_dtype = ty_parse_dtype(dtype) + + ndim = ty_parse_shape(shape) + if nb_dtype is not None and ndim is not None: + retty = types.Array(dtype=nb_dtype, ndim=ndim, layout="C") + + def impl(shape, dtype=float): + return numpy_empty_nd(shape, dtype, retty) + + return impl + else: + msg = f"Cannot parse input types to function np.empty({shape}, {dtype})" + raise errors.TypingError(msg) + + +@intrinsic +def numpy_empty_like_nd(tyctx, ty_prototype, ty_dtype, ty_retty_ref): + ty_retty = ty_retty_ref.instance_type + sig = ty_retty(ty_prototype, ty_dtype, ty_retty_ref) + + def codegen(cgctx, builder, sig, llargs): + arrtype, shapes = _parse_empty_like_args(cgctx, builder, sig, llargs) + ary = _empty_nd_impl(cgctx, builder, arrtype, shapes) + return ary._getvalue() + + return sig, codegen + + +@overload(np.empty_like) +def ol_np_empty_like(arr, dtype=None): + _check_const_str_dtype("empty_like", dtype) + if not is_nonelike(dtype): + nb_dtype = ty_parse_dtype(dtype) + elif isinstance(arr, types.Array): + nb_dtype = arr.dtype + else: + nb_dtype = arr + if nb_dtype is not None: + if isinstance(arr, types.Array): + layout = arr.layout if arr.layout != "A" else "C" + retty = arr.copy(dtype=nb_dtype, layout=layout, readonly=False) + else: + retty = types.Array(nb_dtype, 0, "C") + else: + msg = ( + "Cannot parse input types to function " + f"np.empty_like({arr}, {dtype})" + ) + raise errors.TypingError(msg) + + def impl(arr, dtype=None): + return numpy_empty_like_nd(arr, dtype, retty) + + return impl + + +@intrinsic +def _zero_fill_array_method(tyctx, self): + sig = types.none(self) + + def codegen(cgctx, builder, sig, llargs): + ary = make_array(sig.args[0])(cgctx, builder, llargs[0]) + cgutils.memset( + builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0 + ) + + return sig, codegen + + +@overload_method(types.Array, "_zero_fill") +def ol_array_zero_fill(self): + """Adds a `._zero_fill` method to zero fill an array using memset.""" + + def impl(self): + _zero_fill_array_method(self) + + return impl + + +@overload(np.zeros) +def ol_np_zeros(shape, dtype=float): + _check_const_str_dtype("zeros", dtype) + + def impl(shape, dtype=float): + arr = np.empty(shape, dtype=dtype) + arr._zero_fill() + return arr + + return impl + + +@overload(np.zeros_like) +def ol_np_zeros_like(a, dtype=None): + _check_const_str_dtype("zeros_like", dtype) + + # NumPy uses 'a' as the arg name for the array-like + def impl(a, dtype=None): + arr = np.empty_like(a, dtype=dtype) + arr._zero_fill() + return arr + + return impl + + +@overload(np.ones_like) +def ol_np_ones_like(a, dtype=None): + _check_const_str_dtype("ones_like", dtype) + + # NumPy uses 'a' as the arg name for the array-like + def impl(a, dtype=None): + arr = np.empty_like(a, dtype=dtype) + arr_flat = arr.flat + for idx in range(len(arr_flat)): + arr_flat[idx] = 1 + return arr + + return impl + + +@overload(np.full) +def impl_np_full(shape, fill_value, dtype=None): + _check_const_str_dtype("full", dtype) + if not is_nonelike(dtype): + nb_dtype = ty_parse_dtype(dtype) + else: + nb_dtype = fill_value + + def full(shape, fill_value, dtype=None): + arr = np.empty(shape, nb_dtype) + arr_flat = arr.flat + for idx in range(len(arr_flat)): + arr_flat[idx] = fill_value + return arr + + return full + + +@overload(np.full_like) +def impl_np_full_like(a, fill_value, dtype=None): + _check_const_str_dtype("full_like", dtype) + + def full_like(a, fill_value, dtype=None): + arr = np.empty_like(a, dtype) + arr_flat = arr.flat + for idx in range(len(arr_flat)): + arr_flat[idx] = fill_value + return arr + + return full_like + + +@overload(np.ones) +def ol_np_ones(shape, dtype=None): + # for some reason the NumPy default for dtype is None in the source but + # ends up as np.float64 by definition. + _check_const_str_dtype("ones", dtype) + + def impl(shape, dtype=None): + arr = np.empty(shape, dtype=dtype) + arr_flat = arr.flat + for idx in range(len(arr_flat)): + arr_flat[idx] = 1 + return arr + + return impl + + +@overload(np.identity) +def impl_np_identity(n, dtype=None): + _check_const_str_dtype("identity", dtype) + if not is_nonelike(dtype): + nb_dtype = ty_parse_dtype(dtype) + else: + nb_dtype = types.double + + def identity(n, dtype=None): + arr = np.zeros((n, n), nb_dtype) + for i in range(n): + arr[i, i] = 1 + return arr + + return identity + + +def _eye_none_handler(N, M): + pass + + +@extending.overload(_eye_none_handler) +def _eye_none_handler_impl(N, M): + if isinstance(M, types.NoneType): + + def impl(N, M): + return N + else: + + def impl(N, M): + return M + + return impl + + +@extending.overload(np.eye) +def numpy_eye(N, M=None, k=0, dtype=float): + if dtype is None or isinstance(dtype, types.NoneType): + dt = np.dtype(float) + elif isinstance(dtype, (types.DTypeSpec, types.Number)): + # dtype or instance of dtype + dt = as_dtype(getattr(dtype, "dtype", dtype)) + else: + dt = np.dtype(dtype) + + def impl(N, M=None, k=0, dtype=float): + _M = _eye_none_handler(N, M) + arr = np.zeros((N, _M), dt) + if k >= 0: + d = min(N, _M - k) + for i in range(d): + arr[i, i + k] = 1 + else: + d = min(N + k, _M) + for i in range(d): + arr[i - k, i] = 1 + return arr + + return impl + + +@overload(np.diag) +def impl_np_diag(v, k=0): + if not type_can_asarray(v): + raise errors.TypingError('The argument "v" must be array-like') + + if isinstance(v, types.Array): + if v.ndim not in (1, 2): + raise errors.NumbaTypeError("Input must be 1- or 2-d.") + + def diag_impl(v, k=0): + if v.ndim == 1: + s = v.shape + n = s[0] + abs(k) + ret = np.zeros((n, n), v.dtype) + if k >= 0: + for i in range(n - k): + ret[i, k + i] = v[i] + else: + for i in range(n + k): + ret[i - k, i] = v[i] + return ret + else: # 2-d + rows, cols = v.shape + if k < 0: + rows = rows + k + if k > 0: + cols = cols - k + n = max(min(rows, cols), 0) + ret = np.empty(n, v.dtype) + if k >= 0: + for i in range(n): + ret[i] = v[i, k + i] + else: + for i in range(n): + ret[i] = v[i - k, i] + return ret + + return diag_impl + + +@overload(np.indices) +def numpy_indices(dimensions): + if not isinstance(dimensions, types.UniTuple): + msg = 'The argument "dimensions" must be a tuple of integers' + raise errors.TypingError(msg) + + if not isinstance(dimensions.dtype, types.Integer): + msg = 'The argument "dimensions" must be a tuple of integers' + raise errors.TypingError(msg) + + N = len(dimensions) + shape = (1,) * N + + def impl(dimensions): + res = np.empty((N,) + dimensions, dtype=np.int64) + i = 0 + for dim in dimensions: + idx = np.arange(dim, dtype=np.int64).reshape( + tuple_setitem(shape, i, dim) + ) + res[i] = idx + i += 1 + + return res + + return impl + + +@overload(np.diagflat) +def numpy_diagflat(v, k=0): + if not type_can_asarray(v): + msg = 'The argument "v" must be array-like' + raise errors.TypingError(msg) + + if not isinstance(k, (int, types.Integer)): + msg = 'The argument "k" must be an integer' + raise errors.TypingError(msg) + + def impl(v, k=0): + v = np.asarray(v) + v = v.ravel() + s = len(v) + abs_k = abs(k) + n = s + abs_k + res = np.zeros((n, n), v.dtype) + i = np.maximum(0, -k) + j = np.maximum(0, k) + for t in range(s): + res[i + t, j + t] = v[t] + + return res + + return impl + + +def generate_getitem_setitem_with_axis(ndim, kind): + assert kind in ("getitem", "setitem") + + if kind == "getitem": + fn = """ + def _getitem(a, idx, axis): + if axis == 0: + return a[idx, ...] + """ + for i in range(1, ndim): + lst = (":",) * i + fn += f""" + elif axis == {i}: + return a[{", ".join(lst)}, idx, ...] + """ + else: + fn = """ + def _setitem(a, idx, axis, vals): + if axis == 0: + a[idx, ...] = vals + """ + + for i in range(1, ndim): + lst = (":",) * i + fn += f""" + elif axis == {i}: + a[{", ".join(lst)}, idx, ...] = vals + """ + + fn = textwrap.dedent(fn) + exec(fn, globals()) + fn = globals()[f"_{kind}"] + return register_jitable(fn) + + +@overload(np.take) +@overload_method(types.Array, "take") +def numpy_take(a, indices, axis=None): + if cgutils.is_nonelike(axis): + if isinstance(a, types.Array) and isinstance(indices, types.Integer): + + def take_impl(a, indices, axis=None): + if indices > (a.size - 1) or indices < -a.size: + raise IndexError("Index out of bounds") + return a.ravel()[indices] + + return take_impl + + if isinstance(a, types.Array) and isinstance(indices, types.Array): + F_order = indices.layout == "F" + + def take_impl(a, indices, axis=None): + ret = np.empty(indices.size, dtype=a.dtype) + if F_order: + walker = indices.copy() # get C order + else: + walker = indices + it = np.nditer(walker) + i = 0 + flat = a.ravel() + for x in it: + if x > (a.size - 1) or x < -a.size: + raise IndexError("Index out of bounds") + ret[i] = flat[x] + i = i + 1 + return ret.reshape(indices.shape) + + return take_impl + + if isinstance(a, types.Array) and isinstance( + indices, (types.List, types.BaseTuple) + ): + + def take_impl(a, indices, axis=None): + convert = np.array(indices) + return np.take(a, convert) + + return take_impl + else: + if isinstance(a, types.Array) and isinstance(indices, types.Integer): + t = (0,) * (a.ndim - 1) + + # np.squeeze is too hard to implement in Numba as the tuple "t" + # needs to be allocated beforehand we don't know it's size until + # code gets executed. + @register_jitable + def _squeeze(r, axis): + tup = tuple(t) + j = 0 + assert axis < len(r.shape) and r.shape[axis] == 1, r.shape + for idx in range(len(r.shape)): + s = r.shape[idx] + if idx != axis: + tup = tuple_setitem(tup, j, s) + j += 1 + return r.reshape(tup) + + def take_impl(a, indices, axis=None): + r = np.take(a, (indices,), axis=axis) + if a.ndim == 1: + return r[0] + if axis < 0: + axis += a.ndim + return _squeeze(r, axis) + + return take_impl + + if isinstance(a, types.Array) and isinstance( + indices, (types.Array, types.List, types.BaseTuple) + ): + ndim = a.ndim + + _getitem = generate_getitem_setitem_with_axis(ndim, "getitem") + _setitem = generate_getitem_setitem_with_axis(ndim, "setitem") + + def take_impl(a, indices, axis=None): + if axis < 0: + axis += a.ndim + + if axis < 0 or axis >= a.ndim: + msg = ( + f"axis {axis} is out of bounds for array " + f"of dimension {a.ndim}" + ) + raise ValueError(msg) + + shape = tuple_setitem(a.shape, axis, len(indices)) + out = np.empty(shape, dtype=a.dtype) + for i in range(len(indices)): + y = _getitem(a, indices[i], axis) + _setitem(out, i, axis, y) + return out + + return take_impl + + +def _arange_dtype(*args): + bounds = [a for a in args if not isinstance(a, types.NoneType)] + + if any(isinstance(a, types.Complex) for a in bounds): + dtype = types.complex128 + elif any(isinstance(a, types.Float) for a in bounds): + dtype = types.float64 + else: + # `np.arange(10).dtype` is always `np.dtype(int)`, aka `np.int_`, which + # in all released versions of numpy corresponds to the C `long` type. + # Windows 64 is broken by default here because Numba (as of 0.47) does + # not differentiate between Python and NumPy integers, so a `typeof(1)` + # on w64 is `int64`, i.e. `intp`. This means an arange() will + # be typed as arange(int64) and the following will yield int64 opposed + # to int32. Example: without a load of analysis to work out of the args + # were wrapped in NumPy int*() calls it's not possible to detect the + # difference between `np.arange(10)` and `np.arange(np.int64(10)`. + NPY_TY = getattr(types, "int%s" % (8 * np.dtype(int).itemsize)) + + # unliteral these types such that `max` works. + unliteral_bounds = [types.unliteral(x) for x in bounds] + dtype = max( + unliteral_bounds + + [ + NPY_TY, + ] + ) + + return dtype + + +@overload(np.arange) +def np_arange(start, /, stop=None, step=None, dtype=None): + if isinstance(stop, types.Optional): + stop = stop.type + if isinstance(step, types.Optional): + step = step.type + if isinstance(dtype, types.Optional): + dtype = dtype.type + + if stop is None: + stop = types.none + if step is None: + step = types.none + if dtype is None: + dtype = types.none + + if ( + not isinstance(start, types.Number) + or not isinstance(stop, (types.NoneType, types.Number)) + or not isinstance(step, (types.NoneType, types.Number)) + or not isinstance(dtype, (types.NoneType, types.DTypeSpec)) + ): + return + + if isinstance(dtype, types.NoneType): + true_dtype = _arange_dtype(start, stop, step) + else: + true_dtype = dtype.dtype + + use_complex = any( + [isinstance(x, types.Complex) for x in (start, stop, step)] + ) + + start_value = getattr(start, "literal_value", None) + stop_value = getattr(stop, "literal_value", None) + step_value = getattr(step, "literal_value", None) + + def impl(start, /, stop=None, step=None, dtype=None): + # Allow for improved performance if given literal arguments. + lit_start = start_value if start_value is not None else start + lit_stop = stop_value if stop_value is not None else stop + lit_step = step_value if step_value is not None else step + + _step = lit_step if lit_step is not None else 1 + if lit_stop is None: + _start, _stop = 0, lit_start + else: + _start, _stop = lit_start, lit_stop + + if _step == 0: + raise ValueError("Maximum allowed size exceeded") + + nitems_c = (_stop - _start) / _step + nitems_r = int(math.ceil(nitems_c.real)) + + # Binary operator needed for compiler branch pruning. + if use_complex is True: + nitems_i = int(math.ceil(nitems_c.imag)) + nitems = max(min(nitems_i, nitems_r), 0) + else: + nitems = max(nitems_r, 0) + arr = np.empty(nitems, true_dtype) + val = _start + for i in range(nitems): + arr[i] = val + (i * _step) + return arr + + return impl + + +@overload(np.linspace) +def numpy_linspace(start, stop, num=50): + if not all(isinstance(arg, types.Number) for arg in [start, stop]): + return + + if not isinstance(num, (int, types.Integer)): + msg = 'The argument "num" must be an integer' + raise errors.TypingError(msg) + + if any(isinstance(arg, types.Complex) for arg in [start, stop]): + dtype = types.complex128 + else: + dtype = types.float64 + + # Implementation based on https://github.com/numpy/numpy/blob/v1.20.0/numpy/core/function_base.py#L24 # noqa: E501 + def linspace(start, stop, num=50): + arr = np.empty(num, dtype) + # The multiply by 1.0 mirrors + # https://github.com/numpy/numpy/blob/v1.20.0/numpy/core/function_base.py#L125-L128 # noqa: E501 + # the side effect of this is important... start and stop become the same + # type as `dtype` i.e. 64/128 bits wide (float/complex). This is + # important later when used in the `np.divide`. + start = start * 1.0 + stop = stop * 1.0 + if num == 0: + return arr + div = num - 1 + if div > 0: + delta = stop - start + step = np.divide(delta, div) + for i in range(0, num): + arr[i] = start + (i * step) + else: + arr[0] = start + if num > 1: + arr[-1] = stop + return arr + + return linspace + + +def _array_copy(context, builder, sig, args): + """ + Array copy. + """ + arytype = sig.args[0] + ary = make_array(arytype)(context, builder, value=args[0]) + shapes = cgutils.unpack_tuple(builder, ary.shape) + + rettype = sig.return_type + ret = _empty_nd_impl(context, builder, rettype, shapes) + + src_data = ary.data + dest_data = ret.data + + assert rettype.layout in "CF" + if arytype.layout == rettype.layout: + # Fast path: memcpy + cgutils.raw_memcpy( + builder, dest_data, src_data, ary.nitems, ary.itemsize, align=1 + ) + + else: + src_strides = cgutils.unpack_tuple(builder, ary.strides) + dest_strides = cgutils.unpack_tuple(builder, ret.strides) + intp_t = context.get_value_type(types.intp) + + with cgutils.loop_nest(builder, shapes, intp_t) as indices: + src_ptr = cgutils.get_item_pointer2( + context, + builder, + src_data, + shapes, + src_strides, + arytype.layout, + indices, + ) + dest_ptr = cgutils.get_item_pointer2( + context, + builder, + dest_data, + shapes, + dest_strides, + rettype.layout, + indices, + ) + builder.store(builder.load(src_ptr), dest_ptr) + + return impl_ret_new_ref(context, builder, sig.return_type, ret._getvalue()) + + +@intrinsic +def _array_copy_intrinsic(typingctx, a): + assert isinstance(a, types.Array) + layout = "F" if a.layout == "F" else "C" + ret = a.copy(layout=layout, readonly=False) + sig = ret(a) + return sig, _array_copy + + +@lower("array.copy", types.Array) +def array_copy(context, builder, sig, args): + return _array_copy(context, builder, sig, args) + + +@overload(np.copy) +def impl_numpy_copy(a): + if isinstance(a, types.Array): + + def numpy_copy(a): + return _array_copy_intrinsic(a) + + return numpy_copy + + +def _as_layout_array(context, builder, sig, args, output_layout): + """ + Common logic for layout conversion function; + e.g. ascontiguousarray and asfortranarray + """ + retty = sig.return_type + aryty = sig.args[0] + assert retty.layout == output_layout, "return-type has incorrect layout" + + if aryty.ndim == 0: + # 0-dim input => asfortranarray() returns a 1-dim array + assert retty.ndim == 1 + ary = make_array(aryty)(context, builder, value=args[0]) + ret = make_array(retty)(context, builder) + + shape = context.get_constant_generic( + builder, + types.UniTuple(types.intp, 1), + (1,), + ) + strides = context.make_tuple( + builder, types.UniTuple(types.intp, 1), (ary.itemsize,) + ) + populate_array( + ret, ary.data, shape, strides, ary.itemsize, ary.meminfo, ary.parent + ) + return impl_ret_borrowed(context, builder, retty, ret._getvalue()) + + elif retty.layout == aryty.layout or ( + aryty.ndim == 1 and aryty.layout in "CF" + ): + # 1-dim contiguous input => return the same array + return impl_ret_borrowed(context, builder, retty, args[0]) + + else: + if aryty.layout == "A": + # There's still chance the array is in contiguous layout, + # just that we don't know at compile time. + # We can do a runtime check. + + # Prepare and call is_contiguous or is_fortran + assert output_layout in "CF" + check_func = is_contiguous if output_layout == "C" else is_fortran + is_contig = _call_contiguous_check( + check_func, context, builder, aryty, args[0] + ) + with builder.if_else(is_contig) as (then, orelse): + # If the array is already contiguous, just return it + with then: + out_then = impl_ret_borrowed( + context, builder, retty, args[0] + ) + then_blk = builder.block + # Otherwise, copy to a new contiguous region + with orelse: + out_orelse = _array_copy(context, builder, sig, args) + orelse_blk = builder.block + # Phi node for the return value + ret_phi = builder.phi(out_then.type) + ret_phi.add_incoming(out_then, then_blk) + ret_phi.add_incoming(out_orelse, orelse_blk) + return ret_phi + + else: + # Return a copy with the right layout + return _array_copy(context, builder, sig, args) + + +@intrinsic +def _as_layout_array_intrinsic(typingctx, a, output_layout): + if not isinstance(output_layout, types.StringLiteral): + raise errors.RequireLiteralValue(output_layout) + + ret = a.copy(layout=output_layout.literal_value, ndim=max(a.ndim, 1)) + sig = ret(a, output_layout) + + return sig, lambda c, b, s, a: _as_layout_array( + c, b, s, a, output_layout=output_layout.literal_value + ) + + +@overload(np.ascontiguousarray) +def array_ascontiguousarray(a): + if not type_can_asarray(a): + raise errors.TypingError('The argument "a" must be array-like') + + if isinstance( + a, + ( + types.Number, + types.Boolean, + ), + ): + + def impl(a): + return np.ascontiguousarray(np.array(a)) + elif isinstance(a, types.Array): + + def impl(a): + return _as_layout_array_intrinsic(a, "C") + + return impl + + +@overload(np.asfortranarray) +def array_asfortranarray(a): + if not type_can_asarray(a): + raise errors.TypingError('The argument "a" must be array-like') + + if isinstance( + a, + ( + types.Number, + types.Boolean, + ), + ): + + def impl(a): + return np.asfortranarray(np.array(a)) + + return impl + elif isinstance(a, types.Array): + + def impl(a): + return _as_layout_array_intrinsic(a, "F") + + return impl + + +@lower("array.astype", types.Array, types.DTypeSpec) +@lower("array.astype", types.Array, types.StringLiteral) +def array_astype(context, builder, sig, args): + arytype = sig.args[0] + ary = make_array(arytype)(context, builder, value=args[0]) + shapes = cgutils.unpack_tuple(builder, ary.shape) + + rettype = sig.return_type + ret = _empty_nd_impl(context, builder, rettype, shapes) + + src_data = ary.data + dest_data = ret.data + + src_strides = cgutils.unpack_tuple(builder, ary.strides) + dest_strides = cgutils.unpack_tuple(builder, ret.strides) + intp_t = context.get_value_type(types.intp) + + with cgutils.loop_nest(builder, shapes, intp_t) as indices: + src_ptr = cgutils.get_item_pointer2( + context, + builder, + src_data, + shapes, + src_strides, + arytype.layout, + indices, + ) + dest_ptr = cgutils.get_item_pointer2( + context, + builder, + dest_data, + shapes, + dest_strides, + rettype.layout, + indices, + ) + item = load_item(context, builder, arytype, src_ptr) + item = context.cast(builder, item, arytype.dtype, rettype.dtype) + store_item(context, builder, rettype, item, dest_ptr) + + return impl_ret_new_ref(context, builder, sig.return_type, ret._getvalue()) + + +@intrinsic +def np_frombuffer(typingctx, buffer, dtype, retty): + ty = retty.instance_type + sig = ty(buffer, dtype, retty) + + def codegen(context, builder, sig, args): + bufty = sig.args[0] + aryty = sig.return_type + + buf = make_array(bufty)(context, builder, value=args[0]) + out_ary_ty = make_array(aryty) + out_ary = out_ary_ty(context, builder) + out_datamodel = out_ary._datamodel + + itemsize = get_itemsize(context, aryty) + ll_itemsize = Constant(buf.itemsize.type, itemsize) + nbytes = builder.mul(buf.nitems, buf.itemsize) + + # Check that the buffer size is compatible + rem = builder.srem(nbytes, ll_itemsize) + is_incompatible = cgutils.is_not_null(builder, rem) + with builder.if_then(is_incompatible, likely=False): + msg = "buffer size must be a multiple of element size" + context.call_conv.return_user_exc(builder, ValueError, (msg,)) + + shape = cgutils.pack_array(builder, [builder.sdiv(nbytes, ll_itemsize)]) + strides = cgutils.pack_array(builder, [ll_itemsize]) + data = builder.bitcast( + buf.data, context.get_value_type(out_datamodel.get_type("data")) + ) + + populate_array( + out_ary, + data=data, + shape=shape, + strides=strides, + itemsize=ll_itemsize, + meminfo=buf.meminfo, + parent=buf.parent, + ) + + res = out_ary._getvalue() + return impl_ret_borrowed(context, builder, sig.return_type, res) + + return sig, codegen + + +@overload(np.frombuffer) +def impl_np_frombuffer(buffer, dtype=float): + _check_const_str_dtype("frombuffer", dtype) + + if not isinstance(buffer, types.Buffer) or buffer.layout != "C": + msg = f'Argument "buffer" must be buffer-like. Got {buffer}' + raise errors.TypingError(msg) + + if ( + dtype is float + or (isinstance(dtype, types.Function) and dtype.typing_key is float) + or is_nonelike(dtype) + ): # default + nb_dtype = types.double + else: + nb_dtype = ty_parse_dtype(dtype) + + if nb_dtype is not None: + retty = types.Array( + dtype=nb_dtype, ndim=1, layout="C", readonly=not buffer.mutable + ) + else: + msg = ( + "Cannot parse input types to function " + f"np.frombuffer({buffer}, {dtype})" + ) + raise errors.TypingError(msg) + + def impl(buffer, dtype=float): + return np_frombuffer(buffer, dtype, retty) + + return impl + + +@overload(carray) +def impl_carray(ptr, shape, dtype=None): + if is_nonelike(dtype): + intrinsic_cfarray = get_cfarray_intrinsic("C", None) + + def impl(ptr, shape, dtype=None): + return intrinsic_cfarray(ptr, shape) + + return impl + elif isinstance(dtype, types.DTypeSpec): + intrinsic_cfarray = get_cfarray_intrinsic("C", dtype) + + def impl(ptr, shape, dtype=None): + return intrinsic_cfarray(ptr, shape) + + return impl + + +@overload(farray) +def impl_farray(ptr, shape, dtype=None): + if is_nonelike(dtype): + intrinsic_cfarray = get_cfarray_intrinsic("F", None) + + def impl(ptr, shape, dtype=None): + return intrinsic_cfarray(ptr, shape) + + return impl + elif isinstance(dtype, types.DTypeSpec): + intrinsic_cfarray = get_cfarray_intrinsic("F", dtype) + + def impl(ptr, shape, dtype=None): + return intrinsic_cfarray(ptr, shape) + + return impl + + +def get_cfarray_intrinsic(layout, dtype_): + @intrinsic + def intrinsic_cfarray(typingctx, ptr, shape): + if ptr is types.voidptr: + ptr_dtype = None + elif isinstance(ptr, types.CPointer): + ptr_dtype = ptr.dtype + else: + msg = f"pointer argument expected, got '{ptr}'" + raise errors.NumbaTypeError(msg) + + if dtype_ is None: + if ptr_dtype is None: + msg = "explicit dtype required for void* argument" + raise errors.NumbaTypeError(msg) + dtype = ptr_dtype + elif isinstance(dtype_, types.DTypeSpec): + dtype = dtype_.dtype + if ptr_dtype is not None and dtype != ptr_dtype: + msg = f"mismatching dtype '{dtype}' for pointer type '{ptr}'" + raise errors.NumbaTypeError(msg) + else: + msg = f"invalid dtype spec '{dtype_}'" + raise errors.NumbaTypeError(msg) + + ndim = ty_parse_shape(shape) + if ndim is None: + msg = f"invalid shape '{shape}'" + raise errors.NumbaTypeError(msg) + + retty = types.Array(dtype, ndim, layout) + sig = signature(retty, ptr, shape) + return sig, np_cfarray + + return intrinsic_cfarray + + +def np_cfarray(context, builder, sig, args): + """ + numba.cuda.np.numpy_support.carray(...) and + numba.cuda.np.numpy_support.farray(...). + """ + ptrty, shapety = sig.args[:2] + ptr, shape = args[:2] + + aryty = sig.return_type + assert aryty.layout in "CF" + + out_ary = make_array(aryty)(context, builder) + + itemsize = get_itemsize(context, aryty) + ll_itemsize = cgutils.intp_t(itemsize) + + if isinstance(shapety, types.BaseTuple): + shapes = cgutils.unpack_tuple(builder, shape) + else: + shapety = (shapety,) + shapes = (shape,) + shapes = [ + context.cast(builder, value, fromty, types.intp) + for fromty, value in zip(shapety, shapes) + ] + + off = ll_itemsize + strides = [] + if aryty.layout == "F": + for s in shapes: + strides.append(off) + off = builder.mul(off, s) + else: + for s in reversed(shapes): + strides.append(off) + off = builder.mul(off, s) + strides.reverse() + + data = builder.bitcast(ptr, context.get_data_type(aryty.dtype).as_pointer()) + + populate_array( + out_ary, + data=data, + shape=shapes, + strides=strides, + itemsize=ll_itemsize, + # Array is not memory-managed + meminfo=None, + ) + + res = out_ary._getvalue() + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +def _get_seq_size(context, builder, seqty, seq): + if isinstance(seqty, types.BaseTuple): + return context.get_constant(types.intp, len(seqty)) + elif isinstance(seqty, types.Sequence): + len_impl = context.get_function( + len, + signature( + types.intp, + seqty, + ), + ) + return len_impl(builder, (seq,)) + else: + assert 0 + + +def _get_borrowing_getitem(context, seqty): + """ + Return a getitem() implementation that doesn't incref its result. + """ + retty = seqty.dtype + getitem_impl = context.get_function( + operator.getitem, signature(retty, seqty, types.intp) + ) + + def wrap(builder, args): + ret = getitem_impl(builder, args) + if context.enable_nrt: + context.nrt.decref(builder, retty, ret) + return ret + + return wrap + + +def compute_sequence_shape(context, builder, ndim, seqty, seq): + """ + Compute the likely shape of a nested sequence (possibly 0d). + """ + intp_t = context.get_value_type(types.intp) + zero = Constant(intp_t, 0) + + def get_first_item(seqty, seq): + if isinstance(seqty, types.BaseTuple): + if len(seqty) == 0: + return None, None + else: + return seqty[0], builder.extract_value(seq, 0) + else: + getitem_impl = _get_borrowing_getitem(context, seqty) + return seqty.dtype, getitem_impl(builder, (seq, zero)) + + # Compute shape by traversing the first element of each nested + # sequence + shapes = [] + innerty, inner = seqty, seq + + for i in range(ndim): + if i > 0: + innerty, inner = get_first_item(innerty, inner) + shapes.append(_get_seq_size(context, builder, innerty, inner)) + + return tuple(shapes) + + +def check_sequence_shape(context, builder, seqty, seq, shapes): + """ + Check the nested sequence matches the given *shapes*. + """ + + def _fail(): + context.call_conv.return_user_exc( + builder, ValueError, ("incompatible sequence shape",) + ) + + def check_seq_size(seqty, seq, shapes): + if len(shapes) == 0: + return + + size = _get_seq_size(context, builder, seqty, seq) + expected = shapes[0] + mismatch = builder.icmp_signed("!=", size, expected) + with builder.if_then(mismatch, likely=False): + _fail() + + if len(shapes) == 1: + return + + if isinstance(seqty, types.Sequence): + getitem_impl = _get_borrowing_getitem(context, seqty) + with cgutils.for_range(builder, size) as loop: + innerty = seqty.dtype + inner = getitem_impl(builder, (seq, loop.index)) + check_seq_size(innerty, inner, shapes[1:]) + + elif isinstance(seqty, types.BaseTuple): + for i in range(len(seqty)): + innerty = seqty[i] + inner = builder.extract_value(seq, i) + check_seq_size(innerty, inner, shapes[1:]) + + else: + assert 0, seqty + + check_seq_size(seqty, seq, shapes) + + +def assign_sequence_to_array( + context, builder, data, shapes, strides, arrty, seqty, seq +): + """ + Assign a nested sequence contents to an array. The shape must match + the sequence's structure. + """ + + def assign_item(indices, valty, val): + ptr = cgutils.get_item_pointer2( + context, + builder, + data, + shapes, + strides, + arrty.layout, + indices, + wraparound=False, + ) + val = context.cast(builder, val, valty, arrty.dtype) + store_item(context, builder, arrty, val, ptr) + + def assign(seqty, seq, shapes, indices): + if len(shapes) == 0: + assert not isinstance(seqty, (types.Sequence, types.BaseTuple)) + assign_item(indices, seqty, seq) + return + + size = shapes[0] + + if isinstance(seqty, types.Sequence): + getitem_impl = _get_borrowing_getitem(context, seqty) + with cgutils.for_range(builder, size) as loop: + innerty = seqty.dtype + inner = getitem_impl(builder, (seq, loop.index)) + assign(innerty, inner, shapes[1:], indices + (loop.index,)) + + elif isinstance(seqty, types.BaseTuple): + for i in range(len(seqty)): + innerty = seqty[i] + inner = builder.extract_value(seq, i) + index = context.get_constant(types.intp, i) + assign(innerty, inner, shapes[1:], indices + (index,)) + + else: + assert 0, seqty + + assign(seqty, seq, shapes, ()) + + +def np_array_typer(typingctx, object, dtype): + ndim, seq_dtype = _parse_nested_sequence(typingctx, object) + if is_nonelike(dtype): + dtype = seq_dtype + else: + dtype = ty_parse_dtype(dtype) + if dtype is None: + return + return types.Array(dtype, ndim, "C") + + +@intrinsic +def np_array(typingctx, obj, dtype): + _check_const_str_dtype("array", dtype) + ret = np_array_typer(typingctx, obj, dtype) + sig = ret(obj, dtype) + + def codegen(context, builder, sig, args): + arrty = sig.return_type + ndim = arrty.ndim + seqty = sig.args[0] + seq = args[0] + + shapes = compute_sequence_shape(context, builder, ndim, seqty, seq) + assert len(shapes) == ndim + + check_sequence_shape(context, builder, seqty, seq, shapes) + arr = _empty_nd_impl(context, builder, arrty, shapes) + assign_sequence_to_array( + context, builder, arr.data, shapes, arr.strides, arrty, seqty, seq + ) + + return impl_ret_new_ref( + context, builder, sig.return_type, arr._getvalue() + ) + + return sig, codegen + + +@overload(np.array) +def impl_np_array(object, dtype=None): + _check_const_str_dtype("array", dtype) + if not type_can_asarray(object): + raise errors.TypingError('The argument "object" must be array-like') + if not is_nonelike(dtype) and ty_parse_dtype(dtype) is None: + msg = 'The argument "dtype" must be a data-type if it is provided' + raise errors.TypingError(msg) + + def impl(object, dtype=None): + return np_array(object, dtype) + + return impl + + +def _normalize_axis(context, builder, func_name, ndim, axis): + zero = axis.type(0) + ll_ndim = axis.type(ndim) + + # Normalize negative axis + is_neg_axis = builder.icmp_signed("<", axis, zero) + axis = builder.select(is_neg_axis, builder.add(axis, ll_ndim), axis) + + # Check axis for bounds + axis_out_of_bounds = builder.or_( + builder.icmp_signed("<", axis, zero), + builder.icmp_signed(">=", axis, ll_ndim), + ) + with builder.if_then(axis_out_of_bounds, likely=False): + msg = "%s(): axis out of bounds" % func_name + context.call_conv.return_user_exc(builder, IndexError, (msg,)) + + return axis + + +def _insert_axis_in_shape(context, builder, orig_shape, ndim, axis): + """ + Compute shape with the new axis inserted + e.g. given original shape (2, 3, 4) and axis=2, + the returned new shape is (2, 3, 1, 4). + """ + assert len(orig_shape) == ndim - 1 + + ll_shty = ir.ArrayType(cgutils.intp_t, ndim) + shapes = cgutils.alloca_once(builder, ll_shty) + + one = cgutils.intp_t(1) + + # 1. copy original sizes at appropriate places + for dim in range(ndim - 1): + ll_dim = cgutils.intp_t(dim) + after_axis = builder.icmp_signed(">=", ll_dim, axis) + sh = orig_shape[dim] + idx = builder.select(after_axis, builder.add(ll_dim, one), ll_dim) + builder.store(sh, cgutils.gep_inbounds(builder, shapes, 0, idx)) + + # 2. insert new size (1) at axis dimension + builder.store(one, cgutils.gep_inbounds(builder, shapes, 0, axis)) + + return cgutils.unpack_tuple(builder, builder.load(shapes)) + + +def _insert_axis_in_strides(context, builder, orig_strides, ndim, axis): + """ + Same as _insert_axis_in_shape(), but with a strides array. + """ + assert len(orig_strides) == ndim - 1 + + ll_shty = ir.ArrayType(cgutils.intp_t, ndim) + strides = cgutils.alloca_once(builder, ll_shty) + + one = cgutils.intp_t(1) + zero = cgutils.intp_t(0) + + # 1. copy original strides at appropriate places + for dim in range(ndim - 1): + ll_dim = cgutils.intp_t(dim) + after_axis = builder.icmp_signed(">=", ll_dim, axis) + idx = builder.select(after_axis, builder.add(ll_dim, one), ll_dim) + builder.store( + orig_strides[dim], cgutils.gep_inbounds(builder, strides, 0, idx) + ) + + # 2. insert new stride at axis dimension + # (the value is indifferent for a 1-sized dimension, we use 0) + builder.store(zero, cgutils.gep_inbounds(builder, strides, 0, axis)) + + return cgutils.unpack_tuple(builder, builder.load(strides)) + + +def expand_dims(context, builder, sig, args, axis): + """ + np.expand_dims() with the given axis. + """ + retty = sig.return_type + ndim = retty.ndim + arrty = sig.args[0] + + arr = make_array(arrty)(context, builder, value=args[0]) + ret = make_array(retty)(context, builder) + + shapes = cgutils.unpack_tuple(builder, arr.shape) + strides = cgutils.unpack_tuple(builder, arr.strides) + + new_shapes = _insert_axis_in_shape(context, builder, shapes, ndim, axis) + new_strides = _insert_axis_in_strides(context, builder, strides, ndim, axis) + + populate_array( + ret, + data=arr.data, + shape=new_shapes, + strides=new_strides, + itemsize=arr.itemsize, + meminfo=arr.meminfo, + parent=arr.parent, + ) + + return ret._getvalue() + + +@intrinsic +def np_expand_dims(typingctx, a, axis): + layout = a.layout if a.ndim <= 1 else "A" + ret = a.copy(ndim=a.ndim + 1, layout=layout) + sig = ret(a, axis) + + def codegen(context, builder, sig, args): + axis = context.cast(builder, args[1], sig.args[1], types.intp) + axis = _normalize_axis( + context, builder, "np.expand_dims", sig.return_type.ndim, axis + ) + + ret = expand_dims(context, builder, sig, args, axis) + return impl_ret_borrowed(context, builder, sig.return_type, ret) + + return sig, codegen + + +@overload(np.expand_dims) +def impl_np_expand_dims(a, axis): + if not isinstance(a, types.Array): + msg = f'First argument "a" must be an array. Got {a}' + raise errors.TypingError(msg) + + if not isinstance(axis, types.Integer): + msg = f'Argument "axis" must be an integer. Got {axis}' + raise errors.TypingError(msg) + + def impl(a, axis): + return np_expand_dims(a, axis) + + return impl + + +def _atleast_nd(minimum, axes): + @intrinsic + def impl(typingcontext, *args): + arrtys = args + rettys = [arg.copy(ndim=max(arg.ndim, minimum)) for arg in args] + + def codegen(context, builder, sig, args): + transform = _atleast_nd_transform(minimum, axes) + arrs = cgutils.unpack_tuple(builder, args[0]) + + rets = [ + transform(context, builder, arr, arrty, retty) + for arr, arrty, retty in zip(arrs, arrtys, rettys) + ] + + if len(rets) > 1: + ret = context.make_tuple(builder, sig.return_type, rets) + else: + ret = rets[0] + return impl_ret_borrowed(context, builder, sig.return_type, ret) + + return signature( + types.Tuple(rettys) if len(rettys) > 1 else rettys[0], + types.StarArgTuple.from_types(args), + ), codegen + + return lambda *args: impl(*args) + + +def _atleast_nd_transform(min_ndim, axes): + """ + Return a callback successively inserting 1-sized dimensions at the + following axes. + """ + assert min_ndim == len(axes) + + def transform(context, builder, arr, arrty, retty): + for i in range(min_ndim): + ndim = i + 1 + if arrty.ndim < ndim: + axis = cgutils.intp_t(axes[i]) + newarrty = arrty.copy(ndim=arrty.ndim + 1) + arr = expand_dims( + context, + builder, + typing.signature(newarrty, arrty), + (arr,), + axis, + ) + arrty = newarrty + + return arr + + return transform + + +@overload(np.atleast_1d) +def np_atleast_1d(*args): + if all(isinstance(arg, types.Array) for arg in args): + return _atleast_nd(1, [0]) + + +@overload(np.atleast_2d) +def np_atleast_2d(*args): + if all(isinstance(arg, types.Array) for arg in args): + return _atleast_nd(2, [0, 0]) + + +@overload(np.atleast_3d) +def np_atleast_3d(*args): + if all(isinstance(arg, types.Array) for arg in args): + return _atleast_nd(3, [0, 0, 2]) + + +def _do_concatenate( + context, + builder, + axis, + arrtys, + arrs, + arr_shapes, + arr_strides, + retty, + ret_shapes, +): + """ + Concatenate arrays along the given axis. + """ + assert len(arrtys) == len(arrs) == len(arr_shapes) == len(arr_strides) + + zero = cgutils.intp_t(0) + + # Allocate return array + ret = _empty_nd_impl(context, builder, retty, ret_shapes) + ret_strides = cgutils.unpack_tuple(builder, ret.strides) + + # Compute the offset by which to bump the destination pointer + # after copying each input array. + # Morally, we need to copy each input array at different start indices + # into the destination array; bumping the destination pointer + # is simply easier than offsetting all destination indices. + copy_offsets = [] + + for arr_sh in arr_shapes: + # offset = ret_strides[axis] * input_shape[axis] + offset = zero + for dim, (size, stride) in enumerate(zip(arr_sh, ret_strides)): + is_axis = builder.icmp_signed("==", axis.type(dim), axis) + addend = builder.mul(size, stride) + offset = builder.select( + is_axis, builder.add(offset, addend), offset + ) + copy_offsets.append(offset) + + # Copy input arrays into the return array + ret_data = ret.data + + for arrty, arr, arr_sh, arr_st, offset in zip( + arrtys, arrs, arr_shapes, arr_strides, copy_offsets + ): + arr_data = arr.data + + # Do the copy loop + # Note the loop nesting is optimized for the destination layout + loop_nest = cgutils.loop_nest( + builder, arr_sh, cgutils.intp_t, order=retty.layout + ) + + with loop_nest as indices: + src_ptr = cgutils.get_item_pointer2( + context, + builder, + arr_data, + arr_sh, + arr_st, + arrty.layout, + indices, + ) + val = load_item(context, builder, arrty, src_ptr) + val = context.cast(builder, val, arrty.dtype, retty.dtype) + dest_ptr = cgutils.get_item_pointer2( + context, + builder, + ret_data, + ret_shapes, + ret_strides, + retty.layout, + indices, + ) + store_item(context, builder, retty, val, dest_ptr) + + # Bump destination pointer + ret_data = cgutils.pointer_add(builder, ret_data, offset) + + return ret + + +def _np_concatenate(context, builder, arrtys, arrs, retty, axis): + ndim = retty.ndim + + arrs = [ + make_array(aty)(context, builder, value=a) + for aty, a in zip(arrtys, arrs) + ] + + axis = _normalize_axis(context, builder, "np.concatenate", ndim, axis) + + # Get input shapes + arr_shapes = [cgutils.unpack_tuple(builder, arr.shape) for arr in arrs] + arr_strides = [cgutils.unpack_tuple(builder, arr.strides) for arr in arrs] + + # Compute return shape: + # - the dimension for the concatenation axis is summed over all inputs + # - other dimensions must match exactly for each input + ret_shapes = [ + cgutils.alloca_once_value(builder, sh) for sh in arr_shapes[0] + ] + + for dim in range(ndim): + is_axis = builder.icmp_signed("==", axis.type(dim), axis) + ret_shape_ptr = ret_shapes[dim] + ret_sh = builder.load(ret_shape_ptr) + other_shapes = [sh[dim] for sh in arr_shapes[1:]] + + with builder.if_else(is_axis) as (on_axis, on_other_dim): + with on_axis: + sh = functools.reduce(builder.add, other_shapes + [ret_sh]) + builder.store(sh, ret_shape_ptr) + + with on_other_dim: + is_ok = cgutils.true_bit + for sh in other_shapes: + is_ok = builder.and_( + is_ok, builder.icmp_signed("==", sh, ret_sh) + ) + with builder.if_then(builder.not_(is_ok), likely=False): + context.call_conv.return_user_exc( + builder, + ValueError, + ( + "np.concatenate(): input sizes over " + "dimension %d do not match" % dim, + ), + ) + + ret_shapes = [builder.load(sh) for sh in ret_shapes] + + ret = _do_concatenate( + context, + builder, + axis, + arrtys, + arrs, + arr_shapes, + arr_strides, + retty, + ret_shapes, + ) + return impl_ret_new_ref(context, builder, retty, ret._getvalue()) + + +def _np_stack(context, builder, arrtys, arrs, retty, axis): + ndim = retty.ndim + + zero = cgutils.intp_t(0) + one = cgutils.intp_t(1) + ll_narrays = cgutils.intp_t(len(arrs)) + + arrs = [ + make_array(aty)(context, builder, value=a) + for aty, a in zip(arrtys, arrs) + ] + + axis = _normalize_axis(context, builder, "np.stack", ndim, axis) + + # Check input arrays have the same shape + orig_shape = cgutils.unpack_tuple(builder, arrs[0].shape) + + for arr in arrs[1:]: + is_ok = cgutils.true_bit + for sh, orig_sh in zip( + cgutils.unpack_tuple(builder, arr.shape), orig_shape + ): + is_ok = builder.and_(is_ok, builder.icmp_signed("==", sh, orig_sh)) + with builder.if_then(builder.not_(is_ok), likely=False): + context.call_conv.return_user_exc( + builder, + ValueError, + ("np.stack(): all input arrays must have the same shape",), + ) + + orig_strides = [cgutils.unpack_tuple(builder, arr.strides) for arr in arrs] + + # Compute input shapes and return shape with the new axis inserted + # e.g. given 5 input arrays of shape (2, 3, 4) and axis=1, + # corrected input shape is (2, 1, 3, 4) and return shape is (2, 5, 3, 4). + ll_shty = ir.ArrayType(cgutils.intp_t, ndim) + + input_shapes = cgutils.alloca_once(builder, ll_shty) + ret_shapes = cgutils.alloca_once(builder, ll_shty) + + # 1. copy original sizes at appropriate places + for dim in range(ndim - 1): + ll_dim = cgutils.intp_t(dim) + after_axis = builder.icmp_signed(">=", ll_dim, axis) + sh = orig_shape[dim] + idx = builder.select(after_axis, builder.add(ll_dim, one), ll_dim) + builder.store(sh, cgutils.gep_inbounds(builder, input_shapes, 0, idx)) + builder.store(sh, cgutils.gep_inbounds(builder, ret_shapes, 0, idx)) + + # 2. insert new size at axis dimension + builder.store(one, cgutils.gep_inbounds(builder, input_shapes, 0, axis)) + builder.store( + ll_narrays, cgutils.gep_inbounds(builder, ret_shapes, 0, axis) + ) + + input_shapes = cgutils.unpack_tuple(builder, builder.load(input_shapes)) + input_shapes = [input_shapes] * len(arrs) + ret_shapes = cgutils.unpack_tuple(builder, builder.load(ret_shapes)) + + # Compute input strides for each array with the new axis inserted + input_strides = [ + cgutils.alloca_once(builder, ll_shty) for i in range(len(arrs)) + ] + + # 1. copy original strides at appropriate places + for dim in range(ndim - 1): + ll_dim = cgutils.intp_t(dim) + after_axis = builder.icmp_signed(">=", ll_dim, axis) + idx = builder.select(after_axis, builder.add(ll_dim, one), ll_dim) + for i in range(len(arrs)): + builder.store( + orig_strides[i][dim], + cgutils.gep_inbounds(builder, input_strides[i], 0, idx), + ) + + # 2. insert new stride at axis dimension + # (the value is indifferent for a 1-sized dimension, we put 0) + for i in range(len(arrs)): + builder.store( + zero, cgutils.gep_inbounds(builder, input_strides[i], 0, axis) + ) + + input_strides = [ + cgutils.unpack_tuple(builder, builder.load(st)) for st in input_strides + ] + + # Create concatenated array + ret = _do_concatenate( + context, + builder, + axis, + arrtys, + arrs, + input_shapes, + input_strides, + retty, + ret_shapes, + ) + return impl_ret_new_ref(context, builder, retty, ret._getvalue()) + + +def np_concatenate_typer(typingctx, arrays, axis): + if axis is not None and not isinstance(axis, types.Integer): + # Note Numpy allows axis=None, but it isn't documented: + # https://github.com/numpy/numpy/issues/7968 + return + + # does type checking + dtype, ndim = _sequence_of_arrays(typingctx, "np.concatenate", arrays) + if ndim == 0: + msg = "zero-dimensional arrays cannot be concatenated" + raise errors.NumbaTypeError(msg) + + layout = _choose_concatenation_layout(arrays) + + return types.Array(dtype, ndim, layout) + + +@intrinsic +def np_concatenate(typingctx, arrays, axis): + ret = np_concatenate_typer(typingctx, arrays, axis) + assert isinstance(ret, types.Array) + sig = ret(arrays, axis) + + def codegen(context, builder, sig, args): + axis = context.cast(builder, args[1], sig.args[1], types.intp) + return _np_concatenate( + context, + builder, + list(sig.args[0]), + cgutils.unpack_tuple(builder, args[0]), + sig.return_type, + axis, + ) + + return sig, codegen + + +@overload(np.concatenate) +def impl_np_concatenate(arrays, axis=0): + if isinstance(arrays, types.BaseTuple): + + def impl(arrays, axis=0): + return np_concatenate(arrays, axis) + + return impl + + +def _column_stack_dims(context, func_name, arrays): + # column_stack() allows stacking 1-d and 2-d arrays together + for a in arrays: + if a.ndim < 1 or a.ndim > 2: + msg = "np.column_stack() is only defined on 1-d and 2-d arrays" + raise errors.NumbaTypeError(msg) + return 2 + + +@intrinsic +def np_column_stack(typingctx, tup): + dtype, ndim = _sequence_of_arrays( + typingctx, "np.column_stack", tup, dim_chooser=_column_stack_dims + ) + layout = _choose_concatenation_layout(tup) + ret = types.Array(dtype, ndim, layout) + sig = ret(tup) + + def codegen(context, builder, sig, args): + orig_arrtys = list(sig.args[0]) + orig_arrs = cgutils.unpack_tuple(builder, args[0]) + + arrtys = [] + arrs = [] + + axis = context.get_constant(types.intp, 1) + + for arrty, arr in zip(orig_arrtys, orig_arrs): + if arrty.ndim == 2: + arrtys.append(arrty) + arrs.append(arr) + else: + # Convert 1d array to 2d column array: np.expand_dims(a, 1) + assert arrty.ndim == 1 + newty = arrty.copy(ndim=2) + expand_sig = typing.signature(newty, arrty) + newarr = expand_dims(context, builder, expand_sig, (arr,), axis) + + arrtys.append(newty) + arrs.append(newarr) + + return _np_concatenate( + context, builder, arrtys, arrs, sig.return_type, axis + ) + + return sig, codegen + + +@overload(np.column_stack) +def impl_column_stack(tup): + if isinstance(tup, types.BaseTuple): + + def impl(tup): + return np_column_stack(tup) + + return impl + + +def _np_stack_common(context, builder, sig, args, axis): + """ + np.stack() with the given axis value. + """ + return _np_stack( + context, + builder, + list(sig.args[0]), + cgutils.unpack_tuple(builder, args[0]), + sig.return_type, + axis, + ) + + +@intrinsic +def np_stack_common(typingctx, arrays, axis): + # does type checking + dtype, ndim = _sequence_of_arrays(typingctx, "np.stack", arrays) + layout = "F" if all(a.layout == "F" for a in arrays) else "C" + ret = types.Array(dtype, ndim + 1, layout) + sig = ret(arrays, axis) + + def codegen(context, builder, sig, args): + axis = context.cast(builder, args[1], sig.args[1], types.intp) + return _np_stack_common(context, builder, sig, args, axis) + + return sig, codegen + + +@overload(np.stack) +def impl_np_stack(arrays, axis=0): + if isinstance(arrays, types.BaseTuple): + + def impl(arrays, axis=0): + return np_stack_common(arrays, axis) + + return impl + + +def NdStack_typer(typingctx, func_name, arrays, ndim_min): + # does type checking + dtype, ndim = _sequence_of_arrays(typingctx, func_name, arrays) + ndim = max(ndim, ndim_min) + layout = _choose_concatenation_layout(arrays) + ret = types.Array(dtype, ndim, layout) + return ret + + +@intrinsic +def _np_hstack(typingctx, tup): + ret = NdStack_typer(typingctx, "np.hstack", tup, 1) + sig = ret(tup) + + def codegen(context, builder, sig, args): + tupty = sig.args[0] + ndim = tupty[0].ndim + + if ndim == 0: + # hstack() on 0-d arrays returns a 1-d array + axis = context.get_constant(types.intp, 0) + return _np_stack_common(context, builder, sig, args, axis) + + else: + # As a special case, dimension 0 of 1-dimensional arrays + # is "horizontal" + axis = 0 if ndim == 1 else 1 + + def np_hstack_impl(arrays): + return np.concatenate(arrays, axis=axis) + + return context.compile_internal(builder, np_hstack_impl, sig, args) + + return sig, codegen + + +@overload(np.hstack) +def impl_np_hstack(tup): + if isinstance(tup, types.BaseTuple): + + def impl(tup): + return _np_hstack(tup) + + return impl + + +@intrinsic +def _np_vstack(typingctx, tup): + ret = NdStack_typer(typingctx, "np.vstack", tup, 2) + sig = ret(tup) + + def codegen(context, builder, sig, args): + tupty = sig.args[0] + ndim = tupty[0].ndim + + if ndim == 0: + + def np_vstack_impl(arrays): + return np.expand_dims(np.hstack(arrays), 1) + + elif ndim == 1: + # np.stack(arrays, axis=0) + axis = context.get_constant(types.intp, 0) + return _np_stack_common(context, builder, sig, args, axis) + + else: + + def np_vstack_impl(arrays): + return np.concatenate(arrays, axis=0) + + return context.compile_internal(builder, np_vstack_impl, sig, args) + + return sig, codegen + + +@overload(np.vstack) +def impl_np_vstack(tup): + if isinstance(tup, types.BaseTuple): + + def impl(tup): + return _np_vstack(tup) + + return impl + + +if numpy_version >= (2, 0): + overload(np.row_stack)(impl_np_vstack) + + +@intrinsic +def _np_dstack(typingctx, tup): + ret = NdStack_typer(typingctx, "np.dstack", tup, 3) + sig = ret(tup) + + def codegen(context, builder, sig, args): + tupty = sig.args[0] + retty = sig.return_type + ndim = tupty[0].ndim + + if ndim == 0: + + def np_vstack_impl(arrays): + return np.hstack(arrays).reshape(1, 1, -1) + + return context.compile_internal(builder, np_vstack_impl, sig, args) + + elif ndim == 1: + # np.expand_dims(np.stack(arrays, axis=1), axis=0) + axis = context.get_constant(types.intp, 1) + stack_retty = retty.copy(ndim=retty.ndim - 1) + stack_sig = typing.signature(stack_retty, *sig.args) + stack_ret = _np_stack_common( + context, builder, stack_sig, args, axis + ) + + axis = context.get_constant(types.intp, 0) + expand_sig = typing.signature(retty, stack_retty) + return expand_dims(context, builder, expand_sig, (stack_ret,), axis) + + elif ndim == 2: + # np.stack(arrays, axis=2) + axis = context.get_constant(types.intp, 2) + return _np_stack_common(context, builder, sig, args, axis) + + else: + + def np_vstack_impl(arrays): + return np.concatenate(arrays, axis=2) + + return context.compile_internal(builder, np_vstack_impl, sig, args) + + return sig, codegen + + +@overload(np.dstack) +def impl_np_dstack(tup): + if isinstance(tup, types.BaseTuple): + + def impl(tup): + return _np_dstack(tup) + + return impl + + +@extending.overload_method(types.Array, "fill") +def arr_fill(arr, val): + def fill_impl(arr, val): + arr[:] = val + return None + + return fill_impl + + +@extending.overload_method(types.Array, "dot") +def array_dot(arr, other): + def dot_impl(arr, other): + return np.dot(arr, other) + + return dot_impl + + +@overload(np.fliplr) +def np_flip_lr(m): + if not type_can_asarray(m): + raise errors.TypingError("Cannot np.fliplr on %s type" % m) + + def impl(m): + A = np.asarray(m) + # this handling is superfluous/dead as < 2d array cannot be indexed as + # present below and so typing fails. If the typing doesn't fail due to + # some future change, this will catch it. + if A.ndim < 2: + raise ValueError("Input must be >= 2-d.") + return A[::, ::-1, ...] + + return impl + + +@overload(np.flipud) +def np_flip_ud(m): + if not type_can_asarray(m): + raise errors.TypingError("Cannot np.flipud on %s type" % m) + + def impl(m): + A = np.asarray(m) + # this handling is superfluous/dead as a 0d array cannot be indexed as + # present below and so typing fails. If the typing doesn't fail due to + # some future change, this will catch it. + if A.ndim < 1: + raise ValueError("Input must be >= 1-d.") + return A[::-1, ...] + + return impl + + +@intrinsic +def _build_flip_slice_tuple(tyctx, sz): + """Creates a tuple of slices for np.flip indexing like + `(slice(None, None, -1),) * sz`""" + if not isinstance(sz, types.IntegerLiteral): + raise errors.RequireLiteralValue(sz) + size = int(sz.literal_value) + tuple_type = types.UniTuple(dtype=types.slice3_type, count=size) + sig = tuple_type(sz) + + def codegen(context, builder, signature, args): + def impl(length, empty_tuple): + out = empty_tuple + for i in range(length): + out = tuple_setitem(out, i, slice(None, None, -1)) + return out + + inner_argtypes = [types.intp, tuple_type] + inner_sig = typing.signature(tuple_type, *inner_argtypes) + ll_idx_type = context.get_value_type(types.intp) + # Allocate an empty tuple + empty_tuple = context.get_constant_undef(tuple_type) + inner_args = [ll_idx_type(size), empty_tuple] + + res = context.compile_internal(builder, impl, inner_sig, inner_args) + return res + + return sig, codegen + + +@overload(np.flip) +def np_flip(m): + # a constant value is needed for the tuple slice, types.Array.ndim can + # provide this and so at presnet only type.Array is support + if not isinstance(m, types.Array): + raise errors.TypingError("Cannot np.flip on %s type" % m) + + def impl(m): + sl = _build_flip_slice_tuple(m.ndim) + return m[sl] + + return impl + + +@overload(np.array_split) +def np_array_split(ary, indices_or_sections, axis=0): + if isinstance(ary, (types.UniTuple, types.ListType, types.List)): + + def impl(ary, indices_or_sections, axis=0): + return np.array_split( + np.asarray(ary), indices_or_sections, axis=axis + ) + + return impl + + if isinstance(indices_or_sections, types.Integer): + + def impl(ary, indices_or_sections, axis=0): + l, rem = divmod(ary.shape[axis], indices_or_sections) + indices = np.cumsum( + np.array([l + 1] * rem + [l] * (indices_or_sections - rem - 1)) + ) + return np.array_split(ary, indices, axis=axis) + + return impl + + elif isinstance(indices_or_sections, types.IterableType) and isinstance( + indices_or_sections.iterator_type.yield_type, types.Integer + ): + + def impl(ary, indices_or_sections, axis=0): + slice_tup = build_full_slice_tuple(ary.ndim) + axis = normalize_axis("np.split", "axis", ary.ndim, axis) + out = [] + prev = 0 + for cur in indices_or_sections: + idx = tuple_setitem(slice_tup, axis, slice(prev, cur)) + out.append(ary[idx]) + prev = cur + out.append(ary[tuple_setitem(slice_tup, axis, slice(cur, None))]) + return out + + return impl + + elif isinstance(indices_or_sections, types.Tuple) and all( + isinstance(t, types.Integer) for t in indices_or_sections.types + ): + + def impl(ary, indices_or_sections, axis=0): + slice_tup = build_full_slice_tuple(ary.ndim) + axis = normalize_axis("np.split", "axis", ary.ndim, axis) + out = [] + prev = 0 + for cur in literal_unroll(indices_or_sections): + idx = tuple_setitem(slice_tup, axis, slice(prev, cur)) + out.append(ary[idx]) + prev = cur + out.append(ary[tuple_setitem(slice_tup, axis, slice(cur, None))]) + return out + + return impl + + +@overload(np.split) +def np_split(ary, indices_or_sections, axis=0): + # This is just a wrapper of array_split, but with an extra error if + # indices is an int. + if isinstance(ary, (types.UniTuple, types.ListType, types.List)): + + def impl(ary, indices_or_sections, axis=0): + return np.split(np.asarray(ary), indices_or_sections, axis=axis) + + return impl + + if isinstance(indices_or_sections, types.Integer): + + def impl(ary, indices_or_sections, axis=0): + _, rem = divmod(ary.shape[axis], indices_or_sections) + if rem != 0: + raise ValueError( + "array split does not result in an equal division" + ) + return np.array_split(ary, indices_or_sections, axis=axis) + + return impl + + else: + return np_array_split(ary, indices_or_sections, axis=axis) + + +@overload(np.vsplit) +def numpy_vsplit(ary, indices_or_sections): + if not isinstance(ary, types.Array): + msg = 'The argument "ary" must be an array' + raise errors.TypingError(msg) + + if not isinstance( + indices_or_sections, + (types.Integer, types.Array, types.List, types.UniTuple), + ): + msg = 'The argument "indices_or_sections" must be int or 1d-array' + raise errors.TypingError(msg) + + def impl(ary, indices_or_sections): + if ary.ndim < 2: + raise ValueError( + ("vsplit only works on arrays of 2 or more dimensions") + ) + return np.split(ary, indices_or_sections, axis=0) + + return impl + + +@overload(np.hsplit) +def numpy_hsplit(ary, indices_or_sections): + if not isinstance(ary, types.Array): + msg = 'The argument "ary" must be an array' + raise errors.TypingError(msg) + + if not isinstance( + indices_or_sections, + (types.Integer, types.Array, types.List, types.UniTuple), + ): + msg = 'The argument "indices_or_sections" must be int or 1d-array' + raise errors.TypingError(msg) + + def impl(ary, indices_or_sections): + if ary.ndim == 0: + raise ValueError( + ("hsplit only works on arrays of 1 or more dimensions") + ) + if ary.ndim > 1: + return np.split(ary, indices_or_sections, axis=1) + return np.split(ary, indices_or_sections, axis=0) + + return impl + + +@overload(np.dsplit) +def numpy_dsplit(ary, indices_or_sections): + if not isinstance(ary, types.Array): + msg = 'The argument "ary" must be an array' + raise errors.TypingError(msg) + + if not isinstance( + indices_or_sections, + (types.Integer, types.Array, types.List, types.UniTuple), + ): + msg = 'The argument "indices_or_sections" must be int or 1d-array' + raise errors.TypingError(msg) + + def impl(ary, indices_or_sections): + if ary.ndim < 3: + raise ValueError( + "dsplit only works on arrays of 3 or more dimensions" + ) + return np.split(ary, indices_or_sections, axis=2) + + return impl + + +# ----------------------------------------------------------------------------- +# Sorting + +_sorts = {} + + +def default_lt(a, b): + """ + Trivial comparison function between two keys. + """ + return a < b + + +def get_sort_func(kind, lt_impl, is_argsort=False): + """ + Get a sort implementation of the given kind. + """ + key = kind, lt_impl.__name__, is_argsort + + try: + return _sorts[key] + except KeyError: + if kind == "quicksort": + sort = quicksort.make_jit_quicksort( + lt=lt_impl, is_argsort=is_argsort, is_np_array=True + ) + func = sort.run_quicksort + elif kind == "mergesort": + sort = mergesort.make_jit_mergesort( + lt=lt_impl, is_argsort=is_argsort + ) + func = sort.run_mergesort + _sorts[key] = func + return func + + +def lt_implementation(dtype): + if isinstance(dtype, types.Float): + return lt_floats + elif isinstance(dtype, types.Complex): + return lt_complex + else: + return default_lt + + +@lower("array.sort", types.Array) +def array_sort(context, builder, sig, args): + arytype = sig.args[0] + + sort_func = get_sort_func( + kind="quicksort", lt_impl=lt_implementation(arytype.dtype) + ) + + def array_sort_impl(arr): + # Note we clobber the return value + sort_func(arr) + + return context.compile_internal(builder, array_sort_impl, sig, args) + + +@overload(np.sort) +def impl_np_sort(a): + if not type_can_asarray(a): + raise errors.TypingError('Argument "a" must be array-like') + + def np_sort_impl(a): + res = a.copy() + res.sort() + return res + + return np_sort_impl + + +@lower("array.argsort", types.Array, types.StringLiteral) +@lower(np.argsort, types.Array, types.StringLiteral) +def array_argsort(context, builder, sig, args): + arytype, kind = sig.args + + sort_func = get_sort_func( + kind=kind.literal_value, + lt_impl=lt_implementation(arytype.dtype), + is_argsort=True, + ) + + def array_argsort_impl(arr): + return sort_func(arr) + + innersig = sig.replace(args=sig.args[:1]) + innerargs = args[:1] + return context.compile_internal( + builder, array_argsort_impl, innersig, innerargs + ) + + +# ------------------------------------------------------------------------------ +# Implicit cast + + +@lower_cast(types.Array, types.Array) +def array_to_array(context, builder, fromty, toty, val): + # Type inference should have prevented illegal array casting. + assert fromty.mutable != toty.mutable or toty.layout == "A" + return val + + +@lower_cast(types.Array, types.UnicodeCharSeq) +@lower_cast(types.Array, types.Float) +@lower_cast(types.Array, types.Integer) +@lower_cast(types.Array, types.Complex) +@lower_cast(types.Array, types.Boolean) +@lower_cast(types.Array, types.NPTimedelta) +@lower_cast(types.Array, types.NPDatetime) +def array0d_to_scalar(context, builder, fromty, toty, val): + def impl(a): + # a is an array(T, 0d, O), T is type, O is order + return a.take(0) + + sig = signature(toty, fromty) + res = context.compile_internal(builder, impl, sig, [val]) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower_cast(types.Array, types.UnicodeCharSeq) +def array_to_unichrseq(context, builder, fromty, toty, val): + def impl(a): + return str(a[()]) + + sig = signature(toty, fromty) + res = context.compile_internal(builder, impl, sig, [val]) + return impl_ret_borrowed(context, builder, sig.return_type, res) + + +# ------------------------------------------------------------------------------ +# Stride tricks + + +def reshape_unchecked(a, shape, strides): + """ + An intrinsic returning a derived array with the given shape and strides. + """ + raise NotImplementedError + + +@extending.type_callable(reshape_unchecked) +def type_reshape_unchecked(context): + def check_shape(shape): + return isinstance(shape, types.BaseTuple) and all( + isinstance(v, types.Integer) for v in shape + ) + + def typer(a, shape, strides): + if not isinstance(a, types.Array): + return + if not check_shape(shape) or not check_shape(strides): + return + if len(shape) != len(strides): + return + return a.copy(ndim=len(shape), layout="A") + + return typer + + +@lower(reshape_unchecked, types.Array, types.BaseTuple, types.BaseTuple) +def impl_shape_unchecked(context, builder, sig, args): + aryty = sig.args[0] + retty = sig.return_type + + ary = make_array(aryty)(context, builder, args[0]) + out = make_array(retty)(context, builder) + shape = cgutils.unpack_tuple(builder, args[1]) + strides = cgutils.unpack_tuple(builder, args[2]) + + populate_array( + out, + data=ary.data, + shape=shape, + strides=strides, + itemsize=ary.itemsize, + meminfo=ary.meminfo, + ) + + res = out._getvalue() + return impl_ret_borrowed(context, builder, retty, res) + + +@extending.overload(np.lib.stride_tricks.as_strided) +def as_strided(x, shape=None, strides=None): + if shape in (None, types.none): + + @register_jitable + def get_shape(x, shape): + return x.shape + else: + + @register_jitable + def get_shape(x, shape): + return shape + + if strides in (None, types.none): + # When *strides* is not passed, as_strided() does a non-size-checking + # reshape(), possibly changing the original strides. This is too + # cumbersome to support right now, and a Web search shows all example + # use cases of as_strided() pass explicit *strides*. + raise errors.TypingError("as_strided() strides argument cannot be None") + else: + + @register_jitable + def get_strides(x, strides): + return strides + + def as_strided_impl(x, shape=None, strides=None): + x = reshape_unchecked(x, get_shape(x, shape), get_strides(x, strides)) + return x + + return as_strided_impl + + +@extending.overload(np.lib.stride_tricks.sliding_window_view) +def sliding_window_view(x, window_shape, axis=None): + # Window shape must be given as either an integer or tuple of integers. + # We also need to generate buffer tuples we can modify to contain the + # final shape and strides (reshape_unchecked does not accept lists). + if isinstance(window_shape, types.Integer): + shape_buffer = tuple(range(x.ndim + 1)) + stride_buffer = tuple(range(x.ndim + 1)) + + @register_jitable + def get_window_shape(window_shape): + return (window_shape,) + + elif isinstance(window_shape, types.UniTuple) and isinstance( + window_shape.dtype, types.Integer + ): + shape_buffer = tuple(range(x.ndim + len(window_shape))) + stride_buffer = tuple(range(x.ndim + len(window_shape))) + + @register_jitable + def get_window_shape(window_shape): + return window_shape + + else: + raise errors.TypingError( + "window_shape must be an integer or tuple of integers" + ) + + # Axis must be integer, tuple of integers, or None for all axes. + if is_nonelike(axis): + + @register_jitable + def get_axis(window_shape, axis, ndim): + return list(range(ndim)) + + elif isinstance(axis, types.Integer): + + @register_jitable + def get_axis(window_shape, axis, ndim): + return [normalize_axis("sliding_window_view", "axis", ndim, axis)] + + elif isinstance(axis, types.UniTuple) and isinstance( + axis.dtype, types.Integer + ): + + @register_jitable + def get_axis(window_shape, axis, ndim): + return [ + normalize_axis("sliding_window_view", "axis", ndim, a) + for a in axis + ] + + else: + raise errors.TypingError( + "axis must be None, an integer or tuple of integers" + ) + + def sliding_window_view_impl(x, window_shape, axis=None): + window_shape = get_window_shape(window_shape) + axis = get_axis(window_shape, axis, x.ndim) + if len(window_shape) != len(axis): + raise ValueError( + "Must provide matching length window_shape and axis" + ) + + # Initialise view details with shape and strides of x. + out_shape = shape_buffer + out_strides = stride_buffer + for i in range(x.ndim): + out_shape = tuple_setitem(out_shape, i, x.shape[i]) + out_strides = tuple_setitem(out_strides, i, x.strides[i]) + + # Trim the dimensions being windowed and set the window shape and + # strides. Note: the same axis can be windowed repeatedly. + i = x.ndim + for ax, dim in zip(axis, window_shape): + if dim < 0: + raise ValueError( + "`window_shape` cannot contain negative values" + ) + if out_shape[ax] < dim: + raise ValueError( + "window_shape cannot be larger than input array shape" + ) + + trimmed = out_shape[ax] - dim + 1 + out_shape = tuple_setitem(out_shape, ax, trimmed) + out_shape = tuple_setitem(out_shape, i, dim) + out_strides = tuple_setitem(out_strides, i, x.strides[ax]) + i += 1 + + # The NumPy version calls as_strided, but our implementation of + # as_strided is effectively a wrapper for reshape_unchecked. + view = reshape_unchecked(x, out_shape, out_strides) + return view + + return sliding_window_view_impl + + +@overload(bool) +def ol_bool(arr): + if isinstance(arr, types.Array): + + def impl(arr): + if arr.size == 0: + if numpy_version < (2, 2): + return False # this is deprecated + else: + raise ValueError( + ( + "The truth value of an empty array is " + "ambiguous. Use `array.size > 0` to " + "check that an array is not empty." + ) + ) + elif arr.size == 1: + return bool(arr.take(0)) + else: + raise ValueError( + ( + "The truth value of an array with more than" + " one element is ambiguous. Use a.any() or" + " a.all()" + ) + ) + + return impl + + +@overload(np.swapaxes) +def numpy_swapaxes(a, axis1, axis2): + if not isinstance(axis1, (int, types.Integer)): + raise errors.TypingError( + 'The second argument "axis1" must be an integer' + ) + if not isinstance(axis2, (int, types.Integer)): + raise errors.TypingError( + 'The third argument "axis2" must be an integer' + ) + if not isinstance(a, types.Array): + raise errors.TypingError('The first argument "a" must be an array') + + # create tuple list for transpose + ndim = a.ndim + axes_list = tuple(range(ndim)) + + def impl(a, axis1, axis2): + axis1 = normalize_axis("np.swapaxes", "axis1", ndim, axis1) + axis2 = normalize_axis("np.swapaxes", "axis2", ndim, axis2) + + # to ensure tuple_setitem support of negative values + if axis1 < 0: + axis1 += ndim + if axis2 < 0: + axis2 += ndim + + axes_tuple = tuple_setitem(axes_list, axis1, axis2) + axes_tuple = tuple_setitem(axes_tuple, axis2, axis1) + return np.transpose(a, axes_tuple) + + return impl + + +@register_jitable +def _take_along_axis_impl( + arr, indices, axis, Ni_orig, Nk_orig, indices_broadcast_shape +): + # Based on example code in + # https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/shape_base.py#L90-L103 + # With addition of pre-broadcasting: + # https://github.com/numpy/numpy/issues/19704 + + # Wrap axis, it's used in tuple_setitem so must be (axis >= 0) to ensure + # the GEP is in bounds. + axis = normalize_axis("np.take_along_axis", "axis", arr.ndim, axis) + + # Broadcast the two arrays to matching shapes: + arr_shape = list(arr.shape) + arr_shape[axis] = 1 + for i, (d1, d2) in enumerate(zip(arr_shape, indices.shape)): + if d1 == 1: + new_val = d2 + elif d2 == 1: + new_val = d1 + else: + if d1 != d2: + raise ValueError("`arr` and `indices` dimensions don't match") + new_val = d1 + indices_broadcast_shape = tuple_setitem( + indices_broadcast_shape, i, new_val + ) + arr_broadcast_shape = tuple_setitem( + indices_broadcast_shape, axis, arr.shape[axis] + ) + arr = np.broadcast_to(arr, arr_broadcast_shape) + indices = np.broadcast_to(indices, indices_broadcast_shape) + + Ni = Ni_orig + if len(Ni_orig) > 0: + for i in range(len(Ni)): + Ni = tuple_setitem(Ni, i, arr.shape[i]) + Nk = Nk_orig + if len(Nk_orig) > 0: + for i in range(len(Nk)): + Nk = tuple_setitem(Nk, i, arr.shape[axis + 1 + i]) + + J = indices.shape[axis] # Need not equal M + out = np.empty(Ni + (J,) + Nk, arr.dtype) + + np_s_ = (slice(None, None, None),) + + for ii in np.ndindex(Ni): + for kk in np.ndindex(Nk): + a_1d = arr[ii + np_s_ + kk] + indices_1d = indices[ii + np_s_ + kk] + out_1d = out[ii + np_s_ + kk] + for j in range(J): + out_1d[j] = a_1d[indices_1d[j]] + return out + + +@overload(np.take_along_axis) +def arr_take_along_axis(arr, indices, axis): + if not isinstance(arr, types.Array): + raise errors.TypingError('The first argument "arr" must be an array') + if not isinstance(indices, types.Array): + raise errors.TypingError( + 'The second argument "indices" must be an array' + ) + if not isinstance(indices.dtype, types.Integer): + raise errors.TypingError("The indices array must contain integers") + if is_nonelike(axis): + arr_ndim = 1 + else: + arr_ndim = arr.ndim + if arr_ndim != indices.ndim: + # Matches NumPy error: + raise errors.TypingError( + "`indices` and `arr` must have the same number of dimensions" + ) + + indices_broadcast_shape = tuple(range(indices.ndim)) + if is_nonelike(axis): + + def take_along_axis_impl(arr, indices, axis): + return _take_along_axis_impl( + arr.flatten(), indices, 0, (), (), indices_broadcast_shape + ) + else: + check_is_integer(axis, "axis") + if not isinstance(axis, types.IntegerLiteral): + raise errors.NumbaValueError("axis must be a literal value") + axis = axis.literal_value + if axis < 0: + axis = arr.ndim + axis + + if axis < 0 or axis >= arr.ndim: + raise errors.NumbaValueError("axis is out of bounds") + + Ni = tuple(range(axis)) + Nk = tuple(range(axis + 1, arr.ndim)) + + def take_along_axis_impl(arr, indices, axis): + return _take_along_axis_impl( + arr, indices, axis, Ni, Nk, indices_broadcast_shape + ) + + return take_along_axis_impl + + +@overload(np.nan_to_num) +def nan_to_num_impl(x, copy=True, nan=0.0): + if isinstance(x, types.Number): + if isinstance(x, types.Integer): + # Integers do not have nans or infs + def impl(x, copy=True, nan=0.0): + return x + + elif isinstance(x, types.Float): + + def impl(x, copy=True, nan=0.0): + if np.isnan(x): + return nan + elif np.isneginf(x): + return np.finfo(type(x)).min + elif np.isposinf(x): + return np.finfo(type(x)).max + return x + elif isinstance(x, types.Complex): + + def impl(x, copy=True, nan=0.0): + r = np.nan_to_num(x.real, nan=nan) + c = np.nan_to_num(x.imag, nan=nan) + return complex(r, c) + else: + raise errors.TypingError( + "Only Integer, Float, and Complex values are accepted" + ) + + elif type_can_asarray(x): + if isinstance(x.dtype, types.Integer): + # Integers do not have nans or infs + def impl(x, copy=True, nan=0.0): + return x + elif isinstance(x.dtype, types.Float): + + def impl(x, copy=True, nan=0.0): + min_inf = np.finfo(x.dtype).min + max_inf = np.finfo(x.dtype).max + + x_ = np.asarray(x) + output = np.copy(x_) if copy else x_ + + output_flat = output.flat + for i in range(output.size): + if np.isnan(output_flat[i]): + output_flat[i] = nan + elif np.isneginf(output_flat[i]): + output_flat[i] = min_inf + elif np.isposinf(output_flat[i]): + output_flat[i] = max_inf + return output + elif isinstance(x.dtype, types.Complex): + + def impl(x, copy=True, nan=0.0): + x_ = np.asarray(x) + output = np.copy(x_) if copy else x_ + + np.nan_to_num(output.real, copy=False, nan=nan) + np.nan_to_num(output.imag, copy=False, nan=nan) + return output + else: + raise errors.TypingError( + "Only Integer, Float, and Complex values are accepted" + ) + else: + raise errors.TypingError( + "The first argument must be a scalar or an array-like" + ) + return impl diff --git a/numba_cuda/numba/cuda/np/extensions.py b/numba_cuda/numba/cuda/np/extensions.py new file mode 100644 index 000000000..f2b78a0c8 --- /dev/null +++ b/numba_cuda/numba/cuda/np/extensions.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +NumPy extensions. +""" + +from numba.cuda.np.arraymath import cross2d + + +__all__ = ["cross2d"] diff --git a/numba_cuda/numba/cuda/np/linalg.py b/numba_cuda/numba/cuda/np/linalg.py new file mode 100644 index 000000000..f3f5b00b6 --- /dev/null +++ b/numba_cuda/numba/cuda/np/linalg.py @@ -0,0 +1,3087 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause +""" +Implementation of linear algebra operations. +""" + +import contextlib +import warnings + +from llvmlite import ir + +import numpy as np +import operator + +from numba.core.imputils import impl_ret_borrowed, impl_ret_new_ref +from numba.cuda.typing import signature +from numba.cuda.extending import intrinsic, overload, register_jitable +from numba.core import types +from numba.cuda import cgutils +from numba.core.errors import ( + TypingError, + NumbaTypeError, + NumbaPerformanceWarning, +) +from .arrayobj import make_array, array_copy +from numba.cuda.np import numpy_support as np_support + +ll_char = ir.IntType(8) +ll_char_p = ll_char.as_pointer() +ll_void_p = ll_char_p +ll_intc = ir.IntType(32) +ll_intc_p = ll_intc.as_pointer() +intp_t = cgutils.intp_t +ll_intp_p = intp_t.as_pointer() + + +# fortran int type, this needs to match the F_INT C declaration in +# _lapack.c and is present to accommodate potential future 64bit int +# based LAPACK use. +F_INT_nptype = np.int32 +F_INT_nbtype = types.int32 + +# BLAS kinds as letters +_blas_kinds = { + types.float32: "s", + types.float64: "d", + types.complex64: "c", + types.complex128: "z", +} + + +def get_blas_kind(dtype, func_name=""): + kind = _blas_kinds.get(dtype) + if kind is None: + raise NumbaTypeError("unsupported dtype for %s()" % (func_name,)) + return kind + + +def ensure_blas(): + try: + import scipy.linalg.cython_blas # noqa: F401 + except ImportError: + raise ImportError("scipy 0.16+ is required for linear algebra") + + +def ensure_lapack(): + try: + import scipy.linalg.cython_lapack # noqa: F401 + except ImportError: + raise ImportError("scipy 0.16+ is required for linear algebra") + + +def make_constant_slot(context, builder, ty, val): + const = context.get_constant_generic(builder, ty, val) + return cgutils.alloca_once_value(builder, const) + + +class _BLAS: + """ + Functions to return type signatures for wrapped + BLAS functions. + """ + + def __init__(self): + ensure_blas() + + @classmethod + def numba_xxnrm2(cls, dtype): + rtype = getattr(dtype, "underlying_float", dtype) + sig = types.intc( + types.char, # kind + types.intp, # n + types.CPointer(dtype), # x + types.intp, # incx + types.CPointer(rtype), + ) # returned + + return types.ExternalFunction("numba_xxnrm2", sig) + + @classmethod + def numba_xxgemm(cls, dtype): + sig = types.intc( + types.char, # kind + types.char, # transa + types.char, # transb + types.intp, # m + types.intp, # n + types.intp, # k + types.CPointer(dtype), # alpha + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(dtype), # b + types.intp, # ldb + types.CPointer(dtype), # beta + types.CPointer(dtype), # c + types.intp, # ldc + ) + return types.ExternalFunction("numba_xxgemm", sig) + + +class _LAPACK: + """ + Functions to return type signatures for wrapped + LAPACK functions. + """ + + def __init__(self): + ensure_lapack() + + @classmethod + def numba_xxgetrf(cls, dtype): + sig = types.intc( + types.char, # kind + types.intp, # m + types.intp, # n + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(F_INT_nbtype), # ipiv + ) + return types.ExternalFunction("numba_xxgetrf", sig) + + @classmethod + def numba_ez_xxgetri(cls, dtype): + sig = types.intc( + types.char, # kind + types.intp, # n + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(F_INT_nbtype), # ipiv + ) + return types.ExternalFunction("numba_ez_xxgetri", sig) + + @classmethod + def numba_ez_rgeev(cls, dtype): + sig = types.intc( + types.char, # kind + types.char, # jobvl + types.char, # jobvr + types.intp, # n + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(dtype), # wr + types.CPointer(dtype), # wi + types.CPointer(dtype), # vl + types.intp, # ldvl + types.CPointer(dtype), # vr + types.intp, # ldvr + ) + return types.ExternalFunction("numba_ez_rgeev", sig) + + @classmethod + def numba_ez_cgeev(cls, dtype): + sig = types.intc( + types.char, # kind + types.char, # jobvl + types.char, # jobvr + types.intp, # n + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(dtype), # w + types.CPointer(dtype), # vl + types.intp, # ldvl + types.CPointer(dtype), # vr + types.intp, # ldvr + ) + return types.ExternalFunction("numba_ez_cgeev", sig) + + @classmethod + def numba_ez_xxxevd(cls, dtype): + wtype = getattr(dtype, "underlying_float", dtype) + sig = types.intc( + types.char, # kind + types.char, # jobz + types.char, # uplo + types.intp, # n + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(wtype), # w + ) + return types.ExternalFunction("numba_ez_xxxevd", sig) + + @classmethod + def numba_xxpotrf(cls, dtype): + sig = types.intc( + types.char, # kind + types.char, # uplo + types.intp, # n + types.CPointer(dtype), # a + types.intp, # lda + ) + return types.ExternalFunction("numba_xxpotrf", sig) + + @classmethod + def numba_ez_gesdd(cls, dtype): + stype = getattr(dtype, "underlying_float", dtype) + sig = types.intc( + types.char, # kind + types.char, # jobz + types.intp, # m + types.intp, # n + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(stype), # s + types.CPointer(dtype), # u + types.intp, # ldu + types.CPointer(dtype), # vt + types.intp, # ldvt + ) + + return types.ExternalFunction("numba_ez_gesdd", sig) + + @classmethod + def numba_ez_geqrf(cls, dtype): + sig = types.intc( + types.char, # kind + types.intp, # m + types.intp, # n + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(dtype), # tau + ) + return types.ExternalFunction("numba_ez_geqrf", sig) + + @classmethod + def numba_ez_xxgqr(cls, dtype): + sig = types.intc( + types.char, # kind + types.intp, # m + types.intp, # n + types.intp, # k + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(dtype), # tau + ) + return types.ExternalFunction("numba_ez_xxgqr", sig) + + @classmethod + def numba_ez_gelsd(cls, dtype): + rtype = getattr(dtype, "underlying_float", dtype) + sig = types.intc( + types.char, # kind + types.intp, # m + types.intp, # n + types.intp, # nrhs + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(dtype), # b + types.intp, # ldb + types.CPointer(rtype), # S + types.float64, # rcond + types.CPointer(types.intc), # rank + ) + return types.ExternalFunction("numba_ez_gelsd", sig) + + @classmethod + def numba_xgesv(cls, dtype): + sig = types.intc( + types.char, # kind + types.intp, # n + types.intp, # nhrs + types.CPointer(dtype), # a + types.intp, # lda + types.CPointer(F_INT_nbtype), # ipiv + types.CPointer(dtype), # b + types.intp, # ldb + ) + return types.ExternalFunction("numba_xgesv", sig) + + +@contextlib.contextmanager +def make_contiguous(context, builder, sig, args): + """ + Ensure that all array arguments are contiguous, if necessary by + copying them. + A new (sig, args) tuple is yielded. + """ + newtys = [] + newargs = [] + copies = [] + for ty, val in zip(sig.args, args): + if not isinstance(ty, types.Array) or ty.layout in "CF": + newty, newval = ty, val + else: + newty = ty.copy(layout="C") + copysig = signature(newty, ty) + newval = array_copy(context, builder, copysig, (val,)) + copies.append((newty, newval)) + newtys.append(newty) + newargs.append(newval) + yield signature(sig.return_type, *newtys), tuple(newargs) + for ty, val in copies: + context.nrt.decref(builder, ty, val) + + +def check_c_int(context, builder, n): + """ + Check whether *n* fits in a C `int`. + """ + _maxint = 2**31 - 1 + + def impl(n): + if n > _maxint: + raise OverflowError("array size too large to fit in C int") + + context.compile_internal( + builder, impl, signature(types.none, types.intp), (n,) + ) + + +def check_blas_return(context, builder, res): + """ + Check the integer error return from one of the BLAS wrappers in + _helperlib.c. + """ + with builder.if_then(cgutils.is_not_null(builder, res), likely=False): + # Those errors shouldn't happen, it's easier to just abort the process + pyapi = context.get_python_api(builder) + pyapi.gil_ensure() + pyapi.fatal_error("BLAS wrapper returned with an error") + + +def check_lapack_return(context, builder, res): + """ + Check the integer error return from one of the LAPACK wrappers in + _helperlib.c. + """ + with builder.if_then(cgutils.is_not_null(builder, res), likely=False): + # Those errors shouldn't happen, it's easier to just abort the process + pyapi = context.get_python_api(builder) + pyapi.gil_ensure() + pyapi.fatal_error("LAPACK wrapper returned with an error") + + +def call_xxdot(context, builder, conjugate, dtype, n, a_data, b_data, out_data): + """ + Call the BLAS vector * vector product function for the given arguments. + """ + fnty = ir.FunctionType( + ir.IntType(32), + [ + ll_char, + ll_char, + intp_t, # kind, conjugate, n + ll_void_p, + ll_void_p, + ll_void_p, # a, b, out + ], + ) + fn = cgutils.get_or_insert_function(builder.module, fnty, "numba_xxdot") + + kind = get_blas_kind(dtype) + kind_val = ir.Constant(ll_char, ord(kind)) + conjugate = ir.Constant(ll_char, int(conjugate)) + + res = builder.call( + fn, + ( + kind_val, + conjugate, + n, + builder.bitcast(a_data, ll_void_p), + builder.bitcast(b_data, ll_void_p), + builder.bitcast(out_data, ll_void_p), + ), + ) + check_blas_return(context, builder, res) + + +def call_xxgemv( + context, builder, do_trans, m_type, m_shapes, m_data, v_data, out_data +): + """ + Call the BLAS matrix * vector product function for the given arguments. + """ + fnty = ir.FunctionType( + ir.IntType(32), + [ + ll_char, + ll_char, # kind, trans + intp_t, + intp_t, # m, n + ll_void_p, + ll_void_p, + intp_t, # alpha, a, lda + ll_void_p, + ll_void_p, + ll_void_p, # x, beta, y + ], + ) + fn = cgutils.get_or_insert_function(builder.module, fnty, "numba_xxgemv") + + dtype = m_type.dtype + alpha = make_constant_slot(context, builder, dtype, 1.0) + beta = make_constant_slot(context, builder, dtype, 0.0) + + if m_type.layout == "F": + m, n = m_shapes + lda = m_shapes[0] + else: + n, m = m_shapes + lda = m_shapes[1] + + kind = get_blas_kind(dtype) + kind_val = ir.Constant(ll_char, ord(kind)) + trans = ir.Constant(ll_char, ord("t") if do_trans else ord("n")) + + res = builder.call( + fn, + ( + kind_val, + trans, + m, + n, + builder.bitcast(alpha, ll_void_p), + builder.bitcast(m_data, ll_void_p), + lda, + builder.bitcast(v_data, ll_void_p), + builder.bitcast(beta, ll_void_p), + builder.bitcast(out_data, ll_void_p), + ), + ) + check_blas_return(context, builder, res) + + +def call_xxgemm( + context, + builder, + x_type, + x_shapes, + x_data, + y_type, + y_shapes, + y_data, + out_type, + out_shapes, + out_data, +): + """ + Call the BLAS matrix * matrix product function for the given arguments. + """ + fnty = ir.FunctionType( + ir.IntType(32), + [ + ll_char, # kind + ll_char, + ll_char, # transa, transb + intp_t, + intp_t, + intp_t, # m, n, k + ll_void_p, + ll_void_p, + intp_t, # alpha, a, lda + ll_void_p, + intp_t, + ll_void_p, # b, ldb, beta + ll_void_p, + intp_t, # c, ldc + ], + ) + fn = cgutils.get_or_insert_function(builder.module, fnty, "numba_xxgemm") + + m, k = x_shapes + _k, n = y_shapes + dtype = x_type.dtype + alpha = make_constant_slot(context, builder, dtype, 1.0) + beta = make_constant_slot(context, builder, dtype, 0.0) + + trans = ir.Constant(ll_char, ord("t")) + notrans = ir.Constant(ll_char, ord("n")) + + def get_array_param(ty, shapes, data): + return ( + # Transpose if layout different from result's + notrans if ty.layout == out_type.layout else trans, + # Size of the inner dimension in physical array order + shapes[1] if ty.layout == "C" else shapes[0], + # The data pointer, unit-less + builder.bitcast(data, ll_void_p), + ) + + transa, lda, data_a = get_array_param(y_type, y_shapes, y_data) + transb, ldb, data_b = get_array_param(x_type, x_shapes, x_data) + _, ldc, data_c = get_array_param(out_type, out_shapes, out_data) + + kind = get_blas_kind(dtype) + kind_val = ir.Constant(ll_char, ord(kind)) + + res = builder.call( + fn, + ( + kind_val, + transa, + transb, + n, + m, + k, + builder.bitcast(alpha, ll_void_p), + data_a, + lda, + data_b, + ldb, + builder.bitcast(beta, ll_void_p), + data_c, + ldc, + ), + ) + check_blas_return(context, builder, res) + + +def dot_2_mm(context, builder, sig, args): + """ + np.dot(matrix, matrix) + """ + + def dot_impl(a, b): + m, k = a.shape + _k, n = b.shape + if k == 0: + return np.zeros((m, n), a.dtype) + out = np.empty((m, n), a.dtype) + return np.dot(a, b, out) + + res = context.compile_internal(builder, dot_impl, sig, args) + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +def dot_2_vm(context, builder, sig, args): + """ + np.dot(vector, matrix) + """ + + def dot_impl(a, b): + (m,) = a.shape + _m, n = b.shape + if m == 0: + return np.zeros((n,), a.dtype) + out = np.empty((n,), a.dtype) + return np.dot(a, b, out) + + res = context.compile_internal(builder, dot_impl, sig, args) + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +def dot_2_mv(context, builder, sig, args): + """ + np.dot(matrix, vector) + """ + + def dot_impl(a, b): + m, n = a.shape + (_n,) = b.shape + if n == 0: + return np.zeros((m,), a.dtype) + out = np.empty((m,), a.dtype) + return np.dot(a, b, out) + + res = context.compile_internal(builder, dot_impl, sig, args) + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +def dot_2_vv(context, builder, sig, args, conjugate=False): + """ + np.dot(vector, vector) + np.vdot(vector, vector) + """ + aty, bty = sig.args + dtype = sig.return_type + a = make_array(aty)(context, builder, args[0]) + b = make_array(bty)(context, builder, args[1]) + (n,) = cgutils.unpack_tuple(builder, a.shape) + + def check_args(a, b): + (m,) = a.shape + (n,) = b.shape + if m != n: + raise ValueError( + "incompatible array sizes for np.dot(a, b) (vector * vector)" + ) + + context.compile_internal( + builder, check_args, signature(types.none, *sig.args), args + ) + check_c_int(context, builder, n) + + out = cgutils.alloca_once(builder, context.get_value_type(dtype)) + call_xxdot(context, builder, conjugate, dtype, n, a.data, b.data, out) + return builder.load(out) + + +@overload(np.dot) +def dot_2(left, right): + """ + np.dot(a, b) + """ + return dot_2_impl("np.dot()", left, right) + + +@overload(operator.matmul) +def matmul_2(left, right): + """ + a @ b + """ + return dot_2_impl("'@'", left, right) + + +def dot_2_impl(name, left, right): + if isinstance(left, types.Array) and isinstance(right, types.Array): + + @intrinsic + def _impl(typingcontext, left, right): + ndims = (left.ndim, right.ndim) + + def _dot2_codegen(context, builder, sig, args): + ensure_blas() + + with make_contiguous(context, builder, sig, args) as ( + sig, + args, + ): + if ndims == (2, 2): + return dot_2_mm(context, builder, sig, args) + elif ndims == (2, 1): + return dot_2_mv(context, builder, sig, args) + elif ndims == (1, 2): + return dot_2_vm(context, builder, sig, args) + elif ndims == (1, 1): + return dot_2_vv(context, builder, sig, args) + else: + raise AssertionError("unreachable") + + if left.dtype != right.dtype: + raise TypingError( + "%s arguments must all have the same dtype" % name + ) + + if ndims == (2, 2): + return_type = types.Array(left.dtype, 2, "C") + elif ndims == (2, 1) or ndims == (1, 2): + return_type = types.Array(left.dtype, 1, "C") + elif ndims == (1, 1): + return_type = left.dtype + else: + raise TypingError( + ("%s: inputs must have compatible dimensions") % name + ) + return signature(return_type, left, right), _dot2_codegen + + if left.layout not in "CF" or right.layout not in "CF": + warnings.warn( + "%s is faster on contiguous arrays, called on %s" + % ( + name, + (left, right), + ), + NumbaPerformanceWarning, + ) + + return lambda left, right: _impl(left, right) + + +@overload(np.vdot) +def vdot(left, right): + """ + np.vdot(a, b) + """ + if isinstance(left, types.Array) and isinstance(right, types.Array): + + @intrinsic + def _impl(typingcontext, left, right): + def codegen(context, builder, sig, args): + ensure_blas() + + with make_contiguous(context, builder, sig, args) as ( + sig, + args, + ): + return dot_2_vv(context, builder, sig, args, conjugate=True) + + if left.ndim != 1 or right.ndim != 1: + raise TypingError("np.vdot() only supported on 1-D arrays") + + if left.dtype != right.dtype: + raise TypingError( + "np.vdot() arguments must all have the same dtype" + ) + return signature(left.dtype, left, right), codegen + + if left.layout not in "CF" or right.layout not in "CF": + warnings.warn( + "np.vdot() is faster on contiguous arrays, called on %s" + % ((left, right),), + NumbaPerformanceWarning, + ) + + return lambda left, right: _impl(left, right) + + +def dot_3_vm_check_args(a, b, out): + (m,) = a.shape + _m, n = b.shape + if m != _m: + raise ValueError( + "incompatible array sizes for np.dot(a, b) (vector * matrix)" + ) + if out.shape != (n,): + raise ValueError( + "incompatible output array size for " + "np.dot(a, b, out) (vector * matrix)" + ) + + +def dot_3_mv_check_args(a, b, out): + m, _n = a.shape + (n,) = b.shape + if n != _n: + raise ValueError( + "incompatible array sizes for np.dot(a, b) (matrix * vector)" + ) + if out.shape != (m,): + raise ValueError( + "incompatible output array size for " + "np.dot(a, b, out) (matrix * vector)" + ) + + +def dot_3_vm(context, builder, sig, args): + """ + np.dot(vector, matrix, out) + np.dot(matrix, vector, out) + """ + xty, yty, outty = sig.args + assert outty == sig.return_type + + x = make_array(xty)(context, builder, args[0]) + y = make_array(yty)(context, builder, args[1]) + out = make_array(outty)(context, builder, args[2]) + x_shapes = cgutils.unpack_tuple(builder, x.shape) + y_shapes = cgutils.unpack_tuple(builder, y.shape) + out_shapes = cgutils.unpack_tuple(builder, out.shape) # noqa: F841 + if xty.ndim < yty.ndim: + # Vector * matrix + # Asked for x * y, we will compute y.T * x + mty = yty + m_shapes = y_shapes + v_shape = x_shapes[0] + lda = m_shapes[1] + do_trans = yty.layout == "F" + m_data, v_data = y.data, x.data + check_args = dot_3_vm_check_args + else: + # Matrix * vector + # We will compute x * y + mty = xty + m_shapes = x_shapes + v_shape = y_shapes[0] + lda = m_shapes[0] + do_trans = xty.layout == "C" + m_data, v_data = x.data, y.data + check_args = dot_3_mv_check_args + + context.compile_internal( + builder, check_args, signature(types.none, *sig.args), args + ) + for val in m_shapes: + check_c_int(context, builder, val) + + zero = context.get_constant(types.intp, 0) + both_empty = builder.icmp_signed("==", v_shape, zero) + matrix_empty = builder.icmp_signed("==", lda, zero) + is_empty = builder.or_(both_empty, matrix_empty) + with builder.if_else(is_empty, likely=False) as (empty, nonempty): + with empty: + cgutils.memset( + builder, out.data, builder.mul(out.itemsize, out.nitems), 0 + ) + with nonempty: + call_xxgemv( + context, + builder, + do_trans, + mty, + m_shapes, + m_data, + v_data, + out.data, + ) + + return impl_ret_borrowed(context, builder, sig.return_type, out._getvalue()) + + +def dot_3_mm(context, builder, sig, args): + """ + np.dot(matrix, matrix, out) + """ + xty, yty, outty = sig.args + assert outty == sig.return_type + dtype = xty.dtype + + x = make_array(xty)(context, builder, args[0]) + y = make_array(yty)(context, builder, args[1]) + out = make_array(outty)(context, builder, args[2]) + x_shapes = cgutils.unpack_tuple(builder, x.shape) + y_shapes = cgutils.unpack_tuple(builder, y.shape) + out_shapes = cgutils.unpack_tuple(builder, out.shape) + m, k = x_shapes + _k, n = y_shapes + + # The only case Numpy supports + assert outty.layout == "C" + + def check_args(a, b, out): + m, k = a.shape + _k, n = b.shape + if k != _k: + raise ValueError( + "incompatible array sizes for np.dot(a, b) (matrix * matrix)" + ) + if out.shape != (m, n): + raise ValueError( + "incompatible output array size for " + "np.dot(a, b, out) (matrix * matrix)" + ) + + context.compile_internal( + builder, check_args, signature(types.none, *sig.args), args + ) + + check_c_int(context, builder, m) + check_c_int(context, builder, k) + check_c_int(context, builder, n) + + x_data = x.data + y_data = y.data + out_data = out.data + + # If eliminated dimension is zero, set all entries to zero and return + zero = context.get_constant(types.intp, 0) + both_empty = builder.icmp_signed("==", k, zero) + x_empty = builder.icmp_signed("==", m, zero) + y_empty = builder.icmp_signed("==", n, zero) + is_empty = builder.or_(both_empty, builder.or_(x_empty, y_empty)) + with builder.if_else(is_empty, likely=False) as (empty, nonempty): + with empty: + cgutils.memset( + builder, out.data, builder.mul(out.itemsize, out.nitems), 0 + ) + with nonempty: + # Check if any of the operands is really a 1-d vector represented + # as a (1, k) or (k, 1) 2-d array. In those cases, it is pessimal + # to call the generic matrix * matrix product BLAS function. + one = context.get_constant(types.intp, 1) + is_left_vec = builder.icmp_signed("==", m, one) + is_right_vec = builder.icmp_signed("==", n, one) + + with builder.if_else(is_right_vec) as (r_vec, r_mat): + with r_vec: + with builder.if_else(is_left_vec) as (v_v, m_v): + with v_v: + # V * V + call_xxdot( + context, + builder, + False, + dtype, + k, + x_data, + y_data, + out_data, + ) + with m_v: + # M * V + do_trans = xty.layout == outty.layout + call_xxgemv( + context, + builder, + do_trans, + xty, + x_shapes, + x_data, + y_data, + out_data, + ) + with r_mat: + with builder.if_else(is_left_vec) as (v_m, m_m): + with v_m: + # V * M + do_trans = yty.layout != outty.layout + call_xxgemv( + context, + builder, + do_trans, + yty, + y_shapes, + y_data, + x_data, + out_data, + ) + with m_m: + # M * M + call_xxgemm( + context, + builder, + xty, + x_shapes, + x_data, + yty, + y_shapes, + y_data, + outty, + out_shapes, + out_data, + ) + + return impl_ret_borrowed(context, builder, sig.return_type, out._getvalue()) + + +@overload(np.dot) +def dot_3(left, right, out): + """ + np.dot(a, b, out) + """ + if ( + isinstance(left, types.Array) + and isinstance(right, types.Array) + and isinstance(out, types.Array) + ): + + @intrinsic + def _impl(typingcontext, left, right, out): + def codegen(context, builder, sig, args): + ensure_blas() + + with make_contiguous(context, builder, sig, args) as ( + sig, + args, + ): + ndims = set(x.ndim for x in sig.args[:2]) + if ndims == {2}: + return dot_3_mm(context, builder, sig, args) + elif ndims == {1, 2}: + return dot_3_vm(context, builder, sig, args) + else: + raise AssertionError("unreachable") + + if left.dtype != right.dtype or left.dtype != out.dtype: + raise TypingError( + "np.dot() arguments must all have the same dtype" + ) + + return signature(out, left, right, out), codegen + + if ( + left.layout not in "CF" + or right.layout not in "CF" + or out.layout not in "CF" + ): + warnings.warn( + "np.vdot() is faster on contiguous arrays, called on %s" + % ((left, right),), + NumbaPerformanceWarning, + ) + + return lambda left, right, out: _impl(left, right, out) + + +fatal_error_func = types.ExternalFunction("numba_fatal_error", types.intc()) + + +@register_jitable +def _check_finite_matrix(a): + for v in np.nditer(a): + if not np.isfinite(v.item()): + raise np.linalg.LinAlgError("Array must not contain infs or NaNs.") + + +def _check_linalg_matrix(a, func_name, la_prefix=True): + # la_prefix is present as some functions, e.g. np.trace() + # are documented under "linear algebra" but aren't in the + # module + prefix = "np.linalg" if la_prefix else "np" + interp = (prefix, func_name) + # Unpack optional type + if isinstance(a, types.Optional): + a = a.type + if not isinstance(a, types.Array): + msg = "%s.%s() only supported for array types" % interp + raise TypingError(msg, highlighting=False) + if not a.ndim == 2: + msg = "%s.%s() only supported on 2-D arrays." % interp + raise TypingError(msg, highlighting=False) + if not isinstance(a.dtype, (types.Float, types.Complex)): + msg = "%s.%s() only supported on float and complex arrays." % interp + raise TypingError(msg, highlighting=False) + + +def _check_homogeneous_types(func_name, *types): + t0 = types[0].dtype + for t in types[1:]: + if t.dtype != t0: + msg = ( + "np.linalg.%s() only supports inputs that have homogeneous dtypes." + % func_name + ) + raise TypingError(msg, highlighting=False) + + +def _copy_to_fortran_order(): + pass + + +@overload(_copy_to_fortran_order) +def ol_copy_to_fortran_order(a): + # This function copies the array 'a' into a new array with fortran order. + # This exists because the copy routines don't take order flags yet. + F_layout = a.layout == "F" + A_layout = a.layout == "A" + + def impl(a): + if F_layout: + # it's F ordered at compile time, just copy + acpy = np.copy(a) + elif A_layout: + # decide based on runtime value + flag_f = a.flags.f_contiguous + if flag_f: + # it's already F ordered, so copy but in a round about way to + # ensure that the copy is also F ordered + acpy = np.copy(a.T).T + else: + # it's something else ordered, so let asfortranarray deal with + # copying and making it fortran ordered + acpy = np.asfortranarray(a) + else: + # it's C ordered at compile time, asfortranarray it. + acpy = np.asfortranarray(a) + return acpy + + return impl + + +@register_jitable +def _inv_err_handler(r): + if r != 0: + if r < 0: + fatal_error_func() + assert 0 # unreachable + if r > 0: + raise np.linalg.LinAlgError( + "Matrix is singular to machine precision." + ) + + +@register_jitable +def _dummy_liveness_func(a): + """pass a list of variables to be preserved through dead code elimination""" + return a[0] + + +@overload(np.linalg.inv) +def inv_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "inv") + + numba_xxgetrf = _LAPACK().numba_xxgetrf(a.dtype) + + numba_xxgetri = _LAPACK().numba_ez_xxgetri(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "inv")) + + def inv_impl(a): + n = a.shape[-1] + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + if n == 0: + return acpy + + ipiv = np.empty(n, dtype=F_INT_nptype) + + r = numba_xxgetrf(kind, n, n, acpy.ctypes, n, ipiv.ctypes) + _inv_err_handler(r) + + r = numba_xxgetri(kind, n, acpy.ctypes, n, ipiv.ctypes) + _inv_err_handler(r) + + # help liveness analysis + _dummy_liveness_func([acpy.size, ipiv.size]) + return acpy + + return inv_impl + + +@register_jitable +def _handle_err_maybe_convergence_problem(r): + if r != 0: + if r < 0: + fatal_error_func() + assert 0 # unreachable + if r > 0: + raise ValueError("Internal algorithm failed to converge.") + + +def _check_linalg_1_or_2d_matrix(a, func_name, la_prefix=True): + # la_prefix is present as some functions, e.g. np.trace() + # are documented under "linear algebra" but aren't in the + # module + prefix = "np.linalg" if la_prefix else "np" + interp = (prefix, func_name) + # checks that a matrix is 1 or 2D + if not isinstance(a, types.Array): + raise TypingError("%s.%s() only supported for array types " % interp) + if not a.ndim <= 2: + raise TypingError( + "%s.%s() only supported on 1 and 2-D arrays " % interp + ) + if not isinstance(a.dtype, (types.Float, types.Complex)): + raise TypingError( + "%s.%s() only supported on float and complex arrays." % interp + ) + + +@overload(np.linalg.cholesky) +def cho_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "cholesky") + + numba_xxpotrf = _LAPACK().numba_xxpotrf(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "cholesky")) + UP = ord("U") + LO = ord("L") # noqa: F841 + + def cho_impl(a): + n = a.shape[-1] + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + # The output is allocated in C order + out = a.copy() + + if n == 0: + return out + + # Pass UP since xxpotrf() operates in F order + # The semantics ensure this works fine + # (out is really its Hermitian in F order, but UP instructs + # xxpotrf to compute the Hermitian of the upper triangle + # => they cancel each other) + r = numba_xxpotrf(kind, UP, n, out.ctypes, n) + if r != 0: + if r < 0: + fatal_error_func() + assert 0 # unreachable + if r > 0: + raise np.linalg.LinAlgError("Matrix is not positive definite.") + # Zero out upper triangle, in F order + for col in range(n): + out[:col, col] = 0 + return out + + return cho_impl + + +@overload(np.linalg.eig) +def eig_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "eig") + + numba_ez_rgeev = _LAPACK().numba_ez_rgeev(a.dtype) + numba_ez_cgeev = _LAPACK().numba_ez_cgeev(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "eig")) + + JOBVL = ord("N") + JOBVR = ord("V") + + def real_eig_impl(a): + """ + eig() implementation for real arrays. + """ + n = a.shape[-1] + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + ldvl = 1 + ldvr = n + wr = np.empty(n, dtype=a.dtype) + wi = np.empty(n, dtype=a.dtype) + vl = np.empty((n, ldvl), dtype=a.dtype) + vr = np.empty((n, ldvr), dtype=a.dtype) + + if n == 0: + return (wr, vr.T) + + r = numba_ez_rgeev( + kind, + JOBVL, + JOBVR, + n, + acpy.ctypes, + n, + wr.ctypes, + wi.ctypes, + vl.ctypes, + ldvl, + vr.ctypes, + ldvr, + ) + _handle_err_maybe_convergence_problem(r) + + # By design numba does not support dynamic return types, however, + # Numpy does. Numpy uses this ability in the case of returning + # eigenvalues/vectors of a real matrix. The return type of + # np.linalg.eig(), when operating on a matrix in real space + # depends on the values present in the matrix itself (recalling + # that eigenvalues are the roots of the characteristic polynomial + # of the system matrix, which will by construction depend on the + # values present in the system matrix). As numba cannot handle + # the case of a runtime decision based domain change relative to + # the input type, if it is required numba raises as below. + if np.any(wi): + raise ValueError("eig() argument must not cause a domain change.") + + # put these in to help with liveness analysis, + # `.ctypes` doesn't keep the vars alive + _dummy_liveness_func([acpy.size, vl.size, vr.size, wr.size, wi.size]) + return (wr, vr.T) + + def cmplx_eig_impl(a): + """ + eig() implementation for complex arrays. + """ + n = a.shape[-1] + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + ldvl = 1 + ldvr = n + w = np.empty(n, dtype=a.dtype) + vl = np.empty((n, ldvl), dtype=a.dtype) + vr = np.empty((n, ldvr), dtype=a.dtype) + + if n == 0: + return (w, vr.T) + + r = numba_ez_cgeev( + kind, + JOBVL, + JOBVR, + n, + acpy.ctypes, + n, + w.ctypes, + vl.ctypes, + ldvl, + vr.ctypes, + ldvr, + ) + _handle_err_maybe_convergence_problem(r) + + # put these in to help with liveness analysis, + # `.ctypes` doesn't keep the vars alive + _dummy_liveness_func([acpy.size, vl.size, vr.size, w.size]) + return (w, vr.T) + + if isinstance(a.dtype, types.scalars.Complex): + return cmplx_eig_impl + else: + return real_eig_impl + + +@overload(np.linalg.eigvals) +def eigvals_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "eigvals") + + numba_ez_rgeev = _LAPACK().numba_ez_rgeev(a.dtype) + numba_ez_cgeev = _LAPACK().numba_ez_cgeev(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "eigvals")) + + JOBVL = ord("N") + JOBVR = ord("N") + + def real_eigvals_impl(a): + """ + eigvals() implementation for real arrays. + """ + n = a.shape[-1] + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + ldvl = 1 + ldvr = 1 + wr = np.empty(n, dtype=a.dtype) + + if n == 0: + return wr + + wi = np.empty(n, dtype=a.dtype) + + # not referenced but need setting for MKL null check + vl = np.empty((1), dtype=a.dtype) + vr = np.empty((1), dtype=a.dtype) + + r = numba_ez_rgeev( + kind, + JOBVL, + JOBVR, + n, + acpy.ctypes, + n, + wr.ctypes, + wi.ctypes, + vl.ctypes, + ldvl, + vr.ctypes, + ldvr, + ) + _handle_err_maybe_convergence_problem(r) + + # By design numba does not support dynamic return types, however, + # Numpy does. Numpy uses this ability in the case of returning + # eigenvalues/vectors of a real matrix. The return type of + # np.linalg.eigvals(), when operating on a matrix in real space + # depends on the values present in the matrix itself (recalling + # that eigenvalues are the roots of the characteristic polynomial + # of the system matrix, which will by construction depend on the + # values present in the system matrix). As numba cannot handle + # the case of a runtime decision based domain change relative to + # the input type, if it is required numba raises as below. + if np.any(wi): + raise ValueError( + "eigvals() argument must not cause a domain change." + ) + + # put these in to help with liveness analysis, + # `.ctypes` doesn't keep the vars alive + _dummy_liveness_func([acpy.size, vl.size, vr.size, wr.size, wi.size]) + return wr + + def cmplx_eigvals_impl(a): + """ + eigvals() implementation for complex arrays. + """ + n = a.shape[-1] + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + ldvl = 1 + ldvr = 1 + w = np.empty(n, dtype=a.dtype) + + if n == 0: + return w + + vl = np.empty((1), dtype=a.dtype) + vr = np.empty((1), dtype=a.dtype) + + r = numba_ez_cgeev( + kind, + JOBVL, + JOBVR, + n, + acpy.ctypes, + n, + w.ctypes, + vl.ctypes, + ldvl, + vr.ctypes, + ldvr, + ) + _handle_err_maybe_convergence_problem(r) + + # put these in to help with liveness analysis, + # `.ctypes` doesn't keep the vars alive + _dummy_liveness_func([acpy.size, vl.size, vr.size, w.size]) + return w + + if isinstance(a.dtype, types.scalars.Complex): + return cmplx_eigvals_impl + else: + return real_eigvals_impl + + +@overload(np.linalg.eigh) +def eigh_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "eigh") + + # convert typing floats to numpy floats for use in the impl + w_type = getattr(a.dtype, "underlying_float", a.dtype) + w_dtype = np_support.as_dtype(w_type) + + numba_ez_xxxevd = _LAPACK().numba_ez_xxxevd(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "eigh")) + + JOBZ = ord("V") + UPLO = ord("L") + + def eigh_impl(a): + n = a.shape[-1] + + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + w = np.empty(n, dtype=w_dtype) + + if n == 0: + return (w, acpy) + + r = numba_ez_xxxevd( + kind, # kind + JOBZ, # jobz + UPLO, # uplo + n, # n + acpy.ctypes, # a + n, # lda + w.ctypes, # w + ) + _handle_err_maybe_convergence_problem(r) + + # help liveness analysis + _dummy_liveness_func([acpy.size, w.size]) + return (w, acpy) + + return eigh_impl + + +@overload(np.linalg.eigvalsh) +def eigvalsh_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "eigvalsh") + + # convert typing floats to numpy floats for use in the impl + w_type = getattr(a.dtype, "underlying_float", a.dtype) + w_dtype = np_support.as_dtype(w_type) + + numba_ez_xxxevd = _LAPACK().numba_ez_xxxevd(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "eigvalsh")) + + JOBZ = ord("N") + UPLO = ord("L") + + def eigvalsh_impl(a): + n = a.shape[-1] + + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + w = np.empty(n, dtype=w_dtype) + + if n == 0: + return w + + r = numba_ez_xxxevd( + kind, # kind + JOBZ, # jobz + UPLO, # uplo + n, # n + acpy.ctypes, # a + n, # lda + w.ctypes, # w + ) + _handle_err_maybe_convergence_problem(r) + + # help liveness analysis + _dummy_liveness_func([acpy.size, w.size]) + return w + + return eigvalsh_impl + + +@overload(np.linalg.svd) +def svd_impl(a, full_matrices=1): + ensure_lapack() + + _check_linalg_matrix(a, "svd") + + # convert typing floats to numpy floats for use in the impl + s_type = getattr(a.dtype, "underlying_float", a.dtype) + s_dtype = np_support.as_dtype(s_type) + + numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "svd")) + + JOBZ_A = ord("A") + JOBZ_S = ord("S") + + def svd_impl(a, full_matrices=1): + n = a.shape[-1] + m = a.shape[-2] + + if n == 0 or m == 0: + raise np.linalg.LinAlgError("Arrays cannot be empty") + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + ldu = m + minmn = min(m, n) + + if full_matrices: + JOBZ = JOBZ_A + ucol = m + ldvt = n + else: + JOBZ = JOBZ_S + ucol = minmn + ldvt = minmn + + u = np.empty((ucol, ldu), dtype=a.dtype) + s = np.empty(minmn, dtype=s_dtype) + vt = np.empty((n, ldvt), dtype=a.dtype) + + r = numba_ez_gesdd( + kind, # kind + JOBZ, # jobz + m, # m + n, # n + acpy.ctypes, # a + m, # lda + s.ctypes, # s + u.ctypes, # u + ldu, # ldu + vt.ctypes, # vt + ldvt, # ldvt + ) + _handle_err_maybe_convergence_problem(r) + + # help liveness analysis + _dummy_liveness_func([acpy.size, vt.size, u.size, s.size]) + return (u.T, s, vt.T) + + return svd_impl + + +@overload(np.linalg.qr) +def qr_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "qr") + + # Need two functions, the first computes R, storing it in the upper + # triangle of A with the below diagonal part of A containing elementary + # reflectors needed to construct Q. The second turns the below diagonal + # entries of A into Q, storing Q in A (creates orthonormal columns from + # the elementary reflectors). + + numba_ez_geqrf = _LAPACK().numba_ez_geqrf(a.dtype) + numba_ez_xxgqr = _LAPACK().numba_ez_xxgqr(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "qr")) + + def qr_impl(a): + n = a.shape[-1] + m = a.shape[-2] + + if n == 0 or m == 0: + raise np.linalg.LinAlgError("Arrays cannot be empty") + + _check_finite_matrix(a) + + # copy A as it will be destroyed + q = _copy_to_fortran_order(a) + + minmn = min(m, n) + tau = np.empty((minmn), dtype=a.dtype) + + ret = numba_ez_geqrf( + kind, # kind + m, # m + n, # n + q.ctypes, # a + m, # lda + tau.ctypes, # tau + ) + if ret < 0: + fatal_error_func() + assert 0 # unreachable + + # pull out R, this is transposed because of Fortran + r = np.zeros((n, minmn), dtype=a.dtype).T + + # the triangle in R + for i in range(minmn): + for j in range(i + 1): + r[j, i] = q[j, i] + + # and the possible square in R + for i in range(minmn, n): + for j in range(minmn): + r[j, i] = q[j, i] + + ret = numba_ez_xxgqr( + kind, # kind + m, # m + minmn, # n + minmn, # k + q.ctypes, # a + m, # lda + tau.ctypes, # tau + ) + _handle_err_maybe_convergence_problem(ret) + + # help liveness analysis + _dummy_liveness_func([tau.size, q.size]) + return (q[:, :minmn], r) + + return qr_impl + + +# helpers and jitted specialisations required for np.linalg.lstsq +# and np.linalg.solve. These functions have "system" in their name +# as a differentiator. + + +def _system_copy_in_b(bcpy, b, nrhs): + """ + Correctly copy 'b' into the 'bcpy' scratch space. + """ + raise NotImplementedError + + +@overload(_system_copy_in_b) +def _system_copy_in_b_impl(bcpy, b, nrhs): + if b.ndim == 1: + + def oneD_impl(bcpy, b, nrhs): + bcpy[: b.shape[-1], 0] = b + + return oneD_impl + else: + + def twoD_impl(bcpy, b, nrhs): + bcpy[: b.shape[-2], :nrhs] = b + + return twoD_impl + + +def _system_compute_nrhs(b): + """ + Compute the number of right hand sides in the system of equations + """ + raise NotImplementedError + + +@overload(_system_compute_nrhs) +def _system_compute_nrhs_impl(b): + if b.ndim == 1: + + def oneD_impl(b): + return 1 + + return oneD_impl + else: + + def twoD_impl(b): + return b.shape[-1] + + return twoD_impl + + +def _system_check_dimensionally_valid(a, b): + """ + Check that AX=B style system input is dimensionally valid. + """ + raise NotImplementedError + + +@overload(_system_check_dimensionally_valid) +def _system_check_dimensionally_valid_impl(a, b): + ndim = b.ndim + if ndim == 1: + + def oneD_impl(a, b): + am = a.shape[-2] + bm = b.shape[-1] + if am != bm: + raise np.linalg.LinAlgError( + "Incompatible array sizes, system is not dimensionally valid." + ) + + return oneD_impl + else: + + def twoD_impl(a, b): + am = a.shape[-2] + bm = b.shape[-2] + if am != bm: + raise np.linalg.LinAlgError( + "Incompatible array sizes, system is not dimensionally valid." + ) + + return twoD_impl + + +def _system_check_non_empty(a, b): + """ + Check that AX=B style system input is not empty. + """ + raise NotImplementedError + + +@overload(_system_check_non_empty) +def _system_check_non_empty_impl(a, b): + ndim = b.ndim + if ndim == 1: + + def oneD_impl(a, b): + am = a.shape[-2] + an = a.shape[-1] + bm = b.shape[-1] + if am == 0 or bm == 0 or an == 0: + raise np.linalg.LinAlgError("Arrays cannot be empty") + + return oneD_impl + else: + + def twoD_impl(a, b): + am = a.shape[-2] + an = a.shape[-1] + bm = b.shape[-2] + bn = b.shape[-1] + if am == 0 or bm == 0 or an == 0 or bn == 0: + raise np.linalg.LinAlgError("Arrays cannot be empty") + + return twoD_impl + + +def _lstsq_residual(b, n, nrhs): + """ + Compute the residual from the 'b' scratch space. + """ + raise NotImplementedError + + +@overload(_lstsq_residual) +def _lstsq_residual_impl(b, n, nrhs): + ndim = b.ndim + dtype = b.dtype + real_dtype = np_support.as_dtype(getattr(dtype, "underlying_float", dtype)) + + if ndim == 1: + if isinstance(dtype, (types.Complex)): + + def cmplx_impl(b, n, nrhs): + res = np.empty((1,), dtype=real_dtype) + res[0] = np.sum(np.abs(b[n:, 0]) ** 2) + return res + + return cmplx_impl + else: + + def real_impl(b, n, nrhs): + res = np.empty((1,), dtype=real_dtype) + res[0] = np.sum(b[n:, 0] ** 2) + return res + + return real_impl + else: + assert ndim == 2 + if isinstance(dtype, (types.Complex)): + + def cmplx_impl(b, n, nrhs): + res = np.empty((nrhs), dtype=real_dtype) + for k in range(nrhs): + res[k] = np.sum(np.abs(b[n:, k]) ** 2) + return res + + return cmplx_impl + else: + + def real_impl(b, n, nrhs): + res = np.empty((nrhs), dtype=real_dtype) + for k in range(nrhs): + res[k] = np.sum(b[n:, k] ** 2) + return res + + return real_impl + + +def _lstsq_solution(b, bcpy, n): + """ + Extract 'x' (the lstsq solution) from the 'bcpy' scratch space. + Note 'b' is only used to check the system input dimension... + """ + raise NotImplementedError + + +@overload(_lstsq_solution) +def _lstsq_solution_impl(b, bcpy, n): + if b.ndim == 1: + + def oneD_impl(b, bcpy, n): + return bcpy.T.ravel()[:n] + + return oneD_impl + else: + + def twoD_impl(b, bcpy, n): + return bcpy[:n, :].copy() + + return twoD_impl + + +@overload(np.linalg.lstsq) +def lstsq_impl(a, b, rcond=-1.0): + ensure_lapack() + + _check_linalg_matrix(a, "lstsq") + + # B can be 1D or 2D. + _check_linalg_1_or_2d_matrix(b, "lstsq") + + _check_homogeneous_types("lstsq", a, b) + + np_dt = np_support.as_dtype(a.dtype) + nb_dt = a.dtype + + # convert typing floats to np floats for use in the impl + r_type = getattr(nb_dt, "underlying_float", nb_dt) + real_dtype = np_support.as_dtype(r_type) + + # lapack solver + numba_ez_gelsd = _LAPACK().numba_ez_gelsd(a.dtype) + + kind = ord(get_blas_kind(nb_dt, "lstsq")) + + # The following functions select specialisations based on + # information around 'b', a lot of this effort is required + # as 'b' can be either 1D or 2D, and then there are + # some optimisations available depending on real or complex + # space. + + def lstsq_impl(a, b, rcond=-1.0): + n = a.shape[-1] + m = a.shape[-2] + nrhs = _system_compute_nrhs(b) + + # check the systems have no inf or NaN + _check_finite_matrix(a) + _check_finite_matrix(b) + + # check the system is not empty + _system_check_non_empty(a, b) + + # check the systems are dimensionally valid + _system_check_dimensionally_valid(a, b) + + minmn = min(m, n) + maxmn = max(m, n) + + # a is destroyed on exit, copy it + acpy = _copy_to_fortran_order(a) + + # b is overwritten on exit with the solution, copy allocate + bcpy = np.empty((nrhs, maxmn), dtype=np_dt).T + # specialised copy in due to b being 1 or 2D + _system_copy_in_b(bcpy, b, nrhs) + + # Allocate returns + s = np.empty(minmn, dtype=real_dtype) + rank_ptr = np.empty(1, dtype=np.int32) + + r = numba_ez_gelsd( + kind, # kind + m, # m + n, # n + nrhs, # nrhs + acpy.ctypes, # a + m, # lda + bcpy.ctypes, # a + maxmn, # ldb + s.ctypes, # s + rcond, # rcond + rank_ptr.ctypes, # rank + ) + _handle_err_maybe_convergence_problem(r) + + # set rank to that which was computed + rank = rank_ptr[0] + + # compute residuals + if rank < n or m <= n: + res = np.empty((0), dtype=real_dtype) + else: + # this requires additional dispatch as there's a faster + # impl if the result is in the real domain (no abs() required) + res = _lstsq_residual(bcpy, n, nrhs) + + # extract 'x', the solution + x = _lstsq_solution(b, bcpy, n) + + # help liveness analysis + _dummy_liveness_func([acpy.size, bcpy.size, s.size, rank_ptr.size]) + return (x, res, rank, s[:minmn]) + + return lstsq_impl + + +def _solve_compute_return(b, bcpy): + """ + Extract 'x' (the solution) from the 'bcpy' scratch space. + Note 'b' is only used to check the system input dimension... + """ + raise NotImplementedError + + +@overload(_solve_compute_return) +def _solve_compute_return_impl(b, bcpy): + if b.ndim == 1: + + def oneD_impl(b, bcpy): + return bcpy.T.ravel() + + return oneD_impl + else: + + def twoD_impl(b, bcpy): + return bcpy + + return twoD_impl + + +@overload(np.linalg.solve) +def solve_impl(a, b): + ensure_lapack() + + _check_linalg_matrix(a, "solve") + _check_linalg_1_or_2d_matrix(b, "solve") + + _check_homogeneous_types("solve", a, b) + + np_dt = np_support.as_dtype(a.dtype) + nb_dt = a.dtype + + # the lapack solver + numba_xgesv = _LAPACK().numba_xgesv(a.dtype) + + kind = ord(get_blas_kind(nb_dt, "solve")) + + def solve_impl(a, b): + n = a.shape[-1] + nrhs = _system_compute_nrhs(b) + + # check the systems have no inf or NaN + _check_finite_matrix(a) + _check_finite_matrix(b) + + # check the systems are dimensionally valid + _system_check_dimensionally_valid(a, b) + + # a is destroyed on exit, copy it + acpy = _copy_to_fortran_order(a) + + # b is overwritten on exit with the solution, copy allocate + bcpy = np.empty((nrhs, n), dtype=np_dt).T + if n == 0: + return _solve_compute_return(b, bcpy) + + # specialised copy in due to b being 1 or 2D + _system_copy_in_b(bcpy, b, nrhs) + + # allocate pivot array (needs to be fortran int size) + ipiv = np.empty(n, dtype=F_INT_nptype) + + r = numba_xgesv( + kind, # kind + n, # n + nrhs, # nhrs + acpy.ctypes, # a + n, # lda + ipiv.ctypes, # ipiv + bcpy.ctypes, # b + n, # ldb + ) + _inv_err_handler(r) + + # help liveness analysis + _dummy_liveness_func([acpy.size, bcpy.size, ipiv.size]) + return _solve_compute_return(b, bcpy) + + return solve_impl + + +@overload(np.linalg.pinv) +def pinv_impl(a, rcond=1.0e-15): + ensure_lapack() + + _check_linalg_matrix(a, "pinv") + + # convert typing floats to numpy floats for use in the impl + s_type = getattr(a.dtype, "underlying_float", a.dtype) + s_dtype = np_support.as_dtype(s_type) + + numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype) + + numba_xxgemm = _BLAS().numba_xxgemm(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "pinv")) + JOB = ord("S") + + # need conjugate transposes + TRANSA = ord("C") + TRANSB = ord("C") + + # scalar constants + dt = np_support.as_dtype(a.dtype) + zero = np.array([0.0], dtype=dt) + one = np.array([1.0], dtype=dt) + + def pinv_impl(a, rcond=1.0e-15): + # The idea is to build the pseudo-inverse via inverting the singular + # value decomposition of a matrix `A`. Mathematically, this is roughly + # A = U*S*V^H [The SV decomposition of A] + # A^+ = V*(S^+)*U^H [The inverted SV decomposition of A] + # where ^+ is pseudo inversion and ^H is Hermitian transpose. + # As V and U are unitary, their inverses are simply their Hermitian + # transpose. S has singular values on its diagonal and zero elsewhere, + # it is inverted trivially by reciprocal of the diagonal values with + # the exception that zero singular values remain as zero. + # + # The practical implementation can take advantage of a few things to + # gain a few % performance increase: + # * A is destroyed by the SVD algorithm from LAPACK so a copy is + # required, this memory is exactly the right size in which to return + # the pseudo-inverse and so can be reused for this purpose. + # * The pseudo-inverse of S can be applied to either V or U^H, this + # then leaves a GEMM operation to compute the inverse via either: + # A^+ = (V*(S^+))*U^H + # or + # A^+ = V*((S^+)*U^H) + # however application of S^+ to V^H or U is more convenient as they + # are the result of the SVD algorithm. The application of the + # diagonal system is just a matrix multiplication which results in a + # row/column scaling (direction depending). To save effort, this + # "matrix multiplication" is applied to the smallest of U or V^H and + # only up to the point of "cut-off" (see next note) just as a direct + # scaling. + # * The cut-off level for application of S^+ can be used to reduce + # total effort, this cut-off can come via rcond or may just naturally + # be present as a result of zeros in the singular values. Regardless + # there's no need to multiply by zeros in the application of S^+ to + # V^H or U as above. Further, the GEMM operation can be shrunk in + # effort by noting that the possible zero block generated by the + # presence of zeros in S^+ has no effect apart from wasting cycles as + # it is all fmadd()s where one operand is zero. The inner dimension + # of the GEMM operation can therefore be set as shrunk accordingly! + + n = a.shape[-1] + m = a.shape[-2] + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + if m == 0 or n == 0: + return acpy.T.ravel().reshape(a.shape).T + + minmn = min(m, n) + + u = np.empty((minmn, m), dtype=a.dtype) + s = np.empty(minmn, dtype=s_dtype) + vt = np.empty((n, minmn), dtype=a.dtype) + + r = numba_ez_gesdd( + kind, # kind + JOB, # job + m, # m + n, # n + acpy.ctypes, # a + m, # lda + s.ctypes, # s + u.ctypes, # u + m, # ldu + vt.ctypes, # vt + minmn, # ldvt + ) + _handle_err_maybe_convergence_problem(r) + + # Invert singular values under threshold. Also find the index of + # the threshold value as this is the upper limit for the application + # of the inverted singular values. Finding this value saves + # multiplication by a block of zeros that would be created by the + # application of these values to either U or V^H ahead of multiplying + # them together. This is done by simply in BLAS parlance via + # restricting the `k` dimension to `cut_idx` in `xgemm` whilst keeping + # the leading dimensions correct. + + cut_at = s[0] * rcond + cut_idx = 0 + for k in range(minmn): + if s[k] > cut_at: + s[k] = 1.0 / s[k] + cut_idx = k + cut_idx += 1 + + # Use cut_idx so there's no scaling by 0. + if m >= n: + # U is largest so apply S^+ to V^H. + for i in range(n): + for j in range(cut_idx): + vt[i, j] = vt[i, j] * s[j] + else: + # V^H is largest so apply S^+ to U. + for i in range(cut_idx): + s_local = s[i] + for j in range(minmn): + u[i, j] = u[i, j] * s_local + + # Do (v^H)^H*U^H (obviously one of the matrices includes the S^+ + # scaling) and write back to acpy. Note the innner dimension of cut_idx + # taking account of the possible zero block. + # We can store the result in acpy, given we had to create it + # for use in the SVD, and it is now redundant and the right size + # but wrong shape. + + r = numba_xxgemm( + kind, + TRANSA, # TRANSA + TRANSB, # TRANSB + n, # M + m, # N + cut_idx, # K + one.ctypes, # ALPHA + vt.ctypes, # A + minmn, # LDA + u.ctypes, # B + m, # LDB + zero.ctypes, # BETA + acpy.ctypes, # C + n, # LDC + ) + + # help liveness analysis + # acpy.size + # vt.size + # u.size + # s.size + # one.size + # zero.size + _dummy_liveness_func( + [acpy.size, vt.size, u.size, s.size, one.size, zero.size] + ) + return acpy.T.ravel().reshape(a.shape).T + + return pinv_impl + + +def _get_slogdet_diag_walker(a): + """ + Walks the diag of a LUP decomposed matrix + uses that det(A) = prod(diag(lup(A))) + and also that log(a)+log(b) = log(a*b) + The return sign is adjusted based on the values found + such that the log(value) stays in the real domain. + """ + if isinstance(a.dtype, types.Complex): + + @register_jitable + def cmplx_diag_walker(n, a, sgn): + # walk diagonal + csgn = sgn + 0.0j + acc = 0.0 + for k in range(n): + absel = np.abs(a[k, k]) + csgn = csgn * (a[k, k] / absel) + acc = acc + np.log(absel) + return (csgn, acc) + + return cmplx_diag_walker + else: + + @register_jitable + def real_diag_walker(n, a, sgn): + # walk diagonal + acc = 0.0 + for k in range(n): + v = a[k, k] + if v < 0.0: + sgn = -sgn + v = -v + acc = acc + np.log(v) + # sgn is a float dtype + return (sgn + 0.0, acc) + + return real_diag_walker + + +@overload(np.linalg.slogdet) +def slogdet_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "slogdet") + + numba_xxgetrf = _LAPACK().numba_xxgetrf(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "slogdet")) + + diag_walker = _get_slogdet_diag_walker(a) + + ONE = a.dtype(1) + ZERO = getattr(a.dtype, "underlying_float", a.dtype)(0) + + def slogdet_impl(a): + n = a.shape[-1] + if a.shape[-2] != n: + msg = "Last 2 dimensions of the array must be square." + raise np.linalg.LinAlgError(msg) + + if n == 0: + return (ONE, ZERO) + + _check_finite_matrix(a) + + acpy = _copy_to_fortran_order(a) + + ipiv = np.empty(n, dtype=F_INT_nptype) + + r = numba_xxgetrf(kind, n, n, acpy.ctypes, n, ipiv.ctypes) + + if r > 0: + # factorisation failed, return same defaults as np + return (0.0, -np.inf) + _inv_err_handler(r) # catch input-to-lapack problem + + # The following, prior to the call to diag_walker, is present + # to account for the effect of possible permutations to the + # sign of the determinant. + # This is the same idea as in numpy: + # File name `umath_linalg.c.src` e.g. + # https://github.com/numpy/numpy/blob/master/numpy/linalg/umath_linalg.c.src + # in function `@TYPE@_slogdet_single_element`. + sgn = 1 + for k in range(n): + sgn = sgn + (ipiv[k] != (k + 1)) + + sgn = sgn & 1 + if sgn == 0: + sgn = -1 + + # help liveness analysis + _dummy_liveness_func([ipiv.size]) + return diag_walker(n, acpy, sgn) + + return slogdet_impl + + +@overload(np.linalg.det) +def det_impl(a): + ensure_lapack() + + _check_linalg_matrix(a, "det") + + def det_impl(a): + (sgn, slogdet) = np.linalg.slogdet(a) + return sgn * np.exp(slogdet) + + return det_impl + + +def _compute_singular_values(a): + """ + Compute singular values of *a*. + """ + raise NotImplementedError + + +@overload(_compute_singular_values) +def _compute_singular_values_impl(a): + """ + Returns a function to compute singular values of `a` + """ + numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "svd")) + + # Flag for "only compute `S`" to give to xgesdd + JOBZ_N = ord("N") + + nb_ret_type = getattr(a.dtype, "underlying_float", a.dtype) + np_ret_type = np_support.as_dtype(nb_ret_type) + np_dtype = np_support.as_dtype(a.dtype) + + # These are not referenced in the computation but must be set + # for MKL. + u = np.empty((1, 1), dtype=np_dtype) + vt = np.empty((1, 1), dtype=np_dtype) + + def sv_function(a): + """ + Computes singular values. + """ + # Don't use the np.linalg.svd impl instead + # call LAPACK to shortcut doing the "reconstruct + # singular vectors from reflectors" step and just + # get back the singular values. + n = a.shape[-1] + m = a.shape[-2] + if m == 0 or n == 0: + raise np.linalg.LinAlgError("Arrays cannot be empty") + _check_finite_matrix(a) + + ldu = m + minmn = min(m, n) + + # need to be >=1 but aren't referenced + ucol = 1 # noqa: F841 + ldvt = 1 + + acpy = _copy_to_fortran_order(a) + + # u and vt are not referenced however need to be + # allocated (as done above) for MKL as it + # checks for ref is nullptr. + s = np.empty(minmn, dtype=np_ret_type) + + r = numba_ez_gesdd( + kind, # kind + JOBZ_N, # jobz + m, # m + n, # n + acpy.ctypes, # a + m, # lda + s.ctypes, # s + u.ctypes, # u + ldu, # ldu + vt.ctypes, # vt + ldvt, # ldvt + ) + _handle_err_maybe_convergence_problem(r) + + # help liveness analysis + _dummy_liveness_func([acpy.size, vt.size, u.size, s.size]) + return s + + return sv_function + + +def _oneD_norm_2(a): + """ + Compute the L2-norm of 1D-array *a*. + """ + raise NotImplementedError + + +@overload(_oneD_norm_2) +def _oneD_norm_2_impl(a): + nb_ret_type = getattr(a.dtype, "underlying_float", a.dtype) + np_ret_type = np_support.as_dtype(nb_ret_type) + + xxnrm2 = _BLAS().numba_xxnrm2(a.dtype) + + kind = ord(get_blas_kind(a.dtype, "norm")) + + def impl(a): + # Just ignore order, calls are guarded to only come + # from cases where order=None or order=2. + n = len(a) + # Call L2-norm routine from BLAS + ret = np.empty((1,), dtype=np_ret_type) + jmp = int(a.strides[0] / a.itemsize) + r = xxnrm2( + kind, # kind + n, # n + a.ctypes, # x + jmp, # incx + ret.ctypes, # result + ) + if r < 0: + fatal_error_func() + assert 0 # unreachable + + # help liveness analysis + # ret.size + # a.size + _dummy_liveness_func([ret.size, a.size]) + return ret[0] + + return impl + + +def _get_norm_impl(x, ord_flag): + # This function is quite involved as norm supports a large + # range of values to select different norm types via kwarg `ord`. + # The implementation below branches on dimension of the input + # (1D or 2D). The default for `ord` is `None` which requires + # special handling in numba, this is dealt with first in each of + # the dimension branches. Following this the various norms are + # computed via code that is in most cases simply a loop version + # of a ufunc based version as found in numpy. + + # The following is common to both 1D and 2D cases. + # Convert typing floats to numpy floats for use in the impl. + # The return type is always a float, numba differs from numpy in + # that it returns an input precision specific value whereas numpy + # always returns np.float64. + nb_ret_type = getattr(x.dtype, "underlying_float", x.dtype) + np_ret_type = np_support.as_dtype(nb_ret_type) + + np_dtype = np_support.as_dtype(x.dtype) # noqa: F841 + xxnrm2 = _BLAS().numba_xxnrm2(x.dtype) # noqa: F841 + kind = ord(get_blas_kind(x.dtype, "norm")) # noqa: F841 + + if x.ndim == 1: + # 1D cases + + # handle "ord" being "None", must be done separately + if ord_flag in (None, types.none): + + def oneD_impl(x, ord=None): + return _oneD_norm_2(x) + else: + + def oneD_impl(x, ord=None): + n = len(x) + + # Shortcut to handle zero length arrays + # this differs slightly to numpy in that + # numpy raises a ValueError for kwarg ord= + # +/-np.inf as the reduction operations like + # max() and min() don't accept zero length + # arrays + if n == 0: + return 0.0 + + # Note: on order == 2 + # This is the same as for ord=="None" but because + # we have to handle "None" specially this condition + # is separated + if ord == 2: + return _oneD_norm_2(x) + elif ord == np.inf: + # max(abs(x)) + ret = abs(x[0]) + for k in range(1, n): + val = abs(x[k]) + if val > ret: + ret = val + return ret + + elif ord == -np.inf: + # min(abs(x)) + ret = abs(x[0]) + for k in range(1, n): + val = abs(x[k]) + if val < ret: + ret = val + return ret + + elif ord == 0: + # sum(x != 0) + ret = 0.0 + for k in range(n): + if x[k] != 0.0: + ret += 1.0 + return ret + + elif ord == 1: + # sum(abs(x)) + ret = 0.0 + for k in range(n): + ret += abs(x[k]) + return ret + + else: + # sum(abs(x)**ord)**(1./ord) + ret = 0.0 + for k in range(n): + ret += abs(x[k]) ** ord + return ret ** (1.0 / ord) + + return oneD_impl + + elif x.ndim == 2: + # 2D cases + + # handle "ord" being "None" + if ord_flag in (None, types.none): + # Force `x` to be C-order, so that we can take a contiguous + # 1D view. + if x.layout == "C": + + @register_jitable + def array_prepare(x): + return x + elif x.layout == "F": + + @register_jitable + def array_prepare(x): + # Legal since L2(x) == L2(x.T) + return x.T + else: + + @register_jitable + def array_prepare(x): + return x.copy() + + # Compute the Frobenius norm, this is the L2,2 induced norm of `x` + # which is the L2-norm of x.ravel() and so can be computed via BLAS + def twoD_impl(x, ord=None): + n = x.size + if n == 0: + # reshape() currently doesn't support zero-sized arrays + return 0.0 + x_c = array_prepare(x) + return _oneD_norm_2(x_c.reshape(n)) + else: + # max value for this dtype + max_val = np.finfo(np_ret_type.type).max + + def twoD_impl(x, ord=None): + n = x.shape[-1] + m = x.shape[-2] + + # Shortcut to handle zero size arrays + # this differs slightly to numpy in that + # numpy raises errors for some ord values + # and in other cases returns zero. + if x.size == 0: + return 0.0 + + if ord == np.inf: + # max of sum of abs across rows + # max(sum(abs(x)), axis=1) + global_max = 0.0 + for ii in range(m): + tmp = 0.0 + for jj in range(n): + tmp += abs(x[ii, jj]) + if tmp > global_max: + global_max = tmp + return global_max + + elif ord == -np.inf: + # min of sum of abs across rows + # min(sum(abs(x)), axis=1) + global_min = max_val + for ii in range(m): + tmp = 0.0 + for jj in range(n): + tmp += abs(x[ii, jj]) + if tmp < global_min: + global_min = tmp + return global_min + elif ord == 1: + # max of sum of abs across cols + # max(sum(abs(x)), axis=0) + global_max = 0.0 + for ii in range(n): + tmp = 0.0 + for jj in range(m): + tmp += abs(x[jj, ii]) + if tmp > global_max: + global_max = tmp + return global_max + + elif ord == -1: + # min of sum of abs across cols + # min(sum(abs(x)), axis=0) + global_min = max_val + for ii in range(n): + tmp = 0.0 + for jj in range(m): + tmp += abs(x[jj, ii]) + if tmp < global_min: + global_min = tmp + return global_min + + # Results via SVD, singular values are sorted on return + # by definition. + elif ord == 2: + # max SV + return _compute_singular_values(x)[0] + elif ord == -2: + # min SV + return _compute_singular_values(x)[-1] + else: + # replicate numpy error + raise ValueError("Invalid norm order for matrices.") + + return twoD_impl + else: + assert 0 # unreachable + + +@overload(np.linalg.norm) +def norm_impl(x, ord=None): + ensure_lapack() + + _check_linalg_1_or_2d_matrix(x, "norm") + + return _get_norm_impl(x, ord) + + +@overload(np.linalg.cond) +def cond_impl(x, p=None): + ensure_lapack() + + _check_linalg_matrix(x, "cond") + + def impl(x, p=None): + # This is extracted for performance, numpy does approximately: + # `condition = norm(x) * norm(inv(x))` + # in the cases of `p == 2` or `p ==-2` singular values are used + # for computing norms. This costs numpy an svd of `x` then an + # inversion of `x` and another svd of `x`. + # Below is a different approach, which also gives a more + # accurate answer as there is no inversion involved. + # Recall that the singular values of an inverted matrix are the + # reciprocal of singular values of the original matrix. + # Therefore calling `svd(x)` once yields all the information + # needed about both `x` and `inv(x)` without the cost or + # potential loss of accuracy incurred through inversion. + # For the case of `p == 2`, the result is just the ratio of + # `largest singular value/smallest singular value`, and for the + # case of `p==-2` the result is simply the + # `smallest singular value/largest singular value`. + # As a result of this, numba accepts non-square matrices as + # input when p==+/-2 as well as when p==None. + if p == 2 or p == -2 or p is None: + s = _compute_singular_values(x) + if p == 2 or p is None: + r = np.divide(s[0], s[-1]) + else: + r = np.divide(s[-1], s[0]) + else: # cases np.inf, -np.inf, 1, -1 + norm_x = np.linalg.norm(x, p) + norm_inv_x = np.linalg.norm(np.linalg.inv(x), p) + r = norm_x * norm_inv_x + # NumPy uses a NaN mask, if the input has a NaN, it will return NaN, + # Numba calls ban NaN through the use of _check_finite_matrix but this + # catches cases where NaN occurs through floating point use + if np.isnan(r): + return np.inf + else: + return r + + return impl + + +@register_jitable +def _get_rank_from_singular_values(sv, t): + """ + Gets rank from singular values with cut-off at a given tolerance + """ + rank = 0 + for k in range(len(sv)): + if sv[k] > t: + rank = rank + 1 + else: # sv is ordered big->small so break on condition not met + break + return rank + + +@overload(np.linalg.matrix_rank) +def matrix_rank_impl(A, tol=None): + """ + Computes rank for matrices and vectors. + The only issue that may arise is that because numpy uses double + precision lapack calls whereas numba uses type specific lapack + calls, some singular values may differ and therefore counting the + number of them above a tolerance may lead to different counts, + and therefore rank, in some cases. + """ + ensure_lapack() + + _check_linalg_1_or_2d_matrix(A, "matrix_rank") + + def _2d_matrix_rank_impl(A, tol): + # handle the tol==None case separately for type inference to work + if tol in (None, types.none): + nb_type = getattr(A.dtype, "underlying_float", A.dtype) + np_type = np_support.as_dtype(nb_type) + eps_val = np.finfo(np_type).eps + + def _2d_tol_none_impl(A, tol=None): + s = _compute_singular_values(A) + # replicate numpy default tolerance calculation + r = A.shape[0] + c = A.shape[1] + l = max(r, c) + t = s[0] * l * eps_val + return _get_rank_from_singular_values(s, t) + + return _2d_tol_none_impl + else: + + def _2d_tol_not_none_impl(A, tol=None): + s = _compute_singular_values(A) + return _get_rank_from_singular_values(s, tol) + + return _2d_tol_not_none_impl + + def _get_matrix_rank_impl(A, tol): + ndim = A.ndim + if ndim == 1: + # NOTE: Technically, the numpy implementation could be argued as + # incorrect for the case of a vector (1D matrix). If a tolerance + # is provided and a vector with a singular value below tolerance is + # encountered this should report a rank of zero, the numpy + # implementation does not do this and instead elects to report that + # if any value in the vector is nonzero then the rank is 1. + # An example would be [0, 1e-15, 0, 2e-15] which numpy reports as + # rank 1 invariant of `tol`. The singular value for this vector is + # obviously sqrt(5)*1e-15 and so a tol of e.g. sqrt(6)*1e-15 should + # lead to a reported rank of 0 whereas a tol of 1e-15 should lead + # to a reported rank of 1, numpy reports 1 regardless. + # The code below replicates the numpy behaviour. + def _1d_matrix_rank_impl(A, tol=None): + for k in range(len(A)): + if A[k] != 0.0: + return 1 + return 0 + + return _1d_matrix_rank_impl + elif ndim == 2: + return _2d_matrix_rank_impl(A, tol) + else: + assert 0 # unreachable + + return _get_matrix_rank_impl(A, tol) + + +@overload(np.linalg.matrix_power) +def matrix_power_impl(a, n): + """ + Computes matrix power. Only integer powers are supported in numpy. + """ + + _check_linalg_matrix(a, "matrix_power") + np_dtype = np_support.as_dtype(a.dtype) + + nt = getattr(n, "dtype", n) + if not isinstance(nt, types.Integer): + raise NumbaTypeError("Exponent must be an integer.") + + def matrix_power_impl(a, n): + if n == 0: + # this should be eye() but it doesn't support + # the dtype kwarg yet so do it manually to save + # the copy required by eye(a.shape[0]).asdtype() + A = np.zeros(a.shape, dtype=np_dtype) + for k in range(a.shape[0]): + A[k, k] = 1.0 + return A + + am, an = a.shape[-1], a.shape[-2] + if am != an: + raise ValueError("input must be a square array") + + # empty, return a copy + if am == 0: + return a.copy() + + # note: to be consistent over contiguousness, C order is + # returned as that is what dot() produces and the most common + # paths through matrix_power will involve that. Therefore + # copies are made here to ensure the data ordering is + # correct for paths not going via dot(). + + if n < 0: + A = np.linalg.inv(a).copy() + if n == -1: # return now + return A + n = -n + else: + if n == 1: # return a copy now + return a.copy() + A = a # this is safe, `a` is only read + + if n < 4: + if n == 2: + return np.dot(A, A) + if n == 3: + return np.dot(np.dot(A, A), A) + else: + acc = A + exp = n + + # Initialise ret, SSA cannot see the loop will execute, without this + # it appears as uninitialised. + ret = acc + # tried a loop split and branchless using identity matrix as + # input but it seems like having a "first entry" flag is quicker + flag = True + while exp != 0: + if exp & 1: + if flag: + ret = acc + flag = False + else: + ret = np.dot(ret, acc) + acc = np.dot(acc, acc) + exp = exp >> 1 + + return ret + + return matrix_power_impl + + +# This is documented under linalg despite not being in the module + + +@overload(np.trace) +def matrix_trace_impl(a, offset=0): + """ + Computes the trace of an array. + """ + + _check_linalg_matrix(a, "trace", la_prefix=False) + + if not isinstance(offset, (int, types.Integer)): + raise NumbaTypeError("integer argument expected, got %s" % offset) + + def matrix_trace_impl(a, offset=0): + rows, cols = a.shape + k = offset + if k < 0: + rows = rows + k + if k > 0: + cols = cols - k + n = max(min(rows, cols), 0) + ret = 0 + if k >= 0: + for i in range(n): + ret += a[i, k + i] + else: + for i in range(n): + ret += a[i - k, i] + return ret + + return matrix_trace_impl + + +def _check_scalar_or_lt_2d_mat(a, func_name, la_prefix=True): + prefix = "np.linalg" if la_prefix else "np" + interp = (prefix, func_name) + # checks that a matrix is 1 or 2D + if isinstance(a, types.Array): + if not a.ndim <= 2: + raise TypingError( + "%s.%s() only supported on 1 and 2-D arrays " % interp, + highlighting=False, + ) + + +@register_jitable +def outer_impl_none(a, b, out): + aa = np.asarray(a) + bb = np.asarray(b) + return np.multiply( + aa.ravel().reshape((aa.size, 1)), bb.ravel().reshape((1, bb.size)) + ) + + +@register_jitable +def outer_impl_arr(a, b, out): + aa = np.asarray(a) + bb = np.asarray(b) + np.multiply( + aa.ravel().reshape((aa.size, 1)), bb.ravel().reshape((1, bb.size)), out + ) + return out + + +def _get_outer_impl(a, b, out): + if out in (None, types.none): + return outer_impl_none + else: + return outer_impl_arr + + +@overload(np.outer) +def outer_impl(a, b, out=None): + _check_scalar_or_lt_2d_mat(a, "outer", la_prefix=False) + _check_scalar_or_lt_2d_mat(b, "outer", la_prefix=False) + + impl = _get_outer_impl(a, b, out) + + def outer_impl(a, b, out=None): + return impl(a, b, out) + + return outer_impl + + +def _kron_normaliser_impl(x): + # makes x into a 2d array + if isinstance(x, types.Array): + if x.layout not in ("C", "F"): + raise TypingError( + "np.linalg.kron only supports 'C' or 'F' layout " + "input arrays. Received an input of " + "layout '{}'.".format(x.layout) + ) + elif x.ndim == 2: + + @register_jitable + def nrm_shape(x): + xn = x.shape[-1] + xm = x.shape[-2] + return x.reshape(xm, xn) + + return nrm_shape + else: + + @register_jitable + def nrm_shape(x): + xn = x.shape[-1] + return x.reshape(1, xn) + + return nrm_shape + else: # assume its a scalar + + @register_jitable + def nrm_shape(x): + a = np.empty((1, 1), type(x)) + a[0] = x + return a + + return nrm_shape + + +def _kron_return(a, b): + # transforms c into something that kron would return + # based on the shapes of a and b + a_is_arr = isinstance(a, types.Array) + b_is_arr = isinstance(b, types.Array) + if a_is_arr and b_is_arr: + if a.ndim == 2 or b.ndim == 2: + + @register_jitable + def ret(a, b, c): + return c + + return ret + else: + + @register_jitable + def ret(a, b, c): + return c.reshape(c.size) + + return ret + else: # at least one of (a, b) is a scalar + if a_is_arr: + + @register_jitable + def ret(a, b, c): + return c.reshape(a.shape) + + return ret + elif b_is_arr: + + @register_jitable + def ret(a, b, c): + return c.reshape(b.shape) + + return ret + else: # both scalars + + @register_jitable + def ret(a, b, c): + return c[0] + + return ret + + +@overload(np.kron) +def kron_impl(a, b): + _check_scalar_or_lt_2d_mat(a, "kron", la_prefix=False) + _check_scalar_or_lt_2d_mat(b, "kron", la_prefix=False) + + fix_a = _kron_normaliser_impl(a) + fix_b = _kron_normaliser_impl(b) + ret_c = _kron_return(a, b) + + # this is fine because the ufunc for the Hadamard product + # will reject differing dtypes in a and b. + dt = getattr(a, "dtype", a) + + def kron_impl(a, b): + aa = fix_a(a) + bb = fix_b(b) + + am = aa.shape[-2] + an = aa.shape[-1] + bm = bb.shape[-2] + bn = bb.shape[-1] + + cm = am * bm + cn = an * bn + + # allocate c + C = np.empty((cm, cn), dtype=dt) + + # In practice this is runs quicker than the more obvious + # `each element of A multiplied by B and assigned to + # a block in C` like alg. + + # loop over rows of A + for i in range(am): + # compute the column offset into C + rjmp = i * bm + # loop over rows of B + for k in range(bm): + # compute row the offset into C + irjmp = rjmp + k + # slice a given row of B + slc = bb[k, :] + # loop over columns of A + for j in range(an): + # vectorized assignment of an element of A + # multiplied by the current row of B into + # a slice of a row of C + cjmp = j * bn + C[irjmp, cjmp : cjmp + bn] = aa[i, j] * slc + + return ret_c(a, b, C) + + return kron_impl diff --git a/numba_cuda/numba/cuda/np/math/__init__.py b/numba_cuda/numba/cuda/np/math/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/numba_cuda/numba/cuda/np/math/cmathimpl.py b/numba_cuda/numba/cuda/np/math/cmathimpl.py new file mode 100644 index 000000000..86b4bb01e --- /dev/null +++ b/numba_cuda/numba/cuda/np/math/cmathimpl.py @@ -0,0 +1,558 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implement the cmath module functions. +""" + +import cmath +import math + +from numba.core.imputils import impl_ret_untracked +from numba.core import types +from numba.cuda.typing import signature +from numba.cuda.cpython import mathimpl + +# registry = Registry('cmathimpl') +# lower = registry.lower + + +def is_nan(builder, z): + return builder.fcmp_unordered("uno", z.real, z.imag) + + +def is_inf(builder, z): + return builder.or_( + mathimpl.is_inf(builder, z.real), mathimpl.is_inf(builder, z.imag) + ) + + +def is_finite(builder, z): + return builder.and_( + mathimpl.is_finite(builder, z.real), mathimpl.is_finite(builder, z.imag) + ) + + +# @lower(cmath.isnan, types.Complex) +def isnan_float_impl(context, builder, sig, args): + [typ] = sig.args + [value] = args + z = context.make_complex(builder, typ, value=value) + res = is_nan(builder, z) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(cmath.isinf, types.Complex) +def isinf_float_impl(context, builder, sig, args): + [typ] = sig.args + [value] = args + z = context.make_complex(builder, typ, value=value) + res = is_inf(builder, z) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(cmath.isfinite, types.Complex) +def isfinite_float_impl(context, builder, sig, args): + [typ] = sig.args + [value] = args + z = context.make_complex(builder, typ, value=value) + res = is_finite(builder, z) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @overload(cmath.rect) +def impl_cmath_rect(r, phi): + if all([isinstance(typ, types.Float) for typ in [r, phi]]): + + def impl(r, phi): + if not math.isfinite(phi): + if not r: + # cmath.rect(0, phi={inf, nan}) = 0 + return abs(r) + if math.isinf(r): + # cmath.rect(inf, phi={inf, nan}) = inf + j phi + return complex(r, phi) + real = math.cos(phi) + imag = math.sin(phi) + if real == 0.0 and math.isinf(r): + # 0 * inf would return NaN, we want to keep 0 but xor the sign + real /= r + else: + real *= r + if imag == 0.0 and math.isinf(r): + # ditto + imag /= r + else: + imag *= r + return complex(real, imag) + + return impl + + +def intrinsic_complex_unary(inner_func): + def wrapper(context, builder, sig, args): + [typ] = sig.args + [value] = args + z = context.make_complex(builder, typ, value=value) + x = z.real + y = z.imag + # Same as above: math.isfinite() is unavailable on 2.x so we precompute + # its value and pass it to the pure Python implementation. + x_is_finite = mathimpl.is_finite(builder, x) + y_is_finite = mathimpl.is_finite(builder, y) + inner_sig = signature( + sig.return_type, *(typ.underlying_float,) * 2 + (types.boolean,) * 2 + ) + res = context.compile_internal( + builder, inner_func, inner_sig, (x, y, x_is_finite, y_is_finite) + ) + return impl_ret_untracked(context, builder, sig, res) + + return wrapper + + +NAN = float("nan") +INF = float("inf") + + +# @lower(cmath.exp, types.Complex) +@intrinsic_complex_unary +def exp_impl(x, y, x_is_finite, y_is_finite): + """cmath.exp(x + y j)""" + if x_is_finite: + if y_is_finite: + c = math.cos(y) + s = math.sin(y) + r = math.exp(x) + return complex(r * c, r * s) + else: + return complex(NAN, NAN) + elif math.isnan(x): + if y: + return complex(x, x) # nan + j nan + else: + return complex(x, y) # nan + 0j + elif x > 0.0: + # x == +inf + if y_is_finite: + real = math.cos(y) + imag = math.sin(y) + # Avoid NaNs if math.cos(y) or math.sin(y) == 0 + # (e.g. cmath.exp(inf + 0j) == inf + 0j) + if real != 0: + real *= x + if imag != 0: + imag *= x + return complex(real, imag) + else: + return complex(x, NAN) + else: + # x == -inf + if y_is_finite: + r = math.exp(x) + c = math.cos(y) + s = math.sin(y) + return complex(r * c, r * s) + else: + r = 0 + return complex(r, r) + + +# @lower(cmath.log, types.Complex) +@intrinsic_complex_unary +def log_impl(x, y, x_is_finite, y_is_finite): + """cmath.log(x + y j)""" + a = math.log(math.hypot(x, y)) + b = math.atan2(y, x) + return complex(a, b) + + +# @lower(cmath.log, types.Complex, types.Complex) +def log_base_impl(context, builder, sig, args): + """cmath.log(z, base)""" + [z, base] = args + + def log_base(z, base): + return cmath.log(z) / cmath.log(base) + + res = context.compile_internal(builder, log_base, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @overload(cmath.log10) +def impl_cmath_log10(z): + if not isinstance(z, types.Complex): + return + + LN_10 = 2.302585092994045684 + + def log10_impl(z): + """cmath.log10(z)""" + z = cmath.log(z) + # This formula gives better results on +/-inf than cmath.log(z, 10) + # See http://bugs.python.org/issue22544 + return complex(z.real / LN_10, z.imag / LN_10) + + return log10_impl + + +# @overload(cmath.phase) +def phase_impl(x): + """cmath.phase(x + y j)""" + + if not isinstance(x, types.Complex): + return + + def impl(x): + return math.atan2(x.imag, x.real) + + return impl + + +# @overload(cmath.polar) +def polar_impl(x): + if not isinstance(x, types.Complex): + return + + def impl(x): + r, i = x.real, x.imag + return math.hypot(r, i), math.atan2(i, r) + + return impl + + +# @lower(cmath.sqrt, types.Complex) +def sqrt_impl(context, builder, sig, args): + # We risk spurious overflow for components >= FLT_MAX / (1 + sqrt(2)). + + SQRT2 = 1.414213562373095048801688724209698079e0 + ONE_PLUS_SQRT2 = 1.0 + SQRT2 + theargflt = sig.args[0].underlying_float + # Get a type specific maximum value so scaling for overflow is based on that + MAX = mathimpl.DBL_MAX if theargflt.bitwidth == 64 else mathimpl.FLT_MAX + # THRES will be double precision, should not impact typing as it's just + # used for comparison, there *may* be a few values near THRES which + # deviate from e.g. NumPy due to rounding that occurs in the computation + # of this value in the case of a 32bit argument. + THRES = MAX / ONE_PLUS_SQRT2 + + def sqrt_impl(z): + """cmath.sqrt(z)""" + # This is NumPy's algorithm, see npy_csqrt() in npy_math_complex.c.src + a = z.real + b = z.imag + if a == 0.0 and b == 0.0: + return complex(abs(b), b) + if math.isinf(b): + return complex(abs(b), b) + if math.isnan(a): + return complex(a, a) + if math.isinf(a): + if a < 0.0: + return complex(abs(b - b), math.copysign(a, b)) + else: + return complex(a, math.copysign(b - b, b)) + + # The remaining special case (b is NaN) is handled just fine by + # the normal code path below. + + # Scale to avoid overflow + if abs(a) >= THRES or abs(b) >= THRES: + a *= 0.25 + b *= 0.25 + scale = True + else: + scale = False + # Algorithm 312, CACM vol 10, Oct 1967 + if a >= 0: + t = math.sqrt((a + math.hypot(a, b)) * 0.5) + real = t + imag = b / (2 * t) + else: + t = math.sqrt((-a + math.hypot(a, b)) * 0.5) + real = abs(b) / (2 * t) + imag = math.copysign(t, b) + # Rescale + if scale: + return complex(real * 2, imag) + else: + return complex(real, imag) + + res = context.compile_internal(builder, sqrt_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @lower(cmath.cos, types.Complex) +def cos_impl(context, builder, sig, args): + def cos_impl(z): + """cmath.cos(z) = cmath.cosh(z j)""" + return cmath.cosh(complex(-z.imag, z.real)) + + res = context.compile_internal(builder, cos_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @overload(cmath.cosh) +def impl_cmath_cosh(z): + if not isinstance(z, types.Complex): + return + + def cosh_impl(z): + """cmath.cosh(z)""" + x = z.real + y = z.imag + if math.isinf(x): + if math.isnan(y): + # x = +inf, y = NaN => cmath.cosh(x + y j) = inf + Nan * j + real = abs(x) + imag = y + elif y == 0.0: + # x = +inf, y = 0 => cmath.cosh(x + y j) = inf + 0j + real = abs(x) + imag = y + else: + real = math.copysign(x, math.cos(y)) + imag = math.copysign(x, math.sin(y)) + if x < 0.0: + # x = -inf => negate imaginary part of result + imag = -imag + return complex(real, imag) + return complex(math.cos(y) * math.cosh(x), math.sin(y) * math.sinh(x)) + + return cosh_impl + + +# @lower(cmath.sin, types.Complex) +def sin_impl(context, builder, sig, args): + def sin_impl(z): + """cmath.sin(z) = -j * cmath.sinh(z j)""" + r = cmath.sinh(complex(-z.imag, z.real)) + return complex(r.imag, -r.real) + + res = context.compile_internal(builder, sin_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @overload(cmath.sinh) +def impl_cmath_sinh(z): + if not isinstance(z, types.Complex): + return + + def sinh_impl(z): + """cmath.sinh(z)""" + x = z.real + y = z.imag + if math.isinf(x): + if math.isnan(y): + # x = +/-inf, y = NaN => cmath.sinh(x + y j) = x + NaN * j + real = x + imag = y + else: + real = math.cos(y) + imag = math.sin(y) + if real != 0.0: + real *= x + if imag != 0.0: + imag *= abs(x) + return complex(real, imag) + return complex(math.cos(y) * math.sinh(x), math.sin(y) * math.cosh(x)) + + return sinh_impl + + +# @lower(cmath.tan, types.Complex) +def tan_impl(context, builder, sig, args): + def tan_impl(z): + """cmath.tan(z) = -j * cmath.tanh(z j)""" + r = cmath.tanh(complex(-z.imag, z.real)) + return complex(r.imag, -r.real) + + res = context.compile_internal(builder, tan_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @overload(cmath.tanh) +def impl_cmath_tanh(z): + if not isinstance(z, types.Complex): + return + + def tanh_impl(z): + """cmath.tanh(z)""" + x = z.real + y = z.imag + if math.isinf(x): + real = math.copysign(1.0, x) + if math.isinf(y): + imag = 0.0 + else: + imag = math.copysign(0.0, math.sin(2.0 * y)) + return complex(real, imag) + # This is CPython's algorithm (see c_tanh() in cmathmodule.c). + # XXX how to force float constants into single precision? + tx = math.tanh(x) + ty = math.tan(y) + cx = 1.0 / math.cosh(x) + txty = tx * ty + denom = 1.0 + txty * txty + return complex(tx * (1.0 + ty * ty) / denom, ((ty / denom) * cx) * cx) + + return tanh_impl + + +# @lower(cmath.acos, types.Complex) +def acos_impl(context, builder, sig, args): + LN_4 = math.log(4) + THRES = mathimpl.FLT_MAX / 4 + + def acos_impl(z): + """cmath.acos(z)""" + # CPython's algorithm (see c_acos() in cmathmodule.c) + if abs(z.real) > THRES or abs(z.imag) > THRES: + # Avoid unnecessary overflow for large arguments + # (also handles infinities gracefully) + real = math.atan2(abs(z.imag), z.real) + imag = math.copysign( + math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4, -z.imag + ) + return complex(real, imag) + else: + s1 = cmath.sqrt(complex(1.0 - z.real, -z.imag)) + s2 = cmath.sqrt(complex(1.0 + z.real, z.imag)) + real = 2.0 * math.atan2(s1.real, s2.real) + imag = math.asinh(s2.real * s1.imag - s2.imag * s1.real) + return complex(real, imag) + + res = context.compile_internal(builder, acos_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @overload(cmath.acosh) +def impl_cmath_acosh(z): + if not isinstance(z, types.Complex): + return + + LN_4 = math.log(4) + THRES = mathimpl.FLT_MAX / 4 + + def acosh_impl(z): + """cmath.acosh(z)""" + # CPython's algorithm (see c_acosh() in cmathmodule.c) + if abs(z.real) > THRES or abs(z.imag) > THRES: + # Avoid unnecessary overflow for large arguments + # (also handles infinities gracefully) + real = math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4 + imag = math.atan2(z.imag, z.real) + return complex(real, imag) + else: + s1 = cmath.sqrt(complex(z.real - 1.0, z.imag)) + s2 = cmath.sqrt(complex(z.real + 1.0, z.imag)) + real = math.asinh(s1.real * s2.real + s1.imag * s2.imag) + imag = 2.0 * math.atan2(s1.imag, s2.real) + return complex(real, imag) + # Condensed formula (NumPy) + # return cmath.log(z + cmath.sqrt(z + 1.) * cmath.sqrt(z - 1.)) + + return acosh_impl + + +# @lower(cmath.asinh, types.Complex) +def asinh_impl(context, builder, sig, args): + LN_4 = math.log(4) + THRES = mathimpl.FLT_MAX / 4 + + def asinh_impl(z): + """cmath.asinh(z)""" + # CPython's algorithm (see c_asinh() in cmathmodule.c) + if abs(z.real) > THRES or abs(z.imag) > THRES: + real = math.copysign( + math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4, z.real + ) + imag = math.atan2(z.imag, abs(z.real)) + return complex(real, imag) + else: + s1 = cmath.sqrt(complex(1.0 + z.imag, -z.real)) + s2 = cmath.sqrt(complex(1.0 - z.imag, z.real)) + real = math.asinh(s1.real * s2.imag - s2.real * s1.imag) + imag = math.atan2(z.imag, s1.real * s2.real - s1.imag * s2.imag) + return complex(real, imag) + + res = context.compile_internal(builder, asinh_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @lower(cmath.asin, types.Complex) +def asin_impl(context, builder, sig, args): + def asin_impl(z): + """cmath.asin(z) = -j * cmath.asinh(z j)""" + r = cmath.asinh(complex(-z.imag, z.real)) + return complex(r.imag, -r.real) + + res = context.compile_internal(builder, asin_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @lower(cmath.atan, types.Complex) +def atan_impl(context, builder, sig, args): + def atan_impl(z): + """cmath.atan(z) = -j * cmath.atanh(z j)""" + r = cmath.atanh(complex(-z.imag, z.real)) + if math.isinf(z.real) and math.isnan(z.imag): + # XXX this is odd but necessary + return complex(r.imag, r.real) + else: + return complex(r.imag, -r.real) + + res = context.compile_internal(builder, atan_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +# @lower(cmath.atanh, types.Complex) +def atanh_impl(context, builder, sig, args): + THRES_LARGE = math.sqrt(mathimpl.FLT_MAX / 4) + THRES_SMALL = math.sqrt(mathimpl.FLT_MIN) + PI_12 = math.pi / 2 + + def atanh_impl(z): + """cmath.atanh(z)""" + # CPython's algorithm (see c_atanh() in cmathmodule.c) + if z.real < 0.0: + # Reduce to case where z.real >= 0., using atanh(z) = -atanh(-z). + negate = True + z = -z + else: + negate = False + + ay = abs(z.imag) + if math.isnan(z.real) or z.real > THRES_LARGE or ay > THRES_LARGE: + if math.isinf(z.imag): + real = math.copysign(0.0, z.real) + elif math.isinf(z.real): + real = 0.0 + else: + # may be safe from overflow, depending on hypot's implementation... + h = math.hypot(z.real * 0.5, z.imag * 0.5) + real = z.real / 4.0 / h / h + imag = -math.copysign(PI_12, -z.imag) + elif z.real == 1.0 and ay < THRES_SMALL: + # C99 standard says: atanh(1+/-0.) should be inf +/- 0j + if ay == 0.0: + real = INF + imag = z.imag + else: + real = -math.log(math.sqrt(ay) / math.sqrt(math.hypot(ay, 2.0))) + imag = math.copysign(math.atan2(2.0, -ay) / 2, z.imag) + else: + sqay = ay * ay + zr1 = 1 - z.real + real = math.log1p(4.0 * z.real / (zr1 * zr1 + sqay)) * 0.25 + imag = -math.atan2(-2.0 * z.imag, zr1 * (1 + z.real) - sqay) * 0.5 + + if math.isnan(z.imag): + imag = NAN + if negate: + return complex(-real, -imag) + else: + return complex(real, imag) + + res = context.compile_internal(builder, atanh_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) diff --git a/numba_cuda/numba/cuda/np/math/mathimpl.py b/numba_cuda/numba/cuda/np/math/mathimpl.py new file mode 100644 index 000000000..1c7e8f012 --- /dev/null +++ b/numba_cuda/numba/cuda/np/math/mathimpl.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Provide math calls that uses intrinsics or libc math functions. +""" + +import math +import operator +import sys +import numpy as np + +import llvmlite.ir +from llvmlite.ir import Constant + +from numba.core.imputils import impl_ret_untracked +from numba.core import types +from numba.cuda import cgutils, config +from numba.cuda.extending import overload +from numba.cuda.typing import signature +from numba.cpython.unsafe.numbers import trailing_zeros + + +# registry = Registry('mathimpl') +# lower = registry.lower + + +# Helpers, shared with cmathimpl. +_NP_FLT_FINFO = np.finfo(np.dtype("float32")) +FLT_MAX = _NP_FLT_FINFO.max +FLT_MIN = _NP_FLT_FINFO.tiny + +_NP_DBL_FINFO = np.finfo(np.dtype("float64")) +DBL_MAX = _NP_DBL_FINFO.max +DBL_MIN = _NP_DBL_FINFO.tiny + +FLOAT_ABS_MASK = 0x7FFFFFFF +FLOAT_SIGN_MASK = 0x80000000 +DOUBLE_ABS_MASK = 0x7FFFFFFFFFFFFFFF +DOUBLE_SIGN_MASK = 0x8000000000000000 + + +def is_nan(builder, val): + """ + Return a condition testing whether *val* is a NaN. + """ + return builder.fcmp_unordered("uno", val, val) + + +def is_inf(builder, val): + """ + Return a condition testing whether *val* is an infinite. + """ + pos_inf = Constant(val.type, float("+inf")) + neg_inf = Constant(val.type, float("-inf")) + isposinf = builder.fcmp_ordered("==", val, pos_inf) + isneginf = builder.fcmp_ordered("==", val, neg_inf) + return builder.or_(isposinf, isneginf) + + +def is_finite(builder, val): + """ + Return a condition testing whether *val* is a finite. + """ + # is_finite(x) <=> x - x != NaN + val_minus_val = builder.fsub(val, val) + return builder.fcmp_ordered("ord", val_minus_val, val_minus_val) + + +def f64_as_int64(builder, val): + """ + Bitcast a double into a 64-bit integer. + """ + assert val.type == llvmlite.ir.DoubleType() + return builder.bitcast(val, llvmlite.ir.IntType(64)) + + +def int64_as_f64(builder, val): + """ + Bitcast a 64-bit integer into a double. + """ + assert val.type == llvmlite.ir.IntType(64) + return builder.bitcast(val, llvmlite.ir.DoubleType()) + + +def f32_as_int32(builder, val): + """ + Bitcast a float into a 32-bit integer. + """ + assert val.type == llvmlite.ir.FloatType() + return builder.bitcast(val, llvmlite.ir.IntType(32)) + + +def int32_as_f32(builder, val): + """ + Bitcast a 32-bit integer into a float. + """ + assert val.type == llvmlite.ir.IntType(32) + return builder.bitcast(val, llvmlite.ir.FloatType()) + + +def negate_real(builder, val): + """ + Negate real number *val*, with proper handling of zeros. + """ + # The negative zero forces LLVM to handle signed zeros properly. + return builder.fsub(Constant(val.type, -0.0), val) + + +def call_fp_intrinsic(builder, name, args): + """ + Call a LLVM intrinsic floating-point operation. + """ + mod = builder.module + intr = mod.declare_intrinsic(name, [a.type for a in args]) + return builder.call(intr, args) + + +def _unary_int_input_wrapper_impl(wrapped_impl): + """ + Return an implementation factory to convert the single integral input + argument to a float64, then defer to the *wrapped_impl*. + """ + + def implementer(context, builder, sig, args): + (val,) = args + input_type = sig.args[0] + fpval = context.cast(builder, val, input_type, types.float64) + inner_sig = signature(types.float64, types.float64) + res = wrapped_impl(context, builder, inner_sig, (fpval,)) + return context.cast(builder, res, types.float64, sig.return_type) + + return implementer + + +def unary_math_int_impl(fn, float_impl): + impl = _unary_int_input_wrapper_impl(float_impl) # noqa: F841 + # lower(fn, types.Integer)(impl) + + +def unary_math_intr(fn, intrcode): + """ + Implement the math function *fn* using the LLVM intrinsic *intrcode*. + """ + + # @lower(fn, types.Float) + def float_impl(context, builder, sig, args): + res = call_fp_intrinsic(builder, intrcode, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + unary_math_int_impl(fn, float_impl) + return float_impl + + +def unary_math_extern(fn, f32extern, f64extern, int_restype=False): + """ + Register implementations of Python function *fn* using the + external function named *f32extern* and *f64extern* (for float32 + and float64 inputs, respectively). + If *int_restype* is true, then the function's return value should be + integral, otherwise floating-point. + """ + f_restype = types.int64 if int_restype else None # noqa: F841 + + def float_impl(context, builder, sig, args): + """ + Implement *fn* for a types.Float input. + """ + [val] = args + mod = builder.module # noqa: F841 + input_type = sig.args[0] + lty = context.get_value_type(input_type) + func_name = { + types.float32: f32extern, + types.float64: f64extern, + }[input_type] + fnty = llvmlite.ir.FunctionType(lty, [lty]) + fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name) + res = builder.call(fn, (val,)) + res = context.cast(builder, res, input_type, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + # lower(fn, types.Float)(float_impl) + + # Implement wrapper for integer inputs + unary_math_int_impl(fn, float_impl) + + return float_impl + + +unary_math_intr(math.fabs, "llvm.fabs") +exp_impl = unary_math_intr(math.exp, "llvm.exp") +log_impl = unary_math_intr(math.log, "llvm.log") +log10_impl = unary_math_intr(math.log10, "llvm.log10") +sin_impl = unary_math_intr(math.sin, "llvm.sin") +cos_impl = unary_math_intr(math.cos, "llvm.cos") + +log1p_impl = unary_math_extern(math.log1p, "log1pf", "log1p") +expm1_impl = unary_math_extern(math.expm1, "expm1f", "expm1") +erf_impl = unary_math_extern(math.erf, "erff", "erf") +erfc_impl = unary_math_extern(math.erfc, "erfcf", "erfc") + +tan_impl = unary_math_extern(math.tan, "tanf", "tan") +asin_impl = unary_math_extern(math.asin, "asinf", "asin") +acos_impl = unary_math_extern(math.acos, "acosf", "acos") +atan_impl = unary_math_extern(math.atan, "atanf", "atan") + +asinh_impl = unary_math_extern(math.asinh, "asinhf", "asinh") +acosh_impl = unary_math_extern(math.acosh, "acoshf", "acosh") +atanh_impl = unary_math_extern(math.atanh, "atanhf", "atanh") +sinh_impl = unary_math_extern(math.sinh, "sinhf", "sinh") +cosh_impl = unary_math_extern(math.cosh, "coshf", "cosh") +tanh_impl = unary_math_extern(math.tanh, "tanhf", "tanh") + +log2_impl = unary_math_extern(math.log2, "log2f", "log2") +ceil_impl = unary_math_extern(math.ceil, "ceilf", "ceil", True) +floor_impl = unary_math_extern(math.floor, "floorf", "floor", True) + +gamma_impl = unary_math_extern( + math.gamma, "numba_gammaf", "numba_gamma" +) # work-around +sqrt_impl = unary_math_extern(math.sqrt, "sqrtf", "sqrt") +trunc_impl = unary_math_extern(math.trunc, "truncf", "trunc", True) +lgamma_impl = unary_math_extern(math.lgamma, "lgammaf", "lgamma") + + +# @lower(math.isnan, types.Float) +def isnan_float_impl(context, builder, sig, args): + [val] = args + res = is_nan(builder, val) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.isnan, types.Integer) +def isnan_int_impl(context, builder, sig, args): + res = cgutils.false_bit + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.isinf, types.Float) +def isinf_float_impl(context, builder, sig, args): + [val] = args + res = is_inf(builder, val) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.isinf, types.Integer) +def isinf_int_impl(context, builder, sig, args): + res = cgutils.false_bit + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.isfinite, types.Float) +def isfinite_float_impl(context, builder, sig, args): + [val] = args + res = is_finite(builder, val) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.isfinite, types.Integer) +def isfinite_int_impl(context, builder, sig, args): + res = cgutils.true_bit + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.copysign, types.Float, types.Float) +def copysign_float_impl(context, builder, sig, args): + lty = args[0].type + mod = builder.module + fn = cgutils.get_or_insert_function( + mod, + llvmlite.ir.FunctionType(lty, (lty, lty)), + "llvm.copysign.%s" % lty.intrinsic_name, + ) + res = builder.call(fn, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +# @lower(math.frexp, types.Float) +def frexp_impl(context, builder, sig, args): + (val,) = args + fltty = context.get_data_type(sig.args[0]) + intty = context.get_data_type(sig.return_type[1]) + expptr = cgutils.alloca_once(builder, intty, name="exp") + fnty = llvmlite.ir.FunctionType( + fltty, (fltty, llvmlite.ir.PointerType(intty)) + ) + fname = { + "float": "numba_frexpf", + "double": "numba_frexp", + }[str(fltty)] + fn = cgutils.get_or_insert_function(builder.module, fnty, fname) + res = builder.call(fn, (val, expptr)) + res = cgutils.make_anonymous_struct(builder, (res, builder.load(expptr))) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.ldexp, types.Float, types.intc) +def ldexp_impl(context, builder, sig, args): + val, exp = args + fltty, intty = map(context.get_data_type, sig.args) + fnty = llvmlite.ir.FunctionType(fltty, (fltty, intty)) + fname = { + "float": "numba_ldexpf", + "double": "numba_ldexp", + }[str(fltty)] + fn = cgutils.insert_pure_function(builder.module, fnty, name=fname) + res = builder.call(fn, (val, exp)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +# @lower(math.atan2, types.int64, types.int64) +def atan2_s64_impl(context, builder, sig, args): + [y, x] = args + y = builder.sitofp(y, llvmlite.ir.DoubleType()) + x = builder.sitofp(x, llvmlite.ir.DoubleType()) + fsig = signature(types.float64, types.float64, types.float64) + return atan2_float_impl(context, builder, fsig, (y, x)) + + +# @lower(math.atan2, types.uint64, types.uint64) +def atan2_u64_impl(context, builder, sig, args): + [y, x] = args + y = builder.uitofp(y, llvmlite.ir.DoubleType()) + x = builder.uitofp(x, llvmlite.ir.DoubleType()) + fsig = signature(types.float64, types.float64, types.float64) + return atan2_float_impl(context, builder, fsig, (y, x)) + + +# @lower(math.atan2, types.Float, types.Float) +def atan2_float_impl(context, builder, sig, args): + assert len(args) == 2 + mod = builder.module # noqa: F841 + ty = sig.args[0] + lty = context.get_value_type(ty) + func_name = {types.float32: "atan2f", types.float64: "atan2"}[ty] + fnty = llvmlite.ir.FunctionType(lty, (lty, lty)) + fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name) + res = builder.call(fn, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +# @lower(math.hypot, types.int64, types.int64) +def hypot_s64_impl(context, builder, sig, args): + [x, y] = args + y = builder.sitofp(y, llvmlite.ir.DoubleType()) + x = builder.sitofp(x, llvmlite.ir.DoubleType()) + fsig = signature(types.float64, types.float64, types.float64) + res = hypot_float_impl(context, builder, fsig, (x, y)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.hypot, types.uint64, types.uint64) +def hypot_u64_impl(context, builder, sig, args): + [x, y] = args + y = builder.sitofp(y, llvmlite.ir.DoubleType()) + x = builder.sitofp(x, llvmlite.ir.DoubleType()) + fsig = signature(types.float64, types.float64, types.float64) + res = hypot_float_impl(context, builder, fsig, (x, y)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower(math.hypot, types.Float, types.Float) +def hypot_float_impl(context, builder, sig, args): + xty, yty = sig.args + assert xty == yty == sig.return_type + x, y = args + + # Windows has alternate names for hypot/hypotf, see + # https://msdn.microsoft.com/fr-fr/library/a9yb3dbt%28v=vs.80%29.aspx + fname = { + types.float32: "_hypotf" if sys.platform == "win32" else "hypotf", + types.float64: "_hypot" if sys.platform == "win32" else "hypot", + }[xty] + plat_hypot = types.ExternalFunction(fname, sig) + + if sys.platform == "win32" and config.MACHINE_BITS == 32: + inf = xty(float("inf")) + + def hypot_impl(x, y): + if math.isinf(x) or math.isinf(y): + return inf + return plat_hypot(x, y) + else: + + def hypot_impl(x, y): + return plat_hypot(x, y) + + res = context.compile_internal(builder, hypot_impl, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +# @lower(math.radians, types.Float) +def radians_float_impl(context, builder, sig, args): + [x] = args + coef = context.get_constant(sig.return_type, math.pi / 180) + res = builder.fmul(x, coef) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +unary_math_int_impl(math.radians, radians_float_impl) + +# ----------------------------------------------------------------------------- + + +# @lower(math.degrees, types.Float) +def degrees_float_impl(context, builder, sig, args): + [x] = args + coef = context.get_constant(sig.return_type, 180 / math.pi) + res = builder.fmul(x, coef) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +unary_math_int_impl(math.degrees, degrees_float_impl) + +# ----------------------------------------------------------------------------- + + +# @lower(math.pow, types.Float, types.Float) +# @lower(math.pow, types.Float, types.Integer) +def pow_impl(context, builder, sig, args): + impl = context.get_function(operator.pow, sig) + return impl(builder, args) + + +# ----------------------------------------------------------------------------- + + +def _unsigned(T): + """Convert integer to unsigned integer of equivalent width.""" + pass + + +@overload(_unsigned) +def _unsigned_impl(T): + if T in types.unsigned_domain: + return lambda T: T + elif T in types.signed_domain: + newT = getattr(types, "uint{}".format(T.bitwidth)) + return lambda T: newT(T) + + +def gcd_impl(context, builder, sig, args): + xty, yty = sig.args + assert xty == yty == sig.return_type + x, y = args + + def gcd(a, b): + """ + Stein's algorithm, heavily cribbed from Julia implementation. + """ + T = type(a) + if a == 0: + return abs(b) + if b == 0: + return abs(a) + za = trailing_zeros(a) + zb = trailing_zeros(b) + k = min(za, zb) + # Uses np.*_shift instead of operators due to return types + u = _unsigned(abs(np.right_shift(a, za))) + v = _unsigned(abs(np.right_shift(b, zb))) + while u != v: + if u > v: + u, v = v, u + v -= u + v = np.right_shift(v, trailing_zeros(v)) + r = np.left_shift(T(u), k) + return r + + res = context.compile_internal(builder, gcd, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# lower(math.gcd, types.Integer, types.Integer)(gcd_impl) diff --git a/numba_cuda/numba/cuda/np/math/numbers.py b/numba_cuda/numba/cuda/np/math/numbers.py new file mode 100644 index 000000000..ecdc95f30 --- /dev/null +++ b/numba_cuda/numba/cuda/np/math/numbers.py @@ -0,0 +1,1461 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import math +import numbers + +import numpy as np + +from llvmlite import ir +from llvmlite.ir import Constant + +from numba.core.imputils import impl_ret_untracked +from numba.core import typing, types, errors +from numba.cuda import cgutils +from numba.cpython.unsafe.numbers import viewer + + +def _int_arith_flags(rettype): + """ + Return the modifier flags for integer arithmetic. + """ + if rettype.signed: + # Ignore the effects of signed overflow. This is important for + # optimization of some indexing operations. For example + # array[i+1] could see `i+1` trigger a signed overflow and + # give a negative number. With Python's indexing, a negative + # index is treated differently: its resolution has a runtime cost. + # Telling LLVM to ignore signed overflows allows it to optimize + # away the check for a negative `i+1` if it knows `i` is positive. + return ["nsw"] + else: + return [] + + +def int_add_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + a = context.cast(builder, va, ta, sig.return_type) + b = context.cast(builder, vb, tb, sig.return_type) + res = builder.add(a, b, flags=_int_arith_flags(sig.return_type)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sub_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + a = context.cast(builder, va, ta, sig.return_type) + b = context.cast(builder, vb, tb, sig.return_type) + res = builder.sub(a, b, flags=_int_arith_flags(sig.return_type)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_mul_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + a = context.cast(builder, va, ta, sig.return_type) + b = context.cast(builder, vb, tb, sig.return_type) + res = builder.mul(a, b, flags=_int_arith_flags(sig.return_type)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_divmod_signed(context, builder, ty, x, y): + """ + Reference Objects/intobject.c + xdivy = x / y; + xmody = (long)(x - (unsigned long)xdivy * y); + /* If the signs of x and y differ, and the remainder is non-0, + * C89 doesn't define whether xdivy is now the floor or the + * ceiling of the infinitely precise quotient. We want the floor, + * and we have it iff the remainder's sign matches y's. + */ + if (xmody && ((y ^ xmody) < 0) /* i.e. and signs differ */) { + xmody += y; + --xdivy; + assert(xmody && ((y ^ xmody) >= 0)); + } + *p_xdivy = xdivy; + *p_xmody = xmody; + """ + assert x.type == y.type + + ZERO = y.type(0) + ONE = y.type(1) + + # NOTE: On x86 at least, dividing the lowest representable integer + # (e.g. 0x80000000 for int32) by -1 causes a SIFGPE (division overflow), + # causing the process to crash. + # We return 0, 0 instead (more or less like Numpy). + + resdiv = cgutils.alloca_once_value(builder, ZERO) + resmod = cgutils.alloca_once_value(builder, ZERO) + + is_overflow = builder.and_( + builder.icmp_signed("==", x, x.type(ty.minval)), + builder.icmp_signed("==", y, y.type(-1)), + ) + + with builder.if_then(builder.not_(is_overflow), likely=True): + # Note LLVM will optimize this to a single divmod instruction, + # if available on the target CPU (e.g. x86). + xdivy = builder.sdiv(x, y) + xmody = builder.srem(x, y) + + y_xor_xmody_ltz = builder.icmp_signed("<", builder.xor(y, xmody), ZERO) + xmody_istrue = builder.icmp_signed("!=", xmody, ZERO) + cond = builder.and_(xmody_istrue, y_xor_xmody_ltz) + + with builder.if_else(cond) as (if_different_signs, if_same_signs): + with if_same_signs: + builder.store(xdivy, resdiv) + builder.store(xmody, resmod) + + with if_different_signs: + builder.store(builder.sub(xdivy, ONE), resdiv) + builder.store(builder.add(xmody, y), resmod) + + return builder.load(resdiv), builder.load(resmod) + + +def int_divmod(context, builder, ty, x, y): + """ + Integer divmod(x, y). The caller must ensure that y != 0. + """ + if ty.signed: + return int_divmod_signed(context, builder, ty, x, y) + else: + return builder.udiv(x, y), builder.urem(x, y) + + +def _int_divmod_impl(context, builder, sig, args, zerodiv_message): + va, vb = args + ta, tb = sig.args + + ty = sig.return_type + if isinstance(ty, types.UniTuple): + ty = ty.dtype + a = context.cast(builder, va, ta, ty) + b = context.cast(builder, vb, tb, ty) + quot = cgutils.alloca_once(builder, a.type, name="quot") + rem = cgutils.alloca_once(builder, a.type, name="rem") + + with builder.if_else(cgutils.is_scalar_zero(builder, b), likely=False) as ( + if_zero, + if_non_zero, + ): + with if_zero: + if not context.error_model.fp_zero_division( + builder, (zerodiv_message,) + ): + # No exception raised => return 0 + # XXX We should also set the FPU exception status, but + # there's no easy way to do that from LLVM. + builder.store(b, quot) + builder.store(b, rem) + with if_non_zero: + q, r = int_divmod(context, builder, ty, a, b) + builder.store(q, quot) + builder.store(r, rem) + + return quot, rem + + +# @lower_builtin(divmod, types.Integer, types.Integer) +def int_divmod_impl(context, builder, sig, args): + quot, rem = _int_divmod_impl( + context, builder, sig, args, "integer divmod by zero" + ) + + return cgutils.pack_array(builder, (builder.load(quot), builder.load(rem))) + + +# @lower_builtin(operator.floordiv, types.Integer, types.Integer) +# @lower_builtin(operator.ifloordiv, types.Integer, types.Integer) +def int_floordiv_impl(context, builder, sig, args): + quot, rem = _int_divmod_impl( + context, builder, sig, args, "integer division by zero" + ) + return builder.load(quot) + + +# @lower_builtin(operator.truediv, types.Integer, types.Integer) +# @lower_builtin(operator.itruediv, types.Integer, types.Integer) +def int_truediv_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + a = context.cast(builder, va, ta, sig.return_type) + b = context.cast(builder, vb, tb, sig.return_type) + with cgutils.if_zero(builder, b): + context.error_model.fp_zero_division(builder, ("division by zero",)) + res = builder.fdiv(a, b) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower_builtin(operator.mod, types.Integer, types.Integer) +# @lower_builtin(operator.imod, types.Integer, types.Integer) +def int_rem_impl(context, builder, sig, args): + quot, rem = _int_divmod_impl( + context, builder, sig, args, "integer modulo by zero" + ) + return builder.load(rem) + + +def _get_power_zerodiv_return(context, return_type): + if ( + isinstance(return_type, types.Integer) + and not context.error_model.raise_on_fp_zero_division + ): + # If not raising, return 0x8000... when computing 0 ** + return -1 << (return_type.bitwidth - 1) + else: + return False + + +def int_power_impl(context, builder, sig, args): + """ + a ^ b, where a is an integer or real, and b an integer + """ + is_integer = isinstance(sig.args[0], types.Integer) + tp = sig.return_type + zerodiv_return = _get_power_zerodiv_return(context, tp) + + def int_power(a, b): + # Ensure computations are done with a large enough width + r = tp(1) + a = tp(a) + if b < 0: + invert = True + exp = -b + if exp < 0: + raise OverflowError + if is_integer: + if a == 0: + if zerodiv_return: + return zerodiv_return + else: + raise ZeroDivisionError( + "0 cannot be raised to a negative power" + ) + if a != 1 and a != -1: + return 0 + else: + invert = False + exp = b + if exp > 0x10000: + # Optimization cutoff: fallback on the generic algorithm + return math.pow(a, float(b)) + while exp != 0: + if exp & 1: + r *= a + exp >>= 1 + a *= a + + return 1.0 / r if invert else r + + res = context.compile_internal(builder, int_power, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower_builtin(operator.pow, types.Integer, types.IntegerLiteral) +# @lower_builtin(operator.ipow, types.Integer, types.IntegerLiteral) +# @lower_builtin(operator.pow, types.Float, types.IntegerLiteral) +# @lower_builtin(operator.ipow, types.Float, types.IntegerLiteral) +def static_power_impl(context, builder, sig, args): + """ + a ^ b, where a is an integer or real, and b a constant integer + """ + exp = sig.args[1].value + if not isinstance(exp, numbers.Integral): + raise NotImplementedError + if abs(exp) > 0x10000: + # Optimization cutoff: fallback on the generic algorithm above + raise NotImplementedError + invert = exp < 0 + exp = abs(exp) + + tp = sig.return_type + is_integer = isinstance(tp, types.Integer) + zerodiv_return = _get_power_zerodiv_return(context, tp) + + val = context.cast(builder, args[0], sig.args[0], tp) + lty = val.type + + def mul(a, b): + if is_integer: + return builder.mul(a, b) + else: + return builder.fmul(a, b) + + # Unroll the exponentiation loop + res = lty(1) + while exp != 0: + if exp & 1: + res = mul(res, val) + exp >>= 1 + val = mul(val, val) + + if invert: + # If the exponent was negative, fix the result by inverting it + if is_integer: + # Integer inversion + def invert_impl(a): + if a == 0: + if zerodiv_return: + return zerodiv_return + else: + raise ZeroDivisionError( + "0 cannot be raised to a negative power" + ) + if a != 1 and a != -1: + return 0 + else: + return a + + else: + # Real inversion + def invert_impl(a): + return 1.0 / a + + res = context.compile_internal( + builder, invert_impl, typing.signature(tp, tp), (res,) + ) + + return res + + +def int_slt_impl(context, builder, sig, args): + res = builder.icmp_signed("<", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sle_impl(context, builder, sig, args): + res = builder.icmp_signed("<=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sgt_impl(context, builder, sig, args): + res = builder.icmp_signed(">", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sge_impl(context, builder, sig, args): + res = builder.icmp_signed(">=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_ult_impl(context, builder, sig, args): + res = builder.icmp_unsigned("<", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_ule_impl(context, builder, sig, args): + res = builder.icmp_unsigned("<=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_ugt_impl(context, builder, sig, args): + res = builder.icmp_unsigned(">", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_uge_impl(context, builder, sig, args): + res = builder.icmp_unsigned(">=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_eq_impl(context, builder, sig, args): + res = builder.icmp_unsigned("==", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_ne_impl(context, builder, sig, args): + res = builder.icmp_unsigned("!=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_signed_unsigned_cmp(op): + def impl(context, builder, sig, args): + (left, right) = args + # This code is translated from the NumPy source. + # What we're going to do is divide the range of a signed value at zero. + # If the signed value is less than zero, then we can treat zero as the + # unsigned value since the unsigned value is necessarily zero or larger + # and any signed comparison between a negative value and zero/infinity + # will yield the same result. If the signed value is greater than or + # equal to zero, then we can safely cast it to an unsigned value and do + # the expected unsigned-unsigned comparison operation. + # Original: https://github.com/numpy/numpy/pull/23713 + cmp_zero = builder.icmp_signed("<", left, Constant(left.type, 0)) + lt_zero = builder.icmp_signed(op, left, Constant(left.type, 0)) + ge_zero = builder.icmp_unsigned(op, left, right) + res = builder.select(cmp_zero, lt_zero, ge_zero) + return impl_ret_untracked(context, builder, sig.return_type, res) + + return impl + + +def int_unsigned_signed_cmp(op): + def impl(context, builder, sig, args): + (left, right) = args + # See the function `int_signed_unsigned_cmp` for implementation notes. + cmp_zero = builder.icmp_signed("<", right, Constant(right.type, 0)) + lt_zero = builder.icmp_signed(op, Constant(right.type, 0), right) + ge_zero = builder.icmp_unsigned(op, left, right) + res = builder.select(cmp_zero, lt_zero, ge_zero) + return impl_ret_untracked(context, builder, sig.return_type, res) + + return impl + + +def int_abs_impl(context, builder, sig, args): + [x] = args + ZERO = Constant(x.type, None) + ltz = builder.icmp_signed("<", x, ZERO) + negated = builder.neg(x) + res = builder.select(ltz, negated, x) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def identity_impl(context, builder, sig, args): + [x] = args + return impl_ret_untracked(context, builder, sig.return_type, x) + + +def uint_abs_impl(context, builder, sig, args): + [x] = args + return impl_ret_untracked(context, builder, sig.return_type, x) + + +def int_shl_impl(context, builder, sig, args): + [valty, amtty] = sig.args + [val, amt] = args + val = context.cast(builder, val, valty, sig.return_type) + amt = context.cast(builder, amt, amtty, sig.return_type) + res = builder.shl(val, amt) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_shr_impl(context, builder, sig, args): + [valty, amtty] = sig.args + [val, amt] = args + val = context.cast(builder, val, valty, sig.return_type) + amt = context.cast(builder, amt, amtty, sig.return_type) + if sig.return_type.signed: + res = builder.ashr(val, amt) + else: + res = builder.lshr(val, amt) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_and_impl(context, builder, sig, args): + [at, bt] = sig.args + [av, bv] = args + cav = context.cast(builder, av, at, sig.return_type) + cbc = context.cast(builder, bv, bt, sig.return_type) + res = builder.and_(cav, cbc) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_or_impl(context, builder, sig, args): + [at, bt] = sig.args + [av, bv] = args + cav = context.cast(builder, av, at, sig.return_type) + cbc = context.cast(builder, bv, bt, sig.return_type) + res = builder.or_(cav, cbc) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_xor_impl(context, builder, sig, args): + [at, bt] = sig.args + [av, bv] = args + cav = context.cast(builder, av, at, sig.return_type) + cbc = context.cast(builder, bv, bt, sig.return_type) + res = builder.xor(cav, cbc) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_negate_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + # Negate before upcasting, for unsigned numbers + res = builder.neg(val) + res = context.cast(builder, res, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_positive_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + res = context.cast(builder, val, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_invert_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + # Invert before upcasting, for unsigned numbers + res = builder.xor(val, Constant(val.type, int("1" * val.type.width, 2))) + res = context.cast(builder, res, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sign_impl(context, builder, sig, args): + """ + np.sign(int) + """ + [x] = args + POS = Constant(x.type, 1) + NEG = Constant(x.type, -1) + ZERO = Constant(x.type, 0) + + cmp_zero = builder.icmp_unsigned("==", x, ZERO) + cmp_pos = builder.icmp_signed(">", x, ZERO) + + presult = cgutils.alloca_once(builder, x.type) + + bb_zero = builder.append_basic_block(".zero") + bb_postest = builder.append_basic_block(".postest") + bb_pos = builder.append_basic_block(".pos") + bb_neg = builder.append_basic_block(".neg") + bb_exit = builder.append_basic_block(".exit") + + builder.cbranch(cmp_zero, bb_zero, bb_postest) + + with builder.goto_block(bb_zero): + builder.store(ZERO, presult) + builder.branch(bb_exit) + + with builder.goto_block(bb_postest): + builder.cbranch(cmp_pos, bb_pos, bb_neg) + + with builder.goto_block(bb_pos): + builder.store(POS, presult) + builder.branch(bb_exit) + + with builder.goto_block(bb_neg): + builder.store(NEG, presult) + builder.branch(bb_exit) + + builder.position_at_end(bb_exit) + res = builder.load(presult) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def bool_negate_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + res = context.cast(builder, val, typ, sig.return_type) + res = builder.neg(res) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def bool_unary_positive_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + res = context.cast(builder, val, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# lower_builtin(operator.eq, types.boolean, types.boolean)(int_eq_impl) +# lower_builtin(operator.ne, types.boolean, types.boolean)(int_ne_impl) +# lower_builtin(operator.lt, types.boolean, types.boolean)(int_ult_impl) +# lower_builtin(operator.le, types.boolean, types.boolean)(int_ule_impl) +# lower_builtin(operator.gt, types.boolean, types.boolean)(int_ugt_impl) +# lower_builtin(operator.ge, types.boolean, types.boolean)(int_uge_impl) +# lower_builtin(operator.neg, types.boolean)(bool_negate_impl) +# lower_builtin(operator.pos, types.boolean)(bool_unary_positive_impl) + + +# def _implement_integer_operators(): +# ty = types.Integer + +# lower_builtin(operator.add, ty, ty)(int_add_impl) +# lower_builtin(operator.iadd, ty, ty)(int_add_impl) +# lower_builtin(operator.sub, ty, ty)(int_sub_impl) +# lower_builtin(operator.isub, ty, ty)(int_sub_impl) +# lower_builtin(operator.mul, ty, ty)(int_mul_impl) +# lower_builtin(operator.imul, ty, ty)(int_mul_impl) +# lower_builtin(operator.eq, ty, ty)(int_eq_impl) +# lower_builtin(operator.ne, ty, ty)(int_ne_impl) + +# lower_builtin(operator.lshift, ty, ty)(int_shl_impl) +# lower_builtin(operator.ilshift, ty, ty)(int_shl_impl) +# lower_builtin(operator.rshift, ty, ty)(int_shr_impl) +# lower_builtin(operator.irshift, ty, ty)(int_shr_impl) + +# lower_builtin(operator.neg, ty)(int_negate_impl) +# lower_builtin(operator.pos, ty)(int_positive_impl) + +# lower_builtin(operator.pow, ty, ty)(int_power_impl) +# lower_builtin(operator.ipow, ty, ty)(int_power_impl) +# lower_builtin(pow, ty, ty)(int_power_impl) + +# for ty in types.unsigned_domain: +# lower_builtin(operator.lt, ty, ty)(int_ult_impl) +# lower_builtin(operator.le, ty, ty)(int_ule_impl) +# lower_builtin(operator.gt, ty, ty)(int_ugt_impl) +# lower_builtin(operator.ge, ty, ty)(int_uge_impl) +# lower_builtin(operator.pow, types.Float, ty)(int_power_impl) +# lower_builtin(operator.ipow, types.Float, ty)(int_power_impl) +# lower_builtin(pow, types.Float, ty)(int_power_impl) +# lower_builtin(abs, ty)(uint_abs_impl) + +# lower_builtin(operator.lt, types.IntegerLiteral, types.IntegerLiteral)(int_slt_impl) +# lower_builtin(operator.gt, types.IntegerLiteral, types.IntegerLiteral)(int_slt_impl) +# lower_builtin(operator.le, types.IntegerLiteral, types.IntegerLiteral)(int_slt_impl) +# lower_builtin(operator.ge, types.IntegerLiteral, types.IntegerLiteral)(int_slt_impl) +# for ty in types.signed_domain: +# lower_builtin(operator.lt, ty, ty)(int_slt_impl) +# lower_builtin(operator.le, ty, ty)(int_sle_impl) +# lower_builtin(operator.gt, ty, ty)(int_sgt_impl) +# lower_builtin(operator.ge, ty, ty)(int_sge_impl) +# lower_builtin(operator.pow, types.Float, ty)(int_power_impl) +# lower_builtin(operator.ipow, types.Float, ty)(int_power_impl) +# lower_builtin(pow, types.Float, ty)(int_power_impl) +# lower_builtin(abs, ty)(int_abs_impl) + +# def _implement_bitwise_operators(): +# for ty in (types.Boolean, types.Integer): +# lower_builtin(operator.and_, ty, ty)(int_and_impl) +# lower_builtin(operator.iand, ty, ty)(int_and_impl) +# lower_builtin(operator.or_, ty, ty)(int_or_impl) +# lower_builtin(operator.ior, ty, ty)(int_or_impl) +# lower_builtin(operator.xor, ty, ty)(int_xor_impl) +# lower_builtin(operator.ixor, ty, ty)(int_xor_impl) + +# lower_builtin(operator.invert, ty)(int_invert_impl) + +# _implement_integer_operators() + +# _implement_bitwise_operators() + + +def real_add_impl(context, builder, sig, args): + res = builder.fadd(*args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_sub_impl(context, builder, sig, args): + res = builder.fsub(*args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_mul_impl(context, builder, sig, args): + res = builder.fmul(*args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_div_impl(context, builder, sig, args): + with cgutils.if_zero(builder, args[1]): + context.error_model.fp_zero_division(builder, ("division by zero",)) + res = builder.fdiv(*args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_divmod(context, builder, x, y): + assert x.type == y.type + floatty = x.type + + module = builder.module + fname = context.mangler(".numba.python.rem", [x.type]) + fnty = ir.FunctionType(floatty, (floatty, floatty, ir.PointerType(floatty))) + fn = cgutils.get_or_insert_function(module, fnty, fname) + + if fn.is_declaration: + fn.linkage = "linkonce_odr" + fnbuilder = ir.IRBuilder(fn.append_basic_block("entry")) + fx, fy, pmod = fn.args + div, mod = real_divmod_func_body(context, fnbuilder, fx, fy) + fnbuilder.store(mod, pmod) + fnbuilder.ret(div) + + pmod = cgutils.alloca_once(builder, floatty) + quotient = builder.call(fn, (x, y, pmod)) + return quotient, builder.load(pmod) + + +def real_divmod_func_body(context, builder, vx, wx): + # Reference Objects/floatobject.c + # + # float_divmod(PyObject *v, PyObject *w) + # { + # double vx, wx; + # double div, mod, floordiv; + # CONVERT_TO_DOUBLE(v, vx); + # CONVERT_TO_DOUBLE(w, wx); + # mod = fmod(vx, wx); + # /* fmod is typically exact, so vx-mod is *mathematically* an + # exact multiple of wx. But this is fp arithmetic, and fp + # vx - mod is an approximation; the result is that div may + # not be an exact integral value after the division, although + # it will always be very close to one. + # */ + # div = (vx - mod) / wx; + # if (mod) { + # /* ensure the remainder has the same sign as the denominator */ + # if ((wx < 0) != (mod < 0)) { + # mod += wx; + # div -= 1.0; + # } + # } + # else { + # /* the remainder is zero, and in the presence of signed zeroes + # fmod returns different results across platforms; ensure + # it has the same sign as the denominator; we'd like to do + # "mod = wx * 0.0", but that may get optimized away */ + # mod *= mod; /* hide "mod = +0" from optimizer */ + # if (wx < 0.0) + # mod = -mod; + # } + # /* snap quotient to nearest integral value */ + # if (div) { + # floordiv = floor(div); + # if (div - floordiv > 0.5) + # floordiv += 1.0; + # } + # else { + # /* div is zero - get the same sign as the true quotient */ + # div *= div; /* hide "div = +0" from optimizers */ + # floordiv = div * vx / wx; /* zero w/ sign of vx/wx */ + # } + # return Py_BuildValue("(dd)", floordiv, mod); + # } + pmod = cgutils.alloca_once(builder, vx.type) + pdiv = cgutils.alloca_once(builder, vx.type) + pfloordiv = cgutils.alloca_once(builder, vx.type) + + mod = builder.frem(vx, wx) + div = builder.fdiv(builder.fsub(vx, mod), wx) + + builder.store(mod, pmod) + builder.store(div, pdiv) + + # Note the use of negative zero for proper negating with `ZERO - x` + ZERO = vx.type(0.0) + NZERO = vx.type(-0.0) + ONE = vx.type(1.0) + mod_istrue = builder.fcmp_unordered("!=", mod, ZERO) + wx_ltz = builder.fcmp_ordered("<", wx, ZERO) + mod_ltz = builder.fcmp_ordered("<", mod, ZERO) + + with builder.if_else(mod_istrue, likely=True) as ( + if_nonzero_mod, + if_zero_mod, + ): + with if_nonzero_mod: + # `mod` is non-zero or NaN + # Ensure the remainder has the same sign as the denominator + wx_ltz_ne_mod_ltz = builder.icmp_unsigned("!=", wx_ltz, mod_ltz) + + with builder.if_then(wx_ltz_ne_mod_ltz): + builder.store(builder.fsub(div, ONE), pdiv) + builder.store(builder.fadd(mod, wx), pmod) + + with if_zero_mod: + # `mod` is zero, select the proper sign depending on + # the denominator's sign + mod = builder.select(wx_ltz, NZERO, ZERO) + builder.store(mod, pmod) + + del mod, div + + div = builder.load(pdiv) + div_istrue = builder.fcmp_ordered("!=", div, ZERO) + + with builder.if_then(div_istrue): + realtypemap = {"float": types.float32, "double": types.float64} + realtype = realtypemap[str(wx.type)] + floorfn = context.get_function( + math.floor, typing.signature(realtype, realtype) + ) + floordiv = floorfn(builder, [div]) + floordivdiff = builder.fsub(div, floordiv) + floordivincr = builder.fadd(floordiv, ONE) + HALF = Constant(wx.type, 0.5) + pred = builder.fcmp_ordered(">", floordivdiff, HALF) + floordiv = builder.select(pred, floordivincr, floordiv) + builder.store(floordiv, pfloordiv) + + with cgutils.ifnot(builder, div_istrue): + div = builder.fmul(div, div) + builder.store(div, pdiv) + floordiv = builder.fdiv(builder.fmul(div, vx), wx) + builder.store(floordiv, pfloordiv) + + return builder.load(pfloordiv), builder.load(pmod) + + +# @lower_builtin(divmod, types.Float, types.Float) +def real_divmod_impl(context, builder, sig, args, loc=None): + x, y = args + quot = cgutils.alloca_once(builder, x.type, name="quot") + rem = cgutils.alloca_once(builder, x.type, name="rem") + + with builder.if_else(cgutils.is_scalar_zero(builder, y), likely=False) as ( + if_zero, + if_non_zero, + ): + with if_zero: + if not context.error_model.fp_zero_division( + builder, ("modulo by zero",), loc + ): + # No exception raised => compute the nan result, + # and set the FP exception word for Numpy warnings. + q = builder.fdiv(x, y) + r = builder.frem(x, y) + builder.store(q, quot) + builder.store(r, rem) + with if_non_zero: + q, r = real_divmod(context, builder, x, y) + builder.store(q, quot) + builder.store(r, rem) + + return cgutils.pack_array(builder, (builder.load(quot), builder.load(rem))) + + +def real_mod_impl(context, builder, sig, args, loc=None): + x, y = args + res = cgutils.alloca_once(builder, x.type) + with builder.if_else(cgutils.is_scalar_zero(builder, y), likely=False) as ( + if_zero, + if_non_zero, + ): + with if_zero: + if not context.error_model.fp_zero_division( + builder, ("modulo by zero",), loc + ): + # No exception raised => compute the nan result, + # and set the FP exception word for Numpy warnings. + rem = builder.frem(x, y) + builder.store(rem, res) + with if_non_zero: + _, rem = real_divmod(context, builder, x, y) + builder.store(rem, res) + return impl_ret_untracked( + context, builder, sig.return_type, builder.load(res) + ) + + +def real_floordiv_impl(context, builder, sig, args, loc=None): + x, y = args + res = cgutils.alloca_once(builder, x.type) + with builder.if_else(cgutils.is_scalar_zero(builder, y), likely=False) as ( + if_zero, + if_non_zero, + ): + with if_zero: + if not context.error_model.fp_zero_division( + builder, ("division by zero",), loc + ): + # No exception raised => compute the +/-inf or nan result, + # and set the FP exception word for Numpy warnings. + quot = builder.fdiv(x, y) + builder.store(quot, res) + with if_non_zero: + quot, _ = real_divmod(context, builder, x, y) + builder.store(quot, res) + return impl_ret_untracked( + context, builder, sig.return_type, builder.load(res) + ) + + +def real_power_impl(context, builder, sig, args): + x, y = args + module = builder.module + if context.implement_powi_as_math_call: + imp = context.get_function(math.pow, sig) + res = imp(builder, args) + else: + fn = module.declare_intrinsic("llvm.pow", [y.type]) + res = builder.call(fn, (x, y)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_lt_impl(context, builder, sig, args): + res = builder.fcmp_ordered("<", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_le_impl(context, builder, sig, args): + res = builder.fcmp_ordered("<=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_gt_impl(context, builder, sig, args): + res = builder.fcmp_ordered(">", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_ge_impl(context, builder, sig, args): + res = builder.fcmp_ordered(">=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_eq_impl(context, builder, sig, args): + res = builder.fcmp_ordered("==", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_ne_impl(context, builder, sig, args): + res = builder.fcmp_unordered("!=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_abs_impl(context, builder, sig, args): + [ty] = sig.args + sig = typing.signature(ty, ty) + impl = context.get_function(math.fabs, sig) + return impl(builder, args) + + +def real_negate_impl(context, builder, sig, args): + from numba.cuda.cpython import mathimpl + + res = mathimpl.negate_real(builder, args[0]) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_positive_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + res = context.cast(builder, val, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_sign_impl(context, builder, sig, args): + """ + np.sign(float) + """ + [x] = args + POS = Constant(x.type, 1) + NEG = Constant(x.type, -1) + ZERO = Constant(x.type, 0) + + presult = cgutils.alloca_once(builder, x.type) + + is_pos = builder.fcmp_ordered(">", x, ZERO) + is_neg = builder.fcmp_ordered("<", x, ZERO) + + with builder.if_else(is_pos) as (gt_zero, not_gt_zero): + with gt_zero: + builder.store(POS, presult) + with not_gt_zero: + with builder.if_else(is_neg) as (lt_zero, not_lt_zero): + with lt_zero: + builder.store(NEG, presult) + with not_lt_zero: + # For both NaN and 0, the result of sign() is simply + # the input value. + builder.store(x, presult) + + res = builder.load(presult) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ty = types.Float + +# lower_builtin(operator.add, ty, ty)(real_add_impl) +# lower_builtin(operator.iadd, ty, ty)(real_add_impl) +# lower_builtin(operator.sub, ty, ty)(real_sub_impl) +# lower_builtin(operator.isub, ty, ty)(real_sub_impl) +# lower_builtin(operator.mul, ty, ty)(real_mul_impl) +# lower_builtin(operator.imul, ty, ty)(real_mul_impl) +# lower_builtin(operator.floordiv, ty, ty)(real_floordiv_impl) +# lower_builtin(operator.ifloordiv, ty, ty)(real_floordiv_impl) +# lower_builtin(operator.truediv, ty, ty)(real_div_impl) +# lower_builtin(operator.itruediv, ty, ty)(real_div_impl) +# lower_builtin(operator.mod, ty, ty)(real_mod_impl) +# lower_builtin(operator.imod, ty, ty)(real_mod_impl) +# lower_builtin(operator.pow, ty, ty)(real_power_impl) +# lower_builtin(operator.ipow, ty, ty)(real_power_impl) +# lower_builtin(pow, ty, ty)(real_power_impl) + +# lower_builtin(operator.eq, ty, ty)(real_eq_impl) +# lower_builtin(operator.ne, ty, ty)(real_ne_impl) +# lower_builtin(operator.lt, ty, ty)(real_lt_impl) +# lower_builtin(operator.le, ty, ty)(real_le_impl) +# lower_builtin(operator.gt, ty, ty)(real_gt_impl) +# lower_builtin(operator.ge, ty, ty)(real_ge_impl) + +# lower_builtin(abs, ty)(real_abs_impl) + +# lower_builtin(operator.neg, ty)(real_negate_impl) +# lower_builtin(operator.pos, ty)(real_positive_impl) + +# del ty + + +# @lower_getattr(types.Complex, "real") +def complex_real_impl(context, builder, typ, value): + cplx = context.make_complex(builder, typ, value=value) + res = cplx.real + return impl_ret_untracked(context, builder, typ, res) + + +# @lower_getattr(types.Complex, "imag") +def complex_imag_impl(context, builder, typ, value): + cplx = context.make_complex(builder, typ, value=value) + res = cplx.imag + return impl_ret_untracked(context, builder, typ, res) + + +# @lower_builtin("complex.conjugate", types.Complex) +def complex_conjugate_impl(context, builder, sig, args): + from numba.cuda.cpython import mathimpl + + z = context.make_complex(builder, sig.args[0], args[0]) + z.imag = mathimpl.negate_real(builder, z.imag) + res = z._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_real_impl(context, builder, typ, value): + return impl_ret_untracked(context, builder, typ, value) + + +def real_imag_impl(context, builder, typ, value): + res = cgutils.get_null_value(value.type) + return impl_ret_untracked(context, builder, typ, res) + + +def real_conjugate_impl(context, builder, sig, args): + return impl_ret_untracked(context, builder, sig.return_type, args[0]) + + +# for cls in (types.Float, types.Integer): +# lower_getattr(cls, "real")(real_real_impl) +# lower_getattr(cls, "imag")(real_imag_impl) +# lower_builtin("complex.conjugate", cls)(real_conjugate_impl) + + +# @lower_builtin(operator.pow, types.Complex, types.Complex) +# @lower_builtin(operator.ipow, types.Complex, types.Complex) +# @lower_builtin(pow, types.Complex, types.Complex) +def complex_power_impl(context, builder, sig, args): + [ca, cb] = args + ty = sig.args[0] + fty = ty.underlying_float + a = context.make_helper(builder, ty, value=ca) + b = context.make_helper(builder, ty, value=cb) + c = context.make_helper(builder, ty) + module = builder.module + pa = a._getpointer() + pb = b._getpointer() + pc = c._getpointer() + + # Optimize for square because cpow loses a lot of precision + TWO = context.get_constant(fty, 2) + ZERO = context.get_constant(fty, 0) + + b_real_is_two = builder.fcmp_ordered("==", b.real, TWO) + b_imag_is_zero = builder.fcmp_ordered("==", b.imag, ZERO) + b_is_two = builder.and_(b_real_is_two, b_imag_is_zero) + + with builder.if_else(b_is_two) as (then, otherwise): + with then: + # Lower as multiplication + res = complex_mul_impl(context, builder, sig, (ca, ca)) + cres = context.make_helper(builder, ty, value=res) + c.real = cres.real + c.imag = cres.imag + + with otherwise: + # Lower with call to external function + func_name = { + types.complex64: "numba_cpowf", + types.complex128: "numba_cpow", + }[ty] + fnty = ir.FunctionType(ir.VoidType(), [pa.type] * 3) + cpow = cgutils.get_or_insert_function(module, fnty, func_name) + builder.call(cpow, (pa, pb, pc)) + + res = builder.load(pc) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_add_impl(context, builder, sig, args): + [cx, cy] = args + ty = sig.args[0] + x = context.make_complex(builder, ty, value=cx) + y = context.make_complex(builder, ty, value=cy) + z = context.make_complex(builder, ty) + a = x.real + b = x.imag + c = y.real + d = y.imag + z.real = builder.fadd(a, c) + z.imag = builder.fadd(b, d) + res = z._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_sub_impl(context, builder, sig, args): + [cx, cy] = args + ty = sig.args[0] + x = context.make_complex(builder, ty, value=cx) + y = context.make_complex(builder, ty, value=cy) + z = context.make_complex(builder, ty) + a = x.real + b = x.imag + c = y.real + d = y.imag + z.real = builder.fsub(a, c) + z.imag = builder.fsub(b, d) + res = z._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_mul_impl(context, builder, sig, args): + """ + (a+bi)(c+di)=(ac-bd)+i(ad+bc) + """ + [cx, cy] = args + ty = sig.args[0] + x = context.make_complex(builder, ty, value=cx) + y = context.make_complex(builder, ty, value=cy) + z = context.make_complex(builder, ty) + a = x.real + b = x.imag + c = y.real + d = y.imag + ac = builder.fmul(a, c) + bd = builder.fmul(b, d) + ad = builder.fmul(a, d) + bc = builder.fmul(b, c) + z.real = builder.fsub(ac, bd) + z.imag = builder.fadd(ad, bc) + res = z._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +NAN = float("nan") + + +def complex_div_impl(context, builder, sig, args): + def complex_div(a, b): + # This is CPython's algorithm (in _Py_c_quot()). + areal = a.real + aimag = a.imag + breal = b.real + bimag = b.imag + if not breal and not bimag: + raise ZeroDivisionError("complex division by zero") + if abs(breal) >= abs(bimag): + # Divide tops and bottom by b.real + if not breal: + return complex(NAN, NAN) + ratio = bimag / breal + denom = breal + bimag * ratio + return complex( + (areal + aimag * ratio) / denom, (aimag - areal * ratio) / denom + ) + else: + # Divide tops and bottom by b.imag + if not bimag: + return complex(NAN, NAN) + ratio = breal / bimag + denom = breal * ratio + bimag + return complex( + (a.real * ratio + a.imag) / denom, + (a.imag * ratio - a.real) / denom, + ) + + res = context.compile_internal(builder, complex_div, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_negate_impl(context, builder, sig, args): + from numba.cuda.cpython import mathimpl + + [typ] = sig.args + [val] = args + cmplx = context.make_complex(builder, typ, value=val) + res = context.make_complex(builder, typ) + res.real = mathimpl.negate_real(builder, cmplx.real) + res.imag = mathimpl.negate_real(builder, cmplx.imag) + res = res._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_positive_impl(context, builder, sig, args): + [val] = args + return impl_ret_untracked(context, builder, sig.return_type, val) + + +def complex_eq_impl(context, builder, sig, args): + [cx, cy] = args + typ = sig.args[0] + x = context.make_complex(builder, typ, value=cx) + y = context.make_complex(builder, typ, value=cy) + + reals_are_eq = builder.fcmp_ordered("==", x.real, y.real) + imags_are_eq = builder.fcmp_ordered("==", x.imag, y.imag) + res = builder.and_(reals_are_eq, imags_are_eq) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_ne_impl(context, builder, sig, args): + [cx, cy] = args + typ = sig.args[0] + x = context.make_complex(builder, typ, value=cx) + y = context.make_complex(builder, typ, value=cy) + + reals_are_ne = builder.fcmp_unordered("!=", x.real, y.real) + imags_are_ne = builder.fcmp_unordered("!=", x.imag, y.imag) + res = builder.or_(reals_are_ne, imags_are_ne) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_abs_impl(context, builder, sig, args): + """ + abs(z) := hypot(z.real, z.imag) + """ + + def complex_abs(z): + return math.hypot(z.real, z.imag) + + res = context.compile_internal(builder, complex_abs, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ty = types.Complex + +# lower_builtin(operator.add, ty, ty)(complex_add_impl) +# lower_builtin(operator.iadd, ty, ty)(complex_add_impl) +# lower_builtin(operator.sub, ty, ty)(complex_sub_impl) +# lower_builtin(operator.isub, ty, ty)(complex_sub_impl) +# lower_builtin(operator.mul, ty, ty)(complex_mul_impl) +# lower_builtin(operator.imul, ty, ty)(complex_mul_impl) +# lower_builtin(operator.truediv, ty, ty)(complex_div_impl) +# lower_builtin(operator.itruediv, ty, ty)(complex_div_impl) +# lower_builtin(operator.neg, ty)(complex_negate_impl) +# lower_builtin(operator.pos, ty)(complex_positive_impl) +# # Complex modulo is deprecated in python3 + +# lower_builtin(operator.eq, ty, ty)(complex_eq_impl) +# lower_builtin(operator.ne, ty, ty)(complex_ne_impl) + +# lower_builtin(abs, ty)(complex_abs_impl) + +# del ty + + +# @lower_builtin("number.item", types.Boolean) +# @lower_builtin("number.item", types.Number) +def number_item_impl(context, builder, sig, args): + """ + The no-op .item() method on booleans and numbers. + """ + return args[0] + + +# ------------------------------------------------------------------------------ + + +def number_not_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + istrue = context.cast(builder, val, typ, sig.return_type) + res = builder.not_(istrue) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# @lower_builtin(bool, types.Boolean) +def bool_as_bool(context, builder, sig, args): + [val] = args + return val + + +# @lower_builtin(bool, types.Integer) +def int_as_bool(context, builder, sig, args): + [val] = args + return builder.icmp_unsigned("!=", val, Constant(val.type, 0)) + + +# @lower_builtin(bool, types.Float) +def float_as_bool(context, builder, sig, args): + [val] = args + return builder.fcmp_unordered("!=", val, Constant(val.type, 0.0)) + + +# @lower_builtin(bool, types.Complex) +def complex_as_bool(context, builder, sig, args): + [typ] = sig.args + [val] = args + cmplx = context.make_complex(builder, typ, val) + real, imag = cmplx.real, cmplx.imag + zero = Constant(real.type, 0.0) + real_istrue = builder.fcmp_unordered("!=", real, zero) + imag_istrue = builder.fcmp_unordered("!=", imag, zero) + return builder.or_(real_istrue, imag_istrue) + + +# for ty in (types.Integer, types.Float, types.Complex): +# lower_builtin(operator.not_, ty)(number_not_impl) + +# lower_builtin(operator.not_, types.boolean)(number_not_impl) + + +# ------------------------------------------------------------------------------ +# Hashing numbers, see hashing.py + +# ------------------------------------------------------------------------------- +# Implicit casts between numerics + + +# @lower_cast(types.IntegerLiteral, types.Integer) +# @lower_cast(types.IntegerLiteral, types.Float) +# @lower_cast(types.IntegerLiteral, types.Complex) +def literal_int_to_number(context, builder, fromty, toty, val): + lit = context.get_constant_generic( + builder, + fromty.literal_type, + fromty.literal_value, + ) + return context.cast(builder, lit, fromty.literal_type, toty) + + +# @lower_cast(types.Integer, types.Integer) +def integer_to_integer(context, builder, fromty, toty, val): + if toty.bitwidth == fromty.bitwidth: + # Just a change of signedness + return val + elif toty.bitwidth < fromty.bitwidth: + # Downcast + return builder.trunc(val, context.get_value_type(toty)) + elif fromty.signed: + # Signed upcast + return builder.sext(val, context.get_value_type(toty)) + else: + # Unsigned upcast + return builder.zext(val, context.get_value_type(toty)) + + +# @lower_cast(types.Integer, types.voidptr) +def integer_to_voidptr(context, builder, fromty, toty, val): + return builder.inttoptr(val, context.get_value_type(toty)) + + +# @lower_cast(types.Float, types.Float) +def float_to_float(context, builder, fromty, toty, val): + lty = context.get_value_type(toty) + if fromty.bitwidth < toty.bitwidth: + return builder.fpext(val, lty) + else: + return builder.fptrunc(val, lty) + + +# @lower_cast(types.Integer, types.Float) +def integer_to_float(context, builder, fromty, toty, val): + lty = context.get_value_type(toty) + if fromty.signed: + return builder.sitofp(val, lty) + else: + return builder.uitofp(val, lty) + + +# @lower_cast(types.Float, types.Integer) +def float_to_integer(context, builder, fromty, toty, val): + lty = context.get_value_type(toty) + if toty.signed: + return builder.fptosi(val, lty) + else: + return builder.fptoui(val, lty) + + +# @lower_cast(types.Float, types.Complex) +# @lower_cast(types.Integer, types.Complex) +def non_complex_to_complex(context, builder, fromty, toty, val): + real = context.cast(builder, val, fromty, toty.underlying_float) + imag = context.get_constant(toty.underlying_float, 0) + + cmplx = context.make_complex(builder, toty) + cmplx.real = real + cmplx.imag = imag + return cmplx._getvalue() + + +# @lower_cast(types.Complex, types.Complex) +def complex_to_complex(context, builder, fromty, toty, val): + srcty = fromty.underlying_float + dstty = toty.underlying_float + + src = context.make_complex(builder, fromty, value=val) + dst = context.make_complex(builder, toty) + dst.real = context.cast(builder, src.real, srcty, dstty) + dst.imag = context.cast(builder, src.imag, srcty, dstty) + return dst._getvalue() + + +# @lower_cast(types.Any, types.Boolean) +def any_to_boolean(context, builder, fromty, toty, val): + return context.is_true(builder, fromty, val) + + +# @lower_cast(types.Boolean, types.Number) +def boolean_to_any(context, builder, fromty, toty, val): + # Casting from boolean to anything first casts to int32 + asint = builder.zext(val, ir.IntType(32)) + return context.cast(builder, asint, types.int32, toty) + + +# @lower_cast(types.IntegerLiteral, types.Boolean) +# @lower_cast(types.BooleanLiteral, types.Boolean) +def literal_int_to_boolean(context, builder, fromty, toty, val): + lit = context.get_constant_generic( + builder, + fromty.literal_type, + fromty.literal_value, + ) + return context.is_true(builder, fromty.literal_type, lit) + + +# ------------------------------------------------------------------------------- +# Constants + + +# @lower_constant(types.Complex) +def constant_complex(context, builder, ty, pyval): + fty = ty.underlying_float + real = context.get_constant_generic(builder, fty, pyval.real) + imag = context.get_constant_generic(builder, fty, pyval.imag) + return Constant.literal_struct((real, imag)) + + +# @lower_constant(types.Integer) +# @lower_constant(types.Float) +# @lower_constant(types.Boolean) +def constant_integer(context, builder, ty, pyval): + # See https://github.com/numba/numba/issues/6979 + # llvmlite ir.IntType specialises the formatting of the constant for a + # cpython bool. A NumPy np.bool_ is not a cpython bool so force it to be one + # so that the constant renders correctly! + if isinstance(pyval, np.bool_): + pyval = bool(pyval) + lty = context.get_value_type(ty) + return lty(pyval) + + +# ------------------------------------------------------------------------------- +# View + + +def scalar_view(scalar, viewty): + """Typing for the np scalar 'view' method.""" + if isinstance(scalar, (types.Float, types.Integer)) and isinstance( + viewty, types.abstract.DTypeSpec + ): + if scalar.bitwidth != viewty.dtype.bitwidth: + raise errors.TypingError( + "Changing the dtype of a 0d array is only supported if the " + "itemsize is unchanged" + ) + + def impl(scalar, viewty): + return viewer(scalar, viewty) + + return impl + + +# overload_method(types.Float, 'view')(scalar_view) +# overload_method(types.Integer, 'view')(scalar_view) diff --git a/numba_cuda/numba/cuda/np/npdatetime.py b/numba_cuda/numba/cuda/np/npdatetime.py new file mode 100644 index 000000000..2fd1ab2f8 --- /dev/null +++ b/numba_cuda/numba/cuda/np/npdatetime.py @@ -0,0 +1,969 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implementation of operations on numpy timedelta64. +""" + +import numpy as np +import operator + +import llvmlite.ir +from llvmlite.ir import Constant + +from numba.core import types +from numba.cuda import cgutils +from numba.cuda.cgutils import create_constant_array +from numba.core.imputils import ( + impl_ret_untracked, + lower_cast, + Registry, +) +from numba.cuda.np import npdatetime_helpers, numpy_support, npyfuncs +from numba.cuda.extending import overload_method +from numba.cuda.core.config import IS_32BITS +from numba.core.errors import LoweringError + +# datetime64 and timedelta64 use the same internal representation +DATETIME64 = TIMEDELTA64 = llvmlite.ir.IntType(64) +NAT = Constant(TIMEDELTA64, npdatetime_helpers.NAT) + +TIMEDELTA_BINOP_SIG = (types.NPTimedelta,) * 2 + +registry = Registry("np.npdatetime") +lower = registry.lower +lower_constant = registry.lower_constant + + +def scale_by_constant(builder, val, factor): + """ + Multiply *val* by the constant *factor*. + """ + return builder.mul(val, Constant(TIMEDELTA64, factor)) + + +def unscale_by_constant(builder, val, factor): + """ + Divide *val* by the constant *factor*. + """ + return builder.sdiv(val, Constant(TIMEDELTA64, factor)) + + +def add_constant(builder, val, const): + """ + Add constant *const* to *val*. + """ + return builder.add(val, Constant(TIMEDELTA64, const)) + + +def scale_timedelta(context, builder, val, srcty, destty): + """ + Scale the timedelta64 *val* from *srcty* to *destty* + (both numba.types.NPTimedelta instances) + """ + factor = npdatetime_helpers.get_timedelta_conversion_factor( + srcty.unit, destty.unit + ) + if factor is None: + # This can happen when using explicit output in a ufunc. + msg = f"cannot convert timedelta64 from {srcty.unit} to {destty.unit}" + raise LoweringError(msg) + return scale_by_constant(builder, val, factor) + + +def normalize_timedeltas(context, builder, left, right, leftty, rightty): + """ + Scale either *left* or *right* to the other's unit, in order to have + homogeneous units. + """ + factor = npdatetime_helpers.get_timedelta_conversion_factor( + leftty.unit, rightty.unit + ) + if factor is not None: + return scale_by_constant(builder, left, factor), right + factor = npdatetime_helpers.get_timedelta_conversion_factor( + rightty.unit, leftty.unit + ) + if factor is not None: + return left, scale_by_constant(builder, right, factor) + # Typing should not let this happen, except on == and != operators + raise RuntimeError("cannot normalize %r and %r" % (leftty, rightty)) + + +def alloc_timedelta_result(builder, name="ret"): + """ + Allocate a NaT-initialized datetime64 (or timedelta64) result slot. + """ + ret = cgutils.alloca_once(builder, TIMEDELTA64, name=name) + builder.store(NAT, ret) + return ret + + +def alloc_boolean_result(builder, name="ret"): + """ + Allocate an uninitialized boolean result slot. + """ + ret = cgutils.alloca_once(builder, llvmlite.ir.IntType(1), name=name) + return ret + + +def is_not_nat(builder, val): + """ + Return a predicate which is true if *val* is not NaT. + """ + return builder.icmp_unsigned("!=", val, NAT) + + +def are_not_nat(builder, vals): + """ + Return a predicate which is true if all of *vals* are not NaT. + """ + assert len(vals) >= 1 + pred = is_not_nat(builder, vals[0]) + for val in vals[1:]: + pred = builder.and_(pred, is_not_nat(builder, val)) + return pred + + +normal_year_months = create_constant_array( + TIMEDELTA64, [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] +) +leap_year_months = create_constant_array( + TIMEDELTA64, [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] +) +normal_year_months_acc = create_constant_array( + TIMEDELTA64, [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334] +) +leap_year_months_acc = create_constant_array( + TIMEDELTA64, [0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335] +) + + +@lower_constant(types.NPDatetime) +@lower_constant(types.NPTimedelta) +def datetime_constant(context, builder, ty, pyval): + return DATETIME64(pyval.astype(np.int64)) + + +# Arithmetic operators on timedelta64 + + +@lower(operator.pos, types.NPTimedelta) +def timedelta_pos_impl(context, builder, sig, args): + res = args[0] + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(operator.neg, types.NPTimedelta) +def timedelta_neg_impl(context, builder, sig, args): + res = builder.neg(args[0]) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(abs, types.NPTimedelta) +def timedelta_abs_impl(context, builder, sig, args): + (val,) = args + ret = alloc_timedelta_result(builder) + with builder.if_else(cgutils.is_scalar_neg(builder, val)) as ( + then, + otherwise, + ): + with then: + builder.store(builder.neg(val), ret) + with otherwise: + builder.store(val, ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def timedelta_sign_impl(context, builder, sig, args): + """ + np.sign(timedelta64) + """ + (val,) = args + ret = alloc_timedelta_result(builder) + zero = Constant(TIMEDELTA64, 0) + with builder.if_else(builder.icmp_signed(">", val, zero)) as ( + gt_zero, + le_zero, + ): + with gt_zero: + builder.store(Constant(TIMEDELTA64, 1), ret) + with le_zero: + with builder.if_else(builder.icmp_unsigned("==", val, zero)) as ( + eq_zero, + lt_zero, + ): + with eq_zero: + builder.store(Constant(TIMEDELTA64, 0), ret) + with lt_zero: + builder.store(Constant(TIMEDELTA64, -1), ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(operator.add, *TIMEDELTA_BINOP_SIG) +@lower(operator.iadd, *TIMEDELTA_BINOP_SIG) +def timedelta_add_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + ret = alloc_timedelta_result(builder) + with cgutils.if_likely(builder, are_not_nat(builder, [va, vb])): + va = scale_timedelta(context, builder, va, ta, sig.return_type) + vb = scale_timedelta(context, builder, vb, tb, sig.return_type) + builder.store(builder.add(va, vb), ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(operator.sub, *TIMEDELTA_BINOP_SIG) +@lower(operator.isub, *TIMEDELTA_BINOP_SIG) +def timedelta_sub_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + ret = alloc_timedelta_result(builder) + with cgutils.if_likely(builder, are_not_nat(builder, [va, vb])): + va = scale_timedelta(context, builder, va, ta, sig.return_type) + vb = scale_timedelta(context, builder, vb, tb, sig.return_type) + builder.store(builder.sub(va, vb), ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def _timedelta_times_number( + context, builder, td_arg, td_type, number_arg, number_type, return_type +): + ret = alloc_timedelta_result(builder) + with cgutils.if_likely(builder, is_not_nat(builder, td_arg)): + if isinstance(number_type, types.Float): + val = builder.sitofp(td_arg, number_arg.type) + val = builder.fmul(val, number_arg) + val = _cast_to_timedelta(context, builder, val) + else: + val = builder.mul(td_arg, number_arg) + # The scaling is required for ufunc np.multiply() with an explicit + # output in a different unit. + val = scale_timedelta(context, builder, val, td_type, return_type) + builder.store(val, ret) + return builder.load(ret) + + +@lower(operator.mul, types.NPTimedelta, types.Integer) +@lower(operator.imul, types.NPTimedelta, types.Integer) +@lower(operator.mul, types.NPTimedelta, types.Float) +@lower(operator.imul, types.NPTimedelta, types.Float) +def timedelta_times_number(context, builder, sig, args): + res = _timedelta_times_number( + context, + builder, + args[0], + sig.args[0], + args[1], + sig.args[1], + sig.return_type, + ) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(operator.mul, types.Integer, types.NPTimedelta) +@lower(operator.imul, types.Integer, types.NPTimedelta) +@lower(operator.mul, types.Float, types.NPTimedelta) +@lower(operator.imul, types.Float, types.NPTimedelta) +def number_times_timedelta(context, builder, sig, args): + res = _timedelta_times_number( + context, + builder, + args[1], + sig.args[1], + args[0], + sig.args[0], + sig.return_type, + ) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(operator.truediv, types.NPTimedelta, types.Integer) +@lower(operator.itruediv, types.NPTimedelta, types.Integer) +@lower(operator.floordiv, types.NPTimedelta, types.Integer) +@lower(operator.ifloordiv, types.NPTimedelta, types.Integer) +@lower(operator.truediv, types.NPTimedelta, types.Float) +@lower(operator.itruediv, types.NPTimedelta, types.Float) +@lower(operator.floordiv, types.NPTimedelta, types.Float) +@lower(operator.ifloordiv, types.NPTimedelta, types.Float) +def timedelta_over_number(context, builder, sig, args): + td_arg, number_arg = args + number_type = sig.args[1] + ret = alloc_timedelta_result(builder) + ok = builder.and_( + is_not_nat(builder, td_arg), + builder.not_(cgutils.is_scalar_zero_or_nan(builder, number_arg)), + ) + with cgutils.if_likely(builder, ok): + # Denominator is non-zero, non-NaN + if isinstance(number_type, types.Float): + val = builder.sitofp(td_arg, number_arg.type) + val = builder.fdiv(val, number_arg) + val = _cast_to_timedelta(context, builder, val) + else: + val = builder.sdiv(td_arg, number_arg) + # The scaling is required for ufuncs np.*divide() with an explicit + # output in a different unit. + val = scale_timedelta( + context, builder, val, sig.args[0], sig.return_type + ) + builder.store(val, ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(operator.truediv, *TIMEDELTA_BINOP_SIG) +@lower(operator.itruediv, *TIMEDELTA_BINOP_SIG) +def timedelta_over_timedelta(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + not_nan = are_not_nat(builder, [va, vb]) + ll_ret_type = context.get_value_type(sig.return_type) + ret = cgutils.alloca_once(builder, ll_ret_type, name="ret") + builder.store(Constant(ll_ret_type, float("nan")), ret) + with cgutils.if_likely(builder, not_nan): + va, vb = normalize_timedeltas(context, builder, va, vb, ta, tb) + va = builder.sitofp(va, ll_ret_type) + vb = builder.sitofp(vb, ll_ret_type) + builder.store(builder.fdiv(va, vb), ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(operator.floordiv, *TIMEDELTA_BINOP_SIG) +def timedelta_floor_div_timedelta(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + ll_ret_type = context.get_value_type(sig.return_type) + not_nan = are_not_nat(builder, [va, vb]) + ret = cgutils.alloca_once(builder, ll_ret_type, name="ret") + zero = Constant(ll_ret_type, 0) + one = Constant(ll_ret_type, 1) + builder.store(zero, ret) + with cgutils.if_likely(builder, not_nan): + va, vb = normalize_timedeltas(context, builder, va, vb, ta, tb) + # is the denominator zero or NaT? + denom_ok = builder.not_(builder.icmp_signed("==", vb, zero)) + with cgutils.if_likely(builder, denom_ok): + # is either arg negative? + vaneg = builder.icmp_signed("<", va, zero) + neg = builder.or_(vaneg, builder.icmp_signed("<", vb, zero)) + with builder.if_else(neg) as (then, otherwise): + with then: # one or more value negative + with builder.if_else(vaneg) as (negthen, negotherwise): + with negthen: + top = builder.sub(va, one) + div = builder.sdiv(top, vb) + builder.store(div, ret) + with negotherwise: + top = builder.add(va, one) + div = builder.sdiv(top, vb) + builder.store(div, ret) + with otherwise: + div = builder.sdiv(va, vb) + builder.store(div, ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def timedelta_mod_timedelta(context, builder, sig, args): + # inspired by https://github.com/numpy/numpy/blob/fe8072a12d65e43bd2e0b0f9ad67ab0108cc54b3/numpy/core/src/umath/loops.c.src#L1424 + # alg is basically as `a % b`: + # if a or b is NaT return NaT + # elseif b is 0 return NaT + # else pretend a and b are int and do pythonic int modulus + + [va, vb] = args + [ta, tb] = sig.args + not_nan = are_not_nat(builder, [va, vb]) + ll_ret_type = context.get_value_type(sig.return_type) + ret = alloc_timedelta_result(builder) + builder.store(NAT, ret) + zero = Constant(ll_ret_type, 0) + with cgutils.if_likely(builder, not_nan): + va, vb = normalize_timedeltas(context, builder, va, vb, ta, tb) + # is the denominator zero or NaT? + denom_ok = builder.not_(builder.icmp_signed("==", vb, zero)) + with cgutils.if_likely(builder, denom_ok): + # is either arg negative? + vapos = builder.icmp_signed(">", va, zero) + vbpos = builder.icmp_signed(">", vb, zero) + rem = builder.srem(va, vb) + cond = builder.or_( + builder.and_(vapos, vbpos), builder.icmp_signed("==", rem, zero) + ) + with builder.if_else(cond) as (then, otherwise): + with then: + builder.store(rem, ret) + with otherwise: + builder.store(builder.add(rem, vb), ret) + + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# Comparison operators on timedelta64 + + +def _create_timedelta_comparison_impl(ll_op, default_value): + def impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + ret = alloc_boolean_result(builder) + with builder.if_else(are_not_nat(builder, [va, vb])) as ( + then, + otherwise, + ): + with then: + try: + norm_a, norm_b = normalize_timedeltas( + context, builder, va, vb, ta, tb + ) + except RuntimeError: + # Cannot normalize units => the values are unequal (except if NaT) + builder.store(default_value, ret) + else: + builder.store( + builder.icmp_unsigned(ll_op, norm_a, norm_b), ret + ) + with otherwise: + # NaT ==/>=/>/ is True + if ll_op == "!=": + builder.store(cgutils.true_bit, ret) + else: + builder.store(cgutils.false_bit, ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + return impl + + +def _create_timedelta_ordering_impl(ll_op): + def impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + ret = alloc_boolean_result(builder) + with builder.if_else(are_not_nat(builder, [va, vb])) as ( + then, + otherwise, + ): + with then: + norm_a, norm_b = normalize_timedeltas( + context, builder, va, vb, ta, tb + ) + builder.store(builder.icmp_signed(ll_op, norm_a, norm_b), ret) + with otherwise: + # NaT >=/>/") +timedelta_ge_timedelta_impl = _create_timedelta_ordering_impl(">=") + +for op_, func in [ + (operator.eq, timedelta_eq_timedelta_impl), + (operator.ne, timedelta_ne_timedelta_impl), + (operator.lt, timedelta_lt_timedelta_impl), + (operator.le, timedelta_le_timedelta_impl), + (operator.gt, timedelta_gt_timedelta_impl), + (operator.ge, timedelta_ge_timedelta_impl), +]: + lower(op_, *TIMEDELTA_BINOP_SIG)(func) + + +# Arithmetic on datetime64 + + +def is_leap_year(builder, year_val): + """ + Return a predicate indicating whether *year_val* (offset by 1970) is a + leap year. + """ + actual_year = builder.add(year_val, Constant(DATETIME64, 1970)) + multiple_of_4 = cgutils.is_null( + builder, builder.and_(actual_year, Constant(DATETIME64, 3)) + ) + not_multiple_of_100 = cgutils.is_not_null( + builder, builder.srem(actual_year, Constant(DATETIME64, 100)) + ) + multiple_of_400 = cgutils.is_null( + builder, builder.srem(actual_year, Constant(DATETIME64, 400)) + ) + return builder.and_( + multiple_of_4, builder.or_(not_multiple_of_100, multiple_of_400) + ) + + +def year_to_days(builder, year_val): + """ + Given a year *year_val* (offset to 1970), return the number of days + since the 1970 epoch. + """ + # The algorithm below is copied from Numpy's get_datetimestruct_days() + # (src/multiarray/datetime.c) + ret = cgutils.alloca_once(builder, TIMEDELTA64) + # First approximation + days = scale_by_constant(builder, year_val, 365) + # Adjust for leap years + with builder.if_else(cgutils.is_neg_int(builder, year_val)) as ( + if_neg, + if_pos, + ): + with if_pos: + # At or after 1970: + # 1968 is the closest leap year before 1970. + # Exclude the current year, so add 1. + from_1968 = add_constant(builder, year_val, 1) + # Add one day for each 4 years + p_days = builder.add( + days, unscale_by_constant(builder, from_1968, 4) + ) + # 1900 is the closest previous year divisible by 100 + from_1900 = add_constant(builder, from_1968, 68) + # Subtract one day for each 100 years + p_days = builder.sub( + p_days, unscale_by_constant(builder, from_1900, 100) + ) + # 1600 is the closest previous year divisible by 400 + from_1600 = add_constant(builder, from_1900, 300) + # Add one day for each 400 years + p_days = builder.add( + p_days, unscale_by_constant(builder, from_1600, 400) + ) + builder.store(p_days, ret) + with if_neg: + # Before 1970: + # NOTE `year_val` is negative, and so will be `from_1972` and `from_2000`. + # 1972 is the closest later year after 1970. + # Include the current year, so subtract 2. + from_1972 = add_constant(builder, year_val, -2) + # Subtract one day for each 4 years (`from_1972` is negative) + n_days = builder.add( + days, unscale_by_constant(builder, from_1972, 4) + ) + # 2000 is the closest later year divisible by 100 + from_2000 = add_constant(builder, from_1972, -28) + # Add one day for each 100 years + n_days = builder.sub( + n_days, unscale_by_constant(builder, from_2000, 100) + ) + # 2000 is also the closest later year divisible by 400 + # Subtract one day for each 400 years + n_days = builder.add( + n_days, unscale_by_constant(builder, from_2000, 400) + ) + builder.store(n_days, ret) + return builder.load(ret) + + +def reduce_datetime_for_unit(builder, dt_val, src_unit, dest_unit): + dest_unit_code = npdatetime_helpers.DATETIME_UNITS[dest_unit] + src_unit_code = npdatetime_helpers.DATETIME_UNITS[src_unit] + if dest_unit_code < 2 or src_unit_code >= 2: + return dt_val, src_unit + # Need to compute the day ordinal for *dt_val* + if src_unit_code == 0: + # Years to days + year_val = dt_val + days_val = year_to_days(builder, year_val) + + else: + # Months to days + leap_array = cgutils.global_constant( + builder, "leap_year_months_acc", leap_year_months_acc + ) + normal_array = cgutils.global_constant( + builder, "normal_year_months_acc", normal_year_months_acc + ) + + days = cgutils.alloca_once(builder, TIMEDELTA64) + + # First compute year number and month number + year, month = cgutils.divmod_by_constant(builder, dt_val, 12) + + # Then deduce the number of days + with builder.if_else(is_leap_year(builder, year)) as (then, otherwise): + with then: + addend = builder.load( + cgutils.gep(builder, leap_array, 0, month, inbounds=True) + ) + builder.store(addend, days) + with otherwise: + addend = builder.load( + cgutils.gep(builder, normal_array, 0, month, inbounds=True) + ) + builder.store(addend, days) + + days_val = year_to_days(builder, year) + days_val = builder.add(days_val, builder.load(days)) + + if dest_unit_code == 2: + # Need to scale back to weeks + weeks, _ = cgutils.divmod_by_constant(builder, days_val, 7) + return weeks, "W" + else: + return days_val, "D" + + +def convert_datetime_for_arith(builder, dt_val, src_unit, dest_unit): + """ + Convert datetime *dt_val* from *src_unit* to *dest_unit*. + """ + # First partial conversion to days or weeks, if necessary. + dt_val, dt_unit = reduce_datetime_for_unit( + builder, dt_val, src_unit, dest_unit + ) + # Then multiply by the remaining constant factor. + dt_factor = npdatetime_helpers.get_timedelta_conversion_factor( + dt_unit, dest_unit + ) + if dt_factor is None: + # This can happen when using explicit output in a ufunc. + raise LoweringError( + "cannot convert datetime64 from %r to %r" % (src_unit, dest_unit) + ) + return scale_by_constant(builder, dt_val, dt_factor) + + +def _datetime_timedelta_arith(ll_op_name): + def impl(context, builder, dt_arg, dt_unit, td_arg, td_unit, ret_unit): + ret = alloc_timedelta_result(builder) + with cgutils.if_likely(builder, are_not_nat(builder, [dt_arg, td_arg])): + dt_arg = convert_datetime_for_arith( + builder, dt_arg, dt_unit, ret_unit + ) + td_factor = npdatetime_helpers.get_timedelta_conversion_factor( + td_unit, ret_unit + ) + td_arg = scale_by_constant(builder, td_arg, td_factor) + ret_val = getattr(builder, ll_op_name)(dt_arg, td_arg) + builder.store(ret_val, ret) + return builder.load(ret) + + return impl + + +_datetime_plus_timedelta = _datetime_timedelta_arith("add") +_datetime_minus_timedelta = _datetime_timedelta_arith("sub") + +# datetime64 + timedelta64 + + +@lower(operator.add, types.NPDatetime, types.NPTimedelta) +@lower(operator.iadd, types.NPDatetime, types.NPTimedelta) +def datetime_plus_timedelta(context, builder, sig, args): + dt_arg, td_arg = args + dt_type, td_type = sig.args + res = _datetime_plus_timedelta( + context, + builder, + dt_arg, + dt_type.unit, + td_arg, + td_type.unit, + sig.return_type.unit, + ) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(operator.add, types.NPTimedelta, types.NPDatetime) +@lower(operator.iadd, types.NPTimedelta, types.NPDatetime) +def timedelta_plus_datetime(context, builder, sig, args): + td_arg, dt_arg = args + td_type, dt_type = sig.args + res = _datetime_plus_timedelta( + context, + builder, + dt_arg, + dt_type.unit, + td_arg, + td_type.unit, + sig.return_type.unit, + ) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# datetime64 - timedelta64 + + +@lower(operator.sub, types.NPDatetime, types.NPTimedelta) +@lower(operator.isub, types.NPDatetime, types.NPTimedelta) +def datetime_minus_timedelta(context, builder, sig, args): + dt_arg, td_arg = args + dt_type, td_type = sig.args + res = _datetime_minus_timedelta( + context, + builder, + dt_arg, + dt_type.unit, + td_arg, + td_type.unit, + sig.return_type.unit, + ) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# datetime64 - datetime64 + + +@lower(operator.sub, types.NPDatetime, types.NPDatetime) +def datetime_minus_datetime(context, builder, sig, args): + va, vb = args + ta, tb = sig.args + unit_a = ta.unit + unit_b = tb.unit + ret_unit = sig.return_type.unit + ret = alloc_timedelta_result(builder) + with cgutils.if_likely(builder, are_not_nat(builder, [va, vb])): + va = convert_datetime_for_arith(builder, va, unit_a, ret_unit) + vb = convert_datetime_for_arith(builder, vb, unit_b, ret_unit) + ret_val = builder.sub(va, vb) + builder.store(ret_val, ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# datetime64 comparisons + + +def _create_datetime_comparison_impl(ll_op): + def impl(context, builder, sig, args): + va, vb = args + ta, tb = sig.args + unit_a = ta.unit + unit_b = tb.unit + ret_unit = npdatetime_helpers.get_best_unit(unit_a, unit_b) + ret = alloc_boolean_result(builder) + with builder.if_else(are_not_nat(builder, [va, vb])) as ( + then, + otherwise, + ): + with then: + norm_a = convert_datetime_for_arith( + builder, va, unit_a, ret_unit + ) + norm_b = convert_datetime_for_arith( + builder, vb, unit_b, ret_unit + ) + ret_val = builder.icmp_signed(ll_op, norm_a, norm_b) + builder.store(ret_val, ret) + with otherwise: + if ll_op == "!=": + ret_val = cgutils.true_bit + else: + ret_val = cgutils.false_bit + builder.store(ret_val, ret) + res = builder.load(ret) + return impl_ret_untracked(context, builder, sig.return_type, res) + + return impl + + +datetime_eq_datetime_impl = _create_datetime_comparison_impl("==") +datetime_ne_datetime_impl = _create_datetime_comparison_impl("!=") +datetime_lt_datetime_impl = _create_datetime_comparison_impl("<") +datetime_le_datetime_impl = _create_datetime_comparison_impl("<=") +datetime_gt_datetime_impl = _create_datetime_comparison_impl(">") +datetime_ge_datetime_impl = _create_datetime_comparison_impl(">=") + +for op, func in [ + (operator.eq, datetime_eq_datetime_impl), + (operator.ne, datetime_ne_datetime_impl), + (operator.lt, datetime_lt_datetime_impl), + (operator.le, datetime_le_datetime_impl), + (operator.gt, datetime_gt_datetime_impl), + (operator.ge, datetime_ge_datetime_impl), +]: + lower(op, *[types.NPDatetime] * 2)(func) + + +######################################################################## +# datetime/timedelta fmax/fmin maximum/minimum support + + +def _gen_datetime_max_impl(NAT_DOMINATES): + def datetime_max_impl(context, builder, sig, args): + # note this could be optimizing relying on the actual value of NAT + # but as NumPy doesn't rely on this, this seems more resilient + in1, in2 = args + in1_not_nat = is_not_nat(builder, in1) + in2_not_nat = is_not_nat(builder, in2) + in1_ge_in2 = builder.icmp_signed(">=", in1, in2) + res = builder.select(in1_ge_in2, in1, in2) + if NAT_DOMINATES: + # NaT now dominates, like NaN + in1, in2 = in2, in1 + res = builder.select(in1_not_nat, res, in2) + res = builder.select(in2_not_nat, res, in1) + + return impl_ret_untracked(context, builder, sig.return_type, res) + + return datetime_max_impl + + +datetime_maximum_impl = _gen_datetime_max_impl(True) +datetime_fmax_impl = _gen_datetime_max_impl(False) + + +def _gen_datetime_min_impl(NAT_DOMINATES): + def datetime_min_impl(context, builder, sig, args): + # note this could be optimizing relying on the actual value of NAT + # but as NumPy doesn't rely on this, this seems more resilient + in1, in2 = args + in1_not_nat = is_not_nat(builder, in1) + in2_not_nat = is_not_nat(builder, in2) + in1_le_in2 = builder.icmp_signed("<=", in1, in2) + res = builder.select(in1_le_in2, in1, in2) + if NAT_DOMINATES: + # NaT now dominates, like NaN + in1, in2 = in2, in1 + res = builder.select(in1_not_nat, res, in2) + res = builder.select(in2_not_nat, res, in1) + + return impl_ret_untracked(context, builder, sig.return_type, res) + + return datetime_min_impl + + +datetime_minimum_impl = _gen_datetime_min_impl(True) +datetime_fmin_impl = _gen_datetime_min_impl(False) + + +def _gen_timedelta_max_impl(NAT_DOMINATES): + def timedelta_max_impl(context, builder, sig, args): + # note this could be optimizing relying on the actual value of NAT + # but as NumPy doesn't rely on this, this seems more resilient + in1, in2 = args + in1_not_nat = is_not_nat(builder, in1) + in2_not_nat = is_not_nat(builder, in2) + in1_ge_in2 = builder.icmp_signed(">=", in1, in2) + res = builder.select(in1_ge_in2, in1, in2) + if NAT_DOMINATES: + # NaT now dominates, like NaN + in1, in2 = in2, in1 + res = builder.select(in1_not_nat, res, in2) + res = builder.select(in2_not_nat, res, in1) + + return impl_ret_untracked(context, builder, sig.return_type, res) + + return timedelta_max_impl + + +timedelta_maximum_impl = _gen_timedelta_max_impl(True) +timedelta_fmax_impl = _gen_timedelta_max_impl(False) + + +def _gen_timedelta_min_impl(NAT_DOMINATES): + def timedelta_min_impl(context, builder, sig, args): + # note this could be optimizing relying on the actual value of NAT + # but as NumPy doesn't rely on this, this seems more resilient + in1, in2 = args + in1_not_nat = is_not_nat(builder, in1) + in2_not_nat = is_not_nat(builder, in2) + in1_le_in2 = builder.icmp_signed("<=", in1, in2) + res = builder.select(in1_le_in2, in1, in2) + if NAT_DOMINATES: + # NaT now dominates, like NaN + in1, in2 = in2, in1 + res = builder.select(in1_not_nat, res, in2) + res = builder.select(in2_not_nat, res, in1) + + return impl_ret_untracked(context, builder, sig.return_type, res) + + return timedelta_min_impl + + +timedelta_minimum_impl = _gen_timedelta_min_impl(True) +timedelta_fmin_impl = _gen_timedelta_min_impl(False) + + +def _cast_to_timedelta(context, builder, val): + temp = builder.alloca(TIMEDELTA64) + val_is_nan = builder.fcmp_unordered("uno", val, val) + with builder.if_else(val_is_nan) as (then, els): + with then: + # NaN does not guarantee to cast to NAT. + # We should store NAT explicitly. + builder.store(NAT, temp) + with els: + builder.store(builder.fptosi(val, TIMEDELTA64), temp) + return builder.load(temp) + + +@lower(np.isnat, types.NPDatetime) +@lower(np.isnat, types.NPTimedelta) +def _np_isnat_impl(context, builder, sig, args): + return npyfuncs.np_datetime_isnat_impl(context, builder, sig, args) + + +@lower_cast(types.NPDatetime, types.Integer) +@lower_cast(types.NPTimedelta, types.Integer) +def _cast_npdatetime_int64(context, builder, fromty, toty, val): + if toty.bitwidth != 64: # all date time types are 64 bit + msg = f"Cannot cast {fromty} to {toty} as {toty} is not 64 bits wide." + raise ValueError(msg) + return val + + +@overload_method(types.NPTimedelta, "__hash__") +@overload_method(types.NPDatetime, "__hash__") +def ol_hash_npdatetime(x): + if ( + numpy_support.numpy_version >= (2, 2) + and isinstance(x, types.NPTimedelta) + and not x.unit + ): + raise ValueError("Can't hash generic timedelta64") + + if IS_32BITS: + + def impl(x): + x = np.int64(x) + if x < 2**31 - 1: # x < LONG_MAX + y = np.int32(x) + else: + hi = (np.int64(x) & 0xFFFFFFFF00000000) >> 32 + lo = np.int64(x) & 0x00000000FFFFFFFF + y = np.int32(lo + (1000003) * hi) + if y == -1: + y = np.int32(-2) + return y + else: + + def impl(x): + if np.int64(x) == -1: + return np.int64(-2) + return np.int64(x) + + return impl + + +lower(npdatetime_helpers.datetime_minimum, types.NPDatetime, types.NPDatetime)( + datetime_minimum_impl +) +lower( + npdatetime_helpers.datetime_minimum, types.NPTimedelta, types.NPTimedelta +)(datetime_minimum_impl) +lower(npdatetime_helpers.datetime_maximum, types.NPDatetime, types.NPDatetime)( + datetime_maximum_impl +) +lower( + npdatetime_helpers.datetime_maximum, types.NPTimedelta, types.NPTimedelta +)(datetime_maximum_impl) diff --git a/numba_cuda/numba/cuda/np/npyfuncs.py b/numba_cuda/numba/cuda/np/npyfuncs.py index 84eec9ad7..7873d5f35 100644 --- a/numba_cuda/numba/cuda/np/npyfuncs.py +++ b/numba_cuda/numba/cuda/np/npyfuncs.py @@ -16,10 +16,10 @@ from numba.core.imputils import impl_ret_untracked from numba.core import typing, types, errors from numba.cuda import cgutils +from numba.cuda.np import npdatetime from numba.cuda.extending import register_jitable -from numba.np import npdatetime -from numba.np.math import cmathimpl, mathimpl, numbers -from numba.np.numpy_support import numpy_version +from numba.cuda.np.math import cmathimpl, mathimpl, numbers +from numba.cuda.np.numpy_support import numpy_version # some NumPy constants. Note that we could generate some of them using # the math library, but having the values copied from npy_math seems to @@ -398,7 +398,7 @@ def _generate_logaddexp(fnoverload, const, log1pfn, expfn): # Code generation for logaddexp and logaddexp2 is based on: # https://github.com/numpy/numpy/blob/12c2b7dd62fc0c14b81c8892ed5f4f59cc94d09c/numpy/core/src/npymath/npy_math_internal.h.src#L467-L507 - @overload(fnoverload, target="cuda") + @overload(fnoverload) def ol_npy_logaddexp(x1, x2): if x1 != x2: return @@ -452,7 +452,7 @@ def npy_log2_1p(x): # https://github.com/numpy/numpy/blob/12c2b7dd62fc0c14b81c8892ed5f4f59cc94d09c/numpy/core/src/npymath/npy_math_internal.h.src#L457-L460 -@overload(npy_log2_1p, target="cuda") +@overload(npy_log2_1p) def ol_npy_log2_1p(x): LOG2E = x(_NPY_LOG2E) diff --git a/numba_cuda/numba/cuda/np/npyimpl.py b/numba_cuda/numba/cuda/np/npyimpl.py new file mode 100644 index 000000000..fe2d015f1 --- /dev/null +++ b/numba_cuda/numba/cuda/np/npyimpl.py @@ -0,0 +1,1027 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implementation of functions in the Numpy package. +""" + +import itertools +from collections import namedtuple + +import llvmlite.ir as ir + +import numpy as np +import operator + +from numba.cuda.np import arrayobj +from numba.cuda.np import ufunc_db +from numba.cuda.np.ufunc.sigparse import parse_signature +from numba.core.imputils import ( + Registry, + impl_ret_new_ref, + force_error_model, + impl_ret_borrowed, +) +from numba.core import typing, types +from numba.cuda import cgutils +from numba.cuda.np.numpy_support import ( + ufunc_find_matching_loop, + select_array_wrapper, + from_dtype, + _ufunc_loop_sig, +) +from numba.cuda.np.arrayobj import _getitem_array_generic +from numba.core.typing import npydecl +from numba.cuda.extending import overload, intrinsic + +from numba.core import errors + +registry = Registry("npyimpl") + + +######################################################################## + +# In the way we generate code, ufuncs work with scalar as well as +# with array arguments. The following helper classes help dealing +# with scalar and array arguments in a regular way. +# +# In short, the classes provide a uniform interface. The interface +# handles the indexing of as many dimensions as the array may have. +# For scalars, all indexing is ignored and when the value is read, +# the scalar is returned. For arrays code for actual indexing is +# generated and reading performs the appropriate indirection. + + +class _ScalarIndexingHelper(object): + def update_indices(self, loop_indices, name): + pass + + def as_values(self): + pass + + +class _ScalarHelper(object): + """Helper class to handle scalar arguments (and result). + Note that store_data is only used when generating code for + a scalar ufunc and to write the output value. + + For loading, the value is directly used without having any + kind of indexing nor memory backing it up. This is the use + for input arguments. + + For storing, a variable is created in the stack where the + value will be written. + + Note that it is not supported (as it is unneeded for our + current use-cases) reading back a stored value. This class + will always "load" the original value it got at its creation. + """ + + def __init__(self, ctxt, bld, val, ty): + self.context = ctxt + self.builder = bld + self.val = val + self.base_type = ty + intpty = ctxt.get_value_type(types.intp) + self.shape = [ir.Constant(intpty, 1)] + + lty = ctxt.get_data_type(ty) if ty != types.boolean else ir.IntType(1) + self._ptr = cgutils.alloca_once(bld, lty) + + def create_iter_indices(self): + return _ScalarIndexingHelper() + + def load_data(self, indices): + return self.val + + def store_data(self, indices, val): + self.builder.store(val, self._ptr) + + @property + def return_val(self): + return self.builder.load(self._ptr) + + +class _ArrayIndexingHelper( + namedtuple("_ArrayIndexingHelper", ("array", "indices")) +): + def update_indices(self, loop_indices, name): + bld = self.array.builder + intpty = self.array.context.get_value_type(types.intp) + ONE = ir.Constant(ir.IntType(intpty.width), 1) + + # we are only interested in as many inner dimensions as dimensions + # the indexed array has (the outer dimensions are broadcast, so + # ignoring the outer indices produces the desired result. + indices = loop_indices[len(loop_indices) - len(self.indices) :] + for src, dst, dim in zip(indices, self.indices, self.array.shape): + cond = bld.icmp_unsigned(">", dim, ONE) + with bld.if_then(cond): + bld.store(src, dst) + + def as_values(self): + """ + The indexing helper is built using alloca for each value, so it + actually contains pointers to the actual indices to load. Note + that update_indices assumes the same. This method returns the + indices as values + """ + bld = self.array.builder + return [bld.load(index) for index in self.indices] + + +class _ArrayHelper( + namedtuple( + "_ArrayHelper", + ( + "context", + "builder", + "shape", + "strides", + "data", + "layout", + "base_type", + "ndim", + "return_val", + ), + ) +): + """Helper class to handle array arguments/result. + It provides methods to generate code loading/storing specific + items as well as support code for handling indices. + """ + + def create_iter_indices(self): + intpty = self.context.get_value_type(types.intp) + ZERO = ir.Constant(ir.IntType(intpty.width), 0) + + indices = [] + for i in range(self.ndim): + x = cgutils.alloca_once(self.builder, ir.IntType(intpty.width)) + self.builder.store(ZERO, x) + indices.append(x) + return _ArrayIndexingHelper(self, indices) + + def _load_effective_address(self, indices): + return cgutils.get_item_pointer2( + self.context, + self.builder, + data=self.data, + shape=self.shape, + strides=self.strides, + layout=self.layout, + inds=indices, + ) + + def load_data(self, indices): + model = self.context.data_model_manager[self.base_type] + ptr = self._load_effective_address(indices) + return model.load_from_data_pointer(self.builder, ptr) + + def store_data(self, indices, value): + ctx = self.context + bld = self.builder + store_value = ctx.get_value_as_data(bld, self.base_type, value) + assert ctx.get_data_type(self.base_type) == store_value.type + bld.store(store_value, self._load_effective_address(indices)) + + +class _ArrayGUHelper( + namedtuple( + "_ArrayHelper", + ( + "context", + "builder", + "shape", + "strides", + "data", + "layout", + "base_type", + "ndim", + "inner_arr_ty", + "is_input_arg", + ), + ) +): + """Helper class to handle array arguments/result. + It provides methods to generate code loading/storing specific + items as well as support code for handling indices. + + Contrary to _ArrayHelper, this class can create a view to a subarray + """ + + def create_iter_indices(self): + intpty = self.context.get_value_type(types.intp) + ZERO = ir.Constant(ir.IntType(intpty.width), 0) + + indices = [] + for i in range(self.ndim - self.inner_arr_ty.ndim): + x = cgutils.alloca_once(self.builder, ir.IntType(intpty.width)) + self.builder.store(ZERO, x) + indices.append(x) + return _ArrayIndexingHelper(self, indices) + + def _load_effective_address(self, indices): + context = self.context + builder = self.builder + arr_ty = types.Array(self.base_type, self.ndim, self.layout) + arr = context.make_array(arr_ty)(context, builder, self.data) + + return cgutils.get_item_pointer2( + context, + builder, + data=arr.data, + shape=self.shape, + strides=self.strides, + layout=self.layout, + inds=indices, + ) + + def load_data(self, indices): + context, builder = self.context, self.builder + + if self.inner_arr_ty.ndim == 0 and self.is_input_arg: + # scalar case for input arguments + model = context.data_model_manager[self.base_type] + ptr = self._load_effective_address(indices) + return model.load_from_data_pointer(builder, ptr) + elif self.inner_arr_ty.ndim == 0 and not self.is_input_arg: + # Output arrays are handled as 1d with shape=(1,) when its + # signature represents a scalar. For instance: "(n),(m) -> ()" + intpty = context.get_value_type(types.intp) + one = intpty(1) + + fromty = types.Array(self.base_type, self.ndim, self.layout) + toty = types.Array(self.base_type, 1, self.layout) + itemsize = intpty(arrayobj.get_itemsize(context, fromty)) + + # create a view from the original ndarray to a 1d array + arr_from = self.context.make_array(fromty)( + context, builder, self.data + ) + arr_to = self.context.make_array(toty)(context, builder) + arrayobj.populate_array( + arr_to, + data=self._load_effective_address(indices), + shape=cgutils.pack_array(builder, [one]), + strides=cgutils.pack_array(builder, [itemsize]), + itemsize=arr_from.itemsize, + meminfo=arr_from.meminfo, + parent=arr_from.parent, + ) + return arr_to._getvalue() + else: + # generic case + # getitem n-dim array -> m-dim array, where N > M + index_types = (types.int64,) * (self.ndim - self.inner_arr_ty.ndim) + arrty = types.Array(self.base_type, self.ndim, self.layout) + arr = self.context.make_array(arrty)(context, builder, self.data) + res = _getitem_array_generic( + context, + builder, + self.inner_arr_ty, + arrty, + arr, + index_types, + indices, + ) + return impl_ret_borrowed(context, builder, self.inner_arr_ty, res) + + def guard_shape(self, loopshape): + inner_ndim = self.inner_arr_ty.ndim + + def raise_impl(loop_shape, array_shape): + # This would in fact be a test for broadcasting. + # Broadcast would fail if, ignoring the core dimensions, the + # remaining ones are different than indices given by loop shape. + + remaining = len(array_shape) - inner_ndim + _raise = remaining > len(loop_shape) + if not _raise: + for i in range(remaining): + _raise |= array_shape[i] != loop_shape[i] + if _raise: + # Ideally we should call `np.broadcast_shapes` with loop and + # array shapes. But since broadcasting is not supported here, + # we just raise an error + # TODO: check why raising a dynamic exception here fails + raise ValueError("Loop and array shapes are incompatible") + + context, builder = self.context, self.builder + sig = types.none( + types.UniTuple(types.intp, len(loopshape)), + types.UniTuple(types.intp, len(self.shape)), + ) + tup = ( + context.make_tuple(builder, sig.args[0], loopshape), + context.make_tuple(builder, sig.args[1], self.shape), + ) + context.compile_internal(builder, raise_impl, sig, tup) + + def guard_match_core_dims(self, other: "_ArrayGUHelper", ndims: int): + # arguments with the same signature should match their core dimensions + # + # @guvectorize('(n,m), (n,m) -> (n)') + # def foo(x, y, res): + # ... + # + # x and y should have the same core (2D) dimensions + def raise_impl(self_shape, other_shape): + same = True + a, b = len(self_shape) - ndims, len(other_shape) - ndims + for i in range(ndims): + same &= self_shape[a + i] == other_shape[b + i] + if not same: + # NumPy raises the following: + # ValueError: gufunc: Input operand 1 has a mismatch in its + # core dimension 0, with gufunc signature (n),(n) -> () + # (size 3 is different from 2) + # But since we cannot raise a dynamic exception here, we just + # (try) something meaninful + msg = ( + "Operand has a mismatch in one of its core dimensions. " + "Please, check if all arguments to a @guvectorize " + "function have the same core dimensions." + ) + raise ValueError(msg) + + context, builder = self.context, self.builder + sig = types.none( + types.UniTuple(types.intp, len(self.shape)), + types.UniTuple(types.intp, len(other.shape)), + ) + tup = ( + context.make_tuple(builder, sig.args[0], self.shape), + context.make_tuple(builder, sig.args[1], other.shape), + ) + context.compile_internal(builder, raise_impl, sig, tup) + + +def _prepare_argument(ctxt, bld, inp, tyinp, where="input operand"): + """returns an instance of the appropriate Helper (either + _ScalarHelper or _ArrayHelper) class to handle the argument. + using the polymorphic interface of the Helper classes, scalar + and array cases can be handled with the same code""" + + # first un-Optional Optionals + if isinstance(tyinp, types.Optional): + oty = tyinp + tyinp = tyinp.type + inp = ctxt.cast(bld, inp, oty, tyinp) + + # then prepare the arg for a concrete instance + if isinstance(tyinp, types.ArrayCompatible): + ary = ctxt.make_array(tyinp)(ctxt, bld, inp) + shape = cgutils.unpack_tuple(bld, ary.shape, tyinp.ndim) + strides = cgutils.unpack_tuple(bld, ary.strides, tyinp.ndim) + return _ArrayHelper( + ctxt, + bld, + shape, + strides, + ary.data, + tyinp.layout, + tyinp.dtype, + tyinp.ndim, + inp, + ) + elif types.unliteral(tyinp) in types.number_domain | { + types.boolean + } or isinstance(tyinp, types.scalars._NPDatetimeBase): + return _ScalarHelper(ctxt, bld, inp, tyinp) + else: + raise NotImplementedError( + "unsupported type for {0}: {1}".format(where, str(tyinp)) + ) + + +_broadcast_onto_sig = types.intp( + types.intp, + types.CPointer(types.intp), + types.intp, + types.CPointer(types.intp), +) + + +def _broadcast_onto(src_ndim, src_shape, dest_ndim, dest_shape): + """Low-level utility function used in calculating a shape for + an implicit output array. This function assumes that the + destination shape is an LLVM pointer to a C-style array that was + already initialized to a size of one along all axes. + + Returns an integer value: + >= 1 : Succeeded. Return value should equal the number of dimensions in + the destination shape. + 0 : Failed to broadcast because source shape is larger than the + destination shape (this case should be weeded out at type + checking). + < 0 : Failed to broadcast onto destination axis, at axis number == + -(return_value + 1). + """ + if src_ndim > dest_ndim: + # This check should have been done during type checking, but + # let's be defensive anyway... + return 0 + else: + src_index = 0 + dest_index = dest_ndim - src_ndim + while src_index < src_ndim: + src_dim_size = src_shape[src_index] + dest_dim_size = dest_shape[dest_index] + # Check to see if we've already mutated the destination + # shape along this axis. + if dest_dim_size != 1: + # If we have mutated the destination shape already, + # then the source axis size must either be one, + # or the destination axis size. + if src_dim_size != dest_dim_size and src_dim_size != 1: + return -(dest_index + 1) + elif src_dim_size != 1: + # If the destination size is still its initial + dest_shape[dest_index] = src_dim_size + src_index += 1 + dest_index += 1 + return dest_index + + +def _build_array(context, builder, array_ty, input_types, inputs): + """Utility function to handle allocation of an implicit output array + given the target context, builder, output array type, and a list of + _ArrayHelper instances. + """ + # First, strip optional types, ufunc loops are typed on concrete types + input_types = [ + x.type if isinstance(x, types.Optional) else x for x in input_types + ] + + intp_ty = context.get_value_type(types.intp) + + def make_intp_const(val): + return context.get_constant(types.intp, val) + + ZERO = make_intp_const(0) # noqa: F841 + ONE = make_intp_const(1) + + src_shape = cgutils.alloca_once( + builder, intp_ty, array_ty.ndim, "src_shape" + ) + dest_ndim = make_intp_const(array_ty.ndim) + dest_shape = cgutils.alloca_once( + builder, intp_ty, array_ty.ndim, "dest_shape" + ) + dest_shape_addrs = tuple( + cgutils.gep_inbounds(builder, dest_shape, index) + for index in range(array_ty.ndim) + ) + + # Initialize the destination shape with all ones. + for dest_shape_addr in dest_shape_addrs: + builder.store(ONE, dest_shape_addr) + + # For each argument, try to broadcast onto the destination shape, + # mutating along any axis where the argument shape is not one and + # the destination shape is one. + for arg_number, arg in enumerate(inputs): + if not hasattr(arg, "ndim"): # Skip scalar arguments + continue + arg_ndim = make_intp_const(arg.ndim) + for index in range(arg.ndim): + builder.store( + arg.shape[index], + cgutils.gep_inbounds(builder, src_shape, index), + ) + arg_result = context.compile_internal( + builder, + _broadcast_onto, + _broadcast_onto_sig, + [arg_ndim, src_shape, dest_ndim, dest_shape], + ) + with cgutils.if_unlikely( + builder, builder.icmp_signed("<", arg_result, ONE) + ): + msg = "unable to broadcast argument %d to output array" % ( + arg_number, + ) + + loc = errors.loc_info.get("loc", None) + if loc is not None: + msg += '\nFile "%s", line %d, ' % (loc.filename, loc.line) + + context.call_conv.return_user_exc(builder, ValueError, (msg,)) + + real_array_ty = array_ty.as_array + + dest_shape_tup = tuple( + builder.load(dest_shape_addr) for dest_shape_addr in dest_shape_addrs + ) + array_val = arrayobj._empty_nd_impl( + context, builder, real_array_ty, dest_shape_tup + ) + + # Get the best argument to call __array_wrap__ on + array_wrapper_index = select_array_wrapper(input_types) + array_wrapper_ty = input_types[array_wrapper_index] + try: + # __array_wrap__(source wrapped array, out array) -> out wrapped array + array_wrap = context.get_function( + "__array_wrap__", array_ty(array_wrapper_ty, real_array_ty) + ) + except NotImplementedError: + # If it's the same priority as a regular array, assume we + # should use the allocated array unchanged. + if array_wrapper_ty.array_priority != types.Array.array_priority: + raise + out_val = array_val._getvalue() + else: + wrap_args = ( + inputs[array_wrapper_index].return_val, + array_val._getvalue(), + ) + out_val = array_wrap(builder, wrap_args) + + ndim = array_ty.ndim + shape = cgutils.unpack_tuple(builder, array_val.shape, ndim) + strides = cgutils.unpack_tuple(builder, array_val.strides, ndim) + return _ArrayHelper( + context, + builder, + shape, + strides, + array_val.data, + array_ty.layout, + array_ty.dtype, + ndim, + out_val, + ) + + +# ufuncs either return a single result when nout == 1, else a tuple of results + + +def _unpack_output_types(ufunc, sig): + if ufunc.nout == 1: + return [sig.return_type] + else: + return list(sig.return_type) + + +def _unpack_output_values(ufunc, builder, values): + if ufunc.nout == 1: + return [values] + else: + return cgutils.unpack_tuple(builder, values) + + +def _pack_output_values(ufunc, context, builder, typ, values): + if ufunc.nout == 1: + return values[0] + else: + return context.make_tuple(builder, typ, values) + + +def numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel_class): + # This is the code generator that builds all the looping needed + # to execute a numpy functions over several dimensions (including + # scalar cases). + # + # context - the code generation context + # builder - the code emitter + # sig - signature of the ufunc + # args - the args to the ufunc + # ufunc - the ufunc itself + # kernel_class - a code generating subclass of _Kernel that provides + + arguments = [ + _prepare_argument(context, builder, arg, tyarg) + for arg, tyarg in zip(args, sig.args) + ] + + if len(arguments) < ufunc.nin: + raise RuntimeError( + "Not enough inputs to {}, expected {} got {}".format( + ufunc.__name__, ufunc.nin, len(arguments) + ) + ) + + for out_i, ret_ty in enumerate(_unpack_output_types(ufunc, sig)): + if ufunc.nin + out_i >= len(arguments): + # this out argument is not provided + if isinstance(ret_ty, types.ArrayCompatible): + output = _build_array( + context, builder, ret_ty, sig.args, arguments + ) + else: + output = _prepare_argument( + context, + builder, + ir.Constant(context.get_value_type(ret_ty), None), + ret_ty, + ) + arguments.append(output) + elif context.enable_nrt: + # Incref the output + context.nrt.incref(builder, ret_ty, args[ufunc.nin + out_i]) + + inputs = arguments[: ufunc.nin] + outputs = arguments[ufunc.nin :] + assert len(outputs) == ufunc.nout + + outer_sig = _ufunc_loop_sig( + [a.base_type for a in outputs], [a.base_type for a in inputs] + ) + kernel = kernel_class(context, builder, outer_sig) + intpty = context.get_value_type(types.intp) + + indices = [inp.create_iter_indices() for inp in inputs] + + # assume outputs are all the same size, which numpy requires + + loopshape = outputs[0].shape + + # count the number of C and F layout arrays, respectively + input_layouts = [ + inp.layout for inp in inputs if isinstance(inp, _ArrayHelper) + ] + num_c_layout = len([x for x in input_layouts if x == "C"]) + num_f_layout = len([x for x in input_layouts if x == "F"]) + + # Only choose F iteration order if more arrays are in F layout. + # Default to C order otherwise. + # This is a best effort for performance. NumPy has more fancy logic that + # uses array iterators in non-trivial cases. + if num_f_layout > num_c_layout: + order = "F" + else: + order = "C" + + with cgutils.loop_nest( + builder, loopshape, intp=intpty, order=order + ) as loop_indices: + vals_in = [] + for i, (index, arg) in enumerate(zip(indices, inputs)): + index.update_indices(loop_indices, i) + vals_in.append(arg.load_data(index.as_values())) + + vals_out = _unpack_output_values( + ufunc, builder, kernel.generate(*vals_in) + ) + for val_out, output in zip(vals_out, outputs): + output.store_data(loop_indices, val_out) + + out = _pack_output_values( + ufunc, + context, + builder, + sig.return_type, + [o.return_val for o in outputs], + ) + return impl_ret_new_ref(context, builder, sig.return_type, out) + + +def numpy_gufunc_kernel(context, builder, sig, args, ufunc, kernel_class): + arguments = [] + expected_ndims = kernel_class.dufunc.expected_ndims() + expected_ndims = expected_ndims[0] + expected_ndims[1] + is_input = [True] * ufunc.nin + [False] * ufunc.nout + for arg, ty, exp_ndim, is_inp in zip( + args, sig.args, expected_ndims, is_input + ): # noqa: E501 + if isinstance(ty, types.ArrayCompatible): + # Create an array helper that iteration returns a subarray + # with ndim specified by "exp_ndim" + arr = context.make_array(ty)(context, builder, arg) + shape = cgutils.unpack_tuple(builder, arr.shape, ty.ndim) + strides = cgutils.unpack_tuple(builder, arr.strides, ty.ndim) + inner_arr_ty = ty.copy(ndim=exp_ndim) + ndim = ty.ndim + layout = ty.layout + base_type = ty.dtype + array_helper = _ArrayGUHelper( + context, + builder, + shape, + strides, + arg, + layout, + base_type, + ndim, + inner_arr_ty, + is_inp, + ) + arguments.append(array_helper) + else: + scalar_helper = _ScalarHelper(context, builder, arg, ty) + arguments.append(scalar_helper) + kernel = kernel_class(context, builder, sig) + + layouts = [ + arg.layout for arg in arguments if isinstance(arg, _ArrayGUHelper) + ] + num_c_layout = len([x for x in layouts if x == "C"]) + num_f_layout = len([x for x in layouts if x == "F"]) + + # Only choose F iteration order if more arrays are in F layout. + # Default to C order otherwise. + # This is a best effort for performance. NumPy has more fancy logic that + # uses array iterators in non-trivial cases. + if num_f_layout > num_c_layout: + order = "F" + else: + order = "C" + + outputs = arguments[ufunc.nin :] + intpty = context.get_value_type(types.intp) + indices = [inp.create_iter_indices() for inp in arguments] + loopshape_ndim = outputs[0].ndim - outputs[0].inner_arr_ty.ndim + loopshape = outputs[0].shape[:loopshape_ndim] + + _sig = parse_signature(ufunc.gufunc_builder.signature) + for (idx_a, sig_a), (idx_b, sig_b) in itertools.combinations( + zip(range(len(arguments)), _sig[0] + _sig[1]), r=2 + ): + # For each pair of arguments, both inputs and outputs, must match their + # inner dimensions if their signatures are the same. + arg_a, arg_b = arguments[idx_a], arguments[idx_b] + if sig_a == sig_b and all( + isinstance(x, _ArrayGUHelper) for x in (arg_a, arg_b) + ): + arg_a, arg_b = arguments[idx_a], arguments[idx_b] + arg_a.guard_match_core_dims(arg_b, len(sig_a)) + + for arg in arguments[: ufunc.nin]: + if isinstance(arg, _ArrayGUHelper): + arg.guard_shape(loopshape) + + with cgutils.loop_nest( + builder, loopshape, intp=intpty, order=order + ) as loop_indices: + vals_in = [] + for i, (index, arg) in enumerate(zip(indices, arguments)): + index.update_indices(loop_indices, i) + vals_in.append(arg.load_data(index.as_values())) + + kernel.generate(*vals_in) + + +# Kernels are the code to be executed inside the multidimensional loop. +class _Kernel(object): + def __init__(self, context, builder, outer_sig): + self.context = context + self.builder = builder + self.outer_sig = outer_sig + + def cast(self, val, fromty, toty): + """Numpy uses cast semantics that are different from standard Python + (for example, it does allow casting from complex to float). + + This method acts as a patch to context.cast so that it allows + complex to real/int casts. + + """ + if isinstance(fromty, types.Complex) and not isinstance( + toty, types.Complex + ): + # attempt conversion of the real part to the specified type. + # note that NumPy issues a warning in this kind of conversions + newty = fromty.underlying_float + attr = self.context.get_getattr(fromty, "real") + val = attr(self.context, self.builder, fromty, val, "real") + fromty = newty + # let the regular cast do the rest... + + return self.context.cast(self.builder, val, fromty, toty) + + def generate(self, *args): + isig = self.inner_sig + osig = self.outer_sig + cast_args = [ + self.cast(val, inty, outty) + for val, inty, outty in zip(args, osig.args, isig.args) + ] + if self.cres.objectmode: + func_type = self.context.call_conv.get_function_type( + types.pyobject, [types.pyobject] * len(isig.args) + ) + else: + func_type = self.context.call_conv.get_function_type( + isig.return_type, isig.args + ) + module = self.builder.block.function.module + entry_point = cgutils.get_or_insert_function( + module, func_type, self.cres.fndesc.llvm_func_name + ) + entry_point.attributes.add("alwaysinline") + + _, res = self.context.call_conv.call_function( + self.builder, entry_point, isig.return_type, isig.args, cast_args + ) + return self.cast(res, isig.return_type, osig.return_type) + + +def _ufunc_db_function(ufunc): + """Use the ufunc loop type information to select the code generation + function from the table provided by the dict_of_kernels. The dict + of kernels maps the loop identifier to a function with the + following signature: (context, builder, signature, args). + + The loop type information has the form 'AB->C'. The letters to the + left of '->' are the input types (specified as NumPy letter + types). The letters to the right of '->' are the output + types. There must be 'ufunc.nin' letters to the left of '->', and + 'ufunc.nout' letters to the right. + + For example, a binary float loop resulting in a float, will have + the following signature: 'ff->f'. + + A given ufunc implements many loops. The list of loops implemented + for a given ufunc can be accessed using the 'types' attribute in + the ufunc object. The NumPy machinery selects the first loop that + fits a given calling signature (in our case, what we call the + outer_sig). This logic is mimicked by 'ufunc_find_matching_loop'. + """ + + class _KernelImpl(_Kernel): + def __init__(self, context, builder, outer_sig): + super(_KernelImpl, self).__init__(context, builder, outer_sig) + loop = ufunc_find_matching_loop( + ufunc, + outer_sig.args + tuple(_unpack_output_types(ufunc, outer_sig)), + ) + self.fn = context.get_ufunc_info(ufunc).get(loop.ufunc_sig) + self.inner_sig = _ufunc_loop_sig(loop.outputs, loop.inputs) + + if self.fn is None: + msg = "Don't know how to lower ufunc '{0}' for loop '{1}'" + raise NotImplementedError(msg.format(ufunc.__name__, loop)) + + def generate(self, *args): + isig = self.inner_sig + osig = self.outer_sig + + cast_args = [ + self.cast(val, inty, outty) + for val, inty, outty in zip(args, osig.args, isig.args) + ] + with force_error_model(self.context, "numpy"): + res = self.fn(self.context, self.builder, isig, cast_args) + dmm = self.context.data_model_manager + res = dmm[isig.return_type].from_return(self.builder, res) + return self.cast(res, isig.return_type, osig.return_type) + + return _KernelImpl + + +################################################################################ +# Helper functions that register the ufuncs + + +def register_ufunc_kernel(ufunc, kernel, lower): + def do_ufunc(context, builder, sig, args): + return numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel) + + _any = types.Any + in_args = (_any,) * ufunc.nin + + # Add a lowering for each out argument that is missing. + for n_explicit_out in range(ufunc.nout + 1): + out_args = (types.Array,) * n_explicit_out + lower(ufunc, *in_args, *out_args)(do_ufunc) + + return kernel + + +def register_unary_operator_kernel( + operator, ufunc, kernel, lower, inplace=False +): + assert not inplace # are there any inplace unary operators? + + def lower_unary_operator(context, builder, sig, args): + return numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel) + + _arr_kind = types.Array + lower(operator, _arr_kind)(lower_unary_operator) + + +def register_binary_operator_kernel(op, ufunc, kernel, lower, inplace=False): + def lower_binary_operator(context, builder, sig, args): + return numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel) + + def lower_inplace_operator(context, builder, sig, args): + # The visible signature is (A, B) -> A + # The implementation's signature (with explicit output) + # is (A, B, A) -> A + args = tuple(args) + (args[0],) + sig = typing.signature(sig.return_type, *sig.args + (sig.args[0],)) + return numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel) + + _any = types.Any + _arr_kind = types.Array + formal_sigs = [(_arr_kind, _arr_kind), (_any, _arr_kind), (_arr_kind, _any)] + for sig in formal_sigs: + if not inplace: + lower(op, *sig)(lower_binary_operator) + else: + lower(op, *sig)(lower_inplace_operator) + + +################################################################################ +# Use the contents of ufunc_db to initialize the supported ufuncs + + +@registry.lower(operator.pos, types.Array) +def array_positive_impl(context, builder, sig, args): + """Lowering function for +(array) expressions. Defined here + (numba.targets.npyimpl) since the remaining array-operator + lowering functions are also registered in this module. + """ + + class _UnaryPositiveKernel(_Kernel): + def generate(self, *args): + [val] = args + return val + + return numpy_ufunc_kernel( + context, builder, sig, args, np.positive, _UnaryPositiveKernel + ) + + +def register_ufuncs(ufuncs, lower): + kernels = {} + for ufunc in ufuncs: + db_func = _ufunc_db_function(ufunc) + kernels[ufunc] = register_ufunc_kernel(ufunc, db_func, lower) + + for _op_map in ( + npydecl.NumpyRulesUnaryArrayOperator._op_map, + npydecl.NumpyRulesArrayOperator._op_map, + ): + for op, ufunc_name in _op_map.items(): + ufunc = getattr(np, ufunc_name) + kernel = kernels[ufunc] + if ufunc.nin == 1: + register_unary_operator_kernel(op, ufunc, kernel, lower) + elif ufunc.nin == 2: + register_binary_operator_kernel(op, ufunc, kernel, lower) + else: + raise RuntimeError( + "There shouldn't be any non-unary or binary operators" + ) + + for _op_map in (npydecl.NumpyRulesInplaceArrayOperator._op_map,): + for op, ufunc_name in _op_map.items(): + ufunc = getattr(np, ufunc_name) + kernel = kernels[ufunc] + if ufunc.nin == 1: + register_unary_operator_kernel( + op, ufunc, kernel, lower, inplace=True + ) + elif ufunc.nin == 2: + register_binary_operator_kernel( + op, ufunc, kernel, lower, inplace=True + ) + else: + raise RuntimeError( + "There shouldn't be any non-unary or binary operators" + ) + + +register_ufuncs(ufunc_db.get_ufuncs(), registry.lower) + + +@intrinsic +def _make_dtype_object(typingctx, desc): + """Given a string or NumberClass description *desc*, returns the dtype object.""" + + def from_nb_type(nb_type): + return_type = types.DType(nb_type) + sig = return_type(desc) + + def codegen(context, builder, signature, args): + # All dtype objects are dummy values in LLVM. + # They only exist in the type level. + return context.get_dummy_value() + + return sig, codegen + + if isinstance(desc, types.Literal): + # Convert the str description into np.dtype then to numba type. + nb_type = from_dtype(np.dtype(desc.literal_value)) + return from_nb_type(nb_type) + elif isinstance(desc, types.functions.NumberClass): + thestr = str(desc.dtype) + # Convert the str description into np.dtype then to numba type. + nb_type = from_dtype(np.dtype(thestr)) + return from_nb_type(nb_type) + + +@overload(np.dtype) +def numpy_dtype(desc): + """Provide an implementation so that numpy.dtype function can be lowered.""" + if isinstance(desc, (types.Literal, types.functions.NumberClass)): + + def imp(desc): + return _make_dtype_object(desc) + + return imp + else: + raise errors.NumbaTypeError("unknown dtype descriptor: {}".format(desc)) diff --git a/numba_cuda/numba/cuda/np/numpy_support.py b/numba_cuda/numba/cuda/np/numpy_support.py index 4fe3f6f54..5f3ffed80 100644 --- a/numba_cuda/numba/cuda/np/numpy_support.py +++ b/numba_cuda/numba/cuda/np/numpy_support.py @@ -2,15 +2,20 @@ # SPDX-License-Identifier: BSD-2-Clause import collections -import numpy as np +import ctypes import re -from numba.core import types, errors +import numpy as np + +from numba.core import errors, types from numba.cuda.typing.templates import signature from numba.cuda.np import npdatetime_helpers +from numba.core.errors import TypingError -numpy_version = tuple(map(int, np.__version__.split(".")[:2])) +# re-export +from numba.cuda.cgutils import is_nonelike # noqa: F401 +numpy_version = tuple(map(int, np.__version__.split(".")[:2])) FROM_DTYPE = { np.dtype("bool"): types.boolean, @@ -30,7 +35,6 @@ np.dtype(object): types.pyobject, } - re_typestr = re.compile(r"[<>=\|]([a-z])(\d+)?$", re.I) re_datetimestr = re.compile(r"[<>=\|]([mM])8?(\[([a-z]+)\])?$", re.I) @@ -117,6 +121,43 @@ def from_dtype(dtype): } +def as_dtype(nbtype): + """ + Return a numpy dtype instance corresponding to the given Numba type. + NotImplementedError is if no correspondence is known. + """ + nbtype = types.unliteral(nbtype) + if isinstance(nbtype, (types.Complex, types.Integer, types.Float)): + return np.dtype(str(nbtype)) + if isinstance(nbtype, (types.Boolean)): + return np.dtype("?") + if isinstance(nbtype, (types.NPDatetime, types.NPTimedelta)): + letter = _as_dtype_letters[type(nbtype)] + if nbtype.unit: + return np.dtype("%s[%s]" % (letter, nbtype.unit)) + else: + return np.dtype(letter) + if isinstance(nbtype, (types.CharSeq, types.UnicodeCharSeq)): + letter = _as_dtype_letters[type(nbtype)] + return np.dtype("%s%d" % (letter, nbtype.count)) + if isinstance(nbtype, types.Record): + return as_struct_dtype(nbtype) + if isinstance(nbtype, types.EnumMember): + return as_dtype(nbtype.dtype) + if isinstance(nbtype, types.npytypes.DType): + return as_dtype(nbtype.dtype) + if isinstance(nbtype, types.NumberClass): + return as_dtype(nbtype.dtype) + if isinstance(nbtype, types.NestedArray): + spec = (as_dtype(nbtype.dtype), tuple(nbtype.shape)) + return np.dtype(spec) + if isinstance(nbtype, types.PyObject): + return np.dtype(object) + + msg = f"{nbtype} cannot be represented as a NumPy dtype" + raise errors.NumbaNotImplementedError(msg) + + def as_struct_dtype(rec): """Convert Numba Record type to NumPy structured dtype""" assert isinstance(rec, types.Record) @@ -158,41 +199,33 @@ def _check_struct_alignment(rec, fields): raise ValueError(msg.format(npy_align, llvm_align, dt)) -def as_dtype(nbtype): - """ - Return a numpy dtype instance corresponding to the given Numba type. - NotImplementedError is if no correspondence is known. - """ - nbtype = types.unliteral(nbtype) - if isinstance(nbtype, (types.Complex, types.Integer, types.Float)): - return np.dtype(str(nbtype)) - if isinstance(nbtype, (types.Boolean)): - return np.dtype("?") - if isinstance(nbtype, (types.NPDatetime, types.NPTimedelta)): - letter = _as_dtype_letters[type(nbtype)] - if nbtype.unit: - return np.dtype("%s[%s]" % (letter, nbtype.unit)) - else: - return np.dtype(letter) - if isinstance(nbtype, (types.CharSeq, types.UnicodeCharSeq)): - letter = _as_dtype_letters[type(nbtype)] - return np.dtype("%s%d" % (letter, nbtype.count)) - if isinstance(nbtype, types.Record): - return as_struct_dtype(nbtype) - if isinstance(nbtype, types.EnumMember): - return as_dtype(nbtype.dtype) - if isinstance(nbtype, types.npytypes.DType): - return as_dtype(nbtype.dtype) - if isinstance(nbtype, types.NumberClass): - return as_dtype(nbtype.dtype) - if isinstance(nbtype, types.NestedArray): - spec = (as_dtype(nbtype.dtype), tuple(nbtype.shape)) - return np.dtype(spec) - if isinstance(nbtype, types.PyObject): - return np.dtype(object) +def map_arrayscalar_type(val): + if isinstance(val, np.generic): + # We can't blindly call np.dtype() as it loses information + # on some types, e.g. datetime64 and timedelta64. + dtype = val.dtype + else: + try: + dtype = np.dtype(type(val)) + except TypeError: + raise errors.NumbaNotImplementedError( + "no corresponding numpy dtype for %r" % type(val) + ) + return from_dtype(dtype) - msg = f"{nbtype} cannot be represented as a NumPy dtype" - raise errors.NumbaNotImplementedError(msg) + +def is_array(val): + return isinstance(val, np.ndarray) + + +def map_layout(val): + if val.flags["C_CONTIGUOUS"]: + layout = "C" + elif val.flags["F_CONTIGUOUS"]: + layout = "F" + else: + layout = "A" + return layout def select_array_wrapper(inputs): @@ -255,7 +288,7 @@ def supported_ufunc_loop(ufunc, loop): as it allows for a more fine-grained incremental support. """ # NOTE: Assuming ufunc for the CPUContext - from numba.np import ufunc_db + from numba.cuda.np import ufunc_db loop_sig = loop.ufunc_sig try: @@ -415,7 +448,7 @@ def make_datetime_specific(outputs, dt_unit, td_unit): dt_unit, td_unit ) if unit is None: - raise errors.TypingError( + raise TypingError( f"ufunc '{ufunc_name}' is not " + "supported between " + f"datetime64[{dt_unit}] " @@ -551,3 +584,215 @@ def from_struct_dtype(dtype): aligned = _is_aligned_struct(dtype) return types.Record(fields, size, aligned) + + +def _get_bytes_buffer(ptr, nbytes): + """ + Get a ctypes array of *nbytes* starting at *ptr*. + """ + if isinstance(ptr, ctypes.c_void_p): + ptr = ptr.value + arrty = ctypes.c_byte * nbytes + return arrty.from_address(ptr) + + +def _get_array_from_ptr(ptr, nbytes, dtype): + return np.frombuffer(_get_bytes_buffer(ptr, nbytes), dtype) + + +def carray(ptr, shape, dtype=None): + """ + Return a Numpy array view over the data pointed to by *ptr* with the + given *shape*, in C order. If *dtype* is given, it is used as the + array's dtype, otherwise the array's dtype is inferred from *ptr*'s type. + """ + from numba.core.typing.ctypes_utils import from_ctypes + + try: + # Use ctypes parameter protocol if available + ptr = ptr._as_parameter_ + except AttributeError: + pass + + # Normalize dtype, to accept e.g. "int64" or np.int64 + if dtype is not None: + dtype = np.dtype(dtype) + + if isinstance(ptr, ctypes.c_void_p): + if dtype is None: + raise TypeError("explicit dtype required for void* argument") + p = ptr + elif isinstance(ptr, ctypes._Pointer): + ptrty = from_ctypes(ptr.__class__) + assert isinstance(ptrty, types.CPointer) + ptr_dtype = as_dtype(ptrty.dtype) + if dtype is not None and dtype != ptr_dtype: + raise TypeError( + "mismatching dtype '%s' for pointer %s" % (dtype, ptr) + ) + dtype = ptr_dtype + p = ctypes.cast(ptr, ctypes.c_void_p) + else: + raise TypeError("expected a ctypes pointer, got %r" % (ptr,)) + + nbytes = dtype.itemsize * np.prod(shape, dtype=np.intp) + return _get_array_from_ptr(p, nbytes, dtype).reshape(shape) + + +def farray(ptr, shape, dtype=None): + """ + Return a Numpy array view over the data pointed to by *ptr* with the + given *shape*, in Fortran order. If *dtype* is given, it is used as the + array's dtype, otherwise the array's dtype is inferred from *ptr*'s type. + """ + if not isinstance(shape, int): + shape = shape[::-1] + return carray(ptr, shape, dtype).T + + +def is_contiguous(dims, strides, itemsize): + """Is the given shape, strides, and itemsize of C layout? + + Note: The code is usable as a numba-compiled function + """ + nd = len(dims) + # Check and skip 1s or 0s in inner dims + innerax = nd - 1 + while innerax > -1 and dims[innerax] <= 1: + innerax -= 1 + + # Early exit if all axis are 1s or 0s + if innerax < 0: + return True + + # Check itemsize matches innermost stride + if itemsize != strides[innerax]: + return False + + # Check and skip 1s or 0s in outer dims + outerax = 0 + while outerax < innerax and dims[outerax] <= 1: + outerax += 1 + + # Check remaining strides to be contiguous + ax = innerax + while ax > outerax: + if strides[ax] * dims[ax] != strides[ax - 1]: + return False + ax -= 1 + return True + + +def is_fortran(dims, strides, itemsize): + """Is the given shape, strides, and itemsize of F layout? + + Note: The code is usable as a numba-compiled function + """ + nd = len(dims) + # Check and skip 1s or 0s in inner dims + firstax = 0 + while firstax < nd and dims[firstax] <= 1: + firstax += 1 + + # Early exit if all axis are 1s or 0s + if firstax >= nd: + return True + + # Check itemsize matches innermost stride + if itemsize != strides[firstax]: + return False + + # Check and skip 1s or 0s in outer dims + lastax = nd - 1 + while lastax > firstax and dims[lastax] <= 1: + lastax -= 1 + + # Check remaining strides to be contiguous + ax = firstax + while ax < lastax: + if strides[ax] * dims[ax] != strides[ax + 1]: + return False + ax += 1 + return True + + +def type_can_asarray(arr): + """Returns True if the type of 'arr' is supported by the Numba `np.asarray` + implementation, False otherwise. + """ + + ok = ( + types.Array, + types.Sequence, + types.Tuple, + types.StringLiteral, + types.Number, + types.Boolean, + types.containers.ListType, + ) + + return isinstance(arr, ok) + + +def type_is_scalar(typ): + """Returns True if the type of 'typ' is a scalar type, according to + NumPy rules. False otherwise. + https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types + """ + + ok = ( + types.Boolean, + types.Number, + types.UnicodeType, + types.StringLiteral, + types.NPTimedelta, + types.NPDatetime, + ) + return isinstance(typ, ok) + + +def check_is_integer(v, name): + """Raises TypingError if the value is not an integer.""" + if not isinstance(v, (int, types.Integer)): + raise TypingError("{} must be an integer".format(name)) + + +def lt_floats(a, b): + # Adapted from NumPy commit 717c7acf which introduced the behavior of + # putting NaNs at the end. + # The code is later moved to numpy/core/src/npysort/npysort_common.h + # This info is gathered as of NumPy commit d8c09c50 + return a < b or (np.isnan(b) and not np.isnan(a)) + + +def lt_complex(a, b): + if np.isnan(a.real): + if np.isnan(b.real): + if np.isnan(a.imag): + return False + else: + if np.isnan(b.imag): + return True + else: + return a.imag < b.imag + else: + return False + + else: + if np.isnan(b.real): + return True + else: + if np.isnan(a.imag): + if np.isnan(b.imag): + return a.real < b.real + else: + return False + else: + if np.isnan(b.imag): + return True + else: + if a.real < b.real: + return True + elif a.real == b.real: + return a.imag < b.imag + return False diff --git a/numba_cuda/numba/cuda/np/polynomial/__init__.py b/numba_cuda/numba/cuda/np/polynomial/__init__.py new file mode 100644 index 000000000..72e59bd84 --- /dev/null +++ b/numba_cuda/numba/cuda/np/polynomial/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from .polynomial_core import registry # noqa: F401 diff --git a/numba_cuda/numba/cuda/np/polynomial/polynomial_core.py b/numba_cuda/numba/cuda/np/polynomial/polynomial_core.py new file mode 100644 index 000000000..a86777e3a --- /dev/null +++ b/numba_cuda/numba/cuda/np/polynomial/polynomial_core.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from numba.cuda.extending import ( + core_models, + register_model, + type_callable, + unbox, + NativeValue, + make_attribute_wrapper, + box, +) +from numba.core import types +from numba.cuda import cgutils +import warnings +from numba.core.errors import NumbaExperimentalFeatureWarning, NumbaValueError +from numpy.polynomial.polynomial import Polynomial +from contextlib import ExitStack +import numpy as np +from llvmlite import ir +from numba.core.imputils import Registry + +registry = Registry("np.polynomial_core") +lower = registry.lower + + +@register_model(types.PolynomialType) +class PolynomialModel(core_models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("coef", fe_type.coef), + ("domain", fe_type.domain), + ("window", fe_type.window), + # Introduced in NumPy 1.24, maybe leave it out for now + # ('symbol', types.string) + ] + super(PolynomialModel, self).__init__(dmm, fe_type, members) + + +@type_callable(Polynomial) +def type_polynomial(context): + def typer(coef, domain=None, window=None): + default_domain = types.Array(types.int64, 1, "C") + double_domain = types.Array(types.double, 1, "C") + default_window = types.Array(types.int64, 1, "C") + double_window = types.Array(types.double, 1, "C") + double_coef = types.Array(types.double, 1, "C") + + warnings.warn( + "Polynomial class is experimental", + category=NumbaExperimentalFeatureWarning, + ) + + if isinstance(coef, types.Array) and all( + [a is None for a in (domain, window)] + ): + if coef.ndim == 1: + # If Polynomial(coef) is called, coef is cast to double dtype, + # and domain and window are set to equal [-1, 1], i.e. have + # integer dtype + return types.PolynomialType( + double_coef, default_domain, default_window, 1 + ) + else: + msg = "Coefficient array is not 1-d" + raise NumbaValueError(msg) + elif all([isinstance(a, types.Array) for a in (coef, domain, window)]): + if coef.ndim == 1: + if all([a.ndim == 1 for a in (domain, window)]): + # If Polynomial(coef, domain, window) is called, then coef, + # domain and window are cast to double dtype + return types.PolynomialType( + double_coef, double_domain, double_window, 3 + ) + else: + msg = "Coefficient array is not 1-d" + raise NumbaValueError(msg) + + return typer + + +make_attribute_wrapper(types.PolynomialType, "coef", "coef") +make_attribute_wrapper(types.PolynomialType, "domain", "domain") +make_attribute_wrapper(types.PolynomialType, "window", "window") +# Introduced in NumPy 1.24, maybe leave it out for now +# make_attribute_wrapper(types.PolynomialType, 'symbol', 'symbol') + + +@lower(Polynomial, types.Array) +def impl_polynomial1(context, builder, sig, args): + def to_double(arr): + return np.asarray(arr, dtype=np.double) + + def const_impl(): + return np.asarray([-1, 1]) + + typ = sig.return_type + polynomial = cgutils.create_struct_proxy(typ)(context, builder) + sig_coef = sig.args[0].copy(dtype=types.double)(sig.args[0]) + coef_cast = context.compile_internal(builder, to_double, sig_coef, args) + sig_domain = sig.args[0].copy(dtype=types.intp)() + sig_window = sig.args[0].copy(dtype=types.intp)() + domain_cast = context.compile_internal(builder, const_impl, sig_domain, ()) + window_cast = context.compile_internal(builder, const_impl, sig_window, ()) + polynomial.coef = coef_cast + polynomial.domain = domain_cast + polynomial.window = window_cast + + return polynomial._getvalue() + + +@lower(Polynomial, types.Array, types.Array, types.Array) +def impl_polynomial3(context, builder, sig, args): + def to_double(coef): + return np.asarray(coef, dtype=np.double) + + typ = sig.return_type + polynomial = cgutils.create_struct_proxy(typ)(context, builder) + + coef_sig = sig.args[0].copy(dtype=types.double)(sig.args[0]) + domain_sig = sig.args[1].copy(dtype=types.double)(sig.args[1]) + window_sig = sig.args[2].copy(dtype=types.double)(sig.args[2]) + coef_cast = context.compile_internal( + builder, to_double, coef_sig, (args[0],) + ) + domain_cast = context.compile_internal( + builder, to_double, domain_sig, (args[1],) + ) + window_cast = context.compile_internal( + builder, to_double, window_sig, (args[2],) + ) + + domain_helper = context.make_helper( + builder, domain_sig.return_type, value=domain_cast + ) + window_helper = context.make_helper( + builder, window_sig.return_type, value=window_cast + ) + + i64 = ir.IntType(64) + two = i64(2) + + s1 = builder.extract_value(domain_helper.shape, 0) + s2 = builder.extract_value(window_helper.shape, 0) + pred1 = builder.icmp_signed("!=", s1, two) + pred2 = builder.icmp_signed("!=", s2, two) + + with cgutils.if_unlikely(builder, pred1): + context.call_conv.return_user_exc( + builder, ValueError, ("Domain has wrong number of elements.",) + ) + + with cgutils.if_unlikely(builder, pred2): + context.call_conv.return_user_exc( + builder, ValueError, ("Window has wrong number of elements.",) + ) + + polynomial.coef = coef_cast + polynomial.domain = domain_helper._getvalue() + polynomial.window = window_helper._getvalue() + + return polynomial._getvalue() + + +@unbox(types.PolynomialType) +def unbox_polynomial(typ, obj, c): + """ + Convert a Polynomial object to a native polynomial structure. + """ + is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) + polynomial = cgutils.create_struct_proxy(typ)(c.context, c.builder) + with ExitStack() as stack: + natives = [] + for name in ("coef", "domain", "window"): + attr = c.pyapi.object_getattr_string(obj, name) + with cgutils.early_exit_if_null(c.builder, stack, attr): + c.builder.store(cgutils.true_bit, is_error_ptr) + t = getattr(typ, name) + native = c.unbox(t, attr) + c.pyapi.decref(attr) + with cgutils.early_exit_if(c.builder, stack, native.is_error): + c.builder.store(cgutils.true_bit, is_error_ptr) + natives.append(native) + + polynomial.coef = natives[0] + polynomial.domain = natives[1] + polynomial.window = natives[2] + + return NativeValue( + polynomial._getvalue(), is_error=c.builder.load(is_error_ptr) + ) + + +@box(types.PolynomialType) +def box_polynomial(typ, val, c): + """ + Convert a native polynomial structure to a Polynomial object. + """ + ret_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj) + fail_obj = c.pyapi.get_null_object() + + with ExitStack() as stack: + polynomial = cgutils.create_struct_proxy(typ)( + c.context, c.builder, value=val + ) + coef_obj = c.box(typ.coef, polynomial.coef) + with cgutils.early_exit_if_null(c.builder, stack, coef_obj): + c.builder.store(fail_obj, ret_ptr) + + domain_obj = c.box(typ.domain, polynomial.domain) + with cgutils.early_exit_if_null(c.builder, stack, domain_obj): + c.builder.store(fail_obj, ret_ptr) + + window_obj = c.box(typ.window, polynomial.window) + with cgutils.early_exit_if_null(c.builder, stack, window_obj): + c.builder.store(fail_obj, ret_ptr) + + class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Polynomial)) + with cgutils.early_exit_if_null(c.builder, stack, class_obj): + c.pyapi.decref(coef_obj) + c.pyapi.decref(domain_obj) + c.pyapi.decref(window_obj) + c.builder.store(fail_obj, ret_ptr) + + if typ.n_args == 1: + res1 = c.pyapi.call_function_objargs(class_obj, (coef_obj,)) + c.builder.store(res1, ret_ptr) + else: + res3 = c.pyapi.call_function_objargs( + class_obj, (coef_obj, domain_obj, window_obj) + ) + c.builder.store(res3, ret_ptr) + + c.pyapi.decref(coef_obj) + c.pyapi.decref(domain_obj) + c.pyapi.decref(window_obj) + c.pyapi.decref(class_obj) + + return c.builder.load(ret_ptr) diff --git a/numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py b/numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py new file mode 100644 index 000000000..9e1414549 --- /dev/null +++ b/numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py @@ -0,0 +1,379 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implementation of operations involving polynomials. +""" + +import numpy as np +from numpy.polynomial import polynomial as poly +from numpy.polynomial import polyutils as pu + +from numba import literal_unroll +from numba.core import types, errors +from numba.cuda.extending import overload +from numba.cuda.np.numpy_support import type_can_asarray, as_dtype, from_dtype + + +@overload(np.roots) +def roots_impl(p): + # cast int vectors to float cf. numpy, this is a bit dicey as + # the roots could be complex which will fail anyway + ty = getattr(p, "dtype", p) + if isinstance(ty, types.Integer): + cast_t = np.float64 + else: + cast_t = as_dtype(ty) + + def roots_impl(p): + # impl based on numpy: + # https://github.com/numpy/numpy/blob/master/numpy/lib/polynomial.py + + if len(p.shape) != 1: + raise ValueError("Input must be a 1d array.") + + non_zero = np.nonzero(p)[0] + + if len(non_zero) == 0: + return np.zeros(0, dtype=cast_t) + + tz = len(p) - non_zero[-1] - 1 + + # pull out the coeffs selecting between possible zero pads + p = p[int(non_zero[0]) : int(non_zero[-1]) + 1] + + n = len(p) + if n > 1: + # construct companion matrix, ensure fortran order + # to give to eigvals, write to upper diag and then + # transpose. + A = np.diag(np.ones((n - 2,), cast_t), 1).T + A[0, :] = -p[1:] / p[0] # normalize + roots = np.linalg.eigvals(A) + else: + roots = np.zeros(0, dtype=cast_t) + + # add in additional zeros on the end if needed + if tz > 0: + return np.hstack((roots, np.zeros(tz, dtype=cast_t))) + else: + return roots + + return roots_impl + + +@overload(pu.trimseq) +def polyutils_trimseq(seq): + if not type_can_asarray(seq): + msg = 'The argument "seq" must be array-like' + raise errors.TypingError(msg) + + if isinstance(seq, types.BaseTuple): + msg = 'Unsupported type %r for argument "seq"' + raise errors.TypingError(msg % (seq)) + + if np.ndim(seq) > 1: + msg = "Coefficient array is not 1-d" + raise errors.NumbaValueError(msg) + + def impl(seq): + if len(seq) == 0: + return seq + else: + for i in range(len(seq) - 1, -1, -1): + if seq[i] != 0: + break + return seq[: i + 1] + + return impl + + +@overload(pu.as_series) +def polyutils_as_series(alist, trim=True): + if not type_can_asarray(alist): + msg = 'The argument "alist" must be array-like' + raise errors.TypingError(msg) + + if not isinstance(trim, (bool, types.Boolean)): + msg = 'The argument "trim" must be boolean' + raise errors.TypingError(msg) + + res_dtype = np.float64 + + tuple_input = isinstance(alist, types.BaseTuple) + list_input = isinstance(alist, types.List) + if tuple_input: + if np.any(np.array([np.ndim(a) > 1 for a in alist])): + raise errors.NumbaValueError("Coefficient array is not 1-d") + + res_dtype = _poly_result_dtype(*alist) + + elif list_input: + dt = as_dtype(_get_list_type(alist)) + res_dtype = np.result_type(dt, np.float64) + + else: + if np.ndim(alist) <= 2: + res_dtype = np.result_type(res_dtype, as_dtype(alist.dtype)) + else: + # If total dimension has ndim > 2, then coeff arrays are not 1D + raise errors.NumbaValueError("Coefficient array is not 1-d") + + def impl(alist, trim=True): + if tuple_input: + arrays = [] + for item in literal_unroll(alist): + arrays.append(np.atleast_1d(np.asarray(item)).astype(res_dtype)) + + elif list_input: + arrays = [ + np.atleast_1d(np.asarray(a)).astype(res_dtype) for a in alist + ] + + else: + alist_arr = np.asarray(alist) + arrays = [ + np.atleast_1d(np.asarray(a)).astype(res_dtype) + for a in alist_arr + ] + + if min([a.size for a in arrays]) == 0: + raise ValueError("Coefficient array is empty") + + if trim: + arrays = [pu.trimseq(a) for a in arrays] + + ret = arrays + return ret + + return impl + + +def _get_list_type(l): + # A helper function that takes a list (possibly nested) and returns its + # dtype. Returns a Numba type. + dt = l.dtype + if (not isinstance(dt, types.Number)) and type_can_asarray(dt): + return _get_list_type(dt) + else: + return dt + + +def _poly_result_dtype(*args): + # A helper function that takes a tuple of inputs and returns their result + # dtype. Used for poly functions. Returns a NumPy dtype. + res_dtype = np.float64 + for item in args: + if isinstance(item, types.BaseTuple): + s1 = item.types + elif isinstance(item, types.List): + s1 = [_get_list_type(item)] + elif isinstance(item, types.Number): + s1 = [item] + elif isinstance(item, types.Array): + s1 = [item.dtype] + else: + msg = "Input dtype must be scalar" + raise errors.TypingError(msg) + + try: + l = [as_dtype(t) for t in s1] + l.append(res_dtype) + res_dtype = np.result_type(*l) + except errors.NumbaNotImplementedError: + msg = "Input dtype must be scalar." + raise errors.TypingError(msg) + + return from_dtype(res_dtype) + + +@overload(poly.polyadd) +def numpy_polyadd(c1, c2): + if not type_can_asarray(c1): + msg = 'The argument "c1" must be array-like' + raise errors.TypingError(msg) + + if not type_can_asarray(c2): + msg = 'The argument "c2" must be array-like' + raise errors.TypingError(msg) + + def impl(c1, c2): + arr1, arr2 = pu.as_series((c1, c2)) + diff = len(arr2) - len(arr1) + if diff > 0: + zr = np.zeros(diff) + arr1 = np.concatenate((arr1, zr)) + if diff < 0: + zr = np.zeros(-diff) + arr2 = np.concatenate((arr2, zr)) + val = arr1 + arr2 + return pu.trimseq(val) + + return impl + + +@overload(poly.polysub) +def numpy_polysub(c1, c2): + if not type_can_asarray(c1): + msg = 'The argument "c1" must be array-like' + raise errors.TypingError(msg) + + if not type_can_asarray(c2): + msg = 'The argument "c2" must be array-like' + raise errors.TypingError(msg) + + def impl(c1, c2): + arr1, arr2 = pu.as_series((c1, c2)) + diff = len(arr2) - len(arr1) + if diff > 0: + zr = np.zeros(diff) + arr1 = np.concatenate((arr1, zr)) + if diff < 0: + zr = np.zeros(-diff) + arr2 = np.concatenate((arr2, zr)) + val = arr1 - arr2 + return pu.trimseq(val) + + return impl + + +@overload(poly.polymul) +def numpy_polymul(c1, c2): + if not type_can_asarray(c1): + msg = 'The argument "c1" must be array-like' + raise errors.TypingError(msg) + + if not type_can_asarray(c2): + msg = 'The argument "c2" must be array-like' + raise errors.TypingError(msg) + + def impl(c1, c2): + arr1, arr2 = pu.as_series((c1, c2)) + val = np.convolve(arr1, arr2) + return pu.trimseq(val) + + return impl + + +@overload(poly.polyval, prefer_literal=True) +def poly_polyval(x, c, tensor=True): + if not type_can_asarray(x): + msg = 'The argument "x" must be array-like' + raise errors.TypingError(msg) + + if not type_can_asarray(c): + msg = 'The argument "c" must be array-like' + raise errors.TypingError(msg) + + if not isinstance(tensor, (bool, types.BooleanLiteral)): + msg = 'The argument "tensor" must be boolean' + raise errors.RequireLiteralValue(msg) + + res_dtype = _poly_result_dtype(c, x) + + # Simulate new_shape = (1,) * np.ndim(x) in the general case + # If x is a number, new_shape is not used + # If x is a tuple or a list, then it's 1d hence new_shape=(1,) + x_nd_array = not isinstance(x, types.Number) + new_shape = (1,) + if isinstance(x, types.Array): + # If x is a np.array, then take its dimension + new_shape = (1,) * np.ndim(x) + + if isinstance(tensor, bool): + tensor_arg = tensor + else: + tensor_arg = tensor.literal_value + + def impl(x, c, tensor=True): + arr = np.asarray(c).astype(res_dtype) + inputs = np.asarray(x).astype(res_dtype) + if x_nd_array and tensor_arg: + arr = arr.reshape(arr.shape + new_shape) + + l = len(arr) + y = arr[l - 1] + inputs * 0 + + for i in range(l - 1, 0, -1): + y = arr[i - 1] + y * inputs + + return y + + return impl + + +@overload(poly.polyint) +def poly_polyint(c, m=1): + if not type_can_asarray(c): + msg = 'The argument "c" must be array-like' + raise errors.TypingError(msg) + + if not isinstance(m, (int, types.Integer)): + msg = 'The argument "m" must be an integer' + raise errors.TypingError(msg) + + res_dtype = as_dtype(_poly_result_dtype(c)) + + if not np.issubdtype(res_dtype, np.number): + msg = f"Input dtype must be scalar. Found {res_dtype} instead" + raise errors.TypingError(msg) + + is1D = (np.ndim(c) == 1) or ( + isinstance(c, (types.List, types.BaseTuple)) + and isinstance(c.dtype, types.Number) + ) + + def impl(c, m=1): + c = np.asarray(c).astype(res_dtype) + cdt = c.dtype + for i in range(m): + n = len(c) + + tmp = np.empty((n + 1,) + c.shape[1:], dtype=cdt) + tmp[0] = c[0] * 0 + tmp[1] = c[0] + for j in range(1, n): + tmp[j + 1] = c[j] / (j + 1) + c = tmp + if is1D: + return pu.trimseq(c) + else: + return c + + return impl + + +@overload(poly.polydiv) +def numpy_polydiv(c1, c2): + if not type_can_asarray(c1): + msg = 'The argument "c1" must be array-like' + raise errors.TypingError(msg) + + if not type_can_asarray(c2): + msg = 'The argument "c2" must be array-like' + raise errors.TypingError(msg) + + def impl(c1, c2): + arr1, arr2 = pu.as_series((c1, c2)) + if arr2[-1] == 0: + raise ZeroDivisionError() + + l1 = len(arr1) + l2 = len(arr2) + if l1 < l2: + return arr1[:1] * 0, arr1 + elif l2 == 1: + return arr1 / arr2[-1], arr1[:1] * 0 + else: + dlen = l1 - l2 + scl = arr2[-1] + arr2 = arr2[:-1] / scl + i = dlen + j = l1 - 1 + while i >= 0: + arr1[i:j] -= arr2 * arr1[j] + i -= 1 + j -= 1 + return arr1[j + 1 :] / scl, pu.trimseq(arr1[: j + 1]) + + return impl diff --git a/numba_cuda/numba/cuda/np/ufunc/sigparse.py b/numba_cuda/numba/cuda/np/ufunc/sigparse.py new file mode 100644 index 000000000..89831c664 --- /dev/null +++ b/numba_cuda/numba/cuda/np/ufunc/sigparse.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import tokenize +import string + + +def parse_signature(sig): + """Parse generalized ufunc signature. + + NOTE: ',' (COMMA) is a delimiter; not separator. + This means trailing comma is legal. + """ + + def stripws(s): + return "".join(c for c in s if c not in string.whitespace) + + def tokenizer(src): + def readline(): + yield src + + gen = readline() + return tokenize.generate_tokens(lambda: next(gen)) + + def parse(src): + tokgen = tokenizer(src) + while True: + tok = next(tokgen) + if tok[1] == "(": + symbols = [] + while True: + tok = next(tokgen) + if tok[1] == ")": + break + elif tok[0] == tokenize.NAME: + symbols.append(tok[1]) + elif tok[1] == ",": + continue + else: + raise ValueError('bad token in signature "%s"' % tok[1]) + yield tuple(symbols) + tok = next(tokgen) + if tok[1] == ",": + continue + elif tokenize.ISEOF(tok[0]): + break + elif tokenize.ISEOF(tok[0]): + break + else: + raise ValueError('bad token in signature "%s"' % tok[1]) + + ins, _, outs = stripws(sig).partition("->") + inputs = list(parse(ins)) + outputs = list(parse(outs)) + + # check that all output symbols are defined in the inputs + isym = set() + osym = set() + for grp in inputs: + isym |= set(grp) + for grp in outputs: + osym |= set(grp) + + diff = osym.difference(isym) + if diff: + raise NameError("undefined output symbols: %s" % ",".join(sorted(diff))) + + return inputs, outputs diff --git a/numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py b/numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py index edef490d5..c12e38c52 100644 --- a/numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +++ b/numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py @@ -2,10 +2,16 @@ # SPDX-License-Identifier: BSD-2-Clause from numba.cuda.core import sigutils -from numba.np.ufunc import _internal # Utility functions +# HACK: These are explicitly defined here to avoid having a CExt just to import these constants. +# np doesn't expose these in the python API. +PyUFunc_Zero = 0 +PyUFunc_One = 1 +PyUFunc_None = -1 +PyUFunc_ReorderableNone = -2 + def _compile_element_wise_function(nb_func, targetoptions, sig): # Do compilation @@ -40,10 +46,10 @@ def disable_compile(self): _identities = { - 0: _internal.PyUFunc_Zero, - 1: _internal.PyUFunc_One, - None: _internal.PyUFunc_None, - "reorderable": _internal.PyUFunc_ReorderableNone, + 0: PyUFunc_Zero, + 1: PyUFunc_One, + None: PyUFunc_None, + "reorderable": PyUFunc_ReorderableNone, } diff --git a/numba_cuda/numba/cuda/np/ufunc_db.py b/numba_cuda/numba/cuda/np/ufunc_db.py new file mode 100644 index 000000000..1bd76e960 --- /dev/null +++ b/numba_cuda/numba/cuda/np/ufunc_db.py @@ -0,0 +1,1282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +"""This file contains information on how to translate different ufuncs +into numba. It is a database of different ufuncs and how each of its +loops maps to a function that implements the inner kernel of that ufunc +(the inner kernel being the per-element function). + +Use the function get_ufunc_info to get the information related to the +ufunc +""" + +import numpy as np +import sys + +# this is lazily initialized to avoid circular imports +IS_WIN32 = sys.platform.startswith("win32") +numpy_version = tuple(map(int, np.__version__.split(".")[:2])) +_ufunc_db = None + + +def _lazy_init_db(): + global _ufunc_db + + if _ufunc_db is None: + _ufunc_db = {} + _fill_ufunc_db(_ufunc_db) + + +def get_ufuncs(): + """obtain a list of supported ufuncs in the db""" + _lazy_init_db() + return _ufunc_db.keys() + + +def get_ufunc_info(ufunc_key): + """get the lowering information for the ufunc with key ufunc_key. + + The lowering information is a dictionary that maps from a numpy + loop string (as given by the ufunc types attribute) to a function + that handles code generation for a scalar version of the ufunc + (that is, generates the "per element" operation"). + + raises a KeyError if the ufunc is not in the ufunc_db + """ + _lazy_init_db() + return _ufunc_db[ufunc_key] + + +def _fill_ufunc_db(ufunc_db): + # some of these imports would cause a problem of circular + # imports if done at global scope when importing the numba + # module. + from numba.cuda.np import npyfuncs + from numba.cuda.np.math import cmathimpl, mathimpl, numbers + from numba.cuda.np.numpy_support import numpy_version + + ufunc_db[np.isnat] = { + # datetime & timedelta + "M->?": npyfuncs.np_datetime_isnat_impl, + "m->?": npyfuncs.np_datetime_isnat_impl, + } + + ufunc_db[np.negative] = { + "?->?": numbers.int_invert_impl, + "b->b": numbers.int_negate_impl, + "B->B": numbers.int_negate_impl, + "h->h": numbers.int_negate_impl, + "H->H": numbers.int_negate_impl, + "i->i": numbers.int_negate_impl, + "I->I": numbers.int_negate_impl, + "l->l": numbers.int_negate_impl, + "L->L": numbers.int_negate_impl, + "q->q": numbers.int_negate_impl, + "Q->Q": numbers.int_negate_impl, + "f->f": numbers.real_negate_impl, + "d->d": numbers.real_negate_impl, + "F->F": numbers.complex_negate_impl, + "D->D": numbers.complex_negate_impl, + } + + ufunc_db[np.positive] = { + "?->?": numbers.int_positive_impl, + "b->b": numbers.int_positive_impl, + "B->B": numbers.int_positive_impl, + "h->h": numbers.int_positive_impl, + "H->H": numbers.int_positive_impl, + "i->i": numbers.int_positive_impl, + "I->I": numbers.int_positive_impl, + "l->l": numbers.int_positive_impl, + "L->L": numbers.int_positive_impl, + "q->q": numbers.int_positive_impl, + "Q->Q": numbers.int_positive_impl, + "f->f": numbers.real_positive_impl, + "d->d": numbers.real_positive_impl, + "F->F": numbers.complex_positive_impl, + "D->D": numbers.complex_positive_impl, + } + + ufunc_db[np.absolute] = { + "?->?": numbers.int_abs_impl, + "b->b": numbers.int_abs_impl, + "B->B": numbers.uint_abs_impl, + "h->h": numbers.int_abs_impl, + "H->H": numbers.uint_abs_impl, + "i->i": numbers.int_abs_impl, + "I->I": numbers.uint_abs_impl, + "l->l": numbers.int_abs_impl, + "L->L": numbers.uint_abs_impl, + "q->q": numbers.int_abs_impl, + "Q->Q": numbers.uint_abs_impl, + "f->f": numbers.real_abs_impl, + "d->d": numbers.real_abs_impl, + "F->f": numbers.complex_abs_impl, + "D->d": numbers.complex_abs_impl, + } + + ufunc_db[np.sign] = { + "b->b": numbers.int_sign_impl, + "B->B": numbers.int_sign_impl, + "h->h": numbers.int_sign_impl, + "H->H": numbers.int_sign_impl, + "i->i": numbers.int_sign_impl, + "I->I": numbers.int_sign_impl, + "l->l": numbers.int_sign_impl, + "L->L": numbers.int_sign_impl, + "q->q": numbers.int_sign_impl, + "Q->Q": numbers.int_sign_impl, + "f->f": numbers.real_sign_impl, + "d->d": numbers.real_sign_impl, + "F->F": npyfuncs.np_complex_sign_impl, + "D->D": npyfuncs.np_complex_sign_impl, + } + + ufunc_db[np.add] = { + "??->?": numbers.int_or_impl, + "bb->b": numbers.int_add_impl, + "BB->B": numbers.int_add_impl, + "hh->h": numbers.int_add_impl, + "HH->H": numbers.int_add_impl, + "ii->i": numbers.int_add_impl, + "II->I": numbers.int_add_impl, + "ll->l": numbers.int_add_impl, + "LL->L": numbers.int_add_impl, + "qq->q": numbers.int_add_impl, + "QQ->Q": numbers.int_add_impl, + "ff->f": numbers.real_add_impl, + "dd->d": numbers.real_add_impl, + "FF->F": numbers.complex_add_impl, + "DD->D": numbers.complex_add_impl, + } + + ufunc_db[np.subtract] = { + "??->?": numbers.int_xor_impl, + "bb->b": numbers.int_sub_impl, + "BB->B": numbers.int_sub_impl, + "hh->h": numbers.int_sub_impl, + "HH->H": numbers.int_sub_impl, + "ii->i": numbers.int_sub_impl, + "II->I": numbers.int_sub_impl, + "ll->l": numbers.int_sub_impl, + "LL->L": numbers.int_sub_impl, + "qq->q": numbers.int_sub_impl, + "QQ->Q": numbers.int_sub_impl, + "ff->f": numbers.real_sub_impl, + "dd->d": numbers.real_sub_impl, + "FF->F": numbers.complex_sub_impl, + "DD->D": numbers.complex_sub_impl, + } + + ufunc_db[np.multiply] = { + "??->?": numbers.int_and_impl, + "bb->b": numbers.int_mul_impl, + "BB->B": numbers.int_mul_impl, + "hh->h": numbers.int_mul_impl, + "HH->H": numbers.int_mul_impl, + "ii->i": numbers.int_mul_impl, + "II->I": numbers.int_mul_impl, + "ll->l": numbers.int_mul_impl, + "LL->L": numbers.int_mul_impl, + "qq->q": numbers.int_mul_impl, + "QQ->Q": numbers.int_mul_impl, + "ff->f": numbers.real_mul_impl, + "dd->d": numbers.real_mul_impl, + "FF->F": numbers.complex_mul_impl, + "DD->D": numbers.complex_mul_impl, + } + + if np.divide != np.true_divide: + ufunc_db[np.divide] = { + "bb->b": npyfuncs.np_int_sdiv_impl, + "BB->B": npyfuncs.np_int_udiv_impl, + "hh->h": npyfuncs.np_int_sdiv_impl, + "HH->H": npyfuncs.np_int_udiv_impl, + "ii->i": npyfuncs.np_int_sdiv_impl, + "II->I": npyfuncs.np_int_udiv_impl, + "ll->l": npyfuncs.np_int_sdiv_impl, + "LL->L": npyfuncs.np_int_udiv_impl, + "qq->q": npyfuncs.np_int_sdiv_impl, + "QQ->Q": npyfuncs.np_int_udiv_impl, + "ff->f": npyfuncs.np_real_div_impl, + "dd->d": npyfuncs.np_real_div_impl, + "FF->F": npyfuncs.np_complex_div_impl, + "DD->D": npyfuncs.np_complex_div_impl, + } + + ufunc_db[np.true_divide] = { + "bb->d": npyfuncs.np_int_truediv_impl, + "BB->d": npyfuncs.np_int_truediv_impl, + "hh->d": npyfuncs.np_int_truediv_impl, + "HH->d": npyfuncs.np_int_truediv_impl, + "ii->d": npyfuncs.np_int_truediv_impl, + "II->d": npyfuncs.np_int_truediv_impl, + "ll->d": npyfuncs.np_int_truediv_impl, + "LL->d": npyfuncs.np_int_truediv_impl, + "qq->d": npyfuncs.np_int_truediv_impl, + "QQ->d": npyfuncs.np_int_truediv_impl, + "ff->f": npyfuncs.np_real_div_impl, + "dd->d": npyfuncs.np_real_div_impl, + "FF->F": npyfuncs.np_complex_div_impl, + "DD->D": npyfuncs.np_complex_div_impl, + } + + ufunc_db[np.floor_divide] = { + "bb->b": npyfuncs.np_int_sdiv_impl, + "BB->B": npyfuncs.np_int_udiv_impl, + "hh->h": npyfuncs.np_int_sdiv_impl, + "HH->H": npyfuncs.np_int_udiv_impl, + "ii->i": npyfuncs.np_int_sdiv_impl, + "II->I": npyfuncs.np_int_udiv_impl, + "ll->l": npyfuncs.np_int_sdiv_impl, + "LL->L": npyfuncs.np_int_udiv_impl, + "qq->q": npyfuncs.np_int_sdiv_impl, + "QQ->Q": npyfuncs.np_int_udiv_impl, + "ff->f": npyfuncs.np_real_floor_div_impl, + "dd->d": npyfuncs.np_real_floor_div_impl, + } + + ufunc_db[np.remainder] = { + "bb->b": npyfuncs.np_int_srem_impl, + "BB->B": npyfuncs.np_int_urem_impl, + "hh->h": npyfuncs.np_int_srem_impl, + "HH->H": npyfuncs.np_int_urem_impl, + "ii->i": npyfuncs.np_int_srem_impl, + "II->I": npyfuncs.np_int_urem_impl, + "ll->l": npyfuncs.np_int_srem_impl, + "LL->L": npyfuncs.np_int_urem_impl, + "qq->q": npyfuncs.np_int_srem_impl, + "QQ->Q": npyfuncs.np_int_urem_impl, + "ff->f": npyfuncs.np_real_mod_impl, + "dd->d": npyfuncs.np_real_mod_impl, + } + + ufunc_db[np.divmod] = { + "bb->bb": npyfuncs.np_int_sdivrem_impl, + "BB->BB": npyfuncs.np_int_udivrem_impl, + "hh->hh": npyfuncs.np_int_sdivrem_impl, + "HH->HH": npyfuncs.np_int_udivrem_impl, + "ii->ii": npyfuncs.np_int_sdivrem_impl, + "II->II": npyfuncs.np_int_udivrem_impl, + "ll->ll": npyfuncs.np_int_sdivrem_impl, + "LL->LL": npyfuncs.np_int_udivrem_impl, + "qq->qq": npyfuncs.np_int_sdivrem_impl, + "QQ->QQ": npyfuncs.np_int_udivrem_impl, + "ff->ff": npyfuncs.np_real_divmod_impl, + "dd->dd": npyfuncs.np_real_divmod_impl, + } + + ufunc_db[np.fmod] = { + "bb->b": npyfuncs.np_int_fmod_impl, + "BB->B": npyfuncs.np_int_fmod_impl, + "hh->h": npyfuncs.np_int_fmod_impl, + "HH->H": npyfuncs.np_int_fmod_impl, + "ii->i": npyfuncs.np_int_fmod_impl, + "II->I": npyfuncs.np_int_fmod_impl, + "ll->l": npyfuncs.np_int_fmod_impl, + "LL->L": npyfuncs.np_int_fmod_impl, + "qq->q": npyfuncs.np_int_fmod_impl, + "QQ->Q": npyfuncs.np_int_fmod_impl, + "ff->f": npyfuncs.np_real_fmod_impl, + "dd->d": npyfuncs.np_real_fmod_impl, + } + + ufunc_db[np.logaddexp] = { + "ff->f": npyfuncs.np_real_logaddexp_impl, + "dd->d": npyfuncs.np_real_logaddexp_impl, + } + + ufunc_db[np.logaddexp2] = { + "ff->f": npyfuncs.np_real_logaddexp2_impl, + "dd->d": npyfuncs.np_real_logaddexp2_impl, + } + + ufunc_db[np.power] = { + "bb->b": numbers.int_power_impl, + "BB->B": numbers.int_power_impl, + "hh->h": numbers.int_power_impl, + "HH->H": numbers.int_power_impl, + "ii->i": numbers.int_power_impl, + "II->I": numbers.int_power_impl, + "ll->l": numbers.int_power_impl, + "LL->L": numbers.int_power_impl, + "qq->q": numbers.int_power_impl, + "QQ->Q": numbers.int_power_impl, + # XXX we would like to use `int_power_impl` for real ** integer + # as well (for better performance), but the current ufunc typing + # rules forbid that + "ff->f": numbers.real_power_impl, + "dd->d": numbers.real_power_impl, + "FF->F": npyfuncs.np_complex_power_impl, + "DD->D": npyfuncs.np_complex_power_impl, + } + + ufunc_db[np.float_power] = { + "ff->f": npyfuncs.real_float_power_impl, + "dd->d": npyfuncs.real_float_power_impl, + "FF->F": npyfuncs.np_complex_float_power_impl, + "DD->D": npyfuncs.np_complex_float_power_impl, + } + + ufunc_db[np.gcd] = { + "bb->b": npyfuncs.np_gcd_impl, + "BB->B": npyfuncs.np_gcd_impl, + "hh->h": npyfuncs.np_gcd_impl, + "HH->H": npyfuncs.np_gcd_impl, + "ii->i": npyfuncs.np_gcd_impl, + "II->I": npyfuncs.np_gcd_impl, + "ll->l": npyfuncs.np_gcd_impl, + "LL->L": npyfuncs.np_gcd_impl, + "qq->q": npyfuncs.np_gcd_impl, + "QQ->Q": npyfuncs.np_gcd_impl, + } + + ufunc_db[np.lcm] = { + "bb->b": npyfuncs.np_lcm_impl, + "BB->B": npyfuncs.np_lcm_impl, + "hh->h": npyfuncs.np_lcm_impl, + "HH->H": npyfuncs.np_lcm_impl, + "ii->i": npyfuncs.np_lcm_impl, + "II->I": npyfuncs.np_lcm_impl, + "ll->l": npyfuncs.np_lcm_impl, + "LL->L": npyfuncs.np_lcm_impl, + "qq->q": npyfuncs.np_lcm_impl, + "QQ->Q": npyfuncs.np_lcm_impl, + } + + ufunc_db[np.rint] = { + "f->f": npyfuncs.np_real_rint_impl, + "d->d": npyfuncs.np_real_rint_impl, + "F->F": npyfuncs.np_complex_rint_impl, + "D->D": npyfuncs.np_complex_rint_impl, + } + + ufunc_db[np.conjugate] = { + "b->b": numbers.real_conjugate_impl, + "B->B": numbers.real_conjugate_impl, + "h->h": numbers.real_conjugate_impl, + "H->H": numbers.real_conjugate_impl, + "i->i": numbers.real_conjugate_impl, + "I->I": numbers.real_conjugate_impl, + "l->l": numbers.real_conjugate_impl, + "L->L": numbers.real_conjugate_impl, + "q->q": numbers.real_conjugate_impl, + "Q->Q": numbers.real_conjugate_impl, + "f->f": numbers.real_conjugate_impl, + "d->d": numbers.real_conjugate_impl, + "F->F": numbers.complex_conjugate_impl, + "D->D": numbers.complex_conjugate_impl, + } + + ufunc_db[np.exp] = { + "f->f": npyfuncs.np_real_exp_impl, + "d->d": npyfuncs.np_real_exp_impl, + "F->F": npyfuncs.np_complex_exp_impl, + "D->D": npyfuncs.np_complex_exp_impl, + } + + ufunc_db[np.exp2] = { + "f->f": npyfuncs.np_real_exp2_impl, + "d->d": npyfuncs.np_real_exp2_impl, + "F->F": npyfuncs.np_complex_exp2_impl, + "D->D": npyfuncs.np_complex_exp2_impl, + } + + ufunc_db[np.log] = { + "f->f": npyfuncs.np_real_log_impl, + "d->d": npyfuncs.np_real_log_impl, + "F->F": npyfuncs.np_complex_log_impl, + "D->D": npyfuncs.np_complex_log_impl, + } + + ufunc_db[np.log2] = { + "f->f": npyfuncs.np_real_log2_impl, + "d->d": npyfuncs.np_real_log2_impl, + "F->F": npyfuncs.np_complex_log2_impl, + "D->D": npyfuncs.np_complex_log2_impl, + } + + ufunc_db[np.log10] = { + "f->f": npyfuncs.np_real_log10_impl, + "d->d": npyfuncs.np_real_log10_impl, + "F->F": npyfuncs.np_complex_log10_impl, + "D->D": npyfuncs.np_complex_log10_impl, + } + + ufunc_db[np.expm1] = { + "f->f": npyfuncs.np_real_expm1_impl, + "d->d": npyfuncs.np_real_expm1_impl, + "F->F": npyfuncs.np_complex_expm1_impl, + "D->D": npyfuncs.np_complex_expm1_impl, + } + + ufunc_db[np.log1p] = { + "f->f": npyfuncs.np_real_log1p_impl, + "d->d": npyfuncs.np_real_log1p_impl, + "F->F": npyfuncs.np_complex_log1p_impl, + "D->D": npyfuncs.np_complex_log1p_impl, + } + + ufunc_db[np.sqrt] = { + "f->f": npyfuncs.np_real_sqrt_impl, + "d->d": npyfuncs.np_real_sqrt_impl, + "F->F": npyfuncs.np_complex_sqrt_impl, + "D->D": npyfuncs.np_complex_sqrt_impl, + } + + ufunc_db[np.square] = { + "b->b": npyfuncs.np_int_square_impl, + "B->B": npyfuncs.np_int_square_impl, + "h->h": npyfuncs.np_int_square_impl, + "H->H": npyfuncs.np_int_square_impl, + "i->i": npyfuncs.np_int_square_impl, + "I->I": npyfuncs.np_int_square_impl, + "l->l": npyfuncs.np_int_square_impl, + "L->L": npyfuncs.np_int_square_impl, + "q->q": npyfuncs.np_int_square_impl, + "Q->Q": npyfuncs.np_int_square_impl, + "f->f": npyfuncs.np_real_square_impl, + "d->d": npyfuncs.np_real_square_impl, + "F->F": npyfuncs.np_complex_square_impl, + "D->D": npyfuncs.np_complex_square_impl, + } + + ufunc_db[np.cbrt] = { + "f->f": npyfuncs.np_real_cbrt_impl, + "d->d": npyfuncs.np_real_cbrt_impl, + } + + ufunc_db[np.reciprocal] = { + "b->b": npyfuncs.np_int_reciprocal_impl, + "B->B": npyfuncs.np_int_reciprocal_impl, + "h->h": npyfuncs.np_int_reciprocal_impl, + "H->H": npyfuncs.np_int_reciprocal_impl, + "i->i": npyfuncs.np_int_reciprocal_impl, + "I->I": npyfuncs.np_int_reciprocal_impl, + "l->l": npyfuncs.np_int_reciprocal_impl, + "L->L": npyfuncs.np_int_reciprocal_impl, + "q->q": npyfuncs.np_int_reciprocal_impl, + "Q->Q": npyfuncs.np_int_reciprocal_impl, + "f->f": npyfuncs.np_real_reciprocal_impl, + "d->d": npyfuncs.np_real_reciprocal_impl, + "F->F": npyfuncs.np_complex_reciprocal_impl, + "D->D": npyfuncs.np_complex_reciprocal_impl, + } + + ufunc_db[np.sin] = { + "f->f": npyfuncs.np_real_sin_impl, + "d->d": npyfuncs.np_real_sin_impl, + "F->F": npyfuncs.np_complex_sin_impl, + "D->D": npyfuncs.np_complex_sin_impl, + } + + ufunc_db[np.cos] = { + "f->f": npyfuncs.np_real_cos_impl, + "d->d": npyfuncs.np_real_cos_impl, + "F->F": npyfuncs.np_complex_cos_impl, + "D->D": npyfuncs.np_complex_cos_impl, + } + + tan_impl = cmathimpl.tan_impl + + ufunc_db[np.tan] = { + "f->f": npyfuncs.np_real_tan_impl, + "d->d": npyfuncs.np_real_tan_impl, + "F->F": tan_impl, + "D->D": tan_impl, + } + + arcsin_impl = cmathimpl.asin_impl + + ufunc_db[np.arcsin] = { + "f->f": npyfuncs.np_real_asin_impl, + "d->d": npyfuncs.np_real_asin_impl, + "F->F": arcsin_impl, + "D->D": arcsin_impl, + } + + ufunc_db[np.arccos] = { + "f->f": npyfuncs.np_real_acos_impl, + "d->d": npyfuncs.np_real_acos_impl, + "F->F": cmathimpl.acos_impl, + "D->D": cmathimpl.acos_impl, + } + + arctan_impl = cmathimpl.atan_impl + + ufunc_db[np.arctan] = { + "f->f": npyfuncs.np_real_atan_impl, + "d->d": npyfuncs.np_real_atan_impl, + "F->F": arctan_impl, + "D->D": arctan_impl, + } + + ufunc_db[np.arctan2] = { + "ff->f": npyfuncs.np_real_atan2_impl, + "dd->d": npyfuncs.np_real_atan2_impl, + } + + ufunc_db[np.hypot] = { + "ff->f": npyfuncs.np_real_hypot_impl, + "dd->d": npyfuncs.np_real_hypot_impl, + } + + ufunc_db[np.sinh] = { + "f->f": npyfuncs.np_real_sinh_impl, + "d->d": npyfuncs.np_real_sinh_impl, + "F->F": npyfuncs.np_complex_sinh_impl, + "D->D": npyfuncs.np_complex_sinh_impl, + } + + ufunc_db[np.cosh] = { + "f->f": npyfuncs.np_real_cosh_impl, + "d->d": npyfuncs.np_real_cosh_impl, + "F->F": npyfuncs.np_complex_cosh_impl, + "D->D": npyfuncs.np_complex_cosh_impl, + } + + ufunc_db[np.tanh] = { + "f->f": npyfuncs.np_real_tanh_impl, + "d->d": npyfuncs.np_real_tanh_impl, + "F->F": npyfuncs.np_complex_tanh_impl, + "D->D": npyfuncs.np_complex_tanh_impl, + } + + arcsinh_impl = cmathimpl.asinh_impl + + ufunc_db[np.arcsinh] = { + "f->f": npyfuncs.np_real_asinh_impl, + "d->d": npyfuncs.np_real_asinh_impl, + "F->F": arcsinh_impl, + "D->D": arcsinh_impl, + } + + ufunc_db[np.arccosh] = { + "f->f": npyfuncs.np_real_acosh_impl, + "d->d": npyfuncs.np_real_acosh_impl, + "F->F": npyfuncs.np_complex_acosh_impl, + "D->D": npyfuncs.np_complex_acosh_impl, + } + + arctanh_impl = cmathimpl.atanh_impl + + ufunc_db[np.arctanh] = { + "f->f": npyfuncs.np_real_atanh_impl, + "d->d": npyfuncs.np_real_atanh_impl, + "F->F": arctanh_impl, + "D->D": arctanh_impl, + } + + ufunc_db[np.deg2rad] = { + "f->f": mathimpl.radians_float_impl, + "d->d": mathimpl.radians_float_impl, + } + + ufunc_db[np.radians] = ufunc_db[np.deg2rad] + + ufunc_db[np.rad2deg] = { + "f->f": mathimpl.degrees_float_impl, + "d->d": mathimpl.degrees_float_impl, + } + + ufunc_db[np.degrees] = ufunc_db[np.rad2deg] + + ufunc_db[np.floor] = { + "f->f": npyfuncs.np_real_floor_impl, + "d->d": npyfuncs.np_real_floor_impl, + } + if numpy_version >= (2, 1): + ufunc_db[np.floor].update( + { + "?->?": numbers.identity_impl, + "b->b": numbers.identity_impl, + "B->B": numbers.identity_impl, + "h->h": numbers.identity_impl, + "H->H": numbers.identity_impl, + "i->i": numbers.identity_impl, + "I->I": numbers.identity_impl, + "l->l": numbers.identity_impl, + "L->L": numbers.identity_impl, + "q->q": numbers.identity_impl, + "Q->Q": numbers.identity_impl, + } + ) + + ufunc_db[np.ceil] = { + "f->f": npyfuncs.np_real_ceil_impl, + "d->d": npyfuncs.np_real_ceil_impl, + } + if numpy_version >= (2, 1): + ufunc_db[np.ceil].update( + { + "?->?": numbers.identity_impl, + "b->b": numbers.identity_impl, + "B->B": numbers.identity_impl, + "h->h": numbers.identity_impl, + "H->H": numbers.identity_impl, + "i->i": numbers.identity_impl, + "I->I": numbers.identity_impl, + "l->l": numbers.identity_impl, + "L->L": numbers.identity_impl, + "q->q": numbers.identity_impl, + "Q->Q": numbers.identity_impl, + } + ) + + ufunc_db[np.trunc] = { + "f->f": npyfuncs.np_real_trunc_impl, + "d->d": npyfuncs.np_real_trunc_impl, + } + if numpy_version >= (2, 1): + ufunc_db[np.trunc].update( + { + "?->?": numbers.identity_impl, + "b->b": numbers.identity_impl, + "B->B": numbers.identity_impl, + "h->h": numbers.identity_impl, + "H->H": numbers.identity_impl, + "i->i": numbers.identity_impl, + "I->I": numbers.identity_impl, + "l->l": numbers.identity_impl, + "L->L": numbers.identity_impl, + "q->q": numbers.identity_impl, + "Q->Q": numbers.identity_impl, + } + ) + + ufunc_db[np.fabs] = { + "f->f": npyfuncs.np_real_fabs_impl, + "d->d": npyfuncs.np_real_fabs_impl, + } + + # logical ufuncs + ufunc_db[np.greater] = { + "??->?": numbers.int_ugt_impl, + "bb->?": numbers.int_sgt_impl, + "BB->?": numbers.int_ugt_impl, + "hh->?": numbers.int_sgt_impl, + "HH->?": numbers.int_ugt_impl, + "ii->?": numbers.int_sgt_impl, + "II->?": numbers.int_ugt_impl, + "ll->?": numbers.int_sgt_impl, + "LL->?": numbers.int_ugt_impl, + "qq->?": numbers.int_sgt_impl, + "QQ->?": numbers.int_ugt_impl, + "ff->?": numbers.real_gt_impl, + "dd->?": numbers.real_gt_impl, + "FF->?": npyfuncs.np_complex_gt_impl, + "DD->?": npyfuncs.np_complex_gt_impl, + } + if numpy_version >= (1, 25): + ufunc_db[np.greater].update( + { + "qQ->?": numbers.int_signed_unsigned_cmp(">"), + "Qq->?": numbers.int_unsigned_signed_cmp(">"), + } + ) + + ufunc_db[np.greater_equal] = { + "??->?": numbers.int_uge_impl, + "bb->?": numbers.int_sge_impl, + "BB->?": numbers.int_uge_impl, + "hh->?": numbers.int_sge_impl, + "HH->?": numbers.int_uge_impl, + "ii->?": numbers.int_sge_impl, + "II->?": numbers.int_uge_impl, + "ll->?": numbers.int_sge_impl, + "LL->?": numbers.int_uge_impl, + "qq->?": numbers.int_sge_impl, + "QQ->?": numbers.int_uge_impl, + "ff->?": numbers.real_ge_impl, + "dd->?": numbers.real_ge_impl, + "FF->?": npyfuncs.np_complex_ge_impl, + "DD->?": npyfuncs.np_complex_ge_impl, + } + if numpy_version >= (1, 25): + ufunc_db[np.greater_equal].update( + { + "qQ->?": numbers.int_signed_unsigned_cmp(">="), + "Qq->?": numbers.int_unsigned_signed_cmp(">="), + } + ) + + ufunc_db[np.less] = { + "??->?": numbers.int_ult_impl, + "bb->?": numbers.int_slt_impl, + "BB->?": numbers.int_ult_impl, + "hh->?": numbers.int_slt_impl, + "HH->?": numbers.int_ult_impl, + "ii->?": numbers.int_slt_impl, + "II->?": numbers.int_ult_impl, + "ll->?": numbers.int_slt_impl, + "LL->?": numbers.int_ult_impl, + "qq->?": numbers.int_slt_impl, + "QQ->?": numbers.int_ult_impl, + "ff->?": numbers.real_lt_impl, + "dd->?": numbers.real_lt_impl, + "FF->?": npyfuncs.np_complex_lt_impl, + "DD->?": npyfuncs.np_complex_lt_impl, + } + if numpy_version >= (1, 25): + ufunc_db[np.less].update( + { + "qQ->?": numbers.int_signed_unsigned_cmp("<"), + "Qq->?": numbers.int_unsigned_signed_cmp("<"), + } + ) + + ufunc_db[np.less_equal] = { + "??->?": numbers.int_ule_impl, + "bb->?": numbers.int_sle_impl, + "BB->?": numbers.int_ule_impl, + "hh->?": numbers.int_sle_impl, + "HH->?": numbers.int_ule_impl, + "ii->?": numbers.int_sle_impl, + "II->?": numbers.int_ule_impl, + "ll->?": numbers.int_sle_impl, + "LL->?": numbers.int_ule_impl, + "qq->?": numbers.int_sle_impl, + "QQ->?": numbers.int_ule_impl, + "ff->?": numbers.real_le_impl, + "dd->?": numbers.real_le_impl, + "FF->?": npyfuncs.np_complex_le_impl, + "DD->?": npyfuncs.np_complex_le_impl, + } + if numpy_version >= (1, 25): + ufunc_db[np.less_equal].update( + { + "qQ->?": numbers.int_signed_unsigned_cmp("<="), + "Qq->?": numbers.int_unsigned_signed_cmp("<="), + } + ) + + ufunc_db[np.not_equal] = { + "??->?": numbers.int_ne_impl, + "bb->?": numbers.int_ne_impl, + "BB->?": numbers.int_ne_impl, + "hh->?": numbers.int_ne_impl, + "HH->?": numbers.int_ne_impl, + "ii->?": numbers.int_ne_impl, + "II->?": numbers.int_ne_impl, + "ll->?": numbers.int_ne_impl, + "LL->?": numbers.int_ne_impl, + "qq->?": numbers.int_ne_impl, + "QQ->?": numbers.int_ne_impl, + "ff->?": numbers.real_ne_impl, + "dd->?": numbers.real_ne_impl, + "FF->?": npyfuncs.np_complex_ne_impl, + "DD->?": npyfuncs.np_complex_ne_impl, + } + if numpy_version >= (1, 25): + ufunc_db[np.not_equal].update( + { + "qQ->?": numbers.int_signed_unsigned_cmp("!="), + "Qq->?": numbers.int_unsigned_signed_cmp("!="), + } + ) + + ufunc_db[np.equal] = { + "??->?": numbers.int_eq_impl, + "bb->?": numbers.int_eq_impl, + "BB->?": numbers.int_eq_impl, + "hh->?": numbers.int_eq_impl, + "HH->?": numbers.int_eq_impl, + "ii->?": numbers.int_eq_impl, + "II->?": numbers.int_eq_impl, + "ll->?": numbers.int_eq_impl, + "LL->?": numbers.int_eq_impl, + "qq->?": numbers.int_eq_impl, + "QQ->?": numbers.int_eq_impl, + "ff->?": numbers.real_eq_impl, + "dd->?": numbers.real_eq_impl, + "FF->?": npyfuncs.np_complex_eq_impl, + "DD->?": npyfuncs.np_complex_eq_impl, + } + if numpy_version >= (1, 25): + ufunc_db[np.equal].update( + { + "qQ->?": numbers.int_signed_unsigned_cmp("=="), + "Qq->?": numbers.int_unsigned_signed_cmp("=="), + } + ) + + ufunc_db[np.logical_and] = { + "??->?": npyfuncs.np_logical_and_impl, + "bb->?": npyfuncs.np_logical_and_impl, + "BB->?": npyfuncs.np_logical_and_impl, + "hh->?": npyfuncs.np_logical_and_impl, + "HH->?": npyfuncs.np_logical_and_impl, + "ii->?": npyfuncs.np_logical_and_impl, + "II->?": npyfuncs.np_logical_and_impl, + "ll->?": npyfuncs.np_logical_and_impl, + "LL->?": npyfuncs.np_logical_and_impl, + "qq->?": npyfuncs.np_logical_and_impl, + "QQ->?": npyfuncs.np_logical_and_impl, + "ff->?": npyfuncs.np_logical_and_impl, + "dd->?": npyfuncs.np_logical_and_impl, + "FF->?": npyfuncs.np_complex_logical_and_impl, + "DD->?": npyfuncs.np_complex_logical_and_impl, + } + + ufunc_db[np.logical_or] = { + "??->?": npyfuncs.np_logical_or_impl, + "bb->?": npyfuncs.np_logical_or_impl, + "BB->?": npyfuncs.np_logical_or_impl, + "hh->?": npyfuncs.np_logical_or_impl, + "HH->?": npyfuncs.np_logical_or_impl, + "ii->?": npyfuncs.np_logical_or_impl, + "II->?": npyfuncs.np_logical_or_impl, + "ll->?": npyfuncs.np_logical_or_impl, + "LL->?": npyfuncs.np_logical_or_impl, + "qq->?": npyfuncs.np_logical_or_impl, + "QQ->?": npyfuncs.np_logical_or_impl, + "ff->?": npyfuncs.np_logical_or_impl, + "dd->?": npyfuncs.np_logical_or_impl, + "FF->?": npyfuncs.np_complex_logical_or_impl, + "DD->?": npyfuncs.np_complex_logical_or_impl, + } + + ufunc_db[np.logical_xor] = { + "??->?": npyfuncs.np_logical_xor_impl, + "bb->?": npyfuncs.np_logical_xor_impl, + "BB->?": npyfuncs.np_logical_xor_impl, + "hh->?": npyfuncs.np_logical_xor_impl, + "HH->?": npyfuncs.np_logical_xor_impl, + "ii->?": npyfuncs.np_logical_xor_impl, + "II->?": npyfuncs.np_logical_xor_impl, + "ll->?": npyfuncs.np_logical_xor_impl, + "LL->?": npyfuncs.np_logical_xor_impl, + "qq->?": npyfuncs.np_logical_xor_impl, + "QQ->?": npyfuncs.np_logical_xor_impl, + "ff->?": npyfuncs.np_logical_xor_impl, + "dd->?": npyfuncs.np_logical_xor_impl, + "FF->?": npyfuncs.np_complex_logical_xor_impl, + "DD->?": npyfuncs.np_complex_logical_xor_impl, + } + + ufunc_db[np.logical_not] = { + "?->?": npyfuncs.np_logical_not_impl, + "b->?": npyfuncs.np_logical_not_impl, + "B->?": npyfuncs.np_logical_not_impl, + "h->?": npyfuncs.np_logical_not_impl, + "H->?": npyfuncs.np_logical_not_impl, + "i->?": npyfuncs.np_logical_not_impl, + "I->?": npyfuncs.np_logical_not_impl, + "l->?": npyfuncs.np_logical_not_impl, + "L->?": npyfuncs.np_logical_not_impl, + "q->?": npyfuncs.np_logical_not_impl, + "Q->?": npyfuncs.np_logical_not_impl, + "f->?": npyfuncs.np_logical_not_impl, + "d->?": npyfuncs.np_logical_not_impl, + "F->?": npyfuncs.np_complex_logical_not_impl, + "D->?": npyfuncs.np_complex_logical_not_impl, + } + + ufunc_db[np.maximum] = { + "??->?": npyfuncs.np_logical_or_impl, + "bb->b": npyfuncs.np_int_smax_impl, + "BB->B": npyfuncs.np_int_umax_impl, + "hh->h": npyfuncs.np_int_smax_impl, + "HH->H": npyfuncs.np_int_umax_impl, + "ii->i": npyfuncs.np_int_smax_impl, + "II->I": npyfuncs.np_int_umax_impl, + "ll->l": npyfuncs.np_int_smax_impl, + "LL->L": npyfuncs.np_int_umax_impl, + "qq->q": npyfuncs.np_int_smax_impl, + "QQ->Q": npyfuncs.np_int_umax_impl, + "ff->f": npyfuncs.np_real_maximum_impl, + "dd->d": npyfuncs.np_real_maximum_impl, + "FF->F": npyfuncs.np_complex_maximum_impl, + "DD->D": npyfuncs.np_complex_maximum_impl, + } + + ufunc_db[np.minimum] = { + "??->?": npyfuncs.np_logical_and_impl, + "bb->b": npyfuncs.np_int_smin_impl, + "BB->B": npyfuncs.np_int_umin_impl, + "hh->h": npyfuncs.np_int_smin_impl, + "HH->H": npyfuncs.np_int_umin_impl, + "ii->i": npyfuncs.np_int_smin_impl, + "II->I": npyfuncs.np_int_umin_impl, + "ll->l": npyfuncs.np_int_smin_impl, + "LL->L": npyfuncs.np_int_umin_impl, + "qq->q": npyfuncs.np_int_smin_impl, + "QQ->Q": npyfuncs.np_int_umin_impl, + "ff->f": npyfuncs.np_real_minimum_impl, + "dd->d": npyfuncs.np_real_minimum_impl, + "FF->F": npyfuncs.np_complex_minimum_impl, + "DD->D": npyfuncs.np_complex_minimum_impl, + } + + ufunc_db[np.fmax] = { + "??->?": npyfuncs.np_logical_or_impl, + "bb->b": npyfuncs.np_int_smax_impl, + "BB->B": npyfuncs.np_int_umax_impl, + "hh->h": npyfuncs.np_int_smax_impl, + "HH->H": npyfuncs.np_int_umax_impl, + "ii->i": npyfuncs.np_int_smax_impl, + "II->I": npyfuncs.np_int_umax_impl, + "ll->l": npyfuncs.np_int_smax_impl, + "LL->L": npyfuncs.np_int_umax_impl, + "qq->q": npyfuncs.np_int_smax_impl, + "QQ->Q": npyfuncs.np_int_umax_impl, + "ff->f": npyfuncs.np_real_fmax_impl, + "dd->d": npyfuncs.np_real_fmax_impl, + "FF->F": npyfuncs.np_complex_fmax_impl, + "DD->D": npyfuncs.np_complex_fmax_impl, + } + + ufunc_db[np.fmin] = { + "??->?": npyfuncs.np_logical_and_impl, + "bb->b": npyfuncs.np_int_smin_impl, + "BB->B": npyfuncs.np_int_umin_impl, + "hh->h": npyfuncs.np_int_smin_impl, + "HH->H": npyfuncs.np_int_umin_impl, + "ii->i": npyfuncs.np_int_smin_impl, + "II->I": npyfuncs.np_int_umin_impl, + "ll->l": npyfuncs.np_int_smin_impl, + "LL->L": npyfuncs.np_int_umin_impl, + "qq->q": npyfuncs.np_int_smin_impl, + "QQ->Q": npyfuncs.np_int_umin_impl, + "ff->f": npyfuncs.np_real_fmin_impl, + "dd->d": npyfuncs.np_real_fmin_impl, + "FF->F": npyfuncs.np_complex_fmin_impl, + "DD->D": npyfuncs.np_complex_fmin_impl, + } + + # misc floating functions + ufunc_db[np.isnan] = { + "f->?": npyfuncs.np_real_isnan_impl, + "d->?": npyfuncs.np_real_isnan_impl, + "F->?": npyfuncs.np_complex_isnan_impl, + "D->?": npyfuncs.np_complex_isnan_impl, + # int8 + "b->?": npyfuncs.np_int_isnan_impl, + "B->?": npyfuncs.np_int_isnan_impl, + # int16 + "h->?": npyfuncs.np_int_isnan_impl, + "H->?": npyfuncs.np_int_isnan_impl, + # int32 + "i->?": npyfuncs.np_int_isnan_impl, + "I->?": npyfuncs.np_int_isnan_impl, + # int64 + "l->?": npyfuncs.np_int_isnan_impl, + "L->?": npyfuncs.np_int_isnan_impl, + # intp + "q->?": npyfuncs.np_int_isnan_impl, + "Q->?": npyfuncs.np_int_isnan_impl, + # boolean + "?->?": npyfuncs.np_int_isnan_impl, + # datetime & timedelta + "m->?": npyfuncs.np_datetime_isnat_impl, + "M->?": npyfuncs.np_datetime_isnat_impl, + } + + ufunc_db[np.isinf] = { + "f->?": npyfuncs.np_real_isinf_impl, + "d->?": npyfuncs.np_real_isinf_impl, + "F->?": npyfuncs.np_complex_isinf_impl, + "D->?": npyfuncs.np_complex_isinf_impl, + # int8 + "b->?": npyfuncs.np_int_isinf_impl, + "B->?": npyfuncs.np_int_isinf_impl, + # int16 + "h->?": npyfuncs.np_int_isinf_impl, + "H->?": npyfuncs.np_int_isinf_impl, + # int32 + "i->?": npyfuncs.np_int_isinf_impl, + "I->?": npyfuncs.np_int_isinf_impl, + # int64 + "l->?": npyfuncs.np_int_isinf_impl, + "L->?": npyfuncs.np_int_isinf_impl, + # intp + "q->?": npyfuncs.np_int_isinf_impl, + "Q->?": npyfuncs.np_int_isinf_impl, + # boolean + "?->?": npyfuncs.np_int_isinf_impl, + # datetime & timedelta + "m->?": npyfuncs.np_int_isinf_impl, + "M->?": npyfuncs.np_int_isinf_impl, + } + + ufunc_db[np.isfinite] = { + "f->?": npyfuncs.np_real_isfinite_impl, + "d->?": npyfuncs.np_real_isfinite_impl, + "F->?": npyfuncs.np_complex_isfinite_impl, + "D->?": npyfuncs.np_complex_isfinite_impl, + # int8 + "b->?": npyfuncs.np_int_isfinite_impl, + "B->?": npyfuncs.np_int_isfinite_impl, + # int16 + "h->?": npyfuncs.np_int_isfinite_impl, + "H->?": npyfuncs.np_int_isfinite_impl, + # int32 + "i->?": npyfuncs.np_int_isfinite_impl, + "I->?": npyfuncs.np_int_isfinite_impl, + # int64 + "l->?": npyfuncs.np_int_isfinite_impl, + "L->?": npyfuncs.np_int_isfinite_impl, + # intp + "q->?": npyfuncs.np_int_isfinite_impl, + "Q->?": npyfuncs.np_int_isfinite_impl, + # boolean + "?->?": npyfuncs.np_int_isfinite_impl, + # datetime & timedelta + "M->?": npyfuncs.np_datetime_isfinite_impl, + "m->?": npyfuncs.np_datetime_isfinite_impl, + } + + ufunc_db[np.signbit] = { + "f->?": npyfuncs.np_real_signbit_impl, + "d->?": npyfuncs.np_real_signbit_impl, + } + + ufunc_db[np.copysign] = { + "ff->f": npyfuncs.np_real_copysign_impl, + "dd->d": npyfuncs.np_real_copysign_impl, + } + + ufunc_db[np.nextafter] = { + "ff->f": npyfuncs.np_real_nextafter_impl, + "dd->d": npyfuncs.np_real_nextafter_impl, + } + + ufunc_db[np.spacing] = { + "f->f": npyfuncs.np_real_spacing_impl, + "d->d": npyfuncs.np_real_spacing_impl, + } + + ufunc_db[np.ldexp] = { + "fi->f": npyfuncs.np_real_ldexp_impl, + "fl->f": npyfuncs.np_real_ldexp_impl, + "di->d": npyfuncs.np_real_ldexp_impl, + "dl->d": npyfuncs.np_real_ldexp_impl, + } + if numpy_version >= (2, 0) and IS_WIN32: + ufunc_db[np.ldexp]["fq->f"] = ufunc_db[np.ldexp].pop("fl->f") + ufunc_db[np.ldexp]["dq->d"] = ufunc_db[np.ldexp].pop("dl->d") + + # bit twiddling functions + ufunc_db[np.bitwise_and] = { + "??->?": numbers.int_and_impl, + "bb->b": numbers.int_and_impl, + "BB->B": numbers.int_and_impl, + "hh->h": numbers.int_and_impl, + "HH->H": numbers.int_and_impl, + "ii->i": numbers.int_and_impl, + "II->I": numbers.int_and_impl, + "ll->l": numbers.int_and_impl, + "LL->L": numbers.int_and_impl, + "qq->q": numbers.int_and_impl, + "QQ->Q": numbers.int_and_impl, + } + + ufunc_db[np.bitwise_or] = { + "??->?": numbers.int_or_impl, + "bb->b": numbers.int_or_impl, + "BB->B": numbers.int_or_impl, + "hh->h": numbers.int_or_impl, + "HH->H": numbers.int_or_impl, + "ii->i": numbers.int_or_impl, + "II->I": numbers.int_or_impl, + "ll->l": numbers.int_or_impl, + "LL->L": numbers.int_or_impl, + "qq->q": numbers.int_or_impl, + "QQ->Q": numbers.int_or_impl, + } + + ufunc_db[np.bitwise_xor] = { + "??->?": numbers.int_xor_impl, + "bb->b": numbers.int_xor_impl, + "BB->B": numbers.int_xor_impl, + "hh->h": numbers.int_xor_impl, + "HH->H": numbers.int_xor_impl, + "ii->i": numbers.int_xor_impl, + "II->I": numbers.int_xor_impl, + "ll->l": numbers.int_xor_impl, + "LL->L": numbers.int_xor_impl, + "qq->q": numbers.int_xor_impl, + "QQ->Q": numbers.int_xor_impl, + } + + ufunc_db[np.invert] = { # aka np.bitwise_not + "?->?": numbers.int_invert_impl, + "b->b": numbers.int_invert_impl, + "B->B": numbers.int_invert_impl, + "h->h": numbers.int_invert_impl, + "H->H": numbers.int_invert_impl, + "i->i": numbers.int_invert_impl, + "I->I": numbers.int_invert_impl, + "l->l": numbers.int_invert_impl, + "L->L": numbers.int_invert_impl, + "q->q": numbers.int_invert_impl, + "Q->Q": numbers.int_invert_impl, + } + + ufunc_db[np.left_shift] = { + "bb->b": numbers.int_shl_impl, + "BB->B": numbers.int_shl_impl, + "hh->h": numbers.int_shl_impl, + "HH->H": numbers.int_shl_impl, + "ii->i": numbers.int_shl_impl, + "II->I": numbers.int_shl_impl, + "ll->l": numbers.int_shl_impl, + "LL->L": numbers.int_shl_impl, + "qq->q": numbers.int_shl_impl, + "QQ->Q": numbers.int_shl_impl, + } + + ufunc_db[np.right_shift] = { + "bb->b": numbers.int_shr_impl, + "BB->B": numbers.int_shr_impl, + "hh->h": numbers.int_shr_impl, + "HH->H": numbers.int_shr_impl, + "ii->i": numbers.int_shr_impl, + "II->I": numbers.int_shr_impl, + "ll->l": numbers.int_shr_impl, + "LL->L": numbers.int_shr_impl, + "qq->q": numbers.int_shr_impl, + "QQ->Q": numbers.int_shr_impl, + } + + # Inject datetime64 support + from numba.cuda.np import npdatetime + + ufunc_db[np.negative].update( + { + "m->m": npdatetime.timedelta_neg_impl, + } + ) + ufunc_db[np.positive].update( + { + "m->m": npdatetime.timedelta_pos_impl, + } + ) + ufunc_db[np.absolute].update( + { + "m->m": npdatetime.timedelta_abs_impl, + } + ) + ufunc_db[np.sign].update( + { + "m->m": npdatetime.timedelta_sign_impl, + } + ) + ufunc_db[np.add].update( + { + "mm->m": npdatetime.timedelta_add_impl, + "Mm->M": npdatetime.datetime_plus_timedelta, + "mM->M": npdatetime.timedelta_plus_datetime, + } + ) + ufunc_db[np.subtract].update( + { + "mm->m": npdatetime.timedelta_sub_impl, + "Mm->M": npdatetime.datetime_minus_timedelta, + "MM->m": npdatetime.datetime_minus_datetime, + } + ) + ufunc_db[np.multiply].update( + { + "mq->m": npdatetime.timedelta_times_number, + "md->m": npdatetime.timedelta_times_number, + "qm->m": npdatetime.number_times_timedelta, + "dm->m": npdatetime.number_times_timedelta, + } + ) + if np.divide != np.true_divide: + ufunc_db[np.divide].update( + { + "mq->m": npdatetime.timedelta_over_number, + "md->m": npdatetime.timedelta_over_number, + "mm->d": npdatetime.timedelta_over_timedelta, + } + ) + ufunc_db[np.true_divide].update( + { + "mq->m": npdatetime.timedelta_over_number, + "md->m": npdatetime.timedelta_over_number, + "mm->d": npdatetime.timedelta_over_timedelta, + } + ) + ufunc_db[np.floor_divide].update( + { + "mq->m": npdatetime.timedelta_over_number, + "md->m": npdatetime.timedelta_over_number, + } + ) + + ufunc_db[np.floor_divide].update( + { + "mm->q": npdatetime.timedelta_floor_div_timedelta, + } + ) + + ufunc_db[np.equal].update( + { + "MM->?": npdatetime.datetime_eq_datetime_impl, + "mm->?": npdatetime.timedelta_eq_timedelta_impl, + } + ) + ufunc_db[np.not_equal].update( + { + "MM->?": npdatetime.datetime_ne_datetime_impl, + "mm->?": npdatetime.timedelta_ne_timedelta_impl, + } + ) + ufunc_db[np.less].update( + { + "MM->?": npdatetime.datetime_lt_datetime_impl, + "mm->?": npdatetime.timedelta_lt_timedelta_impl, + } + ) + ufunc_db[np.less_equal].update( + { + "MM->?": npdatetime.datetime_le_datetime_impl, + "mm->?": npdatetime.timedelta_le_timedelta_impl, + } + ) + ufunc_db[np.greater].update( + { + "MM->?": npdatetime.datetime_gt_datetime_impl, + "mm->?": npdatetime.timedelta_gt_timedelta_impl, + } + ) + ufunc_db[np.greater_equal].update( + { + "MM->?": npdatetime.datetime_ge_datetime_impl, + "mm->?": npdatetime.timedelta_ge_timedelta_impl, + } + ) + ufunc_db[np.maximum].update( + { + "MM->M": npdatetime.datetime_maximum_impl, + "mm->m": npdatetime.timedelta_maximum_impl, + } + ) + ufunc_db[np.minimum].update( + { + "MM->M": npdatetime.datetime_minimum_impl, + "mm->m": npdatetime.timedelta_minimum_impl, + } + ) + # there is no difference for datetime/timedelta in maximum/fmax + # and minimum/fmin + ufunc_db[np.fmax].update( + { + "MM->M": npdatetime.datetime_fmax_impl, + "mm->m": npdatetime.timedelta_fmax_impl, + } + ) + ufunc_db[np.fmin].update( + { + "MM->M": npdatetime.datetime_fmin_impl, + "mm->m": npdatetime.timedelta_fmin_impl, + } + ) + + ufunc_db[np.remainder].update( + { + "mm->m": npdatetime.timedelta_mod_timedelta, + } + ) diff --git a/numba_cuda/numba/cuda/np/unsafe/__init__.py b/numba_cuda/numba/cuda/np/unsafe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/numba_cuda/numba/cuda/np/unsafe/ndarray.py b/numba_cuda/numba/cuda/np/unsafe/ndarray.py new file mode 100644 index 000000000..58bffe805 --- /dev/null +++ b/numba_cuda/numba/cuda/np/unsafe/ndarray.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +This file provides internal compiler utilities that support certain special +operations with numpy. +""" + +from numba.core import types +from numba.cuda.cgutils import unpack_tuple +from numba.cuda.extending import intrinsic +from numba.cuda import typing +from numba.core.imputils import impl_ret_new_ref +from numba.core.errors import RequireLiteralValue, TypingError + +from numba.cpython.unsafe.tuple import tuple_setitem + + +@intrinsic +def empty_inferred(typingctx, shape): + """A version of numpy.empty whose dtype is inferred by the type system. + + Expects `shape` to be a int-tuple. + + There is special logic in the type-inferencer to handle the "refine"-ing + of undefined dtype. + """ + from numba.cuda.np.arrayobj import _empty_nd_impl + + def codegen(context, builder, signature, args): + # check that the return type is now defined + arrty = signature.return_type + assert arrty.is_precise() + shapes = unpack_tuple(builder, args[0]) + # redirect implementation to np.empty + res = _empty_nd_impl(context, builder, arrty, shapes) + return impl_ret_new_ref(context, builder, arrty, res._getvalue()) + + # make function signature + nd = len(shape) + array_ty = types.Array(ndim=nd, layout="C", dtype=types.undefined) + sig = array_ty(shape) + return sig, codegen + + +@intrinsic +def to_fixed_tuple(typingctx, array, length): + """Convert *array* into a tuple of *length* + + Returns ``UniTuple(array.dtype, length)`` + + ** Warning ** + - No boundchecking. + If *length* is longer than *array.size*, the behavior is undefined. + """ + if not isinstance(length, types.IntegerLiteral): + raise RequireLiteralValue("*length* argument must be a constant") + + if array.ndim != 1: + raise TypingError("Not supported on array.ndim={}".format(array.ndim)) + + # Determine types + tuple_size = int(length.literal_value) + tuple_type = types.UniTuple(dtype=array.dtype, count=tuple_size) + sig = tuple_type(array, length) + + def codegen(context, builder, signature, args): + def impl(array, length, empty_tuple): + out = empty_tuple + for i in range(length): + out = tuple_setitem(out, i, array[i]) + return out + + inner_argtypes = [signature.args[0], types.intp, tuple_type] + inner_sig = typing.signature(tuple_type, *inner_argtypes) + ll_idx_type = context.get_value_type(types.intp) + # Allocate an empty tuple + empty_tuple = context.get_constant_undef(tuple_type) + inner_args = [args[0], ll_idx_type(tuple_size), empty_tuple] + + res = context.compile_internal(builder, impl, inner_sig, inner_args) + return res + + return sig, codegen diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 201e4beb6..1b4d79b70 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -169,8 +169,11 @@ def load_additional_registries(self): from numba.cpython import rangeobj, enumimpl # noqa: F401 from numba.cuda.core import optional # noqa: F401 from numba.cuda.misc import cffiimpl - from numba.np import arrayobj # noqa: F401 - from numba.np import npdatetime # noqa: F401 + from numba.cuda.np import ( + arrayobj, + npdatetime, + polynomial, + ) from . import ( cudaimpl, fp16, @@ -182,7 +185,7 @@ def load_additional_registries(self): ) # fix for #8940 - from numba.np.unsafe import ndarray # noqa F401 + from numba.cuda.np.unsafe import ndarray # noqa F401 self.install_registry(cudaimpl.registry) self.install_registry(cffiimpl.registry) @@ -202,6 +205,11 @@ def load_additional_registries(self): self.install_registry(unicode.registry) self.install_registry(charseq.registry) + # install np registries + self.install_registry(polynomial.registry) + self.install_registry(npdatetime.registry) + self.install_registry(arrayobj.registry) + def codegen(self): return self._internal_codegen diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py b/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py index 8bc57b24d..c9f8de1cb 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py @@ -8,7 +8,7 @@ from numba import cuda, types, njit, typeof from numba.cuda import config -from numba.np import numpy_support +from numba.cuda.np import numpy_support from numba.cuda.tests.support import TestCase from numba.cuda.tests.support import MemoryLeakMixin diff --git a/numba_cuda/numba/cuda/tests/nocuda/test_import.py b/numba_cuda/numba/cuda/tests/nocuda/test_import.py index c38c4a9ce..a6ab3c9b3 100644 --- a/numba_cuda/numba/cuda/tests/nocuda/test_import.py +++ b/numba_cuda/numba/cuda/tests/nocuda/test_import.py @@ -43,6 +43,11 @@ def test_no_impl_import(self): "numba.np.arraymath", "numba.np.npdatetime", "numba.np.npyimpl", + "numba.cuda.np.linalg", + "numba.cuda.np.polynomial", + "numba.cuda.np.arraymath", + "numba.cuda.np.npdatetime", + "numba.cuda.np.npyimpl", "numba.typed.typeddict", "numba.typed.typedlist", )