diff --git a/numba_cuda/numba/cuda/np/numpy_support.py b/numba_cuda/numba/cuda/np/numpy_support.py new file mode 100644 index 000000000..8f81e7cdc --- /dev/null +++ b/numba_cuda/numba/cuda/np/numpy_support.py @@ -0,0 +1,244 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import numpy as np +import re +from numba.core import types, errors, config + + +numpy_version = tuple(map(int, np.__version__.split(".")[:2])) + + +if getattr(config, "USE_LEGACY_TYPE_SYSTEM", True): + FROM_DTYPE = { + np.dtype("bool"): types.boolean, + np.dtype("int8"): types.int8, + np.dtype("int16"): types.int16, + np.dtype("int32"): types.int32, + np.dtype("int64"): types.int64, + np.dtype("uint8"): types.uint8, + np.dtype("uint16"): types.uint16, + np.dtype("uint32"): types.uint32, + np.dtype("uint64"): types.uint64, + np.dtype("float32"): types.float32, + np.dtype("float64"): types.float64, + np.dtype("float16"): types.float16, + np.dtype("complex64"): types.complex64, + np.dtype("complex128"): types.complex128, + np.dtype(object): types.pyobject, + } +else: + FROM_DTYPE = { + np.dtype("bool"): types.np_bool_, + np.dtype("int8"): types.np_int8, + np.dtype("int16"): types.np_int16, + np.dtype("int32"): types.np_int32, + np.dtype("int64"): types.np_int64, + np.dtype("uint8"): types.np_uint8, + np.dtype("uint16"): types.np_uint16, + np.dtype("uint32"): types.np_uint32, + np.dtype("uint64"): types.np_uint64, + np.dtype("float32"): types.np_float32, + np.dtype("float64"): types.np_float64, + np.dtype("float16"): types.np_float16, + np.dtype("complex64"): types.np_complex64, + np.dtype("complex128"): types.np_complex128, + np.dtype(object): types.pyobject, + } + + +re_typestr = re.compile(r"[<>=\|]([a-z])(\d+)?$", re.I) +re_datetimestr = re.compile(r"[<>=\|]([mM])8?(\[([a-z]+)\])?$", re.I) + +sizeof_unicode_char = np.dtype("U1").itemsize + + +def _from_str_dtype(dtype): + m = re_typestr.match(dtype.str) + if not m: + raise errors.NumbaNotImplementedError(dtype) + groups = m.groups() + typecode = groups[0] + if typecode == "U": + # unicode + if dtype.byteorder not in "=|": + raise errors.NumbaNotImplementedError( + "Does not support non-native byteorder" + ) + count = dtype.itemsize // sizeof_unicode_char + assert count == int(groups[1]), "Unicode char size mismatch" + return types.UnicodeCharSeq(count) + + elif typecode == "S": + # char + count = dtype.itemsize + assert count == int(groups[1]), "Char size mismatch" + return types.CharSeq(count) + + else: + raise errors.NumbaNotImplementedError(dtype) + + +def _from_datetime_dtype(dtype): + m = re_datetimestr.match(dtype.str) + if not m: + raise errors.NumbaNotImplementedError(dtype) + groups = m.groups() + typecode = groups[0] + unit = groups[2] or "" + if typecode == "m": + return types.NPTimedelta(unit) + elif typecode == "M": + return types.NPDatetime(unit) + else: + raise errors.NumbaNotImplementedError(dtype) + + +def from_dtype(dtype): + """ + Return a Numba Type instance corresponding to the given Numpy *dtype*. + NumbaNotImplementedError is raised on unsupported Numpy dtypes. + """ + if type(dtype) is type and issubclass(dtype, np.generic): + dtype = np.dtype(dtype) + elif getattr(dtype, "fields", None) is not None: + return from_struct_dtype(dtype) + + try: + return FROM_DTYPE[dtype] + except KeyError: + pass + + try: + char = dtype.char + except AttributeError: + pass + else: + if char in "SU": + return _from_str_dtype(dtype) + if char in "mM": + return _from_datetime_dtype(dtype) + if char in "V" and dtype.subdtype is not None: + subtype = from_dtype(dtype.subdtype[0]) + return types.NestedArray(subtype, dtype.shape) + + raise errors.NumbaNotImplementedError(dtype) + + +_as_dtype_letters = { + types.NPDatetime: "M8", + types.NPTimedelta: "m8", + types.CharSeq: "S", + types.UnicodeCharSeq: "U", +} + + +def as_struct_dtype(rec): + """Convert Numba Record type to NumPy structured dtype""" + assert isinstance(rec, types.Record) + names = [] + formats = [] + offsets = [] + titles = [] + # Fill the fields if they are not a title. + for k, t in rec.members: + if not rec.is_title(k): + names.append(k) + formats.append(as_dtype(t)) + offsets.append(rec.offset(k)) + titles.append(rec.fields[k].title) + + fields = { + "names": names, + "formats": formats, + "offsets": offsets, + "itemsize": rec.size, + "titles": titles, + } + _check_struct_alignment(rec, fields) + return np.dtype(fields, align=rec.aligned) + + +def _check_struct_alignment(rec, fields): + """Check alignment compatibility with Numpy""" + if rec.aligned: + for k, dt in zip(fields["names"], fields["formats"]): + llvm_align = rec.alignof(k) + npy_align = dt.alignment + if llvm_align is not None and npy_align != llvm_align: + msg = ( + "NumPy is using a different alignment ({}) " + "than Numba/LLVM ({}) for {}. " + "This is likely a NumPy bug." + ) + raise ValueError(msg.format(npy_align, llvm_align, dt)) + + +def as_dtype(nbtype): + """ + Return a numpy dtype instance corresponding to the given Numba type. + NotImplementedError is if no correspondence is known. + """ + nbtype = types.unliteral(nbtype) + if isinstance(nbtype, (types.Complex, types.Integer, types.Float)): + return np.dtype(str(nbtype)) + if isinstance(nbtype, (types.Boolean)): + return np.dtype("?") + if isinstance(nbtype, (types.NPDatetime, types.NPTimedelta)): + letter = _as_dtype_letters[type(nbtype)] + if nbtype.unit: + return np.dtype("%s[%s]" % (letter, nbtype.unit)) + else: + return np.dtype(letter) + if isinstance(nbtype, (types.CharSeq, types.UnicodeCharSeq)): + letter = _as_dtype_letters[type(nbtype)] + return np.dtype("%s%d" % (letter, nbtype.count)) + if isinstance(nbtype, types.Record): + return as_struct_dtype(nbtype) + if isinstance(nbtype, types.EnumMember): + return as_dtype(nbtype.dtype) + if isinstance(nbtype, types.npytypes.DType): + return as_dtype(nbtype.dtype) + if isinstance(nbtype, types.NumberClass): + return as_dtype(nbtype.dtype) + if isinstance(nbtype, types.NestedArray): + spec = (as_dtype(nbtype.dtype), tuple(nbtype.shape)) + return np.dtype(spec) + if isinstance(nbtype, types.PyObject): + return np.dtype(object) + + msg = f"{nbtype} cannot be represented as a NumPy dtype" + raise errors.NumbaNotImplementedError(msg) + + +def _is_aligned_struct(struct): + return struct.isalignedstruct + + +def from_struct_dtype(dtype): + """Convert a NumPy structured dtype to Numba Record type""" + if dtype.hasobject: + msg = "dtypes that contain object are not supported." + raise errors.NumbaNotImplementedError(msg) + + fields = [] + for name, info in dtype.fields.items(): + # *info* may have 3 element + [elemdtype, offset] = info[:2] + title = info[2] if len(info) == 3 else None + + ty = from_dtype(elemdtype) + infos = { + "type": ty, + "offset": offset, + "title": title, + } + fields.append((name, infos)) + + # Note: dtype.alignment is not consistent. + # It is different after passing into a recarray. + # recarray(N, dtype=mydtype).dtype.alignment != mydtype.alignment + size = dtype.itemsize + aligned = _is_aligned_struct(dtype) + + return types.Record(fields, size, aligned) diff --git a/numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py b/numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py index c177ca860..b1b314b6d 100644 --- a/numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +++ b/numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py @@ -7,7 +7,7 @@ """ from contextlib import contextmanager -from numba.np.numpy_support import numpy_version +from numba.cuda.np.numpy_support import numpy_version import numpy as np diff --git a/numba_cuda/numba/cuda/simulator/kernelapi.py b/numba_cuda/numba/cuda/simulator/kernelapi.py index c4e186431..b25b0f293 100644 --- a/numba_cuda/numba/cuda/simulator/kernelapi.py +++ b/numba_cuda/numba/cuda/simulator/kernelapi.py @@ -13,7 +13,7 @@ from numba.core import types import numpy as np -from numba.np import numpy_support +from numba.cuda.np import numpy_support from .vector_types import vector_types diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py b/numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py index d69c3ef47..36f21cc92 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py @@ -10,7 +10,7 @@ ) from numba.cuda.testing import unittest, CUDATestCase from numba.cuda.testing import skip_on_cudasim -from numba.np import numpy_support +from numba.cuda.np import numpy_support from numba import cuda N_CHARS = 5 diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_complex.py b/numba_cuda/numba/cuda/tests/cudapy/test_complex.py index 1b4efe55c..027497954 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_complex.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_complex.py @@ -37,7 +37,7 @@ sinh_usecase, tanh_usecase, ) -from numba.np import numpy_support +from numba.cuda.np import numpy_support def compile_scalar_func(pyfunc, argtypes, restype): diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_datetime.py b/numba_cuda/numba/cuda/tests/cudapy/test_datetime.py index 73ebfba75..2c513dd0a 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_datetime.py @@ -4,7 +4,7 @@ import numpy as np from numba import cuda, vectorize, guvectorize -from numba.np.numpy_support import from_dtype +from numba.cuda.np.numpy_support import from_dtype from numba.cuda.testing import CUDATestCase, skip_on_cudasim import unittest diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_math.py b/numba_cuda/numba/cuda/tests/cudapy/test_math.py index e672098cd..57f0e3d97 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_math.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_math.py @@ -8,7 +8,7 @@ CUDATestCase, skip_on_cudasim, ) -from numba.np import numpy_support +from numba.cuda.np import numpy_support from numba import cuda, float32, float64, int32, vectorize, void, int64 import math diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_operator.py b/numba_cuda/numba/cuda/tests/cudapy/test_operator.py index bddee4662..ee496242f 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_operator.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_operator.py @@ -15,7 +15,7 @@ from numba.cuda.typing import signature import operator import itertools -from numba.np.numpy_support import from_dtype +from numba.cuda.np.numpy_support import from_dtype def simple_fp16_div_scalar(ary, a, b): diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py b/numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py index 619a03f7d..ab419762c 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py @@ -6,7 +6,7 @@ from numba.core import types from numba.cuda.testing import skip_on_cudasim, CUDATestCase import unittest -from numba.np import numpy_support +from numba.cuda.np import numpy_support def set_a(ary, i, v): diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_serialize.py b/numba_cuda/numba/cuda/tests/cudapy/test_serialize.py index 324a8555a..e8b62646e 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_serialize.py @@ -7,7 +7,7 @@ from numba.core import types from numba.cuda.testing import skip_on_cudasim, CUDATestCase import unittest -from numba.np import numpy_support +from numba.cuda.np import numpy_support @skip_on_cudasim("pickling not supported in CUDASIM") diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_sm.py b/numba_cuda/numba/cuda/tests/cudapy/test_sm.py index 3faa68b98..b5e9bfdeb 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_sm.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_sm.py @@ -7,7 +7,7 @@ from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim import numpy as np -from numba.np import numpy_support as nps +from numba.cuda.np import numpy_support as nps from .extensions_usecases import test_struct_model_type, TestStruct diff --git a/numba_cuda/numba/cuda/tests/support.py b/numba_cuda/numba/cuda/tests/support.py index 7869d20ed..1c79c0fe3 100644 --- a/numba_cuda/numba/cuda/tests/support.py +++ b/numba_cuda/numba/cuda/tests/support.py @@ -31,7 +31,7 @@ NativeValue, ) from numba.core.datamodel.models import OpaqueModel -from numba.np import numpy_support +from numba.cuda.np import numpy_support class EnableNRTStatsMixin(object): diff --git a/numba_cuda/numba/cuda/ufuncs.py b/numba_cuda/numba/cuda/ufuncs.py index b9c05741f..6b28a4a3c 100644 --- a/numba_cuda/numba/cuda/ufuncs.py +++ b/numba_cuda/numba/cuda/ufuncs.py @@ -28,7 +28,7 @@ def ufunc_db(): # Imports here are at function scope to avoid circular imports from numba.cpython import cmathimpl, mathimpl, numbers from numba.np import npyfuncs - from numba.np.numpy_support import numpy_version + from numba.cuda.np.numpy_support import numpy_version def np_unary_impl(fn, context, builder, sig, args): npyfuncs._check_arity_and_homogeneity(sig, args, 1)