diff --git a/numba_cuda/numba/cuda/__init__.py b/numba_cuda/numba/cuda/__init__.py index 64dd06e65..ad92b20ac 100644 --- a/numba_cuda/numba/cuda/__init__.py +++ b/numba_cuda/numba/cuda/__init__.py @@ -7,6 +7,12 @@ import warnings import sys +# Re-export types itself +import numba.cuda.types as types + +# Re-export all type names +from numba.cuda.types import * + # Require NVIDIA CUDA bindings at import time if not ( diff --git a/numba_cuda/numba/cuda/_internal/cuda_bf16.py b/numba_cuda/numba/cuda/_internal/cuda_bf16.py index 009ba85a7..62cacb25b 100644 --- a/numba_cuda/numba/cuda/_internal/cuda_bf16.py +++ b/numba_cuda/numba/cuda/_internal/cuda_bf16.py @@ -18,8 +18,8 @@ import numba from llvmlite import ir -from numba import types -from numba.core.datamodel import PrimitiveModel, StructModel +from numba.cuda import types +from numba.cuda.datamodel import PrimitiveModel, StructModel from numba.cuda.extending import ( lower_cast, make_attribute_wrapper, @@ -41,7 +41,7 @@ from numba.cuda import CUSource, declare_device from numba.cuda.vector_types import vector_types from numba.cuda.extending import as_numba_type -from numba.types import ( +from numba.cuda.types import ( CPointer, Function, Number, @@ -60,7 +60,7 @@ uint64, void, ) -from numba.cuda.types import bfloat16 +from numba.cuda.ext_types import bfloat16 float32x2 = vector_types["float32x2"] __half = float16 diff --git a/numba_cuda/numba/cuda/_internal/cuda_fp16.py b/numba_cuda/numba/cuda/_internal/cuda_fp16.py index 1bcf47543..7ef7d0473 100644 --- a/numba_cuda/numba/cuda/_internal/cuda_fp16.py +++ b/numba_cuda/numba/cuda/_internal/cuda_fp16.py @@ -18,9 +18,9 @@ import numba from llvmlite import ir -from numba import types +from numba.cuda import types from numba.cuda.cudadrv.driver import _have_nvjitlink -from numba.core.datamodel import PrimitiveModel, StructModel +from numba.cuda.datamodel import PrimitiveModel, StructModel from numba.core.errors import NumbaPerformanceWarning from numba.cuda.extending import ( lower_cast, @@ -40,7 +40,7 @@ from numba.cuda.typing.templates import Registry as TypingRegistry from numba.cuda.vector_types import vector_types from numba.cuda.extending import as_numba_type -from numba.types import ( +from numba.cuda.types import ( CPointer, Function, Number, @@ -221,7 +221,7 @@ class _ctor_template_unnamed1362180(ConcreteTemplate): register_global(unnamed1362180, Function(_ctor_template_unnamed1362180)) -__half = _type___half = numba.core.types.float16 +__half = _type___half = numba.cuda.types.float16 setattr(__half, "alignof_", 2) setattr(__half, "align", 2) diff --git a/numba_cuda/numba/cuda/cg.py b/numba_cuda/numba/cuda/cg.py index 8b4599167..d82a7c7f0 100644 --- a/numba_cuda/numba/cuda/cg.py +++ b/numba_cuda/numba/cuda/cg.py @@ -1,12 +1,12 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba.core import types +from numba.cuda import types from numba.cuda.extending import overload, overload_method from numba.cuda.typing import signature from numba.cuda import nvvmutils from numba.cuda.extending import intrinsic -from numba.cuda.types import grid_group, GridGroup as GridGroupClass +from numba.cuda.ext_types import grid_group, GridGroup as GridGroupClass class GridGroup: diff --git a/numba_cuda/numba/cuda/cgutils.py b/numba_cuda/numba/cuda/cgutils.py index 9864a863d..9fbadfff8 100644 --- a/numba_cuda/numba/cuda/cgutils.py +++ b/numba_cuda/numba/cuda/cgutils.py @@ -11,9 +11,9 @@ from llvmlite import ir -from numba.core import types +from numba.cuda import types from numba.cuda import config, utils, debuginfo -import numba.core.datamodel +import numba.cuda.datamodel bool_t = ir.IntType(1) @@ -104,7 +104,7 @@ class _StructProxy(object): def __init__(self, context, builder, value=None, ref=None): self._context = context self._datamodel = self._context.data_model_manager[self._fe_type] - if not isinstance(self._datamodel, numba.core.datamodel.StructModel): + if not isinstance(self._datamodel, numba.cuda.datamodel.StructModel): raise TypeError( "Not a structure model: {0}".format(self._datamodel) ) diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py index bb6cd1c67..480d5a03e 100644 --- a/numba_cuda/numba/cuda/compiler.py +++ b/numba_cuda/numba/cuda/compiler.py @@ -7,10 +7,8 @@ import copy from numba.core import ir as numba_ir -from numba.core import ( - types, - bytecode, -) +from numba.core import bytecode +from numba.cuda import types from numba.cuda.core.options import ParallelOptions from numba.core.compiler_lock import global_compiler_lock from numba.core.errors import NumbaWarning, NumbaInvalidConfigWarning diff --git a/numba_cuda/numba/cuda/core/analysis.py b/numba_cuda/numba/cuda/core/analysis.py index aa70a800d..13ca899f8 100644 --- a/numba_cuda/numba/cuda/core/analysis.py +++ b/numba_cuda/numba/cuda/core/analysis.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause from collections import namedtuple, defaultdict -from numba import types +from numba.cuda import types from numba.core import ir, errors from numba.cuda.core import consts import operator diff --git a/numba_cuda/numba/cuda/core/base.py b/numba_cuda/numba/cuda/core/base.py index 42516cf08..da5023fc1 100644 --- a/numba_cuda/numba/cuda/core/base.py +++ b/numba_cuda/numba/cuda/core/base.py @@ -11,13 +11,9 @@ from llvmlite import ir as llvmir from llvmlite.ir import Constant -from numba.core import ( - types, - datamodel, -) -from numba.cuda import cgutils, debuginfo, utils, config +from numba.cuda.core import imputils, targetconfig, funcdesc +from numba.cuda import cgutils, debuginfo, types, utils, datamodel, config from numba.core import errors -from numba.cuda.core import targetconfig, funcdesc, imputils from numba.core.compiler_lock import global_compiler_lock from numba.cuda.core.pythonapi import PythonAPI from numba.cuda.core.imputils import ( diff --git a/numba_cuda/numba/cuda/core/boxing.py b/numba_cuda/numba/cuda/core/boxing.py index ea7e774b3..513e4d907 100644 --- a/numba_cuda/numba/cuda/core/boxing.py +++ b/numba_cuda/numba/cuda/core/boxing.py @@ -7,8 +7,7 @@ from llvmlite import ir -from numba.core import types -from numba.cuda import cgutils +from numba.cuda import types, cgutils from numba.cuda.core.pythonapi import box, unbox, reflect, NativeValue from numba.core.errors import NumbaNotImplementedError, TypingError from numba.cuda.typing.typeof import typeof, Purpose diff --git a/numba_cuda/numba/cuda/core/callconv.py b/numba_cuda/numba/cuda/core/callconv.py index 827315d31..2ff231934 100644 --- a/numba_cuda/numba/cuda/core/callconv.py +++ b/numba_cuda/numba/cuda/core/callconv.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from collections import namedtuple diff --git a/numba_cuda/numba/cuda/core/funcdesc.py b/numba_cuda/numba/cuda/core/funcdesc.py index d4e0a26ed..bcb238d00 100644 --- a/numba_cuda/numba/cuda/core/funcdesc.py +++ b/numba_cuda/numba/cuda/core/funcdesc.py @@ -8,7 +8,7 @@ from collections import defaultdict import importlib -from numba.core import types +from numba.cuda import types from numba.cuda import itanium_mangler from numba.cuda.utils import _dynamic_modname, _dynamic_module diff --git a/numba_cuda/numba/cuda/core/generators.py b/numba_cuda/numba/cuda/core/generators.py new file mode 100644 index 000000000..f296d6b04 --- /dev/null +++ b/numba_cuda/numba/cuda/core/generators.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Support for lowering generators. +""" + +import llvmlite.ir +from llvmlite.ir import Constant, IRBuilder + +from numba.cuda import types, config, cgutils +from numba.cuda.core.funcdesc import FunctionDescriptor + + +class GeneratorDescriptor(FunctionDescriptor): + """ + The descriptor for a generator's next function. + """ + + __slots__ = () + + @classmethod + def from_generator_fndesc(cls, func_ir, fndesc, gentype, mangler): + """ + Build a GeneratorDescriptor for the generator returned by the + function described by *fndesc*, with type *gentype*. + + The generator inherits the env_name from the *fndesc*. + All emitted functions for the generator shares the same Env. + """ + assert isinstance(gentype, types.Generator) + restype = gentype.yield_type + args = ["gen"] + argtypes = (gentype,) + qualname = fndesc.qualname + ".next" + unique_name = fndesc.unique_name + ".next" + self = cls( + fndesc.native, + fndesc.modname, + qualname, + unique_name, + fndesc.doc, + fndesc.typemap, + restype, + fndesc.calltypes, + args, + fndesc.kws, + argtypes=argtypes, + mangler=mangler, + inline=False, + env_name=fndesc.env_name, + ) + return self + + @property + def llvm_finalizer_name(self): + """ + The LLVM name of the generator's finalizer function + (if .has_finalizer is true). + """ + return "finalize_" + self.mangled_name + + +class BaseGeneratorLower(object): + """ + Base support class for lowering generators. + """ + + def __init__(self, lower): + self.context = lower.context + self.fndesc = lower.fndesc + self.library = lower.library + self.func_ir = lower.func_ir + self.lower = lower + + self.geninfo = lower.generator_info + self.gentype = self.get_generator_type() + self.gendesc = GeneratorDescriptor.from_generator_fndesc( + lower.func_ir, self.fndesc, self.gentype, self.context.mangler + ) + # Helps packing non-omitted arguments into a structure + self.arg_packer = self.context.get_data_packer(self.fndesc.argtypes) + + self.resume_blocks = {} + + @property + def call_conv(self): + return self.lower.call_conv + + def get_args_ptr(self, builder, genptr): + return cgutils.gep_inbounds(builder, genptr, 0, 1) + + def get_resume_index_ptr(self, builder, genptr): + return cgutils.gep_inbounds( + builder, genptr, 0, 0, name="gen.resume_index" + ) + + def get_state_ptr(self, builder, genptr): + return cgutils.gep_inbounds(builder, genptr, 0, 2, name="gen.state") + + def lower_init_func(self, lower): + """ + Lower the generator's initialization function (which will fill up + the passed-by-reference generator structure). + """ + lower.setup_function(self.fndesc) + + builder = lower.builder + + # Insert the generator into the target context in order to allow + # calling from other Numba-compiled functions. + lower.context.insert_generator( + self.gentype, self.gendesc, [self.library] + ) + + # Init argument values + lower.extract_function_arguments() + + lower.pre_lower() + + # Initialize the return structure (i.e. the generator structure). + retty = self.context.get_return_type(self.gentype) + # Structure index #0: the initial resume index (0 == start of generator) + resume_index = self.context.get_constant(types.int32, 0) + # Structure index #2: the states + statesty = retty.elements[2] + + lower.debug_print("# low_init_func incref") + # Incref all NRT arguments before storing into generator states + if self.context.enable_nrt: + for argty, argval in zip(self.fndesc.argtypes, lower.fnargs): + self.context.nrt.incref(builder, argty, argval) + + # Filter out omitted arguments + argsval = self.arg_packer.as_data(builder, lower.fnargs) + + # Zero initialize states + statesval = Constant(statesty, None) + gen_struct = cgutils.make_anonymous_struct( + builder, [resume_index, argsval, statesval], retty + ) + + retval = self.box_generator_struct(lower, gen_struct) + + lower.debug_print("# low_init_func before return") + self.call_conv.return_value(builder, retval) + lower.post_lower() + + def lower_next_func(self, lower): + """ + Lower the generator's next() function (which takes the + passed-by-reference generator structure and returns the next + yielded value). + """ + lower.setup_function(self.gendesc) + lower.debug_print( + "# lower_next_func: {0}".format(self.gendesc.unique_name) + ) + assert self.gendesc.argtypes[0] == self.gentype + builder = lower.builder + function = lower.function + + # Extract argument values and other information from generator struct + (genptr,) = self.call_conv.get_arguments(function) + self.arg_packer.load_into( + builder, self.get_args_ptr(builder, genptr), lower.fnargs + ) + + self.resume_index_ptr = self.get_resume_index_ptr(builder, genptr) + self.gen_state_ptr = self.get_state_ptr(builder, genptr) + + prologue = function.append_basic_block("generator_prologue") + + # Lower the generator's Python code + entry_block_tail = lower.lower_function_body() + + # Add block for StopIteration on entry + stop_block = function.append_basic_block("stop_iteration") + builder.position_at_end(stop_block) + self.call_conv.return_stop_iteration(builder) + + # Add prologue switch to resume blocks + builder.position_at_end(prologue) + # First Python block is also the resume point on first next() call + self.resume_blocks[0] = lower.blkmap[lower.firstblk] + + # Create front switch to resume points + switch = builder.switch(builder.load(self.resume_index_ptr), stop_block) + for index, block in self.resume_blocks.items(): + switch.add_case(index, block) + + # Close tail of entry block + builder.position_at_end(entry_block_tail) + builder.branch(prologue) + + def lower_finalize_func(self, lower): + """ + Lower the generator's finalizer. + """ + fnty = llvmlite.ir.FunctionType( + llvmlite.ir.VoidType(), [self.context.get_value_type(self.gentype)] + ) + function = cgutils.get_or_insert_function( + lower.module, fnty, self.gendesc.llvm_finalizer_name + ) + entry_block = function.append_basic_block("entry") + builder = IRBuilder(entry_block) + + genptrty = self.context.get_value_type(self.gentype) + genptr = builder.bitcast(function.args[0], genptrty) + self.lower_finalize_func_body(builder, genptr) + + def return_from_generator(self, lower): + """ + Emit a StopIteration at generator end and mark the generator exhausted. + """ + indexval = Constant(self.resume_index_ptr.type.pointee, -1) + lower.builder.store(indexval, self.resume_index_ptr) + self.call_conv.return_stop_iteration(lower.builder) + + def create_resumption_block(self, lower, index): + block_name = "generator_resume%d" % (index,) + block = lower.function.append_basic_block(block_name) + lower.builder.position_at_end(block) + self.resume_blocks[index] = block + + def debug_print(self, builder, msg): + if config.DEBUG_JIT: + self.context.debug_print(builder, "DEBUGJIT: {0}".format(msg)) + + +class GeneratorLower(BaseGeneratorLower): + """ + Support class for lowering nopython generators. + """ + + def get_generator_type(self): + return self.fndesc.restype + + def box_generator_struct(self, lower, gen_struct): + return gen_struct + + def lower_finalize_func_body(self, builder, genptr): + """ + Lower the body of the generator's finalizer: decref all live + state variables. + """ + self.debug_print(builder, "# generator: finalize") + if self.context.enable_nrt: + # Always dereference all arguments + # self.debug_print(builder, "# generator: clear args") + args_ptr = self.get_args_ptr(builder, genptr) + for ty, val in self.arg_packer.load(builder, args_ptr): + self.context.nrt.decref(builder, ty, val) + + self.debug_print(builder, "# generator: finalize end") + builder.ret_void() + + +class PyGeneratorLower(BaseGeneratorLower): + """ + Support class for lowering object mode generators. + """ + + def get_generator_type(self): + """ + Compute the actual generator type (the generator function's return + type is simply "pyobject"). + """ + return types.Generator( + gen_func=self.func_ir.func_id.func, + yield_type=types.pyobject, + arg_types=(types.pyobject,) * self.func_ir.arg_count, + state_types=(types.pyobject,) * len(self.geninfo.state_vars), + has_finalizer=True, + ) + + def box_generator_struct(self, lower, gen_struct): + """ + Box the raw *gen_struct* as a Python object. + """ + gen_ptr = cgutils.alloca_once_value(lower.builder, gen_struct) + return lower.pyapi.from_native_generator( + gen_ptr, self.gentype, lower.envarg + ) + + def init_generator_state(self, lower): + """ + NULL-initialize all generator state variables, to avoid spurious + decref's on cleanup. + """ + lower.builder.store( + Constant(self.gen_state_ptr.type.pointee, None), self.gen_state_ptr + ) + + def lower_finalize_func_body(self, builder, genptr): + """ + Lower the body of the generator's finalizer: decref all live + state variables. + """ + pyapi = self.context.get_python_api(builder) + resume_index_ptr = self.get_resume_index_ptr(builder, genptr) + resume_index = builder.load(resume_index_ptr) + # If resume_index is 0, next() was never called + # If resume_index is -1, generator terminated cleanly + # (note function arguments are saved in state variables, + # so they don't need a separate cleanup step) + need_cleanup = builder.icmp_signed( + ">", resume_index, Constant(resume_index.type, 0) + ) + + with cgutils.if_unlikely(builder, need_cleanup): + # Decref all live vars (some may be NULL) + gen_state_ptr = self.get_state_ptr(builder, genptr) + for state_index in range(len(self.gentype.state_types)): + state_slot = cgutils.gep_inbounds( + builder, gen_state_ptr, 0, state_index + ) + ty = self.gentype.state_types[state_index] + val = self.context.unpack_value(builder, ty, state_slot) + pyapi.decref(val) + + builder.ret_void() + + +class LowerYield(object): + """ + Support class for lowering a particular yield point. + """ + + def __init__(self, lower, yield_point, live_vars): + self.lower = lower + self.context = lower.context + self.builder = lower.builder + self.genlower = lower.genlower + self.gentype = self.genlower.gentype + + self.gen_state_ptr = self.genlower.gen_state_ptr + self.resume_index_ptr = self.genlower.resume_index_ptr + self.yp = yield_point + self.inst = self.yp.inst + self.live_vars = live_vars + self.live_var_indices = [ + lower.generator_info.state_vars.index(v) for v in live_vars + ] + + def lower_yield_suspend(self): + self.lower.debug_print("# generator suspend") + # Save live vars in state + for state_index, name in zip(self.live_var_indices, self.live_vars): + state_slot = cgutils.gep_inbounds( + self.builder, self.gen_state_ptr, 0, state_index + ) + ty = self.gentype.state_types[state_index] + # The yield might be in a loop, in which case the state might + # contain a predicate var that branches back to the loop head, in + # this case the var is live but in sequential lowering won't have + # been alloca'd yet, so do this here. + fetype = self.lower.typeof(name) + self.lower._alloca_var(name, fetype) + val = self.lower.loadvar(name) + # IncRef newly stored value + if self.context.enable_nrt: + self.context.nrt.incref(self.builder, ty, val) + + self.context.pack_value(self.builder, ty, val, state_slot) + # Save resume index + indexval = Constant(self.resume_index_ptr.type.pointee, self.inst.index) + self.builder.store(indexval, self.resume_index_ptr) + self.lower.debug_print("# generator suspend end") + + def lower_yield_resume(self): + # Emit resumption point + self.genlower.create_resumption_block(self.lower, self.inst.index) + self.lower.debug_print("# generator resume") + # Reload live vars from state + for state_index, name in zip(self.live_var_indices, self.live_vars): + state_slot = cgutils.gep_inbounds( + self.builder, self.gen_state_ptr, 0, state_index + ) + ty = self.gentype.state_types[state_index] + val = self.context.unpack_value(self.builder, ty, state_slot) + self.lower.storevar(val, name) + # Previous storevar is making an extra incref + if self.context.enable_nrt: + self.context.nrt.decref(self.builder, ty, val) + self.lower.debug_print("# generator resume end") diff --git a/numba_cuda/numba/cuda/core/imputils.py b/numba_cuda/numba/cuda/core/imputils.py index 2ab4924c7..fdd3f22fd 100644 --- a/numba_cuda/numba/cuda/core/imputils.py +++ b/numba_cuda/numba/cuda/core/imputils.py @@ -10,7 +10,7 @@ from enum import Enum from numba.cuda import typing, cgutils -from numba.core import types +from numba.cuda import types from numba.cuda.typing.templates import BaseRegistryLoader diff --git a/numba_cuda/numba/cuda/core/inline_closurecall.py b/numba_cuda/numba/cuda/core/inline_closurecall.py index aabe535dd..4153f9936 100644 --- a/numba_cuda/numba/cuda/core/inline_closurecall.py +++ b/numba_cuda/numba/cuda/core/inline_closurecall.py @@ -5,8 +5,9 @@ import copy import ctypes import numba.core.analysis -from numba.core import types, ir, errors -from numba.cuda import utils, cgutils, typing, config +from numba.cuda import types, config, cgutils +from numba.core import ir, errors +from numba.cuda import typing, utils from numba.cuda.core.ir_utils import ( next_label, add_offset_to_labels, diff --git a/numba_cuda/numba/cuda/core/ir_utils.py b/numba_cuda/numba/cuda/core/ir_utils.py index 35587534d..96139a0bc 100644 --- a/numba_cuda/numba/cuda/core/ir_utils.py +++ b/numba_cuda/numba/cuda/core/ir_utils.py @@ -10,7 +10,8 @@ import warnings import numba -from numba.core import types, ir +from numba.cuda import types +from numba.core import ir from numba.cuda import typing from numba.cuda.core import analysis, postproc, rewrites, config from numba.cuda.typing.templates import signature diff --git a/numba_cuda/numba/cuda/core/optional.py b/numba_cuda/numba/cuda/core/optional.py index c4f56a49c..8cad47935 100644 --- a/numba_cuda/numba/cuda/core/optional.py +++ b/numba_cuda/numba/cuda/core/optional.py @@ -3,8 +3,8 @@ import operator -from numba.core import types -from numba.cuda import cgutils, typing +from numba.cuda import types, typing +from numba.cuda import cgutils from numba.cuda.core.imputils import Registry, impl_ret_untracked diff --git a/numba_cuda/numba/cuda/core/pythonapi.py b/numba_cuda/numba/cuda/core/pythonapi.py index 8a0db624d..f7b5e03b6 100644 --- a/numba_cuda/numba/cuda/core/pythonapi.py +++ b/numba_cuda/numba/cuda/core/pythonapi.py @@ -11,14 +11,10 @@ import ctypes from numba.cuda.cext import _helperlib -from numba.core import ( - errors, - types, -) -from numba.cuda import cgutils, lowering, config, serialize +from numba.core import errors from numba.cuda.core import imputils from numba.cuda.utils import PYVERSION - +from numba.cuda import config, types, lowering, cgutils, serialize PY_UNICODE_1BYTE_KIND = _helperlib.py_unicode_1byte_kind PY_UNICODE_2BYTE_KIND = _helperlib.py_unicode_2byte_kind diff --git a/numba_cuda/numba/cuda/core/removerefctpass.py b/numba_cuda/numba/cuda/core/removerefctpass.py new file mode 100644 index 000000000..11898de84 --- /dev/null +++ b/numba_cuda/numba/cuda/core/removerefctpass.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implement a rewrite pass on a LLVM module to remove unnecessary +refcount operations. +""" + +from llvmlite.ir.transforms import CallVisitor + +from numba.cuda import types + + +class _MarkNrtCallVisitor(CallVisitor): + """ + A pass to mark all NRT_incref and NRT_decref. + """ + + def __init__(self): + self.marked = set() + + def visit_Call(self, instr): + if getattr(instr.callee, "name", "") in _accepted_nrtfns: + self.marked.add(instr) + + +def _rewrite_function(function): + # Mark NRT usage + markpass = _MarkNrtCallVisitor() + markpass.visit_Function(function) + # Remove NRT usage + for bb in function.basic_blocks: + for inst in list(bb.instructions): + if inst in markpass.marked: + bb.instructions.remove(inst) + + +_accepted_nrtfns = "NRT_incref", "NRT_decref" + + +def _legalize(module, dmm, fndesc): + """ + Legalize the code in the module. + Returns True if the module is legal for the rewrite pass that removes + unnecessary refcounts. + """ + + def valid_output(ty): + """ + Valid output are any type that does not need refcount + """ + model = dmm[ty] + return not model.contains_nrt_meminfo() + + def valid_input(ty): + """ + Valid input are any type that does not need refcount except Array. + """ + return valid_output(ty) or isinstance(ty, types.Array) + + # Ensure no reference to function marked as + # "numba_args_may_always_need_nrt" + try: + nmd = module.get_named_metadata("numba_args_may_always_need_nrt") + except KeyError: + # Nothing marked + pass + else: + # Has functions marked as "numba_args_may_always_need_nrt" + if len(nmd.operands) > 0: + # The pass is illegal for this compilation unit. + return False + + # More legalization base on function type + argtypes = fndesc.argtypes + restype = fndesc.restype + calltypes = fndesc.calltypes + + # Legalize function arguments + for argty in argtypes: + if not valid_input(argty): + return False + + # Legalize function return + if not valid_output(restype): + return False + + # Legalize all called functions + for callty in calltypes.values(): + if callty is not None and not valid_output(callty.return_type): + return False + + # Ensure no allocation + for fn in module.functions: + if fn.name.startswith("NRT_"): + if fn.name not in _accepted_nrtfns: + return False + + return True + + +def remove_unnecessary_nrt_usage(function, context, fndesc): + """ + Remove unnecessary NRT incref/decref in the given LLVM function. + It uses highlevel type info to determine if the function does not need NRT. + Such a function does not: + + - return array object(s); + - take arguments that need refcounting except array; + - call function(s) that return refcounted object. + + In effect, the function will not capture or create references that extend + the lifetime of any refcounted objects beyond the lifetime of the function. + + The rewrite is performed in place. + If rewrite has happened, this function returns True, otherwise, it returns False. + """ + dmm = context.data_model_manager + if _legalize(function.module, dmm, fndesc): + _rewrite_function(function) + return True + else: + return False diff --git a/numba_cuda/numba/cuda/core/rewrites/registry.py b/numba_cuda/numba/cuda/core/rewrites/registry.py index bc2a20371..2ed67dcaa 100644 --- a/numba_cuda/numba/cuda/core/rewrites/registry.py +++ b/numba_cuda/numba/cuda/core/rewrites/registry.py @@ -3,7 +3,7 @@ from collections import defaultdict -from numba.core import config +from numba.cuda import config class Rewrite(object): diff --git a/numba_cuda/numba/cuda/core/rewrites/static_getitem.py b/numba_cuda/numba/cuda/core/rewrites/static_getitem.py index a91ccdf13..88ce6927a 100644 --- a/numba_cuda/numba/cuda/core/rewrites/static_getitem.py +++ b/numba_cuda/numba/cuda/core/rewrites/static_getitem.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba.core import errors, types, ir +from numba.core import errors, ir +from numba.cuda import types from numba.cuda.core.rewrites import register_rewrite, Rewrite diff --git a/numba_cuda/numba/cuda/core/sigutils.py b/numba_cuda/numba/cuda/core/sigutils.py index 62040fab1..04fa54c45 100644 --- a/numba_cuda/numba/cuda/core/sigutils.py +++ b/numba_cuda/numba/cuda/core/sigutils.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba.core import types -from numba.cuda import typing +from numba.cuda import types, typing try: from numba.core.typing import Signature as CoreSignature diff --git a/numba_cuda/numba/cuda/core/typed_passes.py b/numba_cuda/numba/cuda/core/typed_passes.py index 2a76bd35f..c81dbcbf4 100644 --- a/numba_cuda/numba/cuda/core/typed_passes.py +++ b/numba_cuda/numba/cuda/core/typed_passes.py @@ -10,10 +10,9 @@ from numba.cuda.core import typeinfer from numba.core import ( errors, - types, ir, ) -from numba.cuda import typing, lowering +from numba.cuda import typing, types, lowering from numba.cuda.core.compiler_machinery import ( FunctionPass, LoweringPass, diff --git a/numba_cuda/numba/cuda/core/typeinfer.py b/numba_cuda/numba/cuda/core/typeinfer.py index a35d9c197..3702985bb 100644 --- a/numba_cuda/numba/cuda/core/typeinfer.py +++ b/numba_cuda/numba/cuda/core/typeinfer.py @@ -35,7 +35,8 @@ from collections import OrderedDict, defaultdict from functools import reduce -from numba.core import types, utils, config, ir +from numba.cuda import types, utils, config +from numba.core import ir from numba.cuda import typing from numba.cuda.typing.templates import Signature from numba.core.errors import ( diff --git a/numba_cuda/numba/cuda/core/unsafe/bytes.py b/numba_cuda/numba/cuda/core/unsafe/bytes.py index b205b1a16..c83f682fb 100644 --- a/numba_cuda/numba/cuda/core/unsafe/bytes.py +++ b/numba_cuda/numba/cuda/core/unsafe/bytes.py @@ -8,7 +8,7 @@ from numba.cuda.extending import intrinsic from llvmlite import ir -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils diff --git a/numba_cuda/numba/cuda/core/unsafe/eh.py b/numba_cuda/numba/cuda/core/unsafe/eh.py index 416f818ec..851ec3337 100644 --- a/numba_cuda/numba/cuda/core/unsafe/eh.py +++ b/numba_cuda/numba/cuda/core/unsafe/eh.py @@ -5,7 +5,8 @@ Exception handling intrinsics. """ -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda import cgutils from numba.cuda.extending import intrinsic diff --git a/numba_cuda/numba/cuda/core/unsafe/refcount.py b/numba_cuda/numba/cuda/core/unsafe/refcount.py index 844367394..8064e7264 100644 --- a/numba_cuda/numba/cuda/core/unsafe/refcount.py +++ b/numba_cuda/numba/cuda/core/unsafe/refcount.py @@ -7,7 +7,7 @@ from llvmlite import ir -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.extending import intrinsic diff --git a/numba_cuda/numba/cuda/core/untyped_passes.py b/numba_cuda/numba/cuda/core/untyped_passes.py index 8a8956266..9bbd3f332 100644 --- a/numba_cuda/numba/cuda/core/untyped_passes.py +++ b/numba_cuda/numba/cuda/core/untyped_passes.py @@ -15,9 +15,9 @@ from numba.cuda.core import postproc, bytecode, transforms, inline_closurecall from numba.core import ( errors, - types, ir, ) +from numba.cuda import types from numba.cuda.core import consts, rewrites, config from numba.cuda.core.interpreter import Interpreter diff --git a/numba_cuda/numba/cuda/cpython/builtins.py b/numba_cuda/numba/cuda/cpython/builtins.py index 8c532602f..615af3723 100644 --- a/numba_cuda/numba/cuda/cpython/builtins.py +++ b/numba_cuda/numba/cuda/cpython/builtins.py @@ -18,8 +18,8 @@ numba_typeref_ctor, Registry, ) -from numba.core import types -from numba.cuda import cgutils, typing +from numba.cuda import typing, types +from numba.cuda import cgutils from numba.cuda.extending import overload, intrinsic, register_jitable from numba.core.errors import ( TypingError, diff --git a/numba_cuda/numba/cuda/cpython/charseq.py b/numba_cuda/numba/cuda/cpython/charseq.py index 607e212b5..d6738a6f1 100644 --- a/numba_cuda/numba/cuda/cpython/charseq.py +++ b/numba_cuda/numba/cuda/cpython/charseq.py @@ -7,7 +7,7 @@ import numpy as np from llvmlite import ir -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.extending import ( overload, diff --git a/numba_cuda/numba/cuda/cpython/cmathimpl.py b/numba_cuda/numba/cuda/cpython/cmathimpl.py index b1784a7cc..2ba2abe09 100644 --- a/numba_cuda/numba/cuda/cpython/cmathimpl.py +++ b/numba_cuda/numba/cuda/cpython/cmathimpl.py @@ -9,7 +9,7 @@ import math from numba.cuda.core.imputils import impl_ret_untracked, Registry -from numba.core import types +from numba.cuda import types from numba.cuda.typing import signature from numba.cuda.cpython import mathimpl from numba.cuda.extending import overload diff --git a/numba_cuda/numba/cuda/cpython/enumimpl.py b/numba_cuda/numba/cuda/cpython/enumimpl.py index 26444ba7f..4baf512ba 100644 --- a/numba_cuda/numba/cuda/cpython/enumimpl.py +++ b/numba_cuda/numba/cuda/cpython/enumimpl.py @@ -8,7 +8,7 @@ import operator from numba.cuda.core.imputils import Registry, impl_ret_untracked -from numba.core import types +from numba.cuda import types from numba.cuda.extending import overload_method registry = Registry("enumimpl") diff --git a/numba_cuda/numba/cuda/cpython/iterators.py b/numba_cuda/numba/cuda/cpython/iterators.py index dcf6914f6..df7a87b75 100644 --- a/numba_cuda/numba/cuda/cpython/iterators.py +++ b/numba_cuda/numba/cuda/cpython/iterators.py @@ -5,7 +5,7 @@ Implementation of various iterable and iterator types. """ -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.core.imputils import ( iternext_impl, diff --git a/numba_cuda/numba/cuda/cpython/listobj.py b/numba_cuda/numba/cuda/cpython/listobj.py index 0d3f2322c..b396f5c60 100644 --- a/numba_cuda/numba/cuda/cpython/listobj.py +++ b/numba_cuda/numba/cuda/cpython/listobj.py @@ -8,7 +8,8 @@ import operator from llvmlite import ir -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda import cgutils from numba.cuda.core.imputils import ( Registry, diff --git a/numba_cuda/numba/cuda/cpython/mathimpl.py b/numba_cuda/numba/cuda/cpython/mathimpl.py index d449b1ec8..0fe129917 100644 --- a/numba_cuda/numba/cuda/cpython/mathimpl.py +++ b/numba_cuda/numba/cuda/cpython/mathimpl.py @@ -14,7 +14,7 @@ from llvmlite.ir import Constant from numba.cuda.core.imputils import impl_ret_untracked, Registry -from numba.core import types +from numba.cuda import types from numba.cuda.core import config from numba.cuda.extending import overload from numba.cuda.typing import signature diff --git a/numba_cuda/numba/cuda/cpython/numbers.py b/numba_cuda/numba/cuda/cpython/numbers.py index 91f3ff189..e5dd6e224 100644 --- a/numba_cuda/numba/cuda/cpython/numbers.py +++ b/numba_cuda/numba/cuda/cpython/numbers.py @@ -11,7 +11,8 @@ from llvmlite.ir import Constant from numba.cuda.core.imputils import impl_ret_untracked, Registry -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda.extending import overload_method from numba.cuda.cpython.unsafe.numbers import viewer from numba.cuda import cgutils, typing diff --git a/numba_cuda/numba/cuda/cpython/rangeobj.py b/numba_cuda/numba/cuda/cpython/rangeobj.py index 7bdbc8250..00a75cba8 100644 --- a/numba_cuda/numba/cuda/cpython/rangeobj.py +++ b/numba_cuda/numba/cuda/cpython/rangeobj.py @@ -7,7 +7,7 @@ import operator -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.core.imputils import ( Registry, diff --git a/numba_cuda/numba/cuda/cpython/slicing.py b/numba_cuda/numba/cuda/cpython/slicing.py index d42e3c0ab..0899db994 100644 --- a/numba_cuda/numba/cuda/cpython/slicing.py +++ b/numba_cuda/numba/cuda/cpython/slicing.py @@ -6,7 +6,7 @@ """ from llvmlite import ir -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.core.imputils import impl_ret_untracked, Registry diff --git a/numba_cuda/numba/cuda/cpython/tupleobj.py b/numba_cuda/numba/cuda/cpython/tupleobj.py index 96bfb84d8..5292d412d 100644 --- a/numba_cuda/numba/cuda/cpython/tupleobj.py +++ b/numba_cuda/numba/cuda/cpython/tupleobj.py @@ -14,8 +14,8 @@ impl_ret_untracked, RefType, ) -from numba.core import types -from numba.cuda import cgutils, typing +from numba.cuda import typing, types +from numba.cuda import cgutils from numba.cuda.extending import overload_method, overload registry = Registry("tupleobj") diff --git a/numba_cuda/numba/cuda/cpython/unicode.py b/numba_cuda/numba/cuda/cpython/unicode.py index f4283abc4..7e19841c2 100644 --- a/numba_cuda/numba/cuda/cpython/unicode.py +++ b/numba_cuda/numba/cuda/cpython/unicode.py @@ -13,8 +13,8 @@ overload, overload_method, register_jitable, - core_models, ) +from numba.cuda.extending import models from numba.cuda.core.pythonapi import box, unbox from numba.cuda.extending import make_attribute_wrapper, intrinsic from numba.cuda.models import register_model @@ -24,8 +24,8 @@ RefType, Registry, ) -from numba.core.datamodel import register_default, StructModel -from numba.core import types +from numba.cuda.datamodel import register_default, StructModel +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.utils import PYVERSION from numba.cuda.core.pythonapi import ( @@ -83,7 +83,7 @@ lower_getattr = registry.lower_getattr if PYVERSION in ((3, 9), (3, 10), (3, 11)): - from numba.core.pythonapi import PY_UNICODE_WCHAR_KIND + from numba.cuda.core.pythonapi import PY_UNICODE_WCHAR_KIND # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L84-L85 # noqa: E501 _MAX_UNICODE = 0x10FFFF @@ -95,7 +95,7 @@ @register_model(types.UnicodeType) -class UnicodeModel(core_models.StructModel): +class UnicodeModel(models.StructModel): def __init__(self, dmm, fe_type): members = [ ("data", types.voidptr), @@ -107,7 +107,7 @@ def __init__(self, dmm, fe_type): # A pointer to the owner python str/unicode object ("parent", types.pyobject), ] - core_models.StructModel.__init__(self, dmm, fe_type, members) + models.StructModel.__init__(self, dmm, fe_type, members) make_attribute_wrapper(types.UnicodeType, "data", "_data") diff --git a/numba_cuda/numba/cuda/cpython/unicode_support.py b/numba_cuda/numba/cuda/cpython/unicode_support.py index c3d58c3c2..6c54ec94f 100644 --- a/numba_cuda/numba/cuda/cpython/unicode_support.py +++ b/numba_cuda/numba/cuda/cpython/unicode_support.py @@ -14,7 +14,7 @@ import llvmlite.ir import numpy as np -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.core.imputils import impl_ret_untracked diff --git a/numba_cuda/numba/cuda/cpython/unsafe/numbers.py b/numba_cuda/numba/cuda/cpython/unsafe/numbers.py index 10fb8af24..4cff57200 100644 --- a/numba_cuda/numba/cuda/cpython/unsafe/numbers.py +++ b/numba_cuda/numba/cuda/cpython/unsafe/numbers.py @@ -3,7 +3,8 @@ """This module provides the unsafe things for targets/numbers.py""" -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda.extending import intrinsic from llvmlite import ir diff --git a/numba_cuda/numba/cuda/cpython/unsafe/tuple.py b/numba_cuda/numba/cuda/cpython/unsafe/tuple.py index c52a67f76..50f1d709e 100644 --- a/numba_cuda/numba/cuda/cpython/unsafe/tuple.py +++ b/numba_cuda/numba/cuda/cpython/unsafe/tuple.py @@ -6,8 +6,8 @@ operations with tuple and workarounds for limitations enforced in userland. """ -from numba.core import types, errors -from numba.cuda import typing +from numba.cuda import types, typing +from numba.core import errors from numba.cuda.cgutils import alloca_once from numba.cuda.extending import intrinsic diff --git a/numba_cuda/numba/cuda/cudadecl.py b/numba_cuda/numba/cuda/cudadecl.py index 88d907bac..6b805514b 100644 --- a/numba_cuda/numba/cuda/cudadecl.py +++ b/numba_cuda/numba/cuda/cudadecl.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba.core import errors, types +from numba.core import errors +from numba.cuda import types from numba.cuda.typing.npydecl import ( parse_dtype, parse_shape, @@ -20,7 +21,7 @@ signature, Registry, ) -from numba.cuda.types import dim3 +from numba.cuda.ext_types import dim3 from numba import cuda registry = Registry() diff --git a/numba_cuda/numba/cuda/cudadrv/devicearray.py b/numba_cuda/numba/cuda/cudadrv/devicearray.py index 6cdb12424..2b799239f 100644 --- a/numba_cuda/numba/cuda/cudadrv/devicearray.py +++ b/numba_cuda/numba/cuda/cudadrv/devicearray.py @@ -18,7 +18,7 @@ from numba.cuda.cext import _devicearray from numba.cuda.cudadrv import devices, dummyarray from numba.cuda.cudadrv import driver as _driver -from numba.core import types +from numba.cuda import types from numba.cuda.core import config from numba.cuda.np.unsafe.ndarray import to_fixed_tuple from numba.cuda.np.numpy_support import numpy_version diff --git a/numba_cuda/numba/cuda/cudaimpl.py b/numba_cuda/numba/cuda/cudaimpl.py index 46b3077e4..34e94d667 100644 --- a/numba_cuda/numba/cuda/cudaimpl.py +++ b/numba_cuda/numba/cuda/cudaimpl.py @@ -11,17 +11,17 @@ from numba.cuda.core.imputils import Registry from numba.cuda.typing.npydecl import parse_dtype -from numba.core.datamodel import models -from numba.core import types +from numba.cuda.datamodel import models +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.np import ufunc_db from numba.cuda.np.npyimpl import register_ufuncs from .cudadrv import nvvm from numba import cuda from numba.cuda import nvvmutils, stubs -from numba.cuda.types import dim3, CUDADispatcher +from numba.cuda.ext_types import dim3, CUDADispatcher -registry = Registry() +registry = Registry("cudaimpl") lower = registry.lower lower_attr = registry.lower_getattr lower_constant = registry.lower_constant diff --git a/numba_cuda/numba/cuda/cudamath.py b/numba_cuda/numba/cuda/cudamath.py index 1cc5a5ac8..e02290489 100644 --- a/numba_cuda/numba/cuda/cudamath.py +++ b/numba_cuda/numba/cuda/cudamath.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause import math -from numba.core import types +from numba.cuda import types from numba.cuda.typing.templates import ConcreteTemplate, signature, Registry diff --git a/numba_cuda/numba/cuda/datamodel/__init__.py b/numba_cuda/numba/cuda/datamodel/__init__.py new file mode 100644 index 000000000..2a9278ca3 --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from .manager import DataModelManager +from .packer import ArgPacker, DataPacker +from .registry import register_default, default_manager, register +from .models import PrimitiveModel, CompositeModel, StructModel # type: ignore diff --git a/numba_cuda/numba/cuda/datamodel/cuda_manager.py b/numba_cuda/numba/cuda/datamodel/cuda_manager.py new file mode 100644 index 000000000..90d0cebec --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/cuda_manager.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import weakref +from collections import ChainMap + +from numba.cuda import types + + +class DataModelManager(object): + """Manages mapping of FE types to their corresponding data model""" + + def __init__(self, handlers=None): + """ + Parameters + ----------- + handlers: Mapping[Type, DataModel] or None + Optionally provide the initial handlers mapping. + """ + # { numba type class -> model factory } + self._handlers = handlers or {} + # { numba type instance -> model instance } + self._cache = weakref.WeakKeyDictionary() + + def register(self, fetypecls, handler): + """Register the datamodel factory corresponding to a frontend-type class""" + assert issubclass(fetypecls, types.Type) + self._handlers[fetypecls] = handler + + def lookup(self, fetype): + """Returns the corresponding datamodel given the frontend-type instance""" + try: + return self._cache[fetype] + except KeyError: + pass + handler = self._handlers[type(fetype)] + model = self._cache[fetype] = handler(self, fetype) + return model + + def __getitem__(self, fetype): + """Shorthand for lookup()""" + return self.lookup(fetype) + + def copy(self): + """ + Make a copy of the manager. + Use this to inherit from the default data model and specialize it + for custom target. + """ + return DataModelManager(self._handlers.copy()) + + def chain(self, other_manager): + """Create a new DataModelManager by chaining the handlers mapping of + `other_manager` with a fresh handlers mapping. + + Any existing and new handlers inserted to `other_manager` will be + visible to the new manager. Any handlers inserted to the new manager + can override existing handlers in `other_manager` without actually + mutating `other_manager`. + + Parameters + ---------- + other_manager: DataModelManager + """ + chained = ChainMap(self._handlers, other_manager._handlers) + return DataModelManager(chained) diff --git a/numba_cuda/numba/cuda/datamodel/cuda_models.py b/numba_cuda/numba/cuda/datamodel/cuda_models.py new file mode 100644 index 000000000..4b0dab47c --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/cuda_models.py @@ -0,0 +1,1446 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from functools import partial +from collections import deque + +from llvmlite import ir + +from numba.cuda.datamodel.registry import register_default +from numba.cuda import types +from numba.cuda import cgutils +from numba.cuda.np import numpy_support + + +class DataModel(object): + """ + DataModel describe how a FE type is represented in the LLVM IR at + different contexts. + + Contexts are: + + - value: representation inside function body. Maybe stored in stack. + The representation here are flexible. + + - data: representation used when storing into containers (e.g. arrays). + + - argument: representation used for function argument. All composite + types are unflattened into multiple primitive types. + + - return: representation used for return argument. + + Throughput the compiler pipeline, a LLVM value is usually passed around + in the "value" representation. All "as_" prefix function converts from + "value" representation. All "from_" prefix function converts to the + "value" representation. + + """ + + def __init__(self, dmm, fe_type): + self._dmm = dmm + self._fe_type = fe_type + + @property + def fe_type(self): + return self._fe_type + + def get_value_type(self): + raise NotImplementedError(self) + + def get_data_type(self): + return self.get_value_type() + + def get_argument_type(self): + """Return a LLVM type or nested tuple of LLVM type""" + return self.get_value_type() + + def get_return_type(self): + return self.get_value_type() + + def as_data(self, builder, value): + raise NotImplementedError(self) + + def as_argument(self, builder, value): + """ + Takes one LLVM value + Return a LLVM value or nested tuple of LLVM value + """ + raise NotImplementedError(self) + + def as_return(self, builder, value): + raise NotImplementedError(self) + + def from_data(self, builder, value): + raise NotImplementedError(self) + + def from_argument(self, builder, value): + """ + Takes a LLVM value or nested tuple of LLVM value + Returns one LLVM value + """ + raise NotImplementedError(self) + + def from_return(self, builder, value): + raise NotImplementedError(self) + + def load_from_data_pointer(self, builder, ptr, align=None): + """ + Load value from a pointer to data. + This is the default implementation, sufficient for most purposes. + """ + return self.from_data(builder, builder.load(ptr, align=align)) + + def traverse(self, builder): + """ + Traverse contained members. + Returns a iterable of contained (types, getters). + Each getter is a one-argument function accepting a LLVM value. + """ + return [] + + def traverse_models(self): + """ + Recursively list all models involved in this model. + """ + return [self._dmm[t] for t in self.traverse_types()] + + def traverse_types(self): + """ + Recursively list all frontend types involved in this model. + """ + types = [self._fe_type] + queue = deque([self]) + while len(queue) > 0: + dm = queue.popleft() + + for i_dm in dm.inner_models(): + if i_dm._fe_type not in types: + queue.append(i_dm) + types.append(i_dm._fe_type) + + return types + + def inner_models(self): + """ + List all *inner* models. + """ + return [] + + def get_nrt_meminfo(self, builder, value): + """ + Returns the MemInfo object or None if it is not tracked. + It is only defined for types.meminfo_pointer + """ + return None + + def has_nrt_meminfo(self): + return False + + def contains_nrt_meminfo(self): + """ + Recursively check all contained types for need for NRT meminfo. + """ + return any(model.has_nrt_meminfo() for model in self.traverse_models()) + + def _compared_fields(self): + return (type(self), self._fe_type) + + def __hash__(self): + return hash(tuple(self._compared_fields())) + + def __eq__(self, other): + if type(self) is type(other): + return self._compared_fields() == other._compared_fields() + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + +@register_default(types.Omitted) +class OmittedArgDataModel(DataModel): + """ + A data model for omitted arguments. Only the "argument" representation + is defined, other representations raise a NotImplementedError. + """ + + # Omitted arguments are using a dummy value type + def get_value_type(self): + return ir.LiteralStructType([]) + + # Omitted arguments don't produce any LLVM function argument. + def get_argument_type(self): + return () + + def as_argument(self, builder, val): + return () + + def from_argument(self, builder, val): + assert val == (), val + return None + + +@register_default(types.Boolean) +@register_default(types.BooleanLiteral) +class BooleanModel(DataModel): + _bit_type = ir.IntType(1) + _byte_type = ir.IntType(8) + + def get_value_type(self): + return self._bit_type + + def get_data_type(self): + return self._byte_type + + def get_return_type(self): + return self.get_data_type() + + def get_argument_type(self): + return self.get_data_type() + + def as_data(self, builder, value): + return builder.zext(value, self.get_data_type()) + + def as_argument(self, builder, value): + return self.as_data(builder, value) + + def as_return(self, builder, value): + return self.as_data(builder, value) + + def from_data(self, builder, value): + ty = self.get_value_type() + resalloca = cgutils.alloca_once(builder, ty) + cond = builder.icmp_unsigned("==", value, value.type(0)) + with builder.if_else(cond) as (then, otherwise): + with then: + builder.store(ty(0), resalloca) + with otherwise: + builder.store(ty(1), resalloca) + return builder.load(resalloca) + + def from_argument(self, builder, value): + return self.from_data(builder, value) + + def from_return(self, builder, value): + return self.from_data(builder, value) + + +class PrimitiveModel(DataModel): + """A primitive type can be represented natively in the target in all + usage contexts. + """ + + def __init__(self, dmm, fe_type, be_type): + super(PrimitiveModel, self).__init__(dmm, fe_type) + self.be_type = be_type + + def get_value_type(self): + return self.be_type + + def as_data(self, builder, value): + return value + + def as_argument(self, builder, value): + return value + + def as_return(self, builder, value): + return value + + def from_data(self, builder, value): + return value + + def from_argument(self, builder, value): + return value + + def from_return(self, builder, value): + return value + + +class ProxyModel(DataModel): + """ + Helper class for models which delegate to another model. + """ + + def get_value_type(self): + return self._proxied_model.get_value_type() + + def get_data_type(self): + return self._proxied_model.get_data_type() + + def get_return_type(self): + return self._proxied_model.get_return_type() + + def get_argument_type(self): + return self._proxied_model.get_argument_type() + + def as_data(self, builder, value): + return self._proxied_model.as_data(builder, value) + + def as_argument(self, builder, value): + return self._proxied_model.as_argument(builder, value) + + def as_return(self, builder, value): + return self._proxied_model.as_return(builder, value) + + def from_data(self, builder, value): + return self._proxied_model.from_data(builder, value) + + def from_argument(self, builder, value): + return self._proxied_model.from_argument(builder, value) + + def from_return(self, builder, value): + return self._proxied_model.from_return(builder, value) + + +@register_default(types.EnumMember) +@register_default(types.IntEnumMember) +class EnumModel(ProxyModel): + """ + Enum members are represented exactly like their values. + """ + + def __init__(self, dmm, fe_type): + super(EnumModel, self).__init__(dmm, fe_type) + self._proxied_model = dmm.lookup(fe_type.dtype) + + +@register_default(types.Opaque) +@register_default(types.PyObject) +@register_default(types.RawPointer) +@register_default(types.NoneType) +@register_default(types.StringLiteral) +@register_default(types.EllipsisType) +@register_default(types.Function) +@register_default(types.Type) +@register_default(types.Object) +@register_default(types.Module) +@register_default(types.Phantom) +@register_default(types.UndefVar) +@register_default(types.ContextManager) +@register_default(types.Dispatcher) +@register_default(types.ObjModeDispatcher) +@register_default(types.ExceptionClass) +@register_default(types.Dummy) +@register_default(types.ExceptionInstance) +@register_default(types.ExternalFunction) +@register_default(types.EnumClass) +@register_default(types.IntEnumClass) +@register_default(types.NumberClass) +@register_default(types.TypeRef) +@register_default(types.NamedTupleClass) +@register_default(types.DType) +@register_default(types.RecursiveCall) +@register_default(types.MakeFunctionLiteral) +@register_default(types.Poison) +class OpaqueModel(PrimitiveModel): + """ + Passed as opaque pointers + """ + + _ptr_type = ir.IntType(8).as_pointer() + + def __init__(self, dmm, fe_type): + be_type = self._ptr_type + super(OpaqueModel, self).__init__(dmm, fe_type, be_type) + + +@register_default(types.MemInfoPointer) +class MemInfoModel(OpaqueModel): + def inner_models(self): + return [self._dmm.lookup(self._fe_type.dtype)] + + def has_nrt_meminfo(self): + return True + + def get_nrt_meminfo(self, builder, value): + return value + + +@register_default(types.Integer) +@register_default(types.IntegerLiteral) +class IntegerModel(PrimitiveModel): + def __init__(self, dmm, fe_type): + be_type = ir.IntType(fe_type.bitwidth) + super(IntegerModel, self).__init__(dmm, fe_type, be_type) + + +@register_default(types.Float) +class FloatModel(PrimitiveModel): + def __init__(self, dmm, fe_type): + if fe_type == types.float32: + be_type = ir.FloatType() + elif fe_type == types.float64: + be_type = ir.DoubleType() + else: + raise NotImplementedError(fe_type) + super(FloatModel, self).__init__(dmm, fe_type, be_type) + + +@register_default(types.CPointer) +class PointerModel(PrimitiveModel): + def __init__(self, dmm, fe_type): + self._pointee_model = dmm.lookup(fe_type.dtype) + self._pointee_be_type = self._pointee_model.get_data_type() + be_type = self._pointee_be_type.as_pointer() + super(PointerModel, self).__init__(dmm, fe_type, be_type) + + +@register_default(types.EphemeralPointer) +class EphemeralPointerModel(PointerModel): + def get_data_type(self): + return self._pointee_be_type + + def as_data(self, builder, value): + value = builder.load(value) + return self._pointee_model.as_data(builder, value) + + def from_data(self, builder, value): + raise NotImplementedError("use load_from_data_pointer() instead") + + def load_from_data_pointer(self, builder, ptr, align=None): + return builder.bitcast(ptr, self.get_value_type()) + + +@register_default(types.EphemeralArray) +class EphemeralArrayModel(PointerModel): + def __init__(self, dmm, fe_type): + super(EphemeralArrayModel, self).__init__(dmm, fe_type) + self._data_type = ir.ArrayType( + self._pointee_be_type, self._fe_type.count + ) + + def get_data_type(self): + return self._data_type + + def as_data(self, builder, value): + values = [ + builder.load(cgutils.gep_inbounds(builder, value, i)) + for i in range(self._fe_type.count) + ] + return cgutils.pack_array(builder, values) + + def from_data(self, builder, value): + raise NotImplementedError("use load_from_data_pointer() instead") + + def load_from_data_pointer(self, builder, ptr, align=None): + return builder.bitcast(ptr, self.get_value_type()) + + +@register_default(types.ExternalFunctionPointer) +class ExternalFuncPointerModel(PrimitiveModel): + def __init__(self, dmm, fe_type): + sig = fe_type.sig + # Since the function is non-Numba, there is no adaptation + # of arguments and return value, hence get_value_type(). + retty = dmm.lookup(sig.return_type).get_value_type() + args = [dmm.lookup(t).get_value_type() for t in sig.args] + be_type = ir.PointerType(ir.FunctionType(retty, args)) + super(ExternalFuncPointerModel, self).__init__(dmm, fe_type, be_type) + + +@register_default(types.UniTuple) +@register_default(types.NamedUniTuple) +@register_default(types.StarArgUniTuple) +class UniTupleModel(DataModel): + def __init__(self, dmm, fe_type): + super(UniTupleModel, self).__init__(dmm, fe_type) + self._elem_model = dmm.lookup(fe_type.dtype) + self._count = len(fe_type) + self._value_type = ir.ArrayType( + self._elem_model.get_value_type(), self._count + ) + self._data_type = ir.ArrayType( + self._elem_model.get_data_type(), self._count + ) + + def get_value_type(self): + return self._value_type + + def get_data_type(self): + return self._data_type + + def get_return_type(self): + return self.get_value_type() + + def get_argument_type(self): + return (self._elem_model.get_argument_type(),) * self._count + + def as_argument(self, builder, value): + out = [] + for i in range(self._count): + v = builder.extract_value(value, [i]) + v = self._elem_model.as_argument(builder, v) + out.append(v) + return out + + def from_argument(self, builder, value): + out = ir.Constant(self.get_value_type(), ir.Undefined) + for i, v in enumerate(value): + v = self._elem_model.from_argument(builder, v) + out = builder.insert_value(out, v, [i]) + return out + + def as_data(self, builder, value): + out = ir.Constant(self.get_data_type(), ir.Undefined) + for i in range(self._count): + val = builder.extract_value(value, [i]) + dval = self._elem_model.as_data(builder, val) + out = builder.insert_value(out, dval, [i]) + return out + + def from_data(self, builder, value): + out = ir.Constant(self.get_value_type(), ir.Undefined) + for i in range(self._count): + val = builder.extract_value(value, [i]) + dval = self._elem_model.from_data(builder, val) + out = builder.insert_value(out, dval, [i]) + return out + + def as_return(self, builder, value): + return value + + def from_return(self, builder, value): + return value + + def traverse(self, builder): + def getter(i, value): + return builder.extract_value(value, i) + + return [ + (self._fe_type.dtype, partial(getter, i)) + for i in range(self._count) + ] + + def inner_models(self): + return [self._elem_model] + + +class CompositeModel(DataModel): + """Any model that is composed of multiple other models should subclass from + this. + """ + + pass + + +class StructModel(CompositeModel): + _value_type = None + _data_type = None + + def __init__(self, dmm, fe_type, members): + super(StructModel, self).__init__(dmm, fe_type) + if members: + self._fields, self._members = zip(*members) + else: + self._fields = self._members = () + self._models = tuple([self._dmm.lookup(t) for t in self._members]) + + def get_member_fe_type(self, name): + """ + StructModel-specific: get the Numba type of the field named *name*. + """ + pos = self.get_field_position(name) + return self._members[pos] + + def get_value_type(self): + if self._value_type is None: + self._value_type = ir.LiteralStructType( + [t.get_value_type() for t in self._models] + ) + return self._value_type + + def get_data_type(self): + if self._data_type is None: + self._data_type = ir.LiteralStructType( + [t.get_data_type() for t in self._models] + ) + return self._data_type + + def get_argument_type(self): + return tuple([t.get_argument_type() for t in self._models]) + + def get_return_type(self): + return self.get_data_type() + + def _as(self, methname, builder, value): + extracted = [] + for i, dm in enumerate(self._models): + extracted.append( + getattr(dm, methname)(builder, self.get(builder, value, i)) + ) + return tuple(extracted) + + def _from(self, methname, builder, value): + struct = ir.Constant(self.get_value_type(), ir.Undefined) + + for i, (dm, val) in enumerate(zip(self._models, value)): + v = getattr(dm, methname)(builder, val) + struct = self.set(builder, struct, v, i) + + return struct + + def as_data(self, builder, value): + """ + Converts the LLVM struct in `value` into a representation suited for + storing into arrays. + + Note + ---- + Current implementation rarely changes how types are represented for + "value" and "data". This is usually a pointless rebuild of the + immutable LLVM struct value. Luckily, LLVM optimization removes all + redundancy. + + Sample usecase: Structures nested with pointers to other structures + that can be serialized into a flat representation when storing into + array. + """ + elems = self._as("as_data", builder, value) + struct = ir.Constant(self.get_data_type(), ir.Undefined) + for i, el in enumerate(elems): + struct = builder.insert_value(struct, el, [i]) + return struct + + def from_data(self, builder, value): + """ + Convert from "data" representation back into "value" representation. + Usually invoked when loading from array. + + See notes in `as_data()` + """ + vals = [ + builder.extract_value(value, [i]) for i in range(len(self._members)) + ] + return self._from("from_data", builder, vals) + + def load_from_data_pointer(self, builder, ptr, align=None): + values = [] + for i, model in enumerate(self._models): + elem_ptr = cgutils.gep_inbounds(builder, ptr, 0, i) + val = model.load_from_data_pointer(builder, elem_ptr, align) + values.append(val) + + struct = ir.Constant(self.get_value_type(), ir.Undefined) + for i, val in enumerate(values): + struct = self.set(builder, struct, val, i) + return struct + + def as_argument(self, builder, value): + return self._as("as_argument", builder, value) + + def from_argument(self, builder, value): + return self._from("from_argument", builder, value) + + def as_return(self, builder, value): + elems = self._as("as_data", builder, value) + struct = ir.Constant(self.get_data_type(), ir.Undefined) + for i, el in enumerate(elems): + struct = builder.insert_value(struct, el, [i]) + return struct + + def from_return(self, builder, value): + vals = [ + builder.extract_value(value, [i]) for i in range(len(self._members)) + ] + return self._from("from_data", builder, vals) + + def get(self, builder, val, pos): + """Get a field at the given position or the fieldname + + Args + ---- + builder: + LLVM IRBuilder + val: + value to be inserted + pos: int or str + field index or field name + + Returns + ------- + Extracted value + """ + if isinstance(pos, str): + pos = self.get_field_position(pos) + return builder.extract_value( + val, [pos], name="extracted." + self._fields[pos] + ) + + def set(self, builder, stval, val, pos): + """Set a field at the given position or the fieldname + + Args + ---- + builder: + LLVM IRBuilder + stval: + LLVM struct value + val: + value to be inserted + pos: int or str + field index or field name + + Returns + ------- + A new LLVM struct with the value inserted + """ + if isinstance(pos, str): + pos = self.get_field_position(pos) + return builder.insert_value( + stval, val, [pos], name="inserted." + self._fields[pos] + ) + + def get_field_position(self, field): + try: + return self._fields.index(field) + except ValueError: + raise KeyError( + "%s does not have a field named %r" + % (self.__class__.__name__, field) + ) + + @property + def field_count(self): + return len(self._fields) + + def get_type(self, pos): + """Get the frontend type (numba type) of a field given the position + or the fieldname + + Args + ---- + pos: int or str + field index or field name + """ + if isinstance(pos, str): + pos = self.get_field_position(pos) + return self._members[pos] + + def get_model(self, pos): + """ + Get the datamodel of a field given the position or the fieldname. + + Args + ---- + pos: int or str + field index or field name + """ + return self._models[pos] + + def traverse(self, builder): + def getter(k, value): + if value.type != self.get_value_type(): + args = self.get_value_type(), value.type + raise TypeError("expecting {0} but got {1}".format(*args)) + return self.get(builder, value, k) + + return [(self.get_type(k), partial(getter, k)) for k in self._fields] + + def inner_models(self): + return self._models + + +@register_default(types.Complex) +class ComplexModel(StructModel): + _element_type = NotImplemented + + def __init__(self, dmm, fe_type): + members = [ + ("real", fe_type.underlying_float), + ("imag", fe_type.underlying_float), + ] + super(ComplexModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.LiteralList) +@register_default(types.LiteralStrKeyDict) +@register_default(types.Tuple) +@register_default(types.NamedTuple) +@register_default(types.StarArgTuple) +class TupleModel(StructModel): + def __init__(self, dmm, fe_type): + members = [("f" + str(i), t) for i, t in enumerate(fe_type)] + super(TupleModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.UnionType) +class UnionModel(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("tag", types.uintp), + # XXX: it should really be a MemInfoPointer(types.voidptr) + ("payload", types.Tuple.from_types(fe_type.types)), + ] + super(UnionModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.Pair) +class PairModel(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("first", fe_type.first_type), + ("second", fe_type.second_type), + ] + super(PairModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.ListPayload) +class ListPayloadModel(StructModel): + def __init__(self, dmm, fe_type): + # The fields are mutable but the payload is always manipulated + # by reference. This scheme allows mutations of an array to + # be seen by its iterators. + members = [ + ("size", types.intp), + ("allocated", types.intp), + # This member is only used only for reflected lists + ("dirty", types.boolean), + # Actually an inlined var-sized array + ("data", fe_type.container.dtype), + ] + super(ListPayloadModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.List) +class ListModel(StructModel): + def __init__(self, dmm, fe_type): + payload_type = types.ListPayload(fe_type) + members = [ + # The meminfo data points to a ListPayload + ("meminfo", types.MemInfoPointer(payload_type)), + # This member is only used only for reflected lists + ("parent", types.pyobject), + ] + super(ListModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.ListIter) +class ListIterModel(StructModel): + def __init__(self, dmm, fe_type): + payload_type = types.ListPayload(fe_type.container) + members = [ + # The meminfo data points to a ListPayload (shared with the + # original list object) + ("meminfo", types.MemInfoPointer(payload_type)), + ("index", types.EphemeralPointer(types.intp)), + ] + super(ListIterModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.SetEntry) +class SetEntryModel(StructModel): + def __init__(self, dmm, fe_type): + dtype = fe_type.set_type.dtype + members = [ + # -1 = empty, -2 = deleted + ("hash", types.intp), + ("key", dtype), + ] + super(SetEntryModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.SetPayload) +class SetPayloadModel(StructModel): + def __init__(self, dmm, fe_type): + entry_type = types.SetEntry(fe_type.container) + members = [ + # Number of active + deleted entries + ("fill", types.intp), + # Number of active entries + ("used", types.intp), + # Allocated size - 1 (size being a power of 2) + ("mask", types.intp), + # Search finger + ("finger", types.intp), + # This member is only used only for reflected sets + ("dirty", types.boolean), + # Actually an inlined var-sized array + ("entries", entry_type), + ] + super(SetPayloadModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.Set) +class SetModel(StructModel): + def __init__(self, dmm, fe_type): + payload_type = types.SetPayload(fe_type) + members = [ + # The meminfo data points to a SetPayload + ("meminfo", types.MemInfoPointer(payload_type)), + # This member is only used only for reflected sets + ("parent", types.pyobject), + ] + super(SetModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.SetIter) +class SetIterModel(StructModel): + def __init__(self, dmm, fe_type): + payload_type = types.SetPayload(fe_type.container) + members = [ + # The meminfo data points to a SetPayload (shared with the + # original set object) + ("meminfo", types.MemInfoPointer(payload_type)), + # The index into the entries table + ("index", types.EphemeralPointer(types.intp)), + ] + super(SetIterModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.Array) +@register_default(types.Buffer) +@register_default(types.ByteArray) +@register_default(types.Bytes) +@register_default(types.MemoryView) +@register_default(types.PyArray) +class ArrayModel(StructModel): + def __init__(self, dmm, fe_type): + ndim = fe_type.ndim + members = [ + ("meminfo", types.MemInfoPointer(fe_type.dtype)), + ("parent", types.pyobject), + ("nitems", types.intp), + ("itemsize", types.intp), + ("data", types.CPointer(fe_type.dtype)), + ("shape", types.UniTuple(types.intp, ndim)), + ("strides", types.UniTuple(types.intp, ndim)), + ] + super(ArrayModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.ArrayFlags) +class ArrayFlagsModel(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("parent", fe_type.array_type), + ] + super(ArrayFlagsModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.NestedArray) +class NestedArrayModel(ArrayModel): + def __init__(self, dmm, fe_type): + self._be_type = dmm.lookup(fe_type.dtype).get_data_type() + super(NestedArrayModel, self).__init__(dmm, fe_type) + + def as_storage_type(self): + """Return the LLVM type representation for the storage of + the nestedarray. + """ + ret = ir.ArrayType(self._be_type, self._fe_type.nitems) + return ret + + +@register_default(types.Optional) +class OptionalModel(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("data", fe_type.type), + ("valid", types.boolean), + ] + self._value_model = dmm.lookup(fe_type.type) + super(OptionalModel, self).__init__(dmm, fe_type, members) + + def get_return_type(self): + return self._value_model.get_return_type() + + def as_return(self, builder, value): + raise NotImplementedError + + def from_return(self, builder, value): + return self._value_model.from_return(builder, value) + + def traverse(self, builder): + def get_data(value): + valid = get_valid(value) + data = self.get(builder, value, "data") + return builder.select(valid, data, ir.Constant(data.type, None)) + + def get_valid(value): + return self.get(builder, value, "valid") + + return [ + (self.get_type("data"), get_data), + (self.get_type("valid"), get_valid), + ] + + +@register_default(types.Record) +class RecordModel(CompositeModel): + def __init__(self, dmm, fe_type): + super(RecordModel, self).__init__(dmm, fe_type) + self._models = [self._dmm.lookup(t) for _, t in fe_type.members] + self._be_type = ir.ArrayType(ir.IntType(8), fe_type.size) + self._be_ptr_type = self._be_type.as_pointer() + + def get_value_type(self): + """Passed around as reference to underlying data""" + return self._be_ptr_type + + def get_argument_type(self): + return self._be_ptr_type + + def get_return_type(self): + return self._be_ptr_type + + def get_data_type(self): + return self._be_type + + def as_data(self, builder, value): + return builder.load(value) + + def from_data(self, builder, value): + raise NotImplementedError("use load_from_data_pointer() instead") + + def as_argument(self, builder, value): + return value + + def from_argument(self, builder, value): + return value + + def as_return(self, builder, value): + return value + + def from_return(self, builder, value): + return value + + def load_from_data_pointer(self, builder, ptr, align=None): + return builder.bitcast(ptr, self.get_value_type()) + + +@register_default(types.UnicodeCharSeq) +class UnicodeCharSeq(DataModel): + def __init__(self, dmm, fe_type): + super(UnicodeCharSeq, self).__init__(dmm, fe_type) + charty = ir.IntType(numpy_support.sizeof_unicode_char * 8) + self._be_type = ir.ArrayType(charty, fe_type.count) + + def get_value_type(self): + return self._be_type + + def get_data_type(self): + return self._be_type + + def as_data(self, builder, value): + return value + + def from_data(self, builder, value): + return value + + def as_return(self, builder, value): + return value + + def from_return(self, builder, value): + return value + + def as_argument(self, builder, value): + return value + + def from_argument(self, builder, value): + return value + + +@register_default(types.CharSeq) +class CharSeq(DataModel): + def __init__(self, dmm, fe_type): + super(CharSeq, self).__init__(dmm, fe_type) + charty = ir.IntType(8) + self._be_type = ir.ArrayType(charty, fe_type.count) + + def get_value_type(self): + return self._be_type + + def get_data_type(self): + return self._be_type + + def as_data(self, builder, value): + return value + + def from_data(self, builder, value): + return value + + def as_return(self, builder, value): + return value + + def from_return(self, builder, value): + return value + + def as_argument(self, builder, value): + return value + + def from_argument(self, builder, value): + return value + + +class CContiguousFlatIter(StructModel): + def __init__(self, dmm, fe_type, need_indices): + assert fe_type.array_type.layout == "C" + array_type = fe_type.array_type + ndim = array_type.ndim + members = [ + ("array", array_type), + ("stride", types.intp), + ("index", types.EphemeralPointer(types.intp)), + ] + if need_indices: + # For ndenumerate() + members.append(("indices", types.EphemeralArray(types.intp, ndim))) + super(CContiguousFlatIter, self).__init__(dmm, fe_type, members) + + +class FlatIter(StructModel): + def __init__(self, dmm, fe_type): + array_type = fe_type.array_type + dtype = array_type.dtype + ndim = array_type.ndim + members = [ + ("array", array_type), + ("pointers", types.EphemeralArray(types.CPointer(dtype), ndim)), + ("indices", types.EphemeralArray(types.intp, ndim)), + ("exhausted", types.EphemeralPointer(types.boolean)), + ] + super(FlatIter, self).__init__(dmm, fe_type, members) + + +@register_default(types.UniTupleIter) +class UniTupleIter(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("index", types.EphemeralPointer(types.intp)), + ( + "tuple", + fe_type.container, + ), + ] + super(UniTupleIter, self).__init__(dmm, fe_type, members) + + +@register_default(types.misc.SliceLiteral) +@register_default(types.SliceType) +class SliceModel(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("start", types.intp), + ("stop", types.intp), + ("step", types.intp), + ] + super(SliceModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.NPDatetime) +@register_default(types.NPTimedelta) +class NPDatetimeModel(PrimitiveModel): + def __init__(self, dmm, fe_type): + be_type = ir.IntType(64) + super(NPDatetimeModel, self).__init__(dmm, fe_type, be_type) + + +@register_default(types.ArrayIterator) +class ArrayIterator(StructModel): + def __init__(self, dmm, fe_type): + # We use an unsigned index to avoid the cost of negative index tests. + members = [ + ("index", types.EphemeralPointer(types.uintp)), + ("array", fe_type.array_type), + ] + super(ArrayIterator, self).__init__(dmm, fe_type, members) + + +@register_default(types.EnumerateType) +class EnumerateType(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("count", types.EphemeralPointer(types.intp)), + ("iter", fe_type.source_type), + ] + + super(EnumerateType, self).__init__(dmm, fe_type, members) + + +@register_default(types.ZipType) +class ZipType(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("iter%d" % i, source_type.iterator_type) + for i, source_type in enumerate(fe_type.source_types) + ] + super(ZipType, self).__init__(dmm, fe_type, members) + + +@register_default(types.RangeIteratorType) +class RangeIteratorType(StructModel): + def __init__(self, dmm, fe_type): + int_type = fe_type.yield_type + members = [ + ("iter", types.EphemeralPointer(int_type)), + ("stop", int_type), + ("step", int_type), + ("count", types.EphemeralPointer(int_type)), + ] + super(RangeIteratorType, self).__init__(dmm, fe_type, members) + + +@register_default(types.Generator) +class GeneratorModel(CompositeModel): + def __init__(self, dmm, fe_type): + super(GeneratorModel, self).__init__(dmm, fe_type) + # XXX Fold this in DataPacker? + self._arg_models = [ + self._dmm.lookup(t) + for t in fe_type.arg_types + if not isinstance(t, types.Omitted) + ] + self._state_models = [self._dmm.lookup(t) for t in fe_type.state_types] + + self._args_be_type = ir.LiteralStructType( + [t.get_data_type() for t in self._arg_models] + ) + self._state_be_type = ir.LiteralStructType( + [t.get_data_type() for t in self._state_models] + ) + # The whole generator closure + self._be_type = ir.LiteralStructType( + [ + self._dmm.lookup(types.int32).get_value_type(), + self._args_be_type, + self._state_be_type, + ] + ) + self._be_ptr_type = self._be_type.as_pointer() + + def get_value_type(self): + """ + The generator closure is passed around as a reference. + """ + return self._be_ptr_type + + def get_argument_type(self): + return self._be_ptr_type + + def get_return_type(self): + return self._be_type + + def get_data_type(self): + return self._be_type + + def as_argument(self, builder, value): + return value + + def from_argument(self, builder, value): + return value + + def as_return(self, builder, value): + return self.as_data(builder, value) + + def from_return(self, builder, value): + return self.from_data(builder, value) + + def as_data(self, builder, value): + return builder.load(value) + + def from_data(self, builder, value): + stack = cgutils.alloca_once(builder, value.type) + builder.store(value, stack) + return stack + + +@register_default(types.ArrayCTypes) +class ArrayCTypesModel(StructModel): + def __init__(self, dmm, fe_type): + # ndim = fe_type.ndim + members = [ + ("data", types.CPointer(fe_type.dtype)), + ("meminfo", types.MemInfoPointer(fe_type.dtype)), + ] + super(ArrayCTypesModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.RangeType) +class RangeModel(StructModel): + def __init__(self, dmm, fe_type): + int_type = fe_type.iterator_type.yield_type + members = [("start", int_type), ("stop", int_type), ("step", int_type)] + super(RangeModel, self).__init__(dmm, fe_type, members) + + +# ============================================================================= + + +@register_default(types.NumpyNdIndexType) +class NdIndexModel(StructModel): + def __init__(self, dmm, fe_type): + ndim = fe_type.ndim + members = [ + ("shape", types.UniTuple(types.intp, ndim)), + ("indices", types.EphemeralArray(types.intp, ndim)), + ("exhausted", types.EphemeralPointer(types.boolean)), + ] + super(NdIndexModel, self).__init__(dmm, fe_type, members) + + +@register_default(types.NumpyFlatType) +def handle_numpy_flat_type(dmm, ty): + if ty.array_type.layout == "C": + return CContiguousFlatIter(dmm, ty, need_indices=False) + else: + return FlatIter(dmm, ty) + + +@register_default(types.NumpyNdEnumerateType) +def handle_numpy_ndenumerate_type(dmm, ty): + if ty.array_type.layout == "C": + return CContiguousFlatIter(dmm, ty, need_indices=True) + else: + return FlatIter(dmm, ty) + + +@register_default(types.BoundFunction) +def handle_bound_function(dmm, ty): + # The same as the underlying type + return dmm[ty.this] + + +@register_default(types.NumpyNdIterType) +class NdIter(StructModel): + def __init__(self, dmm, fe_type): + array_types = fe_type.arrays + ndim = fe_type.ndim + shape_len = ndim if fe_type.need_shaped_indexing else 1 + members = [ + ("exhausted", types.EphemeralPointer(types.boolean)), + ("arrays", types.Tuple(array_types)), + # The iterator's main shape and indices + ("shape", types.UniTuple(types.intp, shape_len)), + ("indices", types.EphemeralArray(types.intp, shape_len)), + ] + # Indexing state for the various sub-iterators + # XXX use a tuple instead? + for i, sub in enumerate(fe_type.indexers): + kind, start_dim, end_dim, _ = sub + member_name = "index%d" % i + if kind == "flat": + # A single index into the flattened array + members.append( + (member_name, types.EphemeralPointer(types.intp)) + ) + elif kind in ("scalar", "indexed", "0d"): + # Nothing required + pass + else: + assert 0 + # Slots holding values of the scalar args + # XXX use a tuple instead? + for i, ty in enumerate(fe_type.arrays): + if not isinstance(ty, types.Array): + member_name = "scalar%d" % i + members.append((member_name, types.EphemeralPointer(ty))) + + super(NdIter, self).__init__(dmm, fe_type, members) + + +@register_default(types.DeferredType) +class DeferredStructModel(CompositeModel): + def __init__(self, dmm, fe_type): + super(DeferredStructModel, self).__init__(dmm, fe_type) + self.typename = "deferred.{0}".format(id(fe_type)) + self.actual_fe_type = fe_type.get() + + def get_value_type(self): + return ir.global_context.get_identified_type(self.typename + ".value") + + def get_data_type(self): + return ir.global_context.get_identified_type(self.typename + ".data") + + def get_argument_type(self): + return self._actual_model.get_argument_type() + + def as_argument(self, builder, value): + inner = self.get(builder, value) + return self._actual_model.as_argument(builder, inner) + + def from_argument(self, builder, value): + res = self._actual_model.from_argument(builder, value) + return self.set(builder, self.make_uninitialized(), res) + + def from_data(self, builder, value): + self._define() + elem = self.get(builder, value) + value = self._actual_model.from_data(builder, elem) + out = self.make_uninitialized() + return self.set(builder, out, value) + + def as_data(self, builder, value): + self._define() + elem = self.get(builder, value) + value = self._actual_model.as_data(builder, elem) + out = self.make_uninitialized(kind="data") + return self.set(builder, out, value) + + def from_return(self, builder, value): + return value + + def as_return(self, builder, value): + return value + + def get(self, builder, value): + return builder.extract_value(value, [0]) + + def set(self, builder, value, content): + return builder.insert_value(value, content, [0]) + + def make_uninitialized(self, kind="value"): + self._define() + if kind == "value": + ty = self.get_value_type() + else: + ty = self.get_data_type() + return ir.Constant(ty, ir.Undefined) + + def _define(self): + valty = self.get_value_type() + self._define_value_type(valty) + datty = self.get_data_type() + self._define_data_type(datty) + + def _define_value_type(self, value_type): + if value_type.is_opaque: + value_type.set_body(self._actual_model.get_value_type()) + + def _define_data_type(self, data_type): + if data_type.is_opaque: + data_type.set_body(self._actual_model.get_data_type()) + + @property + def _actual_model(self): + return self._dmm.lookup(self.actual_fe_type) + + def traverse(self, builder): + return [ + ( + self.actual_fe_type, + lambda value: builder.extract_value(value, [0]), + ) + ] + + +@register_default(types.StructRefPayload) +class StructPayloadModel(StructModel): + """Model for the payload of a mutable struct""" + + def __init__(self, dmm, fe_typ): + members = tuple(fe_typ.field_dict.items()) + super().__init__(dmm, fe_typ, members) + + +class StructRefModel(StructModel): + """Model for a mutable struct. + A reference to the payload + """ + + def __init__(self, dmm, fe_typ): + dtype = fe_typ.get_data_type() + members = [ + ("meminfo", types.MemInfoPointer(dtype)), + ] + super().__init__(dmm, fe_typ, members) diff --git a/numba_cuda/numba/cuda/datamodel/cuda_packer.py b/numba_cuda/numba/cuda/datamodel/cuda_packer.py new file mode 100644 index 000000000..c879971ec --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/cuda_packer.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from collections import deque + +from numba.cuda import types +from numba.cuda import cgutils + + +class DataPacker(object): + """ + A helper to pack a number of typed arguments into a data structure. + Omitted arguments (i.e. values with the type `Omitted`) are automatically + skipped. + """ + + # XXX should DataPacker be a model for a dedicated type? + + def __init__(self, dmm, fe_types): + self._dmm = dmm + self._fe_types = fe_types + self._models = [dmm.lookup(ty) for ty in fe_types] + + self._pack_map = [] + self._be_types = [] + for i, ty in enumerate(fe_types): + if not isinstance(ty, types.Omitted): + self._pack_map.append(i) + self._be_types.append(self._models[i].get_data_type()) + + def as_data(self, builder, values): + """ + Return the given values packed as a data structure. + """ + elems = [ + self._models[i].as_data(builder, values[i]) for i in self._pack_map + ] + return cgutils.make_anonymous_struct(builder, elems) + + def _do_load(self, builder, ptr, formal_list=None): + res = [] + for i, i_formal in enumerate(self._pack_map): + elem_ptr = cgutils.gep_inbounds(builder, ptr, 0, i) + val = self._models[i_formal].load_from_data_pointer( + builder, elem_ptr + ) + if formal_list is None: + res.append((self._fe_types[i_formal], val)) + else: + formal_list[i_formal] = val + return res + + def load(self, builder, ptr): + """ + Load the packed values and return a (type, value) tuples. + """ + return self._do_load(builder, ptr) + + def load_into(self, builder, ptr, formal_list): + """ + Load the packed values into a sequence indexed by formal + argument number (skipping any Omitted position). + """ + self._do_load(builder, ptr, formal_list) + + +class ArgPacker(object): + """ + Compute the position for each high-level typed argument. + It flattens every composite argument into primitive types. + It maintains a position map for unflattening the arguments. + + Since struct (esp. nested struct) have specific ABI requirements (e.g. + alignment, pointer address-space, ...) in different architecture (e.g. + OpenCL, CUDA), flattening composite argument types simplifes the call + setup from the Python side. Functions are receiving simple primitive + types and there are only a handful of these. + """ + + def __init__(self, dmm, fe_args): + self._dmm = dmm + self._fe_args = fe_args + self._nargs = len(fe_args) + + self._dm_args = [] + argtys = [] + for ty in fe_args: + dm = self._dmm.lookup(ty) + self._dm_args.append(dm) + argtys.append(dm.get_argument_type()) + self._unflattener = _Unflattener(argtys) + self._be_args = list(_flatten(argtys)) + + def as_arguments(self, builder, values): + """Flatten all argument values""" + if len(values) != self._nargs: + raise TypeError( + "invalid number of args: expected %d, got %d" + % (self._nargs, len(values)) + ) + + if not values: + return () + + args = [ + dm.as_argument(builder, val) + for dm, val in zip(self._dm_args, values) + ] + + args = tuple(_flatten(args)) + return args + + def from_arguments(self, builder, args): + """Unflatten all argument values""" + + valtree = self._unflattener.unflatten(args) + values = [ + dm.from_argument(builder, val) + for dm, val in zip(self._dm_args, valtree) + ] + + return values + + def assign_names(self, args, names): + """Assign names for each flattened argument values.""" + + valtree = self._unflattener.unflatten(args) + for aval, aname in zip(valtree, names): + self._assign_names(aval, aname) + + def _assign_names(self, val_or_nested, name, depth=()): + if isinstance(val_or_nested, (tuple, list)): + for pos, aval in enumerate(val_or_nested): + self._assign_names(aval, name, depth=depth + (pos,)) + else: + postfix = ".".join(map(str, depth)) + parts = [name, postfix] + val_or_nested.name = ".".join(filter(bool, parts)) + + @property + def argument_types(self): + """Return a list of LLVM types that are results of flattening + composite types. + """ + return tuple(ty for ty in self._be_args if ty != ()) + + +def _flatten(iterable): + """ + Flatten nested iterable of (tuple, list). + """ + + def rec(iterable): + for i in iterable: + if isinstance(i, (tuple, list)): + for j in rec(i): + yield j + else: + yield i + + return rec(iterable) + + +_PUSH_LIST = 1 +_APPEND_NEXT_VALUE = 2 +_APPEND_EMPTY_TUPLE = 3 +_POP = 4 + + +class _Unflattener(object): + """ + An object used to unflatten nested sequences after a given pattern + (an arbitrarily nested sequence). + The pattern shows the nested sequence shape desired when unflattening; + the values it contains are irrelevant. + """ + + def __init__(self, pattern): + self._code = self._build_unflatten_code(pattern) + + def _build_unflatten_code(self, iterable): + """Build the unflatten opcode sequence for the given *iterable* structure + (an iterable of nested sequences). + """ + code = [] + + def rec(iterable): + for i in iterable: + if isinstance(i, (tuple, list)): + if len(i) > 0: + code.append(_PUSH_LIST) + rec(i) + code.append(_POP) + else: + code.append(_APPEND_EMPTY_TUPLE) + else: + code.append(_APPEND_NEXT_VALUE) + + rec(iterable) + return code + + def unflatten(self, flatiter): + """Rebuild a nested tuple structure.""" + vals = deque(flatiter) + + res = [] + cur = res + stack = [] + for op in self._code: + if op is _PUSH_LIST: + stack.append(cur) + cur.append([]) + cur = cur[-1] + elif op is _APPEND_NEXT_VALUE: + cur.append(vals.popleft()) + elif op is _APPEND_EMPTY_TUPLE: + cur.append(()) + elif op is _POP: + cur = stack.pop() + + assert not stack, stack + assert not vals, vals + + return res diff --git a/numba_cuda/numba/cuda/datamodel/cuda_registry.py b/numba_cuda/numba/cuda/datamodel/cuda_registry.py new file mode 100644 index 000000000..39b66fe4a --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/cuda_registry.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import functools +from .manager import DataModelManager + + +def register(dmm, typecls): + """Used as decorator to simplify datamodel registration. + Returns the object being decorated so that chaining is possible. + """ + + def wraps(fn): + dmm.register(typecls, fn) + return fn + + return wraps + + +default_manager = DataModelManager() + +register_default = functools.partial(register, default_manager) diff --git a/numba_cuda/numba/cuda/datamodel/cuda_testing.py b/numba_cuda/numba/cuda/datamodel/cuda_testing.py new file mode 100644 index 000000000..35dcc11d9 --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/cuda_testing.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from llvmlite import ir +from llvmlite import binding as ll + +from numba.cuda import datamodel +import unittest + + +class DataModelTester(unittest.TestCase): + """ + Test the implementation of a DataModel for a frontend type. + """ + + fe_type = NotImplemented + + def setUp(self): + self.module = ir.Module() + self.datamodel = datamodel.default_manager[self.fe_type] + + def test_as_arg(self): + """ + - Is as_arg() and from_arg() implemented? + - Are they the inverse of each other? + """ + fnty = ir.FunctionType(ir.VoidType(), []) + function = ir.Function(self.module, fnty, name="test_as_arg") + builder = ir.IRBuilder() + builder.position_at_end(function.append_basic_block()) + + undef_value = ir.Constant(self.datamodel.get_value_type(), None) + args = self.datamodel.as_argument(builder, undef_value) + self.assertIsNot( + args, NotImplemented, "as_argument returned NotImplementedError" + ) + + if isinstance(args, (tuple, list)): + + def recur_tuplize(args, func=None): + for arg in args: + if isinstance(arg, (tuple, list)): + yield tuple(recur_tuplize(arg, func=func)) + else: + if func is None: + yield arg + else: + yield func(arg) + + argtypes = tuple(recur_tuplize(args, func=lambda x: x.type)) + exptypes = tuple(recur_tuplize(self.datamodel.get_argument_type())) + self.assertEqual(exptypes, argtypes) + else: + self.assertEqual(args.type, self.datamodel.get_argument_type()) + + rev_value = self.datamodel.from_argument(builder, args) + self.assertEqual(rev_value.type, self.datamodel.get_value_type()) + + builder.ret_void() # end function + + # Ensure valid LLVM generation + materialized = ll.parse_assembly(str(self.module)) + str(materialized) + + def test_as_return(self): + """ + - Is as_return() and from_return() implemented? + - Are they the inverse of each other? + """ + fnty = ir.FunctionType(ir.VoidType(), []) + function = ir.Function(self.module, fnty, name="test_as_return") + builder = ir.IRBuilder() + builder.position_at_end(function.append_basic_block()) + + undef_value = ir.Constant(self.datamodel.get_value_type(), None) + ret = self.datamodel.as_return(builder, undef_value) + self.assertIsNot( + ret, NotImplemented, "as_return returned NotImplementedError" + ) + + self.assertEqual(ret.type, self.datamodel.get_return_type()) + + rev_value = self.datamodel.from_return(builder, ret) + self.assertEqual(rev_value.type, self.datamodel.get_value_type()) + + builder.ret_void() # end function + + # Ensure valid LLVM generation + materialized = ll.parse_assembly(str(self.module)) + str(materialized) + + +class SupportAsDataMixin(object): + """Test as_data() and from_data()""" + + # XXX test load_from_data_pointer() as well + + def test_as_data(self): + fnty = ir.FunctionType(ir.VoidType(), []) + function = ir.Function(self.module, fnty, name="test_as_data") + builder = ir.IRBuilder() + builder.position_at_end(function.append_basic_block()) + + undef_value = ir.Constant(self.datamodel.get_value_type(), None) + data = self.datamodel.as_data(builder, undef_value) + self.assertIsNot( + data, NotImplemented, "as_data returned NotImplemented" + ) + + self.assertEqual(data.type, self.datamodel.get_data_type()) + + rev_value = self.datamodel.from_data(builder, data) + self.assertEqual(rev_value.type, self.datamodel.get_value_type()) + + builder.ret_void() # end function + + # Ensure valid LLVM generation + materialized = ll.parse_assembly(str(self.module)) + str(materialized) + + +class NotSupportAsDataMixin(object): + """Ensure as_data() and from_data() raise NotImplementedError.""" + + def test_as_data_not_supported(self): + fnty = ir.FunctionType(ir.VoidType(), []) + function = ir.Function(self.module, fnty, name="test_as_data") + builder = ir.IRBuilder() + builder.position_at_end(function.append_basic_block()) + + undef_value = ir.Constant(self.datamodel.get_value_type(), None) + with self.assertRaises(NotImplementedError): + data = self.datamodel.as_data(builder, undef_value) # noqa: F841 + with self.assertRaises(NotImplementedError): + rev_data = self.datamodel.from_data(builder, undef_value) # noqa: F841 + + +class DataModelTester_SupportAsDataMixin(DataModelTester, SupportAsDataMixin): + pass + + +class DataModelTester_NotSupportAsDataMixin( + DataModelTester, NotSupportAsDataMixin +): + pass + + +def test_factory(support_as_data=True): + """A helper for returning a unittest TestCase for testing""" + if support_as_data: + return DataModelTester_SupportAsDataMixin + else: + return DataModelTester_NotSupportAsDataMixin diff --git a/numba_cuda/numba/cuda/datamodel/manager.py b/numba_cuda/numba/cuda/datamodel/manager.py new file mode 100644 index 000000000..d87636d21 --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/manager.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), + "numba.core.datamodel.manager", + "numba.cuda.datamodel.cuda_manager", +) diff --git a/numba_cuda/numba/cuda/datamodel/models.py b/numba_cuda/numba/cuda/datamodel/models.py new file mode 100644 index 000000000..c37da6956 --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/models.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.datamodel.models", "numba.cuda.datamodel.cuda_models" +) diff --git a/numba_cuda/numba/cuda/datamodel/packer.py b/numba_cuda/numba/cuda/datamodel/packer.py new file mode 100644 index 000000000..bb3855bce --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/packer.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.datamodel.packer", "numba.cuda.datamodel.cuda_packer" +) diff --git a/numba_cuda/numba/cuda/datamodel/registry.py b/numba_cuda/numba/cuda/datamodel/registry.py new file mode 100644 index 000000000..637c7a0c5 --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/registry.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), + "numba.core.datamodel.registry", + "numba.cuda.datamodel.cuda_registry", +) diff --git a/numba_cuda/numba/cuda/datamodel/testing.py b/numba_cuda/numba/cuda/datamodel/testing.py new file mode 100644 index 000000000..9fcb8afa1 --- /dev/null +++ b/numba_cuda/numba/cuda/datamodel/testing.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), + "numba.core.datamodel.testing", + "numba.cuda.datamodel.cuda_testing", +) diff --git a/numba_cuda/numba/cuda/debuginfo.py b/numba_cuda/numba/cuda/debuginfo.py index 349ea4c5e..25c637af1 100644 --- a/numba_cuda/numba/cuda/debuginfo.py +++ b/numba_cuda/numba/cuda/debuginfo.py @@ -6,11 +6,11 @@ from contextlib import contextmanager from llvmlite import ir -from numba.core import types +from numba.cuda import types from numba.cuda.core import config from numba.cuda import cgutils -from numba.core.datamodel.models import ComplexModel, UnionModel, UniTupleModel -from numba.cuda.types import GridGroup +from numba.cuda.datamodel.models import ComplexModel, UnionModel, UniTupleModel +from numba.cuda.ext_types import GridGroup @contextmanager diff --git a/numba_cuda/numba/cuda/decorators.py b/numba_cuda/numba/cuda/decorators.py index 5fa800a1b..1f541ef8e 100644 --- a/numba_cuda/numba/cuda/decorators.py +++ b/numba_cuda/numba/cuda/decorators.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause from warnings import warn -from numba.core import types +from numba.cuda import types from numba.core.errors import DeprecationError, NumbaInvalidConfigWarning from numba.cuda.compiler import declare_device_function from numba.cuda.core import sigutils, config diff --git a/numba_cuda/numba/cuda/deviceufunc.py b/numba_cuda/numba/cuda/deviceufunc.py index 433213961..ce8172378 100644 --- a/numba_cuda/numba/cuda/deviceufunc.py +++ b/numba_cuda/numba/cuda/deviceufunc.py @@ -16,7 +16,7 @@ import numpy as np from numba.cuda.np.ufunc.ufuncbuilder import _BaseUFuncBuilder, parse_identity -from numba.core import types +from numba.cuda import types from numba.cuda.typing import signature from numba.cuda.core import sigutils diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 3d9f45ddf..3f9fe6e0f 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -15,17 +15,16 @@ import re from warnings import warn -from numba.core import types, errors +from numba.core import errors from numba.cuda import serialize, utils from numba import cuda from numba.core.compiler_lock import global_compiler_lock -from numba.core.typeconv.rules import default_type_manager +from numba.cuda.typeconv.rules import default_type_manager from numba.cuda.typing.templates import fold_arguments from numba.cuda.typing.typeof import Purpose, typeof -from numba.cuda import typing -from numba.cuda import types as cuda_types +from numba.cuda import typing, types, ext_types from numba.cuda.api import get_current_device from numba.cuda.args import wrap_arg from numba.core.bytecode import get_code_object @@ -1537,7 +1536,7 @@ def dump(self, tab=""): @property def _numba_type_(self): - return cuda_types.CUDADispatcher(self) + return ext_types.CUDADispatcher(self) def enable_caching(self): self._cache = CUDACache(self.py_func) diff --git a/numba_cuda/numba/cuda/types.py b/numba_cuda/numba/cuda/ext_types.py similarity index 99% rename from numba_cuda/numba/cuda/types.py rename to numba_cuda/numba/cuda/ext_types.py index 7e407ac81..346effab7 100644 --- a/numba_cuda/numba/cuda/types.py +++ b/numba_cuda/numba/cuda/ext_types.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba.core import types +from numba.cuda import types from numba.cuda.typeconv import Conversion diff --git a/numba_cuda/numba/cuda/extending.py b/numba_cuda/numba/cuda/extending.py index 9141cdf36..a9642917f 100644 --- a/numba_cuda/numba/cuda/extending.py +++ b/numba_cuda/numba/cuda/extending.py @@ -10,8 +10,8 @@ import collections import functools -from numba.core import types, errors -from numba.cuda import utils, config +from numba.core import errors +from numba.cuda import types, utils, config # # Exported symbols from numba.cuda.typing.typeof import typeof_impl # noqa: F401 @@ -28,7 +28,7 @@ ) # noqa: F401 from numba.cuda.core.pythonapi import box, unbox, reflect, NativeValue # noqa: F401 from numba.cuda.serialize import ReduceMixin -from numba.core.datamodel import models as core_models # noqa: F401 +from numba.cuda.datamodel import models as core_models # noqa: F401 from numba.cuda.models import register_model # noqa: F401 @@ -45,11 +45,9 @@ def make_attribute_wrapper(typeclass, struct_attr, python_attr): model manager. """ from numba.cuda.typing.templates import AttributeTemplate - - from numba.core.datamodel import default_manager - from numba.core.datamodel.models import StructModel + from numba.cuda.datamodel import default_manager + from numba.cuda.datamodel.models import StructModel from numba.cuda.core.imputils import impl_ret_borrowed - from numba.core import types from numba.cuda import cgutils from numba.cuda.models import cuda_data_manager @@ -59,7 +57,9 @@ def make_attribute_wrapper(typeclass, struct_attr, python_attr): data_model_manager = cuda_data_manager.chain(default_manager) if not isinstance(typeclass, type) or not issubclass(typeclass, types.Type): - raise TypeError(f"typeclass should be a Type subclass, got {typeclass}") + raise TypeError( + "typeclass should be a Type subclass, got %s" % (typeclass,) + ) def get_attr_fe_type(typ): """ @@ -68,7 +68,8 @@ def get_attr_fe_type(typ): model = data_model_manager.lookup(typ) if not isinstance(model, StructModel): raise TypeError( - f"make_attribute_wrapper() needs a type with a StructModel, but got {model}" + "make_struct_attribute_wrapper() needs a type " + "with a StructModel, but got %s" % (model,) ) return model.get_member_fe_type(struct_attr) @@ -223,6 +224,7 @@ def decorate(overload_func): infer as core_infer, infer_global as core_infer_global, ) + from numba.core import types as core_types core_template = core_make_overload_template( func, @@ -235,7 +237,7 @@ def decorate(overload_func): ) core_infer(core_template) if callable(func): - core_infer_global(func, types.Function(core_template)) + core_infer_global(func, core_types.Function(core_template)) except ImportError: pass diff --git a/numba_cuda/numba/cuda/fp16.py b/numba_cuda/numba/cuda/fp16.py index d3e3c0183..9ad106e01 100644 --- a/numba_cuda/numba/cuda/fp16.py +++ b/numba_cuda/numba/cuda/fp16.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -import numba.core.types as types +import numba.cuda.types as types from numba.cuda._internal.cuda_fp16 import ( typing_registry, target_registry, diff --git a/numba_cuda/numba/cuda/intrinsics.py b/numba_cuda/numba/cuda/intrinsics.py index d2ae8d37f..2701ba5d2 100644 --- a/numba_cuda/numba/cuda/intrinsics.py +++ b/numba_cuda/numba/cuda/intrinsics.py @@ -3,7 +3,8 @@ from llvmlite import ir -from numba import cuda, types +from numba import cuda +from numba.cuda import types from numba.cuda import cgutils from numba.core.errors import RequireLiteralValue, TypingError from numba.cuda.typing import signature diff --git a/numba_cuda/numba/cuda/itanium_mangler.py b/numba_cuda/numba/cuda/itanium_mangler.py index f86f5ddf4..3f347d5e8 100644 --- a/numba_cuda/numba/cuda/itanium_mangler.py +++ b/numba_cuda/numba/cuda/itanium_mangler.py @@ -33,7 +33,7 @@ import re -from numba.core import types +from numba.cuda import types # According the scheme, valid characters for mangled names are [a-zA-Z0-9_]. diff --git a/numba_cuda/numba/cuda/libdevicefuncs.py b/numba_cuda/numba/cuda/libdevicefuncs.py index 50b7ac5e9..0ca50da87 100644 --- a/numba_cuda/numba/cuda/libdevicefuncs.py +++ b/numba_cuda/numba/cuda/libdevicefuncs.py @@ -4,7 +4,7 @@ from collections import namedtuple from textwrap import indent -from numba.types import float32, float64, int16, int32, int64, void, Tuple +from numba.cuda.types import float32, float64, int16, int32, int64, void, Tuple from numba.cuda.typing.templates import signature arg = namedtuple("arg", ("name", "ty", "is_ptr")) diff --git a/numba_cuda/numba/cuda/libdeviceimpl.py b/numba_cuda/numba/cuda/libdeviceimpl.py index dbd3fb623..092c684a4 100644 --- a/numba_cuda/numba/cuda/libdeviceimpl.py +++ b/numba_cuda/numba/cuda/libdeviceimpl.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: BSD-2-Clause from llvmlite import ir -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.core.imputils import Registry from numba.cuda import libdevice, libdevicefuncs -registry = Registry() +registry = Registry("libdeviceimpl") lower = registry.lower diff --git a/numba_cuda/numba/cuda/lowering.py b/numba_cuda/numba/cuda/lowering.py index 1ba29d713..ef7fa4a2a 100644 --- a/numba_cuda/numba/cuda/lowering.py +++ b/numba_cuda/numba/cuda/lowering.py @@ -8,14 +8,16 @@ from llvmlite import ir as llvm_ir -from numba.core import ( - types, - ir, +from numba.core import ir +from numba.cuda import debuginfo, cgutils, utils, typing, types +from numba.cuda.core import ( + ir_utils, + targetconfig, + funcdesc, + config, generators, removerefctpass, ) -from numba.cuda import debuginfo, cgutils, utils, typing -from numba.cuda.core import ir_utils, targetconfig, funcdesc, config from numba.core.errors import ( LoweringError, @@ -1237,7 +1239,9 @@ def _lower_call_normal(self, fnty, expr, signature): ) tname = expr.target if tname is not None: - from numba.core.target_extension import resolve_dispatcher_from_str + from numba.core.target_extension import ( + resolve_dispatcher_from_str, + ) disp = resolve_dispatcher_from_str(tname) hw_ctx = disp.targetdescr.target_context diff --git a/numba_cuda/numba/cuda/mathimpl.py b/numba_cuda/numba/cuda/mathimpl.py index 5f3e634b8..7ea567cc9 100644 --- a/numba_cuda/numba/cuda/mathimpl.py +++ b/numba_cuda/numba/cuda/mathimpl.py @@ -4,14 +4,14 @@ import math import operator from llvmlite import ir -from numba.core import types -from numba.cuda import cgutils, typing +from numba.cuda import types, typing +from numba.cuda import cgutils from numba.cuda.core.imputils import Registry -from numba.types import float32, float64, int64, uint64 +from numba.cuda.types import float32, float64, int64, uint64 from numba.cuda import libdevice from numba.cuda.core import targetconfig -registry = Registry() +registry = Registry("mathimpl") lower = registry.lower diff --git a/numba_cuda/numba/cuda/memory_management/nrt.py b/numba_cuda/numba/cuda/memory_management/nrt.py index 70ad01fe4..cf2db24b3 100644 --- a/numba_cuda/numba/cuda/memory_management/nrt.py +++ b/numba_cuda/numba/cuda/memory_management/nrt.py @@ -7,8 +7,8 @@ import numpy as np from collections import namedtuple -from numba import cuda, types -from numba.cuda import config +from numba import cuda +from numba.cuda import config, types from numba.cuda.cudadrv.driver import ( _Linker, diff --git a/numba_cuda/numba/cuda/memory_management/nrt_context.py b/numba_cuda/numba/cuda/memory_management/nrt_context.py index f52747088..dd8bd3c66 100644 --- a/numba_cuda/numba/cuda/memory_management/nrt_context.py +++ b/numba_cuda/numba/cuda/memory_management/nrt_context.py @@ -5,7 +5,8 @@ from collections import namedtuple from llvmlite import ir -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda import cgutils, config from numba.cuda.utils import PYVERSION diff --git a/numba_cuda/numba/cuda/misc/cffiimpl.py b/numba_cuda/numba/cuda/misc/cffiimpl.py index bea920a5b..6c2aa8421 100644 --- a/numba_cuda/numba/cuda/misc/cffiimpl.py +++ b/numba_cuda/numba/cuda/misc/cffiimpl.py @@ -6,8 +6,7 @@ """ from numba.cuda.core.imputils import Registry -from numba.core import types -from numba.cuda.np import arrayobj +from numba.cuda import types registry = Registry("cffiimpl") @@ -21,5 +20,5 @@ def from_buffer(context, builder, sig, args): # Type inference should have prevented passing a buffer from an # array to a pointer of the wrong type assert fromty.dtype == sig.return_type.dtype - ary = arrayobj.make_array(fromty)(context, builder, val) + ary = context.make_array(fromty)(context, builder, val) return ary.data diff --git a/numba_cuda/numba/cuda/misc/coverage_support.py b/numba_cuda/numba/cuda/misc/coverage_support.py index c857039db..fb70a7afd 100644 --- a/numba_cuda/numba/cuda/misc/coverage_support.py +++ b/numba_cuda/numba/cuda/misc/coverage_support.py @@ -11,7 +11,8 @@ from typing import Optional, Sequence, Callable from abc import ABC, abstractmethod -from numba.core import ir, config +from numba.core import ir +from numba.cuda import config _the_registry: Callable[[], Optional["NotifyLocBase"]] = [] diff --git a/numba_cuda/numba/cuda/misc/dump_style.py b/numba_cuda/numba/cuda/misc/dump_style.py index 0a074468e..73416d114 100644 --- a/numba_cuda/numba/cuda/misc/dump_style.py +++ b/numba_cuda/numba/cuda/misc/dump_style.py @@ -7,7 +7,7 @@ msg = "Please install pygments to see highlighted dumps" raise ImportError(msg) -import numba.core.config +import numba.cuda.config from pygments.styles.manni import ManniStyle from pygments.styles.monokai import MonokaiStyle from pygments.styles.native import NativeStyle @@ -38,4 +38,4 @@ def by_colorscheme(): "jupyter_nb": DefaultStyle, } - return style_map[numba.core.config.COLOR_SCHEME] + return style_map[numba.cuda.config.COLOR_SCHEME] diff --git a/numba_cuda/numba/cuda/misc/gdb_hook.py b/numba_cuda/numba/cuda/misc/gdb_hook.py index ad7535956..ee945a3b0 100644 --- a/numba_cuda/numba/cuda/misc/gdb_hook.py +++ b/numba_cuda/numba/cuda/misc/gdb_hook.py @@ -6,7 +6,8 @@ from llvmlite import ir -from numba.core import types, config, errors +from numba.cuda import types, config +from numba.core import errors from numba.cuda import cgutils, utils from numba.cuda.misc.special import gdb, gdb_init, gdb_breakpoint from numba.cuda.extending import overload, intrinsic diff --git a/numba_cuda/numba/cuda/misc/literal.py b/numba_cuda/numba/cuda/misc/literal.py index 425f10154..e19e3a3b1 100644 --- a/numba_cuda/numba/cuda/misc/literal.py +++ b/numba_cuda/numba/cuda/misc/literal.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause from numba.cuda.extending import overload -from numba.core import types +from numba.cuda import types from numba.cuda.misc.special import literally, literal_unroll from numba.core.errors import TypingError diff --git a/numba_cuda/numba/cuda/misc/llvm_pass_timings.py b/numba_cuda/numba/cuda/misc/llvm_pass_timings.py index 77d9ef3d7..219d70abe 100644 --- a/numba_cuda/numba/cuda/misc/llvm_pass_timings.py +++ b/numba_cuda/numba/cuda/misc/llvm_pass_timings.py @@ -9,7 +9,7 @@ from contextlib import contextmanager from functools import cached_property -from numba.core import config +from numba.cuda import config import llvmlite.binding as llvm diff --git a/numba_cuda/numba/cuda/models.py b/numba_cuda/numba/cuda/models.py index b6d7ca180..64bb68899 100644 --- a/numba_cuda/numba/cuda/models.py +++ b/numba_cuda/numba/cuda/models.py @@ -5,11 +5,12 @@ from llvmlite import ir -from numba.core.datamodel.registry import DataModelManager, register -from numba.core.datamodel import PrimitiveModel -from numba.cuda.extending import core_models -from numba.core import types -from numba.cuda.types import Dim3, GridGroup, CUDADispatcher, Bfloat16 +from numba.cuda.datamodel.registry import DataModelManager, register +from numba.cuda.datamodel import PrimitiveModel +from numba.cuda.datamodel.models import StructModel +from numba.cuda.extending import core_models as models +from numba.cuda import types +from numba.cuda.ext_types import Dim3, GridGroup, CUDADispatcher, Bfloat16 cuda_data_manager = DataModelManager() @@ -18,21 +19,21 @@ @register_model(Dim3) -class Dim3Model(core_models.StructModel): +class Dim3Model(StructModel): def __init__(self, dmm, fe_type): members = [("x", types.int32), ("y", types.int32), ("z", types.int32)] super().__init__(dmm, fe_type, members) @register_model(GridGroup) -class GridGroupModel(core_models.PrimitiveModel): +class GridGroupModel(models.PrimitiveModel): def __init__(self, dmm, fe_type): be_type = ir.IntType(64) super().__init__(dmm, fe_type, be_type) @register_model(types.Float) -class FloatModel(core_models.PrimitiveModel): +class FloatModel(models.PrimitiveModel): def __init__(self, dmm, fe_type): if fe_type == types.float16: be_type = ir.IntType(16) @@ -45,7 +46,7 @@ def __init__(self, dmm, fe_type): super(FloatModel, self).__init__(dmm, fe_type, be_type) -register_model(CUDADispatcher)(core_models.OpaqueModel) +register_model(CUDADispatcher)(models.OpaqueModel) @register_model(Bfloat16) diff --git a/numba_cuda/numba/cuda/np/arraymath.py b/numba_cuda/numba/cuda/np/arraymath.py index f8957272a..83d2a4420 100644 --- a/numba_cuda/numba/cuda/np/arraymath.py +++ b/numba_cuda/numba/cuda/np/arraymath.py @@ -12,7 +12,7 @@ import llvmlite.ir import numpy as np -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.extending import overload, overload_method, register_jitable from numba.cuda.np.numpy_support import ( diff --git a/numba_cuda/numba/cuda/np/arrayobj.py b/numba_cuda/numba/cuda/np/arrayobj.py index 3e06d84cd..fcf979f4f 100644 --- a/numba_cuda/numba/cuda/np/arrayobj.py +++ b/numba_cuda/numba/cuda/np/arrayobj.py @@ -17,8 +17,8 @@ import numpy as np from numba import pndindex, literal_unroll -from numba.core import types, errors -from numba.cuda import typing +from numba.cuda import types, typing +from numba.core import errors from numba.cuda import cgutils, extending from numba.cuda.np.numpy_support import ( as_dtype, @@ -46,7 +46,7 @@ Registry, ) from numba.cuda.typing import signature -from numba.core.types import StringLiteral +from numba.cuda.types import StringLiteral from numba.cuda.extending import ( register_jitable, overload, diff --git a/numba_cuda/numba/cuda/np/linalg.py b/numba_cuda/numba/cuda/np/linalg.py index 52f13a7e6..f3f7d39c5 100644 --- a/numba_cuda/numba/cuda/np/linalg.py +++ b/numba_cuda/numba/cuda/np/linalg.py @@ -15,7 +15,7 @@ from numba.cuda.core.imputils import impl_ret_borrowed, impl_ret_new_ref from numba.cuda.typing import signature from numba.cuda.extending import intrinsic, overload, register_jitable -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.core.errors import ( TypingError, diff --git a/numba_cuda/numba/cuda/np/math/cmathimpl.py b/numba_cuda/numba/cuda/np/math/cmathimpl.py index 965e41d65..ba7b74d31 100644 --- a/numba_cuda/numba/cuda/np/math/cmathimpl.py +++ b/numba_cuda/numba/cuda/np/math/cmathimpl.py @@ -9,7 +9,7 @@ import math from numba.cuda.core.imputils import impl_ret_untracked -from numba.core import types +from numba.cuda import types from numba.cuda.typing import signature from numba.cuda.cpython import mathimpl diff --git a/numba_cuda/numba/cuda/np/math/mathimpl.py b/numba_cuda/numba/cuda/np/math/mathimpl.py index f56b7169f..098eb49c1 100644 --- a/numba_cuda/numba/cuda/np/math/mathimpl.py +++ b/numba_cuda/numba/cuda/np/math/mathimpl.py @@ -14,7 +14,7 @@ from llvmlite.ir import Constant from numba.cuda.core.imputils import impl_ret_untracked -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils, config from numba.cuda.extending import overload from numba.cuda.typing import signature diff --git a/numba_cuda/numba/cuda/np/math/numbers.py b/numba_cuda/numba/cuda/np/math/numbers.py index cc8bce417..8dd3a92b0 100644 --- a/numba_cuda/numba/cuda/np/math/numbers.py +++ b/numba_cuda/numba/cuda/np/math/numbers.py @@ -10,8 +10,8 @@ from llvmlite.ir import Constant from numba.cuda.core.imputils import impl_ret_untracked -from numba.core import types, errors -from numba.cuda import cgutils, typing +from numba.cuda import typing, types, cgutils +from numba.core import errors from numba.cuda.cpython.unsafe.numbers import viewer diff --git a/numba_cuda/numba/cuda/np/npdatetime.py b/numba_cuda/numba/cuda/np/npdatetime.py index a556368df..6dd1de800 100644 --- a/numba_cuda/numba/cuda/np/npdatetime.py +++ b/numba_cuda/numba/cuda/np/npdatetime.py @@ -11,7 +11,7 @@ import llvmlite.ir from llvmlite.ir import Constant -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.cuda.cgutils import create_constant_array from numba.cuda.core.imputils import ( diff --git a/numba_cuda/numba/cuda/np/npyfuncs.py b/numba_cuda/numba/cuda/np/npyfuncs.py index fa2b948c2..5164d861f 100644 --- a/numba_cuda/numba/cuda/np/npyfuncs.py +++ b/numba_cuda/numba/cuda/np/npyfuncs.py @@ -14,7 +14,8 @@ from numba.cuda.extending import overload from numba.cuda.core.imputils import impl_ret_untracked -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda import cgutils, typing from numba.cuda.np import npdatetime from numba.cuda.extending import register_jitable diff --git a/numba_cuda/numba/cuda/np/npyimpl.py b/numba_cuda/numba/cuda/np/npyimpl.py index b56b36082..fb42cee8c 100644 --- a/numba_cuda/numba/cuda/np/npyimpl.py +++ b/numba_cuda/numba/cuda/np/npyimpl.py @@ -22,8 +22,7 @@ force_error_model, impl_ret_borrowed, ) -from numba.core import types -from numba.cuda import typing +from numba.cuda import typing, types from numba.cuda import cgutils from numba.cuda.np.numpy_support import ( ufunc_find_matching_loop, diff --git a/numba_cuda/numba/cuda/np/numpy_support.py b/numba_cuda/numba/cuda/np/numpy_support.py index 532143a6f..7925e8319 100644 --- a/numba_cuda/numba/cuda/np/numpy_support.py +++ b/numba_cuda/numba/cuda/np/numpy_support.py @@ -4,10 +4,10 @@ import collections import ctypes import re - import numpy as np -from numba.core import errors, types +from numba.cuda import types +from numba.core import errors from numba.cuda.typing.templates import signature from numba.cuda.np import npdatetime_helpers from numba.core.errors import TypingError diff --git a/numba_cuda/numba/cuda/np/polynomial/polynomial_core.py b/numba_cuda/numba/cuda/np/polynomial/polynomial_core.py index 81d8a93ef..6e0f3113c 100644 --- a/numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +++ b/numba_cuda/numba/cuda/np/polynomial/polynomial_core.py @@ -10,7 +10,7 @@ make_attribute_wrapper, box, ) -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils import warnings from numba.core.errors import NumbaExperimentalFeatureWarning, NumbaValueError diff --git a/numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py b/numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py index 9e1414549..4cd48baf0 100644 --- a/numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +++ b/numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py @@ -10,7 +10,8 @@ from numpy.polynomial import polyutils as pu from numba import literal_unroll -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda.extending import overload from numba.cuda.np.numpy_support import type_can_asarray, as_dtype, from_dtype diff --git a/numba_cuda/numba/cuda/np/unsafe/ndarray.py b/numba_cuda/numba/cuda/np/unsafe/ndarray.py index 94a9b212f..37626095c 100644 --- a/numba_cuda/numba/cuda/np/unsafe/ndarray.py +++ b/numba_cuda/numba/cuda/np/unsafe/ndarray.py @@ -6,7 +6,7 @@ operations with numpy. """ -from numba.core import types +from numba.cuda import types from numba.cuda.cgutils import unpack_tuple from numba.cuda.extending import intrinsic from numba.cuda import typing diff --git a/numba_cuda/numba/cuda/printimpl.py b/numba_cuda/numba/cuda/printimpl.py index 89e081c99..46da0f159 100644 --- a/numba_cuda/numba/cuda/printimpl.py +++ b/numba_cuda/numba/cuda/printimpl.py @@ -3,15 +3,15 @@ from functools import singledispatch from llvmlite import ir -from numba.core import types +from numba.cuda import types from numba.cuda import cgutils from numba.core.errors import NumbaWarning from numba.cuda.core.imputils import Registry from numba.cuda import nvvmutils -from numba.cuda.types import Dim3, Bfloat16 +from numba.cuda.ext_types import Dim3, Bfloat16 from warnings import warn -registry = Registry() +registry = Registry("printimpl") lower = registry.lower voidptr = ir.PointerType(ir.IntType(8)) diff --git a/numba_cuda/numba/cuda/simulator/kernelapi.py b/numba_cuda/numba/cuda/simulator/kernelapi.py index b25b0f293..f67f2221f 100644 --- a/numba_cuda/numba/cuda/simulator/kernelapi.py +++ b/numba_cuda/numba/cuda/simulator/kernelapi.py @@ -10,7 +10,7 @@ import sys import threading import traceback -from numba.core import types +from numba.cuda import types import numpy as np from numba.cuda.np import numpy_support diff --git a/numba_cuda/numba/cuda/simulator/vector_types.py b/numba_cuda/numba/cuda/simulator/vector_types.py index c9b4a49d7..9268bbe84 100644 --- a/numba_cuda/numba/cuda/simulator/vector_types.py +++ b/numba_cuda/numba/cuda/simulator/vector_types.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba import types +from numba.cuda import types from numba.cuda.stubs import _vector_type_stubs diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 31e5f0f18..42cc3f78a 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -8,13 +8,13 @@ import warnings import importlib.util -from numba.core import types +from numba.cuda import types from numba.core.compiler_lock import global_compiler_lock from numba.core.errors import NumbaWarning from numba.cuda.core.base import BaseContext from numba.cuda.typing import cmathdecl -from numba.core import datamodel +from numba.cuda import datamodel from .cudadrv import nvvm from numba.cuda import ( @@ -61,6 +61,7 @@ def load_additional_registries(self): def resolve_value_type(self, val): # treat other dispatcher object as another device function from numba.cuda.dispatcher import CUDADispatcher + from numba.core.dispatcher import Dispatcher try: from numba.core.dispatcher import Dispatcher diff --git a/numba_cuda/numba/cuda/tests/core/serialize_usecases.py b/numba_cuda/numba/cuda/tests/core/serialize_usecases.py index 5f5fb6c2f..bd0fcfc51 100644 --- a/numba_cuda/numba/cuda/tests/core/serialize_usecases.py +++ b/numba_cuda/numba/cuda/tests/core/serialize_usecases.py @@ -12,10 +12,9 @@ import numpy.random as nprand from numba import jit -from numba.core import types -@jit((types.int32, types.int32)) +@jit("int32(int32, int32)") def add_with_sig(a, b): return a + b diff --git a/numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py b/numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py index e52da4960..112e374f8 100644 --- a/numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +++ b/numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause # -*- coding: utf-8 -*- -from numba import int32, int64, uint32, uint64, float32, float64 -from numba.core.types import range_iter32_type +from numba.cuda import int32, int64, uint32, uint64, float32, float64 +from numba.cuda.types import range_iter32_type from numba.cuda import itanium_mangler import unittest diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_linker.py b/numba_cuda/numba/cuda/tests/cudadrv/test_linker.py index 972c869e9..028bd85b9 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_linker.py @@ -13,7 +13,9 @@ from numba.cuda.cudadrv.driver import CudaAPIError, _Linker, LinkerError from numba.cuda import require_context from numba.cuda.tests.support import ignore_internal_warnings -from numba import cuda, void, float64, int64, int32, typeof, float32 +from numba import cuda +from numba.cuda import void, float64, int64, int32, float32 +from numba.cuda.typing.typeof import typeof CONST1D = np.arange(10, dtype=np.float64) diff --git a/numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py b/numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py index f11838949..0c56b9c98 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +++ b/numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba import types +from numba.cuda import types from numba.cuda.core import config @@ -20,7 +20,7 @@ def __init__(self): if not config.ENABLE_CUDASIM: - from numba import int32 + from numba.cuda import int32 from numba.cuda.extending import ( core_models, typeof_impl, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_alignment.py b/numba_cuda/numba/cuda/tests/cudapy/test_alignment.py index e5732e879..3e74e61d4 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_alignment.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba import from_dtype, cuda +from numba import cuda +from numba.cuda.np.numpy_support import from_dtype from numba.cuda.testing import skip_on_cudasim, CUDATestCase import unittest diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_atomics.py b/numba_cuda/numba/cuda/tests/cudapy/test_atomics.py index ba98c5467..40836a2e4 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_atomics.py @@ -4,7 +4,8 @@ import numpy as np from textwrap import dedent -from numba import cuda, uint32, uint64, float32, float64 +from numba import cuda +from numba.cuda import uint32, uint64, float32, float64 from numba.cuda.testing import unittest, CUDATestCase, cc_X_or_above from numba.cuda.core import config diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py b/numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py index b759bbaa8..52fa9e5bb 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py @@ -3,9 +3,8 @@ import numpy as np from ml_dtypes import bfloat16 as mldtypes_bf16 - -from numba import ( - cuda, +from numba import cuda +from numba.cuda import ( float32, float64, int16, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py b/numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py index ea147aec0..4916059a8 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py @@ -10,7 +10,7 @@ import operator from numba.cuda.testing import skip_if_nvjitlink_missing -from numba import ( +from numba.cuda import ( int16, int32, int64, @@ -20,7 +20,7 @@ float32, float64, ) -from numba.types import float16 +from numba.cuda.types import float16 from numba.cuda import config if not config.ENABLE_CUDASIM: diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py b/numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py index 7a0675c50..3612a542e 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py @@ -3,7 +3,8 @@ import numpy as np import math -from numba import cuda, double, void +from numba import cuda +from numba.cuda import double, void from numba.cuda.testing import unittest, CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_casting.py b/numba_cuda/numba/cuda/tests/cudapy/test_casting.py index 547a4af00..5ae3d34d9 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_casting.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_casting.py @@ -4,11 +4,11 @@ import numpy as np from numba.cuda import compile_ptx -from numba.core.types import f2, i1, i2, i4, i8, u1, u2, u4, u8 +from numba.cuda.types import f2, i1, i2, i4, i8, u1, u2, u4, u8 from numba import cuda -from numba.core import types +from numba.cuda import types from numba.cuda.testing import CUDATestCase, skip_on_cudasim, skip_unless_cc_53 -from numba.types import float16, float32 +from numba.cuda.types import float16, float32 import itertools import unittest diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_cffi.py b/numba_cuda/numba/cuda/tests/cudapy/test_cffi.py index ec287ce5a..59e6ba04b 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_cffi.py @@ -3,7 +3,8 @@ import numpy as np -from numba import cuda, types +from numba import cuda +from numba.cuda import types from numba.cuda.testing import ( skip_on_cudasim, test_data_dir, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py b/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py index 45f5634cf..0aa190118 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py @@ -3,18 +3,8 @@ import os from math import sqrt - - -from numba import ( - cuda, - float32, - int16, - int32, - int64, - types, - uint32, - void, -) +from numba import cuda +from numba.cuda import float32, int16, int32, int64, types, uint32, void from numba.cuda import ( compile, compile_for_current_device, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_complex.py b/numba_cuda/numba/cuda/tests/cudapy/test_complex.py index 027497954..c5951a4fe 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_complex.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_complex.py @@ -7,7 +7,7 @@ import numpy as np from numba.cuda.testing import unittest, CUDATestCase -from numba.core import types +from numba.cuda import types from numba import cuda from numba.cuda.tests.complex_usecases import ( real_usecase, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_constmem.py b/numba_cuda/numba/cuda/tests/cudapy/test_constmem.py index f410e2b24..dae53b9ed 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_constmem.py @@ -3,7 +3,8 @@ import numpy as np -from numba import cuda, complex64, int32, float64 +from numba import cuda +from numba.cuda import complex64, int32, float64 from numba.cuda.testing import unittest, CUDATestCase from numba.cuda.core.config import ENABLE_CUDASIM diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py b/numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py index 8fba351b1..4c2f78f0e 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py @@ -9,9 +9,10 @@ import numpy as np -from numba import cuda, int32 +from numba import cuda +from numba.cuda import int32 from numba.cuda import config -from numba.types import CPointer +from numba.cuda.types import CPointer from numba.cuda.testing import ( unittest, CUDATestCase, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py b/numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py index 40ec30fe7..468ee4b58 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py @@ -4,7 +4,8 @@ # SPDX-License-Identifier: BSD-2-Clause # -from numba.core import types, ir, config +from numba.cuda import types, config +from numba.core import ir from numba.cuda import compiler from numba.cuda.core.annotations import type_annotations from numba.cuda.core.ir_utils import ( diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_debug.py b/numba_cuda/numba/cuda/tests/cudapy/test_debug.py index 0c4acc086..0b7cb1821 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_debug.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_debug.py @@ -10,7 +10,8 @@ captured_stderr, captured_stdout, ) -from numba import cuda, float64 +from numba import cuda +from numba.cuda import float64 import unittest diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py index 7d240299b..77f81d410 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py @@ -5,7 +5,7 @@ from numba.cuda.tests.support import override_config, captured_stdout from numba.cuda.testing import skip_on_cudasim from numba import cuda -from numba.core import types +from numba.cuda import types from numba.cuda.testing import CUDATestCase from textwrap import dedent import math diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py index e5ca2b57e..e6e48f7b4 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py @@ -4,7 +4,7 @@ import numba.cuda as cuda from numba.cuda.testing import CUDATestCase, skip_on_cudasim import llvmlite -from numba import types +from numba.cuda import types """ llvmlite pre 45 left redundant metadata nodes for debug info diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py b/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py index 9775b35d0..b617459e0 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py @@ -13,7 +13,8 @@ unittest, CUDATestCase, ) -from numba import cuda, jit, float32, int32, types +from numba import cuda, jit +from numba.cuda import float32, int32, types from numba.core.errors import TypingError from numba.cuda.tests.support import skip_unless_cffi from types import ModuleType diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py b/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py index 4e2e2c292..a8a324202 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py @@ -5,18 +5,17 @@ import numpy as np import threading -from numba import ( +from numba.cuda.types import ( boolean, - cuda, float32, float64, int32, int64, - types, uint32, void, ) -from numba.cuda import config +from numba import cuda +from numba.cuda import config, types from numba.core.errors import TypingError from numba.cuda.testing import ( cc_X_or_above, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_enums.py b/numba_cuda/numba/cuda/tests/cudapy/test_enums.py index 03b4f3cb1..12c5161b4 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_enums.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_enums.py @@ -7,9 +7,9 @@ import numpy as np -from numba import int16, int32 +from numba.cuda import int16, int32 from numba import cuda, vectorize, njit -from numba.core import types +from numba.cuda import types from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim from numba.cuda.tests.enum_usecases import ( Color, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_extending.py b/numba_cuda/numba/cuda/tests/cudapy/test_extending.py index bb711884b..f0e920b83 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_extending.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_extending.py @@ -7,7 +7,8 @@ import numpy as np import os -from numba import cuda, njit, types +from numba import cuda, njit +from numba.cuda import types from numba.cuda import config from numba.cuda.extending import overload diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py b/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py index c1c58dffa..262ea704a 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py @@ -3,7 +3,8 @@ from typing import List from dataclasses import dataclass, field -from numba import cuda, float32 +from numba import cuda +from numba.cuda import float32 from numba.cuda.compiler import compile_ptx_for_current_device, compile_ptx from math import cos, sin, tan, exp, log, log10, log2, pow, tanh from operator import truediv diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_freevar.py b/numba_cuda/numba/cuda/tests/cudapy/test_freevar.py index 07bcb1579..4d638bde5 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_freevar.py @@ -12,7 +12,7 @@ def test_freevar(self): """Make sure we can compile the following kernel with freevar reference in arguments to shared.array """ - from numba import float32 + from numba.cuda import float32 size = 1024 nbtype = float32 diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py b/numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py index 10a83ab99..2c536e706 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py @@ -4,7 +4,7 @@ import numpy as np import math from numba import cuda -from numba.types import float32, float64, int32, void +from numba.cuda.types import float32, float64, int32, void from numba.cuda.testing import unittest, CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_globals.py b/numba_cuda/numba/cuda/tests/cudapy/test_globals.py index c1891b2e1..8420688b0 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_globals.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_globals.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba import cuda, int32, float32 +from numba import cuda +from numba.cuda import int32, float32 from numba.cuda.testing import unittest, CUDATestCase N = 100 diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py b/numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py index 835b1a2e0..975267cfb 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py @@ -4,7 +4,7 @@ import numpy as np from collections import namedtuple -from numba import void, int32, float32, float64 +from numba.cuda import void, int32, float32, float64 from numba import guvectorize from numba import cuda from numba.cuda.testing import skip_on_cudasim, CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_idiv.py b/numba_cuda/numba/cuda/tests/cudapy/test_idiv.py index 52d35a4c4..9550e3e3b 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_idiv.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba import cuda, float32, float64, int32, void +from numba import cuda +from numba.cuda import float32, float64, int32, void from numba.cuda.testing import unittest, CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_inline.py b/numba_cuda/numba/cuda/tests/cudapy/test_inline.py index bb5efb45a..a9e141bce 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_inline.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_inline.py @@ -3,7 +3,8 @@ import re import numpy as np -from numba import cuda, types +from numba import cuda +from numba.cuda import types from numba.cuda.testing import ( unittest, CUDATestCase, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py b/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py index 311bf1270..9c2b47d03 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py @@ -7,8 +7,9 @@ import numpy as np from io import StringIO -from numba import cuda, float32, float64, int32, intp -from numba.types import float16, CPointer +from numba import cuda +from numba.cuda import float32, float64, int32, intp +from numba.cuda.types import float16, CPointer from numba.cuda import declare_device from numba.cuda.testing import unittest, CUDATestCase from numba.cuda.testing import ( diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py b/numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py index 5c2c85d88..e01f87a22 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py @@ -5,9 +5,10 @@ import numpy as np import operator import re -from numba import cuda, int64 +from numba import cuda +from numba.cuda import int64 from numba.core.errors import TypingError -from numba.core.types import f2 +from numba.cuda.types import f2 from numba.cuda.testing import ( unittest, CUDATestCase, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py b/numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py index b03e9c2e3..56547f05f 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py @@ -6,7 +6,8 @@ from numba.cuda.flags import Flags from numba.cuda.core.compiler_machinery import PassManager from numba.cuda.core import ir_utils -from numba.core import types, ir, bytecode +from numba.cuda import types +from numba.core import ir, bytecode from numba.cuda import compiler from numba.cuda.core.untyped_passes import ( ExtractByteCode, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_lang.py b/numba_cuda/numba/cuda/tests/cudapy/test_lang.py index 291f579f7..c8ea278ab 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_lang.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_lang.py @@ -7,7 +7,8 @@ """ import numpy as np -from numba import cuda, float64 +from numba import cuda +from numba.cuda import float64 from numba.cuda.testing import unittest, CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_laplace.py b/numba_cuda/numba/cuda/tests/cudapy/test_laplace.py index 1973df1d8..8874c449d 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_laplace.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba import cuda, float64, void +from numba import cuda +from numba.cuda import float64, void from numba.cuda.testing import unittest, CUDATestCase from numba.cuda.core import config diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py b/numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py index eeb23348d..1bd73fcbc 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba.core import types +from numba.cuda import types from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase from numba import cuda from numba.cuda import libdevice, compile_ptx diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py b/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py index 763f3240d..cff6c7294 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba import cuda, float32, int32 +from numba import cuda +from numba.cuda import float32, int32 from numba.core.errors import NumbaInvalidConfigWarning from numba.cuda.testing import CUDATestCase, skip_on_cudasim from numba.cuda.tests.support import ignore_internal_warnings diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_localmem.py b/numba_cuda/numba/cuda/tests/cudapy/test_localmem.py index 1cbd58f2f..0099a532d 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_localmem.py @@ -3,8 +3,9 @@ import numpy as np -from numba import cuda, int32, complex128, void -from numba.core import types +from numba import cuda +from numba.cuda import int32, complex128, void +from numba.cuda import types from numba.core.errors import TypingError from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim from .extensions_usecases import test_struct_model_type, TestStruct diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_mandel.py b/numba_cuda/numba/cuda/tests/cudapy/test_mandel.py index f4d352544..1ca3bfe8f 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_mandel.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba import float64, uint32 +from numba.cuda import float64, uint32 from numba.cuda.compiler import compile_ptx from numba.cuda.testing import skip_on_cudasim, unittest diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_math.py b/numba_cuda/numba/cuda/tests/cudapy/test_math.py index 57f0e3d97..34776c707 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_math.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_math.py @@ -9,7 +9,8 @@ skip_on_cudasim, ) from numba.cuda.np import numpy_support -from numba import cuda, float32, float64, int32, vectorize, void, int64 +from numba import cuda, vectorize +from numba.cuda import float32, float64, int32, void, int64 import math diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_matmul.py b/numba_cuda/numba/cuda/tests/cudapy/test_matmul.py index 9785932cc..00f0256bd 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_matmul.py @@ -3,7 +3,8 @@ import numpy as np -from numba import cuda, float32, void +from numba import cuda +from numba.cuda import float32, void from numba.cuda.testing import unittest, CUDATestCase from numba.cuda.core import config diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_minmax.py b/numba_cuda/numba/cuda/tests/cudapy/test_minmax.py index 6317d1530..26f2ab830 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_minmax.py @@ -3,7 +3,8 @@ import numpy as np -from numba import cuda, float64 +from numba import cuda +from numba.cuda import float64 from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_nondet.py b/numba_cuda/numba/cuda/tests/cudapy/test_nondet.py index 343c1c5e5..2becebd0a 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_nondet.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba import cuda, float32, void +from numba import cuda +from numba.cuda import float32, void from numba.cuda.testing import unittest, CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_operator.py b/numba_cuda/numba/cuda/tests/cudapy/test_operator.py index ee496242f..6c8790972 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_operator.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_operator.py @@ -10,8 +10,8 @@ skip_if_nvjitlink_missing, ) from numba import cuda -from numba.core import types -from numba.core.types import f2, b1 +from numba.cuda import types +from numba.cuda.types import f2, b1 from numba.cuda.typing import signature import operator import itertools diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py b/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py index 0856f05dc..4c3cc37ba 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py @@ -4,7 +4,8 @@ import numpy as np from numba.cuda.testing import skip_on_cudasim, CUDATestCase -from numba import cuda, float64 +from numba import cuda +from numba.cuda import float64 import unittest diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_overload.py b/numba_cuda/numba/cuda/tests/cudapy/test_overload.py index 0c2e32796..ff8190671 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_overload.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_overload.py @@ -5,6 +5,7 @@ from numba.core.errors import TypingError from numba.cuda.extending import overload, overload_attribute from numba.cuda.typing.typeof import typeof +from numba.core.typing.typeof import typeof as cpu_typeof from numba.cuda.testing import CUDATestCase, skip_on_cudasim, unittest import numpy as np @@ -329,7 +330,8 @@ def kernel(x): def test_overload_attribute_target(self): MyDummy, MyDummyType = self.make_dummy_type() - mydummy_type = typeof(MyDummy()) + mydummy_type_cpu = cpu_typeof(MyDummy()) # For @njit (cpu) + mydummy_type = typeof(MyDummy()) # For @cuda.jit (CUDA) @overload_attribute(MyDummyType, "cuda_only", target="cuda") def ov_dummy_cuda_attr(obj): @@ -351,7 +353,7 @@ def imp(obj): with self.assertRaisesRegex(TypingError, msg): - @njit(types.int64(mydummy_type)) + @njit(types.int64(mydummy_type_cpu)) def illegal_target_attr_use(x): return x.cuda_only diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_powi.py b/numba_cuda/numba/cuda/tests/cudapy/test_powi.py index 30ebf793d..4bf80bb1a 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_powi.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_powi.py @@ -3,7 +3,8 @@ import math import numpy as np -from numba import cuda, float64, int8, int32, void +from numba import cuda +from numba.cuda import float64, int8, int32, void from numba.cuda.testing import unittest, CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_print.py b/numba_cuda/numba/cuda/tests/cudapy/test_print.py index 3b700316c..ef3ef1d9c 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_print.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_print.py @@ -110,7 +110,7 @@ def print_too_many(r): def print_bfloat16(): # 0.9375 is a dyadic rational, it's integer significand can expand within 7 digits. # printing this should not give any rounding error. - a = cuda.types.bfloat16(0.9375) + a = cuda.ext_types.bfloat16(0.9375) print(a, a, a) print_bfloat16[1, 1]() diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py b/numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py index f32909527..2ffb31a13 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba import cuda, float32, int32, void +from numba import cuda +from numba.cuda import float32, int32, void from numba.cuda.testing import unittest, CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py b/numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py index ab419762c..2a18e836a 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py @@ -3,7 +3,7 @@ import numpy as np from numba import cuda -from numba.core import types +from numba.cuda import types from numba.cuda.testing import skip_on_cudasim, CUDATestCase import unittest from numba.cuda.np import numpy_support diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_serialize.py b/numba_cuda/numba/cuda/tests/cudapy/test_serialize.py index e8b62646e..f82abfc67 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_serialize.py @@ -4,7 +4,7 @@ import pickle import numpy as np from numba import cuda, vectorize -from numba.core import types +from numba.cuda import types from numba.cuda.testing import skip_on_cudasim, CUDATestCase import unittest from numba.cuda.np import numpy_support diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_sm.py b/numba_cuda/numba/cuda/tests/cudapy/test_sm.py index b5e9bfdeb..fd00d1d3a 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_sm.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_sm.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba import cuda, int32, float64, void +from numba import cuda +from numba.cuda import int32, float64, void from numba.core.errors import TypingError -from numba.core import types +from numba.cuda import types from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim import numpy as np diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py b/numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py index 89d674a3e..6e2f51bff 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba import cuda, float32, int32, void +from numba import cuda +from numba.cuda import float32, int32, void from numba.core.errors import TypingError from numba.cuda.testing import unittest, CUDATestCase from numba.cuda.testing import skip_on_cudasim diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_ssa.py b/numba_cuda/numba/cuda/tests/cudapy/test_ssa.py index c49f0fd90..d53060d0d 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_ssa.py @@ -10,7 +10,8 @@ import numpy as np -from numba import types, cuda +from numba.cuda import types +from numba import cuda from numba.cuda import jit from numba.core import errors diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_sync.py b/numba_cuda/numba/cuda/tests/cudapy/test_sync.py index b33d0e6f8..5df063a05 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_sync.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_sync.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba import cuda, int32, float32 +from numba import cuda +from numba.cuda import int32, float32 from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase from numba.cuda.core.config import ENABLE_CUDASIM diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py b/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py index 16aa2b4d7..82c25d723 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py @@ -3,7 +3,7 @@ import itertools -from numba.core import types +from numba.cuda import types from numba.cuda.typeconv.typeconv import TypeManager, TypeCastingRules from numba.cuda.typeconv import rules from numba.cuda.typeconv import castgraph, Conversion diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py b/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py index 73bdeaa92..9769c1f83 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py @@ -3,8 +3,8 @@ import itertools -from numba.core import errors, types -from numba.cuda import typing +from numba.core import errors +from numba.cuda import types, typing from numba.cuda.typeconv import Conversion from numba.cuda.testing import CUDATestCase, skip_on_cudasim diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py b/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py index 4474ce34c..f09205c1f 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py @@ -6,8 +6,9 @@ import numpy as np import unittest -from numba import cuda, types, njit, typeof -from numba.cuda import config +from numba import cuda, njit +from numba.cuda import config, types +from numba.cuda.typing.typeof import typeof from numba.cuda.np import numpy_support from numba.cuda.tests.support import TestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py index 1167303e9..3a41b1234 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py @@ -7,7 +7,8 @@ from functools import partial from itertools import product from numba.cuda import vectorize as cuda_vectorize -from numba import cuda, int32, float32, float64, vectorize as numba_vectorize +from numba import cuda, vectorize as numba_vectorize +from numba.cuda.types import int32, float32, float64 from numba.cuda.cudadrv.driver import CudaAPIError, driver from numba.cuda.testing import skip_on_cudasim from numba.cuda.testing import CUDATestCase diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py index 78fc82059..d9d7c78ec 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py @@ -4,8 +4,8 @@ import numpy as np import math -from numba import cuda, int32, uint32, float32, float64 -from numba.cuda import vectorize +from numba import cuda +from numba.cuda import vectorize, int32, uint32, float32, float64 from numba.cuda.testing import skip_on_cudasim, CUDATestCase from numba.cuda.tests.support import CheckWarningsMixin diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py index d7777337e..bd782994d 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: BSD-2-Clause from numba.cuda import vectorize -from numba import cuda, float32 +from numba import cuda +from numba.cuda import float32 import numpy as np from numba.cuda.testing import skip_on_cudasim, CUDATestCase import unittest diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py index 46c00394a..169e2c6b8 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py @@ -3,7 +3,8 @@ import numpy as np from numba.cuda import vectorize -from numba import cuda, float64 +from numba import cuda +from numba.cuda import float64 from numba.cuda.testing import skip_on_cudasim, CUDATestCase import unittest diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_warning.py b/numba_cuda/numba/cuda/tests/cudapy/test_warning.py index fb2f71707..0f8f79301 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_warning.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_warning.py @@ -24,11 +24,11 @@ class TestWarnings(CUDATestCase): def test_float16_warn_if_lto_missing(self): fp16_kernel_invocation = """ import math -from numba import cuda, core +from numba import cuda @cuda.jit def kernel(): - x = core.types.float16(1.0) + x = cuda.types.float16(1.0) y = math.sin(x) kernel[1,1]() diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py b/numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py index 283b71c78..0a5bc3709 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py @@ -4,7 +4,8 @@ import re import numpy as np -from numba import cuda, int32, int64, float32, float64 +from numba import cuda +from numba.cuda import int32, int64, float32, float64 from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim from numba.cuda.compiler import compile_ptx from numba.cuda.core import config diff --git a/numba_cuda/numba/cuda/tests/doc_examples/test_cg.py b/numba_cuda/numba/cuda/tests/doc_examples/test_cg.py index 5834f6df0..be4699279 100644 --- a/numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +++ b/numba_cuda/numba/cuda/tests/doc_examples/test_cg.py @@ -19,7 +19,8 @@ class TestCooperativeGroups(CUDATestCase): def test_ex_grid_sync(self): # magictoken.ex_grid_sync_kernel.begin - from numba import cuda, int32 + from numba import cuda + from numba.cuda import int32 import numpy as np sig = (int32[:, ::1],) diff --git a/numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py b/numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py index aa95268cc..2d85c1241 100644 --- a/numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +++ b/numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py @@ -28,7 +28,8 @@ def tearDown(self): def test_ex_cpointer(self): # ex_cpointer.sig.begin import numpy as np - from numba import cuda, types + from numba import cuda + from numba.cuda import types # The first kernel argument is a pointer to a uint8 array. # The second argument holds the length as a uint32. diff --git a/numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py b/numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py index 7eaad2829..ee8ed61a8 100644 --- a/numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +++ b/numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py @@ -36,7 +36,8 @@ def tearDown(self): def test_ex_matmul(self): """Test of matrix multiplication on various cases.""" # magictoken.ex_import.begin - from numba import cuda, float32 + from numba import cuda + from numba.cuda import float32 import numpy as np import math # magictoken.ex_import.end diff --git a/numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py b/numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py index da177f890..dbb9dc079 100644 --- a/numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +++ b/numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py @@ -28,7 +28,7 @@ def test_ex_reduction(self): # ex_reduction.import.begin import numpy as np from numba import cuda - from numba.types import int32 + from numba.cuda.types import int32 # ex_reduction.import.end # ex_reduction.allocate.begin diff --git a/numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py b/numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py index 9f82a29f9..14d49fdaf 100644 --- a/numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +++ b/numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py @@ -3,8 +3,7 @@ from numba.cuda.testing import unittest, skip_on_cudasim import operator -from numba.core import types -from numba.cuda import typing +from numba.cuda import types, typing from numba.cuda.cudadrv import nvvm diff --git a/numba_cuda/numba/cuda/tests/nrt/test_nrt.py b/numba_cuda/numba/cuda/tests/nrt/test_nrt.py index 7505660b2..1efeed919 100644 --- a/numba_cuda/numba/cuda/tests/nrt/test_nrt.py +++ b/numba_cuda/numba/cuda/tests/nrt/test_nrt.py @@ -10,7 +10,7 @@ from numba.cuda.tests.support import run_in_subprocess, override_config from numba.cuda import get_current_device from numba.cuda.cudadrv.nvrtc import compile -from numba import types +from numba.cuda import types from numba.cuda.typing import signature from numba import cuda from numba.cuda import config diff --git a/numba_cuda/numba/cuda/tests/support.py b/numba_cuda/numba/cuda/tests/support.py index fa1d5f3f3..1a7afeb2b 100644 --- a/numba_cuda/numba/cuda/tests/support.py +++ b/numba_cuda/numba/cuda/tests/support.py @@ -22,7 +22,7 @@ import numpy as np -from numba import types +from numba.cuda import types from numba.core import errors from numba.cuda.core import config from numba.cuda.typing import cffi_utils @@ -33,9 +33,16 @@ NativeValue, ) from numba.cuda.core.pythonapi import unbox -from numba.core.datamodel.models import OpaqueModel +from numba.cuda.datamodel.models import OpaqueModel from numba.cuda.np import numpy_support +try: + from numba.core.extending import typeof_impl as upstream_typeof_impl + from numba.core import types as upstream_types +except ImportError: + upstream_typeof_impl = None + upstream_types = None + class EnableNRTStatsMixin(object): """Mixin to enable the NRT statistics counters.""" @@ -761,6 +768,17 @@ class Dummy(object): def typeof_dummy(val, c): return dummy_type + # Dual registration for cross-target tests + if upstream_typeof_impl is not None and upstream_types is not None: + UpstreamDummyType = type( + "DummyTypeFor{}".format(test_id), (upstream_types.Opaque,), {} + ) + upstream_dummy_type = UpstreamDummyType("my_dummy") + + @upstream_typeof_impl.register(Dummy) + def typeof_dummy_core(val, c): + return upstream_dummy_type + @unbox(DummyType) def unbox_dummy(typ, obj, c): return NativeValue(c.context.get_dummy_value()) diff --git a/numba_cuda/numba/cuda/tests/test_analysis.py b/numba_cuda/numba/cuda/tests/test_analysis.py index fe84ac5e7..47321a36f 100644 --- a/numba_cuda/numba/cuda/tests/test_analysis.py +++ b/numba_cuda/numba/cuda/tests/test_analysis.py @@ -10,7 +10,8 @@ from numba.cuda.flags import Flags from numba.cuda.core.compiler import StateDict from numba.cuda import jit -from numba.core import types, errors, ir +from numba.cuda import types +from numba.core import errors, ir from numba.cuda.utils import PYVERSION from numba.cuda.core import postproc, rewrites, ir_utils from numba.cuda.core.options import ParallelOptions diff --git a/numba_cuda/numba/cuda/tests/test_extending.py b/numba_cuda/numba/cuda/tests/test_extending.py index 15da78e74..69681e463 100644 --- a/numba_cuda/numba/cuda/tests/test_extending.py +++ b/numba_cuda/numba/cuda/tests/test_extending.py @@ -10,7 +10,8 @@ import numba from numba.cuda import jit -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda.tests.support import ( TestCase, ) diff --git a/numba_cuda/numba/cuda/tests/test_extending_types.py b/numba_cuda/numba/cuda/tests/test_extending_types.py index 0170156c4..08dfd4da2 100644 --- a/numba_cuda/numba/cuda/tests/test_extending_types.py +++ b/numba_cuda/numba/cuda/tests/test_extending_types.py @@ -6,7 +6,7 @@ """ from numba.cuda import jit -from numba.core import types +from numba.cuda import types from numba.core.errors import TypingError, NumbaTypeError from numba.cuda.extending import make_attribute_wrapper from numba.cuda.extending import overload diff --git a/numba_cuda/numba/cuda/tests/test_flow_control.py b/numba_cuda/numba/cuda/tests/test_flow_control.py index 49710819d..3fa45a09e 100644 --- a/numba_cuda/numba/cuda/tests/test_flow_control.py +++ b/numba_cuda/numba/cuda/tests/test_flow_control.py @@ -6,7 +6,7 @@ import unittest from numba.cuda import jit from numba.cuda.core.controlflow import CFGraph, ControlFlowAnalysis -from numba.core import types +from numba.cuda import types from numba.cuda.core.bytecode import ( FunctionIdentity, ByteCode, diff --git a/numba_cuda/numba/cuda/typeconv/rules.py b/numba_cuda/numba/cuda/typeconv/rules.py index 8f6513957..484f83254 100644 --- a/numba_cuda/numba/cuda/typeconv/rules.py +++ b/numba_cuda/numba/cuda/typeconv/rules.py @@ -3,7 +3,7 @@ import itertools from .typeconv import TypeManager, TypeCastingRules -from numba.core import types +from numba.cuda import types default_type_manager = TypeManager() diff --git a/numba_cuda/numba/cuda/typeconv/typeconv.py b/numba_cuda/numba/cuda/typeconv/typeconv.py index a87ce5043..7510700f2 100644 --- a/numba_cuda/numba/cuda/typeconv/typeconv.py +++ b/numba_cuda/numba/cuda/typeconv/typeconv.py @@ -3,7 +3,7 @@ from numba.cuda.cext import _typeconv from numba.cuda.typeconv import castgraph, Conversion -from numba.core import types +from numba.cuda import types class TypeManager(object): diff --git a/numba_cuda/numba/cuda/types/__init__.py b/numba_cuda/numba/cuda/types/__init__.py new file mode 100644 index 000000000..4e082ba20 --- /dev/null +++ b/numba_cuda/numba/cuda/types/__init__.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import struct + +import numpy as np +from numba.cuda import utils + +from .abstract import * +from .containers import * +from .functions import * +from .iterators import * +from .misc import * +from .npytypes import * +from .scalars import * +from .function_type import * + +numpy_version = tuple(map(int, np.__version__.split(".")[:2])) + +# Short names + +pyobject = PyObject("pyobject") +ffi_forced_object = Opaque("ffi_forced_object") +ffi = Opaque("ffi") +none = NoneType("none") +ellipsis = EllipsisType("...") +Any = Phantom("any") +undefined = Undefined("undefined") +py2_string_type = Opaque("str") +unicode_type = UnicodeType("unicode_type") +string = unicode_type +unknown = Dummy("unknown") +npy_rng = NumPyRandomGeneratorType("rng") +npy_bitgen = NumPyRandomBitGeneratorType("bitgen") + +# _undef_var is used to represent undefined variables in the type system. +_undef_var = UndefVar("_undef_var") + +code_type = Opaque("code") +pyfunc_type = Opaque("pyfunc") + +# No operation is defined on voidptr +# Can only pass it around +voidptr = RawPointer("void*") + +# optional types +optional = Optional +deferred_type = DeferredType +slice2_type = SliceType("slice", 2) +slice3_type = SliceType("slice", 3) +void = none + +# Need to ignore mypy errors because mypy cannot unify types for both +# the type systems even if they're logically mutually exclusive. +# mypy: ignore-errors + +boolean = bool_ = Boolean("bool") +if numpy_version >= (2, 0): + bool = bool_ + +byte = uint8 = Integer("uint8") +uint16 = Integer("uint16") +uint32 = Integer("uint32") +uint64 = Integer("uint64") + +int8 = Integer("int8") +int16 = Integer("int16") +int32 = Integer("int32") +int64 = Integer("int64") +intp = int32 if utils.MACHINE_BITS == 32 else int64 +uintp = uint32 if utils.MACHINE_BITS == 32 else uint64 +intc = int32 if struct.calcsize("i") == 4 else int64 +uintc = uint32 if struct.calcsize("I") == 4 else uint64 +ssize_t = int32 if struct.calcsize("n") == 4 else int64 +size_t = uint32 if struct.calcsize("N") == 4 else uint64 + +float32 = Float("float32") +float64 = Float("float64") +float16 = Float("float16") + +complex64 = Complex("complex64", float32) +complex128 = Complex("complex128", float64) + +range_iter32_type = RangeIteratorType(int32) +range_iter64_type = RangeIteratorType(int64) +unsigned_range_iter64_type = RangeIteratorType(uint64) +range_state32_type = RangeType(int32) +range_state64_type = RangeType(int64) +unsigned_range_state64_type = RangeType(uint64) + +signed_domain = frozenset([int8, int16, int32, int64]) +unsigned_domain = frozenset([uint8, uint16, uint32, uint64]) +integer_domain = signed_domain | unsigned_domain +real_domain = frozenset([float32, float64]) +complex_domain = frozenset([complex64, complex128]) +number_domain = real_domain | integer_domain | complex_domain + +# Integer Aliases +c_bool = py_bool = np_bool_ = boolean + +c_uint8 = np_uint8 = uint8 +c_uint16 = np_uint16 = uint16 +c_uint32 = np_uint32 = uint32 +c_uint64 = np_uint64 = uint64 +c_uintp = np_uintp = uintp + +c_int8 = np_int8 = int8 +c_int16 = np_int16 = int16 +c_int32 = np_int32 = int32 +c_int64 = np_int64 = int64 +c_intp = py_int = np_intp = intp + +c_float16 = np_float16 = float16 +c_float32 = np_float32 = float32 +c_float64 = py_float = np_float64 = float64 + +np_complex64 = complex64 +py_complex = np_complex128 = complex128 + +# Domain Aliases +py_signed_domain = np_signed_domain = signed_domain +np_unsigned_domain = unsigned_domain +py_integer_domain = np_integer_domain = integer_domain +py_real_domain = np_real_domain = real_domain +py_complex_domain = np_complex_domain = complex_domain +py_number_domain = np_number_domain = number_domain + +# Aliases to NumPy type names + +b1 = bool_ +i1 = int8 +i2 = int16 +i4 = int32 +i8 = int64 +u1 = uint8 +u2 = uint16 +u4 = uint32 +u8 = uint64 + +f2 = float16 +f4 = float32 +f8 = float64 + +c8 = complex64 +c16 = complex128 + +np_float_ = float32 +np_double = double = float64 +if numpy_version < (2, 0): + float_ = float32 + +_make_signed = lambda x: globals()["int%d" % (np.dtype(x).itemsize * 8)] +_make_unsigned = lambda x: globals()["uint%d" % (np.dtype(x).itemsize * 8)] + +char = np_char = _make_signed(np.byte) +uchar = np_uchar = byte = _make_unsigned(np.byte) +short = np_short = _make_signed(np.short) +ushort = np_ushort = _make_unsigned(np.short) +int_ = np_int_ = _make_signed(np.int_) +uint = np_uint = _make_unsigned(np.int_) +intc = np_intc = _make_signed(np.intc) # C-compat int +uintc = np_uintc = _make_unsigned(np.uintc) # C-compat uint +long_ = np_long = _make_signed(np.int_) # C-compat long +ulong = np_ulong = _make_unsigned(np.int_) # C-compat ulong +longlong = np_longlong = _make_signed(np.longlong) +ulonglong = np_ulonglong = _make_unsigned(np.longlong) + +all_str = """ + int8 + int16 + int32 + int64 + uint8 + uint16 + uint32 + uint64 + intp + uintp + intc + uintc + ssize_t + size_t + boolean + float32 + float64 + complex64 + complex128 + bool_ + byte + char + uchar + short + ushort + int_ + uint + long_ + ulong + longlong + ulonglong + float_ + double + void + none + b1 + i1 + i2 + i4 + i8 + u1 + u2 + u4 + u8 + f4 + f8 + c8 + c16 + optional + ffi_forced_object + ffi + deferred_type +""" + + +__all__ = all_str.split() +if numpy_version >= (2, 0): + __all__.remove("float_") + __all__.append("bool") diff --git a/numba_cuda/numba/cuda/types/__init__.pyi b/numba_cuda/numba/cuda/types/__init__.pyi new file mode 100644 index 000000000..f3a92c7de --- /dev/null +++ b/numba_cuda/numba/cuda/types/__init__.pyi @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from .abstract import * +from .common import Opaque +from .containers import * +from .function_type import * +from .functions import * +from .iterators import * +from .misc import * +from .npytypes import * +from .scalars import * + +__all__ = [ + "b1", + "bool", + "bool_", + "boolean", + "byte", + "c8", + "c16", + "char", + "complex64", + "complex128", + "deferred_type", + "double", + "f4", + "f8", + "ffi", + "ffi_forced_object", + "float32", + "float64", + "i1", + "i2", + "i4", + "i8", + "int8", + "int16", + "int32", + "int64", + "int_", + "intc", + "intp", + "long_", + "longlong", + "none", + "optional", + "short", + "size_t", + "ssize_t", + "u1", + "u2", + "u4", + "u8", + "uchar", + "uint", + "uint8", + "uint16", + "uint32", + "uint64", + "uintc", + "uintp", + "ulong", + "ulonglong", + "ushort", + "void", +] + +# TODO: Final + +pyobject: PyObject = ... +ffi_forced_object: Opaque = ... +ffi: Opaque = ... +none: NoneType = ... +ellipsis: EllipsisType = ... +Any: Phantom = ... +undefined: Undefined = ... +py2_string_type: Opaque = ... +unicode_type: UnicodeType = ... +string: UnicodeType = ... +unknown: Dummy = ... +npy_rng: NumPyRandomGeneratorType = ... +npy_bitgen: NumPyRandomBitGeneratorType = ... + +_undef_var: UndefVar = ... + +code_type: Opaque = ... +pyfunc_type: Opaque = ... + +voidptr: RawPointer = ... + +optional = Optional +deferred_type = DeferredType +slice2_type: SliceType = ... +slice3_type: SliceType = ... +void: NoneType = ... + +boolean: Boolean = ... +bool_: Boolean = ... +bool: Boolean = ... # numpy>=2 + +int8: Integer = ... +int16: Integer = ... +int32: Integer = ... +int64: Integer = ... +intp: Integer = ... +intc: Integer = ... +ssize_t: Integer = ... +char: Integer = ... +short: Integer = ... +int_: Integer = ... +long_: Integer = ... +longlong: Integer = ... + +byte: Integer = ... +uint8: Integer = ... +uint16: Integer = ... +uint32: Integer = ... +uint64: Integer = ... +uintp: Integer = ... +uintc: Integer = ... +size_t: Integer = ... +uchar: Integer = ... +ushort: Integer = ... +uint: Integer = ... +ulong: Integer = ... +ulonglong: Integer = ... + +float16: Float = ... +float32: Float = ... +float_ = float32 +float64: Float = ... +double = float64 + +# TODO: make generic in the wrapped `Float` type +complex64: Complex = ... +complex128: Complex = ... + +range_iter32_type: RangeIteratorType = ... +range_iter64_type: RangeIteratorType = ... +unsigned_range_iter64_type: RangeIteratorType = ... +range_state32_type: RangeType = ... +range_state64_type: RangeType = ... +unsigned_range_state64_type: RangeType = ... + +signed_domain: frozenset[Integer] = ... +unsigned_domain: frozenset[Integer] = ... +integer_domain: frozenset[Integer] = ... +real_domain: frozenset[Float] = ... +complex_domain: frozenset[Complex] = ... +number_domain: frozenset[Integer | Float | Complex] = ... + +np_float_ = float32 +b1 = bool_ +i1 = int8 +i2 = int16 +i4 = int32 +i8 = int64 +u1 = uint8 +u2 = uint16 +u4 = uint32 +u8 = uint64 +f2 = float16 +f4 = float32 +f8 = float64 +c8 = complex64 +c16 = complex128 diff --git a/numba_cuda/numba/cuda/types/abstract.py b/numba_cuda/numba/cuda/types/abstract.py new file mode 100644 index 000000000..51b27d9bf --- /dev/null +++ b/numba_cuda/numba/cuda/types/abstract.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.types.abstract", "numba.cuda.types.cuda_abstract" +) diff --git a/numba_cuda/numba/cuda/types/common.py b/numba_cuda/numba/cuda/types/common.py new file mode 100644 index 000000000..5098c4827 --- /dev/null +++ b/numba_cuda/numba/cuda/types/common.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.types.common", "numba.cuda.types.cuda_common" +) diff --git a/numba_cuda/numba/cuda/types/containers.py b/numba_cuda/numba/cuda/types/containers.py new file mode 100644 index 000000000..8a46c45e8 --- /dev/null +++ b/numba_cuda/numba/cuda/types/containers.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.types.containers", "numba.cuda.types.cuda_containers" +) diff --git a/numba_cuda/numba/cuda/types/cuda_abstract.py b/numba_cuda/numba/cuda/types/cuda_abstract.py new file mode 100644 index 000000000..470f9966d --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_abstract.py @@ -0,0 +1,533 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from abc import ABCMeta, abstractmethod, abstractproperty +from typing import Dict as ptDict, Type as ptType +import itertools +import weakref +from functools import cached_property + +import numpy as np + +from numba.cuda.utils import get_hashable_key + +# Types are added to a global registry (_typecache) in order to assign +# them unique integer codes for fast matching in _dispatcher.c. +# However, we also want types to be disposable, therefore we ensure +# each type is interned as a weak reference, so that it lives only as +# long as necessary to keep a stable type code. +# NOTE: some types can still be made immortal elsewhere (for example +# in _dispatcher.c's internal caches). +_typecodes = itertools.count() + + +def _autoincr(): + n = next(_typecodes) + # 4 billion types should be enough, right? + assert n < 2**32, "Limited to 4 billion types" + return n + + +_typecache: ptDict[weakref.ref, weakref.ref] = {} + + +def _on_type_disposal(wr, _pop=_typecache.pop): + _pop(wr, None) + + +class _TypeMetaclass(ABCMeta): + """ + A metaclass that will intern instances after they are created. + This is done by first creating a new instance (including calling + __init__, which sets up the required attributes for equality + and hashing), then looking it up in the _typecache registry. + """ + + def __init__(cls, name, bases, orig_vars): + # __init__ is hooked to mark whether a Type class being defined is a + # Numba internal type (one which is defined somewhere under the `numba` + # module) or an external type (one which is defined elsewhere, for + # example a user defined type). + super(_TypeMetaclass, cls).__init__(name, bases, orig_vars) + root = (cls.__module__.split("."))[0] + cls._is_internal = root == "numba" + + def _intern(cls, inst): + # Try to intern the created instance + wr = weakref.ref(inst, _on_type_disposal) + orig = _typecache.get(wr) + orig = orig and orig() + if orig is not None: + return orig + else: + inst._code = _autoincr() + _typecache[wr] = wr + return inst + + def __call__(cls, *args, **kwargs): + """ + Instantiate *cls* (a Type subclass, presumably) and intern it. + If an interned instance already exists, it is returned, otherwise + the new instance is returned. + """ + inst = type.__call__(cls, *args, **kwargs) + return cls._intern(inst) + + +def _type_reconstructor(reconstructor, reconstructor_args, state): + """ + Rebuild function for unpickling types. + """ + obj = reconstructor(*reconstructor_args) + if state: + obj.__dict__.update(state) + return type(obj)._intern(obj) + + +class Type(metaclass=_TypeMetaclass): + """ + The base class for all Numba types. + It is essential that proper equality comparison is implemented. The + default implementation uses the "key" property (overridable in subclasses) + for both comparison and hashing, to ensure sane behaviour. + """ + + mutable = False + # Rather the type is reflected at the python<->nopython boundary + reflected = False + + def __init__(self, name): + self.name = name + + @property + def key(self): + """ + A property used for __eq__, __ne__ and __hash__. Can be overridden + in subclasses. + """ + return self.name + + @property + def mangling_args(self): + """ + Returns `(basename, args)` where `basename` is the name of the type + and `args` is a sequence of parameters of the type. + + Subclass should override to specialize the behavior. + By default, this returns `(self.name, ())`. + """ + return self.name, () + + def __repr__(self): + return self.name + + def __str__(self): + return self.name + + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.__class__ is other.__class__ and self.key == other.key + + def __ne__(self, other): + return not (self == other) + + def __reduce__(self): + reconstructor, args, state = super(Type, self).__reduce__() + return (_type_reconstructor, (reconstructor, args, state)) + + def unify(self, typingctx, other): + """ + Try to unify this type with the *other*. A third type must + be returned, or None if unification is not possible. + Only override this if the coercion logic cannot be expressed + as simple casting rules. + """ + return None + + def can_convert_to(self, typingctx, other): + """ + Check whether this type can be converted to the *other*. + If successful, must return a string describing the conversion, e.g. + "exact", "promote", "unsafe", "safe"; otherwise None is returned. + """ + return None + + def can_convert_from(self, typingctx, other): + """ + Similar to *can_convert_to*, but in reverse. Only needed if + the type provides conversion from other types. + """ + return None + + def is_precise(self): + """ + Whether this type is precise, i.e. can be part of a successful + type inference. Default implementation returns True. + """ + return True + + def augment(self, other): + """ + Augment this type with the *other*. Return the augmented type, + or None if not supported. + """ + return None + + # User-facing helpers. These are not part of the core Type API but + # are provided so that users can write e.g. `numba.boolean(1.5)` + # (returns True) or `types.int32(types.int32[:])` (returns something + # usable as a function signature). + + def __call__(self, *args): + from numba.cuda.typing import signature + + if len(args) == 1 and not isinstance(args[0], Type): + return self.cast_python_value(args[0]) + return signature( + self, # return_type + *args, + ) + + def __getitem__(self, args): + """ + Return an array of this type. + """ + from numba.cuda.types import Array + + ndim, layout = self._determine_array_spec(args) + return Array(dtype=self, ndim=ndim, layout=layout) + + def _determine_array_spec(self, args): + # XXX non-contiguous by default, even for 1d arrays, + # doesn't sound very intuitive + def validate_slice(s): + return isinstance(s, slice) and s.start is None and s.stop is None + + if isinstance(args, (tuple, list)) and all(map(validate_slice, args)): + ndim = len(args) + if args[0].step == 1: + layout = "F" + elif args[-1].step == 1: + layout = "C" + else: + layout = "A" + elif validate_slice(args): + ndim = 1 + if args.step == 1: + layout = "C" + else: + layout = "A" + else: + # Raise a KeyError to not be handled by collection constructors (e.g. list). + raise KeyError( + f"Can only index numba types with slices with no start or stop, got {args}." + ) + + return ndim, layout + + def cast_python_value(self, args): + raise NotImplementedError + + @property + def is_internal(self): + """Returns True if this class is an internally defined Numba type by + virtue of the module in which it is instantiated, False else.""" + return self._is_internal + + def dump(self, tab=""): + print( + f"{tab}DUMP {type(self).__name__}[code={self._code}, name={self.name}]" + ) + + +# XXX we should distinguish between Dummy (no meaningful +# representation, e.g. None or a builtin function) and Opaque (has a +# meaningful representation, e.g. ExternalFunctionPointer) + + +class Dummy(Type): + """ + Base class for types that do not really have a representation and are + compatible with a void*. + """ + + +class Hashable(Type): + """ + Base class for hashable types. + """ + + +class Number(Hashable): + """ + Base class for number types. + """ + + def unify(self, typingctx, other): + """ + Unify the two number types using Numpy's rules. + """ + from numba.cuda.np import numpy_support + + if isinstance(other, Number): + # XXX: this can produce unsafe conversions, + # e.g. would unify {int64, uint64} to float64 + a = numpy_support.as_dtype(self) + b = numpy_support.as_dtype(other) + sel = np.promote_types(a, b) + return numpy_support.from_dtype(sel) + + +class Callable(Type): + """ + Base class for callables. + """ + + @abstractmethod + def get_call_type(self, context, args, kws): + """ + Using the typing *context*, resolve the callable's signature for + the given arguments. A signature object is returned, or None. + """ + + @abstractmethod + def get_call_signatures(self): + """ + Returns a tuple of (list of signatures, parameterized) + """ + + @abstractmethod + def get_impl_key(self, sig): + """ + Returns the impl key for the given signature + """ + + +class DTypeSpec(Type): + """ + Base class for types usable as "dtype" arguments to various Numpy APIs + (e.g. np.empty()). + """ + + @abstractproperty + def dtype(self): + """ + The actual dtype denoted by this dtype spec (a Type instance). + """ + + +class IterableType(Type): + """ + Base class for iterable types. + """ + + @abstractproperty + def iterator_type(self): + """ + The iterator type obtained when calling iter() (explicitly or implicitly). + """ + + +class Sized(Type): + """ + Base class for objects that support len() + """ + + +class ConstSized(Sized): + """ + For types that have a constant size + """ + + @abstractmethod + def __len__(self): + pass + + +class IteratorType(IterableType): + """ + Base class for all iterator types. + Derived classes should implement the *yield_type* attribute. + """ + + def __init__(self, name, **kwargs): + super(IteratorType, self).__init__(name, **kwargs) + + @abstractproperty + def yield_type(self): + """ + The type of values yielded by the iterator. + """ + + # This is a property to avoid recursivity (for pickling) + + @property + def iterator_type(self): + return self + + +class Container(Sized, IterableType): + """ + Base class for container types. + """ + + +class Sequence(Container): + """ + Base class for 1d sequence types. Instances should have the *dtype* + attribute. + """ + + +class MutableSequence(Sequence): + """ + Base class for 1d mutable sequence types. Instances should have the + *dtype* attribute. + """ + + mutable = True + + +class ArrayCompatible(Type): + """ + Type class for Numpy array-compatible objects (typically, objects + exposing an __array__ method). + Derived classes should implement the *as_array* attribute. + """ + + # If overridden by a subclass, it should also implement typing + # for '__array_wrap__' with arguments (input, formal result). + array_priority = 0.0 + + @abstractproperty + def as_array(self): + """ + The equivalent array type, for operations supporting array-compatible + objects (such as ufuncs). + """ + + # For compatibility with types.Array + + @cached_property + def ndim(self): + return self.as_array.ndim + + @cached_property + def layout(self): + return self.as_array.layout + + @cached_property + def dtype(self): + return self.as_array.dtype + + +class Literal(Type): + """Base class for Literal types. + Literal types contain the original Python value in the type. + + A literal type should always be constructed from the `literal(val)` + function. + """ + + # *ctor_map* is a dictionary mapping Python types to Literal subclasses + # for constructing a numba type for a given Python type. + # It is used in `literal(val)` function. + # To add new Literal subclass, register a new mapping to this dict. + ctor_map: ptDict[type, ptType["Literal"]] = {} + + # *_literal_type_cache* is used to cache the numba type of the given value. + _literal_type_cache = None + + def __init__(self, value): + if type(self) is Literal: + raise TypeError( + "Cannot be constructed directly. " + "Use `numba.cuda.types.literal(value)` instead", + ) + self._literal_init(value) + fmt = "Literal[{}]({})" + super(Literal, self).__init__(fmt.format(type(value).__name__, value)) + + def _literal_init(self, value): + self._literal_value = value + # We want to support constants of non-hashable values, therefore + # fall back on the value's id() if necessary. + self._key = get_hashable_key(value) + + @property + def literal_value(self): + return self._literal_value + + @property + def literal_type(self): + if self._literal_type_cache is None: + from numba.cuda import typing + + ctx = typing.Context() + try: + res = ctx.resolve_value_type(self.literal_value) + except ValueError as e: + if "Int value is too large" in str(e): + # If a string literal cannot create an IntegerLiteral + # because of overflow we generate this message. + msg = f"Cannot create literal type. {str(e)}" + raise TypeError(msg) + # Not all literal types have a literal_value that can be + # resolved to a type, for example, LiteralStrKeyDict has a + # literal_value that is a python dict for which there's no + # `typeof` support. + msg = "{} has no attribute 'literal_type'".format(self) + raise AttributeError(msg) + self._literal_type_cache = res + + return self._literal_type_cache + + +class TypeRef(Dummy): + """Reference to a type. + + Used when a type is passed as a value. + """ + + def __init__(self, instance_type): + self.instance_type = instance_type + super(TypeRef, self).__init__("typeref[{}]".format(self.instance_type)) + + @property + def key(self): + return self.instance_type + + +class InitialValue(object): + """ + Used as a mixin for a type will potentially have an initial value that will + be carried in the .initial_value attribute. + """ + + def __init__(self, initial_value): + self._initial_value = initial_value + + @property + def initial_value(self): + return self._initial_value + + +class Poison(Type): + """ + This is the "bottom" type in the type system. It won't unify and it's + unliteral version is Poison of itself. It's advisable for debugging purposes + to call the constructor with the type that's being poisoned (for whatever + reason) but this isn't strictly required. + """ + + def __init__(self, ty): + self.ty = ty + super(Poison, self).__init__(name="Poison<%s>" % ty) + + def __unliteral__(self): + return Poison(self) + + def unify(self, typingctx, other): + return None diff --git a/numba_cuda/numba/cuda/types/cuda_common.py b/numba_cuda/numba/cuda/types/cuda_common.py new file mode 100644 index 000000000..c4970bcb6 --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_common.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Helper classes / mixins for defining types. +""" + +from .abstract import ArrayCompatible, Dummy, IterableType, IteratorType +from numba.core.errors import NumbaTypeError, NumbaValueError + + +class Opaque(Dummy): + """ + A type that is a opaque pointer. + """ + + +class SimpleIterableType(IterableType): + def __init__(self, name, iterator_type): + self._iterator_type = iterator_type + super(SimpleIterableType, self).__init__(name) + + @property + def iterator_type(self): + return self._iterator_type + + +class SimpleIteratorType(IteratorType): + def __init__(self, name, yield_type): + self._yield_type = yield_type + super(SimpleIteratorType, self).__init__(name) + + @property + def yield_type(self): + return self._yield_type + + +class Buffer(IterableType, ArrayCompatible): + """ + Type class for objects providing the buffer protocol. + Derived classes exist for more specific cases. + """ + + mutable = True + slice_is_copy = False + aligned = True + + # CS and FS are not reserved for inner contig but strided + LAYOUTS = frozenset(["C", "F", "CS", "FS", "A"]) + + def __init__(self, dtype, ndim, layout, readonly=False, name=None): + from .misc import unliteral + + if isinstance(dtype, Buffer): + msg = ( + "The dtype of a Buffer type cannot itself be a Buffer type, " + "this is unsupported behaviour." + "\nThe dtype requested for the unsupported Buffer was: {}." + ) + raise NumbaTypeError(msg.format(dtype)) + if layout not in self.LAYOUTS: + raise NumbaValueError("Invalid layout '%s'" % layout) + self.dtype = unliteral(dtype) + self.ndim = ndim + self.layout = layout + if readonly: + self.mutable = False + if name is None: + type_name = self.__class__.__name__.lower() + if readonly: + type_name = "readonly %s" % type_name + name = "%s(%s, %sd, %s)" % (type_name, dtype, ndim, layout) + super(Buffer, self).__init__(name) + + @property + def iterator_type(self): + from .iterators import ArrayIterator + + return ArrayIterator(self) + + @property + def as_array(self): + return self + + def copy(self, dtype=None, ndim=None, layout=None): + if dtype is None: + dtype = self.dtype + if ndim is None: + ndim = self.ndim + if layout is None: + layout = self.layout + return self.__class__( + dtype=dtype, ndim=ndim, layout=layout, readonly=not self.mutable + ) + + @property + def key(self): + return self.dtype, self.ndim, self.layout, self.mutable + + @property + def is_c_contig(self): + return self.layout == "C" or (self.ndim <= 1 and self.layout in "CF") + + @property + def is_f_contig(self): + return self.layout == "F" or (self.ndim <= 1 and self.layout in "CF") + + @property + def is_contig(self): + return self.layout in "CF" diff --git a/numba_cuda/numba/cuda/types/cuda_containers.py b/numba_cuda/numba/cuda/types/cuda_containers.py new file mode 100644 index 000000000..e356d800c --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_containers.py @@ -0,0 +1,971 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from collections.abc import Iterable +from collections.abc import Sequence as pySequence +from types import MappingProxyType + +from .abstract import ( + ConstSized, + Container, + Hashable, + MutableSequence, + Sequence, + Type, + TypeRef, + Literal, + InitialValue, + Poison, +) +from .common import ( + Buffer, + IterableType, + SimpleIterableType, + SimpleIteratorType, +) +from .misc import Undefined, unliteral, Optional, NoneType +from numba.cuda.typeconv import Conversion +from numba.core.errors import TypingError +from numba.cuda import utils + + +class Pair(Type): + """ + A heterogeneous pair. + """ + + def __init__(self, first_type, second_type): + self.first_type = first_type + self.second_type = second_type + name = "pair<%s, %s>" % (first_type, second_type) + super(Pair, self).__init__(name=name) + + @property + def key(self): + return self.first_type, self.second_type + + def unify(self, typingctx, other): + if isinstance(other, Pair): + first = typingctx.unify_pairs(self.first_type, other.first_type) + second = typingctx.unify_pairs(self.second_type, other.second_type) + if first is not None and second is not None: + return Pair(first, second) + + +class BaseContainerIterator(SimpleIteratorType): + """ + Convenience base class for some container iterators. + + Derived classes must implement the *container_class* attribute. + """ + + def __init__(self, container): + assert isinstance(container, self.container_class), container + self.container = container + yield_type = container.dtype + name = "iter(%s)" % container + super(BaseContainerIterator, self).__init__(name, yield_type) + + def unify(self, typingctx, other): + cls = type(self) + if isinstance(other, cls): + container = typingctx.unify_pairs(self.container, other.container) + if container is not None: + return cls(container) + + @property + def key(self): + return self.container + + +class BaseContainerPayload(Type): + """ + Convenience base class for some container payloads. + + Derived classes must implement the *container_class* attribute. + """ + + def __init__(self, container): + assert isinstance(container, self.container_class) + self.container = container + name = "payload(%s)" % container + super(BaseContainerPayload, self).__init__(name) + + @property + def key(self): + return self.container + + +class Bytes(Buffer): + """ + Type class for Python 3.x bytes objects. + """ + + mutable = False + # Actually true but doesn't matter since bytes is immutable + slice_is_copy = False + + +class ByteArray(Buffer): + """ + Type class for bytearray objects. + """ + + slice_is_copy = True + + +class PyArray(Buffer): + """ + Type class for array.array objects. + """ + + slice_is_copy = True + + +class MemoryView(Buffer): + """ + Type class for memoryview objects. + """ + + +def is_homogeneous(*tys): + """Are the types homogeneous?""" + if tys: + first, tys = tys[0], tys[1:] + return not any(t != first for t in tys) + else: + # *tys* is empty. + return False + + +class BaseTuple(ConstSized, Hashable): + """ + The base class for all tuple types (with a known size). + """ + + @classmethod + def from_types(cls, tys, pyclass=None): + """ + Instantiate the right tuple type for the given element types. + """ + if pyclass is not None and pyclass is not tuple: + # A subclass => is it a namedtuple? + assert issubclass(pyclass, tuple) + if hasattr(pyclass, "_asdict"): + tys = tuple(map(unliteral, tys)) + homogeneous = is_homogeneous(*tys) + if homogeneous: + return NamedUniTuple(tys[0], len(tys), pyclass) + else: + return NamedTuple(tys, pyclass) + else: + dtype = utils.unified_function_type(tys) + if dtype is not None: + return UniTuple(dtype, len(tys)) + # non-named tuple + homogeneous = is_homogeneous(*tys) + if homogeneous: + return cls._make_homogeneous_tuple(tys[0], len(tys)) + else: + return cls._make_heterogeneous_tuple(tys) + + @classmethod + def _make_homogeneous_tuple(cls, dtype, count): + return UniTuple(dtype, count) + + @classmethod + def _make_heterogeneous_tuple(cls, tys): + return Tuple(tys) + + +class BaseAnonymousTuple(BaseTuple): + """ + Mixin for non-named tuples. + """ + + def can_convert_to(self, typingctx, other): + """ + Convert this tuple to another one. Note named tuples are rejected. + """ + if not isinstance(other, BaseAnonymousTuple): + return + if len(self) != len(other): + return + if len(self) == 0: + return Conversion.safe + if isinstance(other, BaseTuple): + kinds = [ + typingctx.can_convert(ta, tb) for ta, tb in zip(self, other) + ] + if any(kind is None for kind in kinds): + return + return max(kinds) + + def __unliteral__(self): + return type(self).from_types([unliteral(t) for t in self]) + + +class _HomogeneousTuple(Sequence, BaseTuple): + @property + def iterator_type(self): + return UniTupleIter(self) + + def __getitem__(self, i): + """ + Return element at position i + """ + return self.dtype + + def __iter__(self): + return iter([self.dtype] * self.count) + + def __len__(self): + return self.count + + @property + def types(self): + return (self.dtype,) * self.count + + +class UniTuple(BaseAnonymousTuple, _HomogeneousTuple, Sequence): + """ + Type class for homogeneous tuples. + """ + + def __init__(self, dtype, count): + self.dtype = dtype + self.count = count + name = "%s(%s x %d)" % ( + self.__class__.__name__, + dtype, + count, + ) + super(UniTuple, self).__init__(name) + + @property + def mangling_args(self): + return self.__class__.__name__, (self.dtype, self.count) + + @property + def key(self): + return self.dtype, self.count + + def unify(self, typingctx, other): + """ + Unify UniTuples with their dtype + """ + if isinstance(other, UniTuple) and len(self) == len(other): + dtype = typingctx.unify_pairs(self.dtype, other.dtype) + if dtype is not None: + return UniTuple(dtype=dtype, count=self.count) + + def __unliteral__(self): + return type(self)(dtype=unliteral(self.dtype), count=self.count) + + def __repr__(self): + return f"UniTuple({repr(self.dtype)}, {self.count})" + + +class UniTupleIter(BaseContainerIterator): + """ + Type class for homogeneous tuple iterators. + """ + + container_class = _HomogeneousTuple + + +class _HeterogeneousTuple(BaseTuple): + def __getitem__(self, i): + """ + Return element at position i + """ + return self.types[i] + + def __len__(self): + # Beware: this makes Tuple(()) false-ish + return len(self.types) + + def __iter__(self): + return iter(self.types) + + @staticmethod + def is_types_iterable(types): + # issue 4463 - check if argument 'types' is iterable + if not isinstance(types, Iterable): + raise TypingError("Argument 'types' is not iterable") + + +class UnionType(Type): + def __init__(self, types): + self.types = tuple(sorted(set(types), key=lambda x: x.name)) + name = "Union[{}]".format(",".join(map(str, self.types))) + super(UnionType, self).__init__(name=name) + + def get_type_tag(self, typ): + return self.types.index(typ) + + +class Tuple(BaseAnonymousTuple, _HeterogeneousTuple): + def __new__(cls, types): + t = utils.unified_function_type(types, require_precise=True) + if t is not None: + return UniTuple(dtype=t, count=len(types)) + + _HeterogeneousTuple.is_types_iterable(types) + + if types and all(t == types[0] for t in types[1:]): + return UniTuple(dtype=types[0], count=len(types)) + else: + return object.__new__(Tuple) + + def __init__(self, types): + self.types = tuple(types) + self.count = len(self.types) + self.dtype = UnionType(types) + name = "%s(%s)" % ( + self.__class__.__name__, + ", ".join(str(i) for i in self.types), + ) + super(Tuple, self).__init__(name) + + @property + def mangling_args(self): + return self.__class__.__name__, tuple(t for t in self.types) + + @property + def key(self): + return self.types + + def unify(self, typingctx, other): + """ + Unify elements of Tuples/UniTuples + """ + # Other is UniTuple or Tuple + if isinstance(other, BaseTuple) and len(self) == len(other): + unified = [ + typingctx.unify_pairs(ta, tb) for ta, tb in zip(self, other) + ] + + if all(t is not None for t in unified): + return Tuple(unified) + + def __repr__(self): + return f"Tuple({tuple(ty for ty in self.types)})" + + +class _StarArgTupleMixin: + @classmethod + def _make_homogeneous_tuple(cls, dtype, count): + return StarArgUniTuple(dtype, count) + + @classmethod + def _make_heterogeneous_tuple(cls, tys): + return StarArgTuple(tys) + + +class StarArgTuple(_StarArgTupleMixin, Tuple): + """To distinguish from Tuple() used as argument to a `*args`.""" + + def __new__(cls, types): + _HeterogeneousTuple.is_types_iterable(types) + + if types and all(t == types[0] for t in types[1:]): + return StarArgUniTuple(dtype=types[0], count=len(types)) + else: + return object.__new__(StarArgTuple) + + +class StarArgUniTuple(_StarArgTupleMixin, UniTuple): + """To distinguish from UniTuple() used as argument to a `*args`.""" + + +class BaseNamedTuple(BaseTuple): + pass + + +class NamedUniTuple(_HomogeneousTuple, BaseNamedTuple): + def __init__(self, dtype, count, cls): + self.dtype = dtype + self.count = count + self.fields = tuple(cls._fields) + self.instance_class = cls + name = "%s(%s x %d)" % (cls.__name__, dtype, count) + super(NamedUniTuple, self).__init__(name) + + @property + def iterator_type(self): + return UniTupleIter(self) + + @property + def key(self): + return self.instance_class, self.dtype, self.count + + +class NamedTuple(_HeterogeneousTuple, BaseNamedTuple): + def __init__(self, types, cls): + _HeterogeneousTuple.is_types_iterable(types) + + self.types = tuple(types) + self.count = len(self.types) + self.fields = tuple(cls._fields) + self.instance_class = cls + name = "%s(%s)" % (cls.__name__, ", ".join(str(i) for i in self.types)) + super(NamedTuple, self).__init__(name) + + @property + def key(self): + return self.instance_class, self.types + + +class List(MutableSequence, InitialValue): + """ + Type class for (arbitrary-sized) homogeneous lists. + """ + + def __init__(self, dtype, reflected=False, initial_value=None): + dtype = unliteral(dtype) + self.dtype = dtype + self.reflected = reflected + cls_name = "reflected list" if reflected else "list" + name = "%s(%s)" % (cls_name, self.dtype, initial_value) + super(List, self).__init__(name=name) + InitialValue.__init__(self, initial_value) + + def copy(self, dtype=None, reflected=None): + if dtype is None: + dtype = self.dtype + if reflected is None: + reflected = self.reflected + return List(dtype, reflected, self.initial_value) + + def unify(self, typingctx, other): + if isinstance(other, List): + dtype = typingctx.unify_pairs(self.dtype, other.dtype) + reflected = self.reflected or other.reflected + if dtype is not None: + siv = self.initial_value + oiv = other.initial_value + if siv is not None and oiv is not None: + use = siv + if siv is None: + use = oiv + return List(dtype, reflected, use) + else: + return List(dtype, reflected) + + @property + def key(self): + return self.dtype, self.reflected, str(self.initial_value) + + @property + def iterator_type(self): + return ListIter(self) + + def is_precise(self): + return self.dtype.is_precise() + + def __getitem__(self, args): + """ + Overrides the default __getitem__ from Type. + """ + return self.dtype + + def __unliteral__(self): + return List(self.dtype, reflected=self.reflected, initial_value=None) + + def __repr__(self): + return f"List({self.dtype}, {self.reflected})" + + +class LiteralList(Literal, ConstSized, Hashable): + """A heterogeneous immutable list (basically a tuple with list semantics).""" + + mutable = False + + def __init__(self, literal_value): + self.is_types_iterable(literal_value) + self._literal_init(list(literal_value)) + self.types = tuple(literal_value) + self.count = len(self.types) + self.name = "LiteralList({})".format(literal_value) + + def __getitem__(self, i): + """ + Return element at position i + """ + return self.types[i] + + def __len__(self): + return len(self.types) + + def __iter__(self): + return iter(self.types) + + @classmethod + def from_types(cls, tys): + return LiteralList(tys) + + @staticmethod + def is_types_iterable(types): + if not isinstance(types, Iterable): + raise TypingError("Argument 'types' is not iterable") + + @property + def iterator_type(self): + return ListIter(self) + + def __unliteral__(self): + return Poison(self) + + def unify(self, typingctx, other): + """ + Unify this with the *other* one. + """ + if isinstance(other, LiteralList) and self.count == other.count: + tys = [] + for i1, i2 in zip(self.types, other.types): + tys.append(typingctx.unify_pairs(i1, i2)) + if all(tys): + return LiteralList(tys) + + +class ListIter(BaseContainerIterator): + """ + Type class for list iterators. + """ + + container_class = List + + +class ListPayload(BaseContainerPayload): + """ + Internal type class for the dynamically-allocated payload of a list. + """ + + container_class = List + + +class Set(Container): + """ + Type class for homogeneous sets. + """ + + mutable = True + + def __init__(self, dtype, reflected=False): + assert isinstance(dtype, (Hashable, Undefined)) + self.dtype = dtype + self.reflected = reflected + cls_name = "reflected set" if reflected else "set" + name = "%s(%s)" % (cls_name, self.dtype) + super(Set, self).__init__(name=name) + + @property + def key(self): + return self.dtype, self.reflected + + @property + def iterator_type(self): + return SetIter(self) + + def is_precise(self): + return self.dtype.is_precise() + + def copy(self, dtype=None, reflected=None): + if dtype is None: + dtype = self.dtype + if reflected is None: + reflected = self.reflected + return Set(dtype, reflected) + + def unify(self, typingctx, other): + if isinstance(other, Set): + dtype = typingctx.unify_pairs(self.dtype, other.dtype) + reflected = self.reflected or other.reflected + if dtype is not None: + return Set(dtype, reflected) + + def __repr__(self): + return f"Set({self.dtype}, {self.reflected})" + + +class SetIter(BaseContainerIterator): + """ + Type class for set iterators. + """ + + container_class = Set + + +class SetPayload(BaseContainerPayload): + """ + Internal type class for the dynamically-allocated payload of a set. + """ + + container_class = Set + + +class SetEntry(Type): + """ + Internal type class for the entries of a Set's hash table. + """ + + def __init__(self, set_type): + self.set_type = set_type + name = "entry(%s)" % set_type + super(SetEntry, self).__init__(name) + + @property + def key(self): + return self.set_type + + +class ListType(IterableType): + """List type""" + + mutable = True + + def __init__(self, itemty): + assert not isinstance(itemty, TypeRef) + itemty = unliteral(itemty) + if isinstance(itemty, Optional): + fmt = "List.item_type cannot be of type {}" + raise TypingError(fmt.format(itemty)) + # FIXME: _sentry_forbidden_types(itemty) + self.item_type = itemty + self.dtype = itemty + name = "{}[{}]".format( + self.__class__.__name__, + itemty, + ) + super(ListType, self).__init__(name) + + @property + def key(self): + return self.item_type + + def is_precise(self): + return not isinstance(self.item_type, Undefined) + + @property + def iterator_type(self): + return ListTypeIterableType(self).iterator_type + + @classmethod + def refine(cls, itemty): + """Refine to a precise list type""" + res = cls(itemty) + assert res.is_precise() + return res + + def unify(self, typingctx, other): + """ + Unify this with the *other* list. + """ + # If other is list + if isinstance(other, ListType): + if not other.is_precise(): + return self + + def __repr__(self): + return f"ListType({self.item_type})" + + +class ListTypeIterableType(SimpleIterableType): + """List iterable type""" + + def __init__(self, parent): + assert isinstance(parent, ListType) + self.parent = parent + self.yield_type = self.parent.item_type + name = "list[{}]".format(self.parent.name) + iterator_type = ListTypeIteratorType(self) + super(ListTypeIterableType, self).__init__(name, iterator_type) + + +class ListTypeIteratorType(SimpleIteratorType): + def __init__(self, iterable): + self.parent = iterable.parent + self.iterable = iterable + yield_type = iterable.yield_type + name = "iter[{}->{}]".format(iterable.parent, yield_type) + super(ListTypeIteratorType, self).__init__(name, yield_type) + + +def _sentry_forbidden_types(key, value): + # Forbids List and Set for now + if isinstance(key, (Set, List)): + raise TypingError("{} as key is forbidden".format(key)) + if isinstance(value, (Set, List)): + raise TypingError("{} as value is forbidden".format(value)) + + +class DictType(IterableType, InitialValue): + """Dictionary type""" + + def __init__(self, keyty, valty, initial_value=None): + assert not isinstance(keyty, TypeRef) + assert not isinstance(valty, TypeRef) + keyty = unliteral(keyty) + valty = unliteral(valty) + if isinstance(keyty, (Optional, NoneType)): + fmt = "Dict.key_type cannot be of type {}" + raise TypingError(fmt.format(keyty)) + if isinstance(valty, (Optional, NoneType)): + fmt = "Dict.value_type cannot be of type {}" + raise TypingError(fmt.format(valty)) + _sentry_forbidden_types(keyty, valty) + self.key_type = keyty + self.value_type = valty + self.keyvalue_type = Tuple([keyty, valty]) + name = "{}[{},{}]".format( + self.__class__.__name__, keyty, valty, initial_value + ) + super(DictType, self).__init__(name) + InitialValue.__init__(self, initial_value) + + def is_precise(self): + return not any( + ( + isinstance(self.key_type, Undefined), + isinstance(self.value_type, Undefined), + ) + ) + + @property + def iterator_type(self): + return DictKeysIterableType(self).iterator_type + + @classmethod + def refine(cls, keyty, valty): + """Refine to a precise dictionary type""" + res = cls(keyty, valty) + assert res.is_precise() + return res + + def unify(self, typingctx, other): + """ + Unify this with the *other* dictionary. + """ + # If other is dict + if isinstance(other, DictType): + if not other.is_precise(): + return self + else: + ukey_type = self.key_type == other.key_type + uvalue_type = self.value_type == other.value_type + if ukey_type and uvalue_type: + siv = self.initial_value + oiv = other.initial_value + siv_none = siv is None + oiv_none = oiv is None + if not siv_none and not oiv_none: + if siv == oiv: + return DictType( + self.key_type, other.value_type, siv + ) + return DictType(self.key_type, other.value_type) + + @property + def key(self): + return self.key_type, self.value_type, str(self.initial_value) + + def __unliteral__(self): + return DictType(self.key_type, self.value_type) + + def __repr__(self): + return f"DictType({self.key_type}, {self.value_type})" + + +class LiteralStrKeyDict(Literal, ConstSized, Hashable): + """A Dictionary of string keys to heterogeneous values (basically a + namedtuple with dict semantics). + """ + + class FakeNamedTuple(pySequence): + # This is namedtuple-like and is a workaround for #6518 and #7416. + # This has the couple of namedtuple properties that are used by Numba's + # internals but avoids use of an actual namedtuple as it cannot have + # numeric field names, i.e. `namedtuple('foo', '0 1')` is invalid. + def __init__(self, name, keys): + self.__name__ = name + self._fields = tuple(keys) + super(LiteralStrKeyDict.FakeNamedTuple, self).__init__() + + def __len__(self): + return len(self._fields) + + def __getitem__(self, key): + return self._fields[key] + + mutable = False + + def __init__(self, literal_value, value_index=None): + self._literal_init(literal_value) + self.value_index = value_index + strkeys = [x.literal_value for x in literal_value.keys()] + self.tuple_ty = self.FakeNamedTuple("_ntclazz", strkeys) + tys = [x for x in literal_value.values()] + self.types = tuple(tys) + self.count = len(self.types) + self.fields = tuple(self.tuple_ty._fields) + self.instance_class = self.tuple_ty + self.name = "LiteralStrKey[Dict]({})".format(literal_value) + + def __unliteral__(self): + return Poison(self) + + def unify(self, typingctx, other): + """ + Unify this with the *other* one. + """ + if isinstance(other, LiteralStrKeyDict): + tys = [] + for (k1, v1), (k2, v2) in zip( + self.literal_value.items(), other.literal_value.items() + ): + if k1 != k2: # keys must be same + break + tys.append(typingctx.unify_pairs(v1, v2)) + else: + if all(tys): + d = {k: v for k, v in zip(self.literal_value.keys(), tys)} + return LiteralStrKeyDict(d) + + def __len__(self): + return len(self.types) + + def __iter__(self): + return iter(self.types) + + @property + def key(self): + # use the namedtuple fields not the namedtuple itself as it's created + # locally in the ctor and comparison would always be False. + return self.tuple_ty._fields, self.types, str(self.literal_value) + + +class DictItemsIterableType(SimpleIterableType): + """Dictionary iterable type for .items()""" + + def __init__(self, parent): + assert isinstance(parent, DictType) + self.parent = parent + self.yield_type = self.parent.keyvalue_type + name = "items[{}]".format(self.parent.name) + self.name = name + iterator_type = DictIteratorType(self) + super(DictItemsIterableType, self).__init__(name, iterator_type) + + +class DictKeysIterableType(SimpleIterableType): + """Dictionary iterable type for .keys()""" + + def __init__(self, parent): + assert isinstance(parent, DictType) + self.parent = parent + self.yield_type = self.parent.key_type + name = "keys[{}]".format(self.parent.name) + self.name = name + iterator_type = DictIteratorType(self) + super(DictKeysIterableType, self).__init__(name, iterator_type) + + +class DictValuesIterableType(SimpleIterableType): + """Dictionary iterable type for .values()""" + + def __init__(self, parent): + assert isinstance(parent, DictType) + self.parent = parent + self.yield_type = self.parent.value_type + name = "values[{}]".format(self.parent.name) + self.name = name + iterator_type = DictIteratorType(self) + super(DictValuesIterableType, self).__init__(name, iterator_type) + + +class DictIteratorType(SimpleIteratorType): + def __init__(self, iterable): + self.parent = iterable.parent + self.iterable = iterable + yield_type = iterable.yield_type + name = "iter[{}->{}],{}".format( + iterable.parent, yield_type, iterable.name + ) + super(DictIteratorType, self).__init__(name, yield_type) + + +class StructRef(Type): + """A mutable struct.""" + + def __init__(self, fields): + """ + Parameters + ---------- + fields : Sequence + A sequence of field descriptions, which is a 2-tuple-like object + containing `(name, type)`, where `name` is a `str` for the field + name, and `type` is a numba type for the field type. + """ + + def check_field_pair(fieldpair): + name, typ = fieldpair + if not isinstance(name, str): + msg = "expecting a str for field name" + raise ValueError(msg) + if not isinstance(typ, Type): + msg = "expecting a Numba Type for field type" + raise ValueError(msg) + return name, typ + + fields = tuple(map(check_field_pair, fields)) + self._fields = tuple( + map(check_field_pair, self.preprocess_fields(fields)) + ) + self._typename = self.__class__.__qualname__ + name = f"numba.{self._typename}{self._fields}" + super().__init__(name=name) + + def preprocess_fields(self, fields): + """Subclasses can override this to do additional clean up on fields. + + The default is an identity function. + + Parameters: + ----------- + fields : Sequence[Tuple[str, Type]] + """ + return fields + + @property + def field_dict(self): + """Return an immutable mapping for the field names and their + corresponding types. + """ + return MappingProxyType(dict(self._fields)) + + def get_data_type(self): + """Get the payload type for the actual underlying structure referred + to by this struct reference. + + See also: `ClassInstanceType.get_data_type` + """ + return StructRefPayload( + typename=self.__class__.__name__, + fields=self._fields, + ) + + +class StructRefPayload(Type): + """The type of the payload of a mutable struct.""" + + mutable = True + + def __init__(self, typename, fields): + self._typename = typename + self._fields = tuple(fields) + super().__init__(name=f"numba.{typename}{self._fields}.payload") + + @property + def field_dict(self): + return MappingProxyType(dict(self._fields)) diff --git a/numba_cuda/numba/cuda/types/cuda_function_type.py b/numba_cuda/numba/cuda/types/cuda_function_type.py new file mode 100644 index 000000000..f4f239226 --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_function_type.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + + +__all__ = [ + "FunctionType", + "UndefinedFunctionType", + "FunctionPrototype", + "WrapperAddressProtocol", + "CompileResultWAP", +] + +from abc import ABC, abstractmethod +from .abstract import Type +from .. import types +from numba.core import errors + + +class FunctionType(Type): + """ + First-class function type. + """ + + cconv = None + + def __init__(self, signature): + sig = types.unliteral(signature) + self.nargs = len(sig.args) + self.signature = sig + self.ftype = FunctionPrototype(sig.return_type, sig.args) + self._key = self.ftype.key + + @property + def key(self): + return self._key + + @property + def name(self): + return f"{type(self).__name__}[{self.key}]" + + def is_precise(self): + return self.signature.is_precise() + + def get_precise(self): + return self + + def dump(self, tab=""): + print(f"{tab}DUMP {type(self).__name__}[code={self._code}]") + self.signature.dump(tab=tab + " ") + print(f"{tab}END DUMP {type(self).__name__}") + + def get_call_type(self, context, args, kws): + from numba.cuda import typing + + if kws: + # First-class functions carry only the type signature + # information and function address value. So, it is not + # possible to determine the positional arguments + # corresponding to the keyword arguments in the call + # expression. For instance, the definition of the + # first-class function may not use the same argument names + # that the caller assumes. [numba/issues/5540]. + raise errors.UnsupportedError( + "first-class function call cannot use keyword arguments" + ) + + if len(args) != self.nargs: + raise ValueError( + f"mismatch of arguments number: {len(args)} vs {self.nargs}" + ) + + sig = self.signature + + # check that arguments types match with the signature types exactly + for atype, sig_atype in zip(args, sig.args): + atype = types.unliteral(atype) + if sig_atype.is_precise(): + conv_score = context.context.can_convert( + fromty=atype, toty=sig_atype + ) + if ( + conv_score is None + or conv_score > typing.context.Conversion.safe + ): + raise ValueError( + f"mismatch of argument types: {atype} vs {sig_atype}" + ) + + if not sig.is_precise(): + for dispatcher in self.dispatchers: + template, pysig, args, kws = dispatcher.get_call_template( + args, kws + ) + new_sig = template(context.context).apply(args, kws) + return types.unliteral(new_sig) + + return sig + + def check_signature(self, other_sig): + """Return True if signatures match (up to being precise).""" + sig = self.signature + return self.nargs == len(other_sig.args) and ( + sig == other_sig or not sig.is_precise() + ) + + def unify(self, context, other): + if ( + isinstance(other, types.UndefinedFunctionType) + and self.nargs == other.nargs + ): + return self + + +class UndefinedFunctionType(FunctionType): + _counter = 0 + + def __init__(self, nargs, dispatchers): + from numba.cuda.typing.templates import Signature + + signature = Signature( + types.undefined, (types.undefined,) * nargs, recvr=None + ) + + super(UndefinedFunctionType, self).__init__(signature) + + self.dispatchers = dispatchers + + # make the undefined function type instance unique + type(self)._counter += 1 + self._key += str(type(self)._counter) + + def get_precise(self): + """ + Return precise function type if possible. + """ + for dispatcher in self.dispatchers: + for cres in dispatcher.overloads.values(): + sig = types.unliteral(cres.signature) + return FunctionType(sig) + return self + + +class FunctionPrototype(Type): + """ + Represents the prototype of a first-class function type. + Used internally. + """ + + cconv = None + + def __init__(self, rtype, atypes): + self.rtype = rtype + self.atypes = tuple(atypes) + + assert isinstance(rtype, Type), rtype + lst = [] + for atype in self.atypes: + assert isinstance(atype, Type), atype + lst.append(atype.name) + name = "%s(%s)" % (rtype, ", ".join(lst)) + + super(FunctionPrototype, self).__init__(name) + + @property + def key(self): + return self.name + + +class WrapperAddressProtocol(ABC): + """Base class for Wrapper Address Protocol. + + Objects that inherit from the WrapperAddressProtocol can be passed + as arguments to Numba jit compiled functions where it can be used + as first-class functions. As a minimum, the derived types must + implement two methods ``__wrapper_address__`` and ``signature``. + """ + + @abstractmethod + def __wrapper_address__(self): + """Return the address of a first-class function. + + Returns + ------- + addr : int + """ + + @abstractmethod + def signature(self): + """Return the signature of a first-class function. + + Returns + ------- + sig : Signature + The returned Signature instance represents the type of a + first-class function that the given WrapperAddressProtocol + instance represents. + """ + + +class CompileResultWAP(WrapperAddressProtocol): + """Wrapper of dispatcher instance compilation result to turn it a + first-class function. + """ + + def __init__(self, cres): + """ + Parameters + ---------- + cres : CompileResult + Specify compilation result of a Numba jit-decorated function + (that is a value of dispatcher instance ``overloads`` + attribute) + """ + self.cres = cres + name = getattr(cres.fndesc, "llvm_cfunc_wrapper_name") + self.address = cres.library.get_pointer_to_function(name) + + def dump(self, tab=""): + print(f"{tab}DUMP {type(self).__name__} [addr={self.address}]") + self.cres.signature.dump(tab=tab + " ") + print(f"{tab}END DUMP {type(self).__name__}") + + def __wrapper_address__(self): + return self.address + + def signature(self): + return self.cres.signature + + def __call__(self, *args, **kwargs): # used in object-mode + return self.cres.entry_point(*args, **kwargs) diff --git a/numba_cuda/numba/cuda/types/cuda_functions.py b/numba_cuda/numba/cuda/types/cuda_functions.py new file mode 100644 index 000000000..594a89e84 --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_functions.py @@ -0,0 +1,804 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import traceback +from collections import namedtuple, defaultdict +import itertools +import logging +import textwrap +from shutil import get_terminal_size + +from .abstract import Callable, DTypeSpec, Dummy, Literal, Type, weakref +from .common import Opaque +from .misc import unliteral +from numba.core import errors +from numba.cuda import utils, types, config +from numba.cuda.typeconv import Conversion + +_logger = logging.getLogger(__name__) + + +# terminal color markup +_termcolor = errors.termcolor() + +_FAILURE = namedtuple("_FAILURE", "template matched error literal") + +_termwidth = get_terminal_size().columns + + +# pull out the lead line as unit tests often use this +_header_lead = "No implementation of function" +_header_template = ( + _header_lead + " {the_function} found for signature:\n \n " + ">>> {fname}({signature})\n \nThere are {ncandidates} " + "candidate implementations:" +) + +_reason_template = """ +" - Of which {nmatches} did not match due to:\n +""" + + +def _wrapper(tmp, indent=0): + return textwrap.indent(tmp, " " * indent, lambda line: True) + + +_overload_template = ( + "- Of which {nduplicates} did not match due to:\n" + "{kind} {inof} function '{function}': File: {file}: " + "Line {line}.\n With argument(s): '({args})':" +) + + +_err_reasons = { + "specific_error": "Rejected as the implementation raised a " + "specific error:\n{}" +} + + +def _bt_as_lines(bt): + """ + Converts a backtrace into a list of lines, squashes it a bit on the way. + """ + return [y for y in itertools.chain(*[x.split("\n") for x in bt]) if y] + + +def argsnkwargs_to_str(args, kwargs): + buf = [str(a) for a in tuple(args)] + buf.extend(["{}={}".format(k, v) for k, v in kwargs.items()]) + return ", ".join(buf) + + +class _ResolutionFailures(object): + """Collect and format function resolution failures.""" + + def __init__(self, context, function_type, args, kwargs, depth=0): + self._context = context + self._function_type = function_type + self._args = args + self._kwargs = kwargs + self._failures = defaultdict(list) + self._depth = depth + self._max_depth = 5 + self._scale = 2 + + def __len__(self): + return len(self._failures) + + def add_error(self, calltemplate, matched, error, literal): + """ + Args + ---- + calltemplate : CallTemplate + error : Exception or str + Error message + """ + isexc = isinstance(error, Exception) + errclazz = "%s: " % type(error).__name__ if isexc else "" + + key = "{}{}".format(errclazz, str(error)) + self._failures[key].append( + _FAILURE(calltemplate, matched, error, literal) + ) + + def format(self): + """Return a formatted error message from all the gathered errors.""" + indent = " " * self._scale + argstr = argsnkwargs_to_str(self._args, self._kwargs) + ncandidates = sum([len(x) for x in self._failures.values()]) + + # sort out a display name for the function + tykey = self._function_type.typing_key + # most things have __name__ + fname = getattr(tykey, "__name__", None) + is_external_fn_ptr = isinstance( + self._function_type, ExternalFunctionPointer + ) + + if fname is None: + if is_external_fn_ptr: + fname = "ExternalFunctionPointer" + else: + fname = "" + + msgbuf = [ + _header_template.format( + the_function=self._function_type, + fname=fname, + signature=argstr, + ncandidates=ncandidates, + ) + ] + nolitargs = tuple([unliteral(a) for a in self._args]) + nolitkwargs = {k: unliteral(v) for k, v in self._kwargs.items()} + nolitargstr = argsnkwargs_to_str(nolitargs, nolitkwargs) + + # depth could potentially get massive, so limit it. + ldepth = min(max(self._depth, 0), self._max_depth) + + def template_info(tp): + src_info = tp.get_template_info() + unknown = "unknown" + source_name = src_info.get("name", unknown) + source_file = src_info.get("filename", unknown) + source_lines = src_info.get("lines", unknown) + source_kind = src_info.get("kind", "Unknown template") + return source_name, source_file, source_lines, source_kind + + for i, (k, err_list) in enumerate(self._failures.items()): + err = err_list[0] + nduplicates = len(err_list) + template, error = err.template, err.error + ifo = template_info(template) + source_name, source_file, source_lines, source_kind = ifo + largstr = argstr if err.literal else nolitargstr + + if err.error == "No match.": + err_dict = defaultdict(set) + for errs in err_list: + err_dict[errs.template].add(errs.literal) + # if there's just one template, and it's erroring on + # literal/nonliteral be specific + if len(err_dict) == 1: + template = [_ for _ in err_dict.keys()][0] + source_name, source_file, source_lines, source_kind = ( + template_info(template) + ) + source_lines = source_lines[0] + else: + source_file = "" + source_lines = "N/A" + + msgbuf.append( + _termcolor.errmsg( + _wrapper( + _overload_template.format( + nduplicates=nduplicates, + kind=source_kind.title(), + function=fname, + inof="of", + file=source_file, + line=source_lines, + args=largstr, + ), + ldepth + 1, + ) + ) + ) + msgbuf.append( + _termcolor.highlight(_wrapper(err.error, ldepth + 2)) + ) + else: + # There was at least one match in this failure class, but it + # failed for a specific reason try and report this. + msgbuf.append( + _termcolor.errmsg( + _wrapper( + _overload_template.format( + nduplicates=nduplicates, + kind=source_kind.title(), + function=source_name, + inof="in", + file=source_file, + line=source_lines[0], + args=largstr, + ), + ldepth + 1, + ) + ) + ) + + if isinstance(error, BaseException): + reason = indent + self.format_error(error) + errstr = _err_reasons["specific_error"].format(reason) + else: + errstr = error + # if you are a developer, show the back traces + if config.DEVELOPER_MODE: + if isinstance(error, BaseException): + # if the error is an actual exception instance, trace it + bt = traceback.format_exception( + type(error), error, error.__traceback__ + ) + else: + bt = [""] + bt_as_lines = _bt_as_lines(bt) + nd2indent = "\n{}".format(2 * indent) + errstr += _termcolor.reset( + nd2indent + nd2indent.join(bt_as_lines) + ) + msgbuf.append( + _termcolor.highlight(_wrapper(errstr, ldepth + 2)) + ) + loc = self.get_loc(template, error) + if loc: + msgbuf.append("{}raised from {}".format(indent, loc)) + + # the commented bit rewraps each block, may not be helpful?! + return _wrapper("\n".join(msgbuf) + "\n") # , self._scale * ldepth) + + def format_error(self, error): + """Format error message or exception""" + if isinstance(error, Exception): + return "{}: {}".format(type(error).__name__, error) + else: + return "{}".format(error) + + def get_loc(self, classtemplate, error): + """Get source location information from the error message.""" + if isinstance(error, Exception) and hasattr(error, "__traceback__"): + # traceback is unavailable in py2 + frame = traceback.extract_tb(error.__traceback__)[-1] + return "{}:{}".format(frame[0], frame[1]) + + def raise_error(self): + for faillist in self._failures.values(): + for fail in faillist: + if isinstance(fail.error, errors.ForceLiteralArg): + raise fail.error + raise errors.TypingError(self.format()) + + +def _unlit_non_poison(ty): + """Apply unliteral(ty) and raise a TypingError if type is Poison.""" + out = unliteral(ty) + if isinstance(out, types.Poison): + m = f"Poison type used in arguments; got {out}" + raise errors.TypingError(m) + return out + + +class BaseFunction(Callable): + """ + Base type class for some function types. + """ + + def __init__(self, template): + if isinstance(template, (list, tuple)): + self.templates = tuple(template) + keys = set(temp.key for temp in self.templates) + if len(keys) != 1: + raise ValueError("incompatible templates: keys = %s" % (keys,)) + (self.typing_key,) = keys + else: + self.templates = (template,) + self.typing_key = template.key + self._impl_keys = {} + name = "%s(%s)" % (self.__class__.__name__, self.typing_key) + self._depth = 0 + super(BaseFunction, self).__init__(name) + + @property + def key(self): + return self.typing_key, self.templates + + def augment(self, other): + """ + Augment this function type with the other function types' templates, + so as to support more input types. + """ + if type(other) is type(self) and other.typing_key == self.typing_key: + return type(self)(self.templates + other.templates) + + def get_impl_key(self, sig): + """ + Get the implementation key (used by the target context) for the + given signature. + """ + return self._impl_keys[sig.args] + + def get_call_type(self, context, args, kws): + prefer_lit = [True, False] # old behavior preferring literal + prefer_not = [False, True] # new behavior preferring non-literal + failures = _ResolutionFailures( + context, self, args, kws, depth=self._depth + ) + + # get the order in which to try templates + from numba.core.target_extension import ( + get_local_target, + ) # circular + + target_hw = get_local_target(context) + order = utils.order_by_target_specificity( + target_hw, self.templates, fnkey=self.key[0] + ) + + self._depth += 1 + + for temp_cls in order: + temp = temp_cls(context) + # The template can override the default and prefer literal args + choice = prefer_lit if temp.prefer_literal else prefer_not + for uselit in choice: + try: + if uselit: + sig = temp.apply(args, kws) + else: + nolitargs = tuple([_unlit_non_poison(a) for a in args]) + nolitkws = { + k: _unlit_non_poison(v) for k, v in kws.items() + } + sig = temp.apply(nolitargs, nolitkws) + except Exception as e: + if not isinstance(e, errors.NumbaError): + raise e + sig = None + failures.add_error(temp, False, e, uselit) + else: + if sig is not None: + self._impl_keys[sig.args] = temp.get_impl_key(sig) + self._depth -= 1 + return sig + else: + registered_sigs = getattr(temp, "cases", None) + if registered_sigs is not None: + msg = "No match for registered cases:\n%s" + msg = msg % "\n".join( + " * {}".format(x) for x in registered_sigs + ) + else: + msg = "No match." + failures.add_error(temp, True, msg, uselit) + + failures.raise_error() + + def get_call_signatures(self): + sigs = [] + is_param = False + for temp in self.templates: + sigs += getattr(temp, "cases", []) + is_param = is_param or hasattr(temp, "generic") + return sigs, is_param + + +class Function(BaseFunction, Opaque): + """ + Type class for builtin functions implemented by Numba. + """ + + +class BoundFunction(Callable, Opaque): + """ + A function with an implicit first argument (denoted as *this* below). + """ + + def __init__(self, template, this): + # Create a derived template with an attribute *this* + newcls = type( + template.__name__ + "." + str(this), (template,), dict(this=this) + ) + self.template = newcls + self.typing_key = self.template.key + self.this = this + name = "%s(%s for %s)" % ( + self.__class__.__name__, + self.typing_key, + self.this, + ) + super(BoundFunction, self).__init__(name) + + def unify(self, typingctx, other): + if ( + isinstance(other, BoundFunction) + and self.typing_key == other.typing_key + ): + this = typingctx.unify_pairs(self.this, other.this) + if this is not None: + # XXX is it right that both template instances are distinct? + return self.copy(this=this) + + def copy(self, this): + return type(self)(self.template, this) + + @property + def key(self): + # FIXME: With target-overload, the MethodTemplate can change depending + # on the target. + unique_impl = getattr(self.template, "_overload_func", None) + return self.typing_key, self.this, unique_impl + + def get_impl_key(self, sig): + """ + Get the implementation key (used by the target context) for the + given signature. + """ + return self.typing_key + + def get_call_type(self, context, args, kws): + template = self.template(context) + literal_e = None + nonliteral_e = None + out = None + + choice = [True, False] if template.prefer_literal else [False, True] + for uselit in choice: + if uselit: + # Try with Literal + try: + out = template.apply(args, kws) + except Exception as exc: + if not isinstance(exc, errors.NumbaError): + raise exc + if isinstance(exc, errors.ForceLiteralArg): + raise exc + literal_e = exc + out = None + else: + break + else: + # if the unliteral_args and unliteral_kws are the same as the + # literal ones, set up to not bother retrying + unliteral_args = tuple([_unlit_non_poison(a) for a in args]) + unliteral_kws = { + k: _unlit_non_poison(v) for k, v in kws.items() + } + skip = unliteral_args == args and kws == unliteral_kws + + # If the above template application failed and the non-literal + # args are different to the literal ones, try again with + # literals rewritten as non-literals + if not skip and out is None: + try: + out = template.apply(unliteral_args, unliteral_kws) + except Exception as exc: + if isinstance(exc, errors.ForceLiteralArg): + if template.prefer_literal: + # For template that prefers literal types, + # reaching here means that the literal types + # have failed typing as well. + raise exc + nonliteral_e = exc + else: + break + + if out is None and (nonliteral_e is not None or literal_e is not None): + header = "- Resolution failure for {} arguments:\n{}\n" + tmplt = _termcolor.highlight(header) + if config.DEVELOPER_MODE: + indent = " " * 4 + + def add_bt(error): + if isinstance(error, BaseException): + # if the error is an actual exception instance, trace it + bt = traceback.format_exception( + type(error), error, error.__traceback__ + ) + else: + bt = [""] + nd2indent = "\n{}".format(2 * indent) + errstr = _termcolor.reset( + nd2indent + nd2indent.join(_bt_as_lines(bt)) + ) + return _termcolor.reset(errstr) + else: + add_bt = lambda X: "" + + def nested_msg(literalness, e): + estr = str(e) + estr = estr if estr else (str(repr(e)) + add_bt(e)) + new_e = errors.TypingError(textwrap.dedent(estr)) + return tmplt.format(literalness, str(new_e)) + + raise errors.TypingError( + nested_msg("literal", literal_e) + + nested_msg("non-literal", nonliteral_e) + ) + return out + + def get_call_signatures(self): + sigs = getattr(self.template, "cases", []) + is_param = hasattr(self.template, "generic") + return sigs, is_param + + +class MakeFunctionLiteral(Literal, Opaque): + pass + + +class _PickleableWeakRef(weakref.ref): + """ + Allow a weakref to be pickled. + + Note that if the object referred to is not kept alive elsewhere in the + pickle, the weakref will immediately expire after being constructed. + """ + + def __getnewargs__(self): + obj = self() + if obj is None: + raise ReferenceError("underlying object has vanished") + return (obj,) + + +class WeakType(Type): + """ + Base class for types parametered by a mortal object, to which only + a weak reference is kept. + """ + + def _store_object(self, obj): + self._wr = _PickleableWeakRef(obj) + + def _get_object(self): + obj = self._wr() + if obj is None: + raise ReferenceError("underlying object has vanished") + return obj + + @property + def key(self): + return self._wr + + def __eq__(self, other): + if type(self) is type(other): + obj = self._wr() + return obj is not None and obj is other._wr() + return NotImplemented + + def __hash__(self): + return Type.__hash__(self) + + +class Dispatcher(WeakType, Callable, Dummy): + """ + Type class for @jit-compiled functions. + """ + + def __init__(self, dispatcher): + self._store_object(dispatcher) + super(Dispatcher, self).__init__("type(%s)" % dispatcher) + + def dump(self, tab=""): + print( + ( + f"{tab}DUMP {type(self).__name__}[code={self._code}, " + f"name={self.name}]" + ) + ) + self.dispatcher.dump(tab=tab + " ") + print(f"{tab}END DUMP") + + def get_call_type(self, context, args, kws): + """ + Resolve a call to this dispatcher using the given argument types. + A signature returned and it is ensured that a compiled specialization + is available for it. + """ + template, pysig, args, kws = self.dispatcher.get_call_template( + args, kws + ) + sig = template(context).apply(args, kws) + if sig: + sig = sig.replace(pysig=pysig) + return sig + + def get_call_signatures(self): + sigs = self.dispatcher.nopython_signatures + return sigs, True + + @property + def dispatcher(self): + """ + A strong reference to the underlying numba.dispatcher.Dispatcher + instance. + """ + return self._get_object() + + def get_overload(self, sig): + """ + Get the compiled overload for the given signature. + """ + return self.dispatcher.get_overload(sig.args) + + def get_impl_key(self, sig): + """ + Get the implementation key for the given signature. + """ + return self.get_overload(sig) + + def unify(self, context, other): + return utils.unified_function_type((self, other), require_precise=False) + + def can_convert_to(self, typingctx, other): + if isinstance(other, types.FunctionType): + try: + self.dispatcher.get_compile_result(other.signature) + except errors.NumbaError: + return None + else: + return Conversion.safe + + +class ObjModeDispatcher(Dispatcher): + """Dispatcher subclass that enters objectmode function.""" + + pass + + +class ExternalFunctionPointer(BaseFunction): + """ + A pointer to a native function (e.g. exported via ctypes or cffi). + *get_pointer* is a Python function taking an object + and returning the raw pointer value as an int. + """ + + def __init__(self, sig, get_pointer, cconv=None): + from numba.cuda.typing.templates import ( + AbstractTemplate, + make_concrete_template, + signature, + ) + from numba.cuda.types import ffi_forced_object + + if sig.return_type == ffi_forced_object: + msg = "Cannot return a pyobject from an external function" + raise errors.TypingError(msg) + self.sig = sig + self.requires_gil = any(a == ffi_forced_object for a in self.sig.args) + self.get_pointer = get_pointer + self.cconv = cconv + if self.requires_gil: + + class GilRequiringDefn(AbstractTemplate): + key = self.sig + + def generic(self, args, kws): + if kws: + msg = "does not support keyword arguments" + raise errors.TypingError(msg) + # Make ffi_forced_object a bottom type to allow any type to + # be casted to it. This is the only place that support + # ffi_forced_object. + coerced = [ + actual if formal == ffi_forced_object else formal + for actual, formal in zip(args, self.key.args) + ] + return signature(self.key.return_type, *coerced) + + template = GilRequiringDefn + else: + template = make_concrete_template("CFuncPtr", sig, [sig]) + super(ExternalFunctionPointer, self).__init__(template) + + @property + def key(self): + return self.sig, self.cconv, self.get_pointer + + +class ExternalFunction(Function): + """ + A named native function (resolvable by LLVM) accepting an explicit + signature. For internal use only. + """ + + def __init__(self, symbol, sig): + from numba.cuda import typing + + self.symbol = symbol + self.sig = sig + template = typing.make_concrete_template(symbol, symbol, [sig]) + super(ExternalFunction, self).__init__(template) + + @property + def key(self): + return self.symbol, self.sig + + +class NamedTupleClass(Callable, Opaque): + """ + Type class for namedtuple classes. + """ + + def __init__(self, instance_class): + self.instance_class = instance_class + name = "class(%s)" % (instance_class) + super(NamedTupleClass, self).__init__(name) + + def get_call_type(self, context, args, kws): + # Overridden by the __call__ constructor resolution in + # typing.collections + return None + + def get_call_signatures(self): + return (), True + + def get_impl_key(self, sig): + return type(self) + + @property + def key(self): + return self.instance_class + + +class NumberClass(Callable, DTypeSpec, Opaque): + """ + Type class for number classes (e.g. "np.float64"). + """ + + def __init__(self, instance_type): + self.instance_type = instance_type + name = "class(%s)" % (instance_type,) + super(NumberClass, self).__init__(name) + + def get_call_type(self, context, args, kws): + # Overridden by the __call__ constructor resolution in typing.builtins + return None + + def get_call_signatures(self): + return (), True + + def get_impl_key(self, sig): + return type(self) + + @property + def key(self): + return self.instance_type + + @property + def dtype(self): + return self.instance_type + + +_RecursiveCallOverloads = namedtuple("_RecursiveCallOverloads", "qualname,uid") + + +class RecursiveCall(Opaque): + """ + Recursive call to a Dispatcher. + """ + + _overloads = None + + def __init__(self, dispatcher_type): + assert isinstance(dispatcher_type, Dispatcher) + self.dispatcher_type = dispatcher_type + name = "recursive(%s)" % (dispatcher_type,) + super(RecursiveCall, self).__init__(name) + # Initializing for the first time + if self._overloads is None: + self._overloads = {} + + def add_overloads(self, args, qualname, uid): + """Add an overload of the function. + + Parameters + ---------- + args : + argument types + qualname : + function qualifying name + uid : + unique id + """ + self._overloads[args] = _RecursiveCallOverloads(qualname, uid) + + def get_overloads(self, args): + """Get the qualifying name and unique id for the overload given the + argument types. + """ + return self._overloads[args] + + @property + def key(self): + return self.dispatcher_type diff --git a/numba_cuda/numba/cuda/types/cuda_iterators.py b/numba_cuda/numba/cuda/types/cuda_iterators.py new file mode 100644 index 000000000..a6cfc1a29 --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_iterators.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from .common import SimpleIterableType, SimpleIteratorType +from numba.core.errors import TypingError + + +class RangeType(SimpleIterableType): + def __init__(self, dtype): + self.dtype = dtype + name = "range_state_%s" % (dtype,) + super(SimpleIterableType, self).__init__(name) + self._iterator_type = RangeIteratorType(self.dtype) + + def unify(self, typingctx, other): + if isinstance(other, RangeType): + dtype = typingctx.unify_pairs(self.dtype, other.dtype) + if dtype is not None: + return RangeType(dtype) + + +class RangeIteratorType(SimpleIteratorType): + def __init__(self, dtype): + name = "range_iter_%s" % (dtype,) + super(SimpleIteratorType, self).__init__(name) + self._yield_type = dtype + + def unify(self, typingctx, other): + if isinstance(other, RangeIteratorType): + dtype = typingctx.unify_pairs(self.yield_type, other.yield_type) + if dtype is not None: + return RangeIteratorType(dtype) + + +class Generator(SimpleIteratorType): + """ + Type class for Numba-compiled generator objects. + """ + + def __init__( + self, gen_func, yield_type, arg_types, state_types, has_finalizer + ): + self.gen_func = gen_func + self.arg_types = tuple(arg_types) + self.state_types = tuple(state_types) + self.has_finalizer = has_finalizer + name = "%s generator(func=%s, args=%s, has_finalizer=%s)" % ( + yield_type, + self.gen_func, + self.arg_types, + self.has_finalizer, + ) + super(Generator, self).__init__(name, yield_type) + + @property + def key(self): + return ( + self.gen_func, + self.arg_types, + self.yield_type, + self.has_finalizer, + self.state_types, + ) + + +class EnumerateType(SimpleIteratorType): + """ + Type class for `enumerate` objects. + Type instances are parametered with the underlying source type. + """ + + def __init__(self, iterable_type): + from numba.cuda.types import Tuple, intp + + self.source_type = iterable_type.iterator_type + yield_type = Tuple([intp, self.source_type.yield_type]) + name = "enumerate(%s)" % (self.source_type) + super(EnumerateType, self).__init__(name, yield_type) + + @property + def key(self): + return self.source_type + + +class ZipType(SimpleIteratorType): + """ + Type class for `zip` objects. + Type instances are parametered with the underlying source types. + """ + + def __init__(self, iterable_types): + from numba.cuda.types import Tuple + + self.source_types = tuple(tp.iterator_type for tp in iterable_types) + yield_type = Tuple([tp.yield_type for tp in self.source_types]) + name = "zip(%s)" % ", ".join(str(tp) for tp in self.source_types) + super(ZipType, self).__init__(name, yield_type) + + @property + def key(self): + return self.source_types + + +class ArrayIterator(SimpleIteratorType): + """ + Type class for iterators of array and buffer objects. + """ + + def __init__(self, array_type): + self.array_type = array_type + name = "iter(%s)" % (self.array_type,) + nd = array_type.ndim + if nd == 0: + raise TypingError("iteration over a 0-d array") + elif nd == 1: + yield_type = array_type.dtype + else: + # iteration semantics leads to A order layout + yield_type = array_type.copy(ndim=array_type.ndim - 1, layout="A") + super(ArrayIterator, self).__init__(name, yield_type) diff --git a/numba_cuda/numba/cuda/types/cuda_misc.py b/numba_cuda/numba/cuda/types/cuda_misc.py new file mode 100644 index 000000000..1b7e65f76 --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_misc.py @@ -0,0 +1,569 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from numba.cuda.types.abstract import Callable, Literal, Type, Hashable +from numba.cuda.types.common import ( + Dummy, + IterableType, + Opaque, + SimpleIteratorType, +) +from numba.cuda.typeconv import Conversion +from numba.core.errors import TypingError, LiteralTypingError +from numba.cuda.utils import get_hashable_key + + +class PyObject(Dummy): + """ + A generic CPython object. + """ + + def is_precise(self): + return False + + +class Phantom(Dummy): + """ + A type that cannot be materialized. A Phantom cannot be used as + argument or return type. + """ + + +class Undefined(Dummy): + """ + A type that is left imprecise. This is used as a temporaray placeholder + during type inference in the hope that the type can be later refined. + """ + + def is_precise(self): + return False + + +class UndefVar(Dummy): + """ + A type that is created by Expr.undef to represent an undefined variable. + This type can be promoted to any other type. + This is introduced to handle Python 3.12 LOAD_FAST_AND_CLEAR. + """ + + def can_convert_to(self, typingctx, other): + return Conversion.promote + + +class RawPointer(Opaque): + """ + A raw pointer without any specific meaning. + """ + + +class StringLiteral(Literal, Dummy): + def can_convert_to(self, typingctx, other): + if isinstance(other, UnicodeType): + return Conversion.safe + + +Literal.ctor_map[str] = StringLiteral + + +def unliteral(lit_type): + """ + Get base type from Literal type. + """ + if hasattr(lit_type, "__unliteral__"): + return lit_type.__unliteral__() + return getattr(lit_type, "literal_type", lit_type) + + +def literal(value): + """Returns a Literal instance or raise LiteralTypingError""" + ty = type(value) + if isinstance(value, Literal): + msg = "the function does not accept a Literal type; got {} ({})" + raise ValueError(msg.format(value, ty)) + try: + ctor = Literal.ctor_map[ty] + except KeyError: + raise LiteralTypingError("{} cannot be used as a literal".format(ty)) + else: + return ctor(value) + + +def maybe_literal(value): + """Get a Literal type for the value or None.""" + try: + return literal(value) + except LiteralTypingError: + return + + +class Omitted(Opaque): + """ + An omitted function argument with a default value. + """ + + def __init__(self, value): + self._value = value + # Use helper function to support both hashable and non-hashable + # values. See discussion in gh #6957. + self._value_key = get_hashable_key(value) + super(Omitted, self).__init__("omitted(default=%r)" % (value,)) + + @property + def key(self): + return type(self._value), self._value_key + + @property + def value(self): + return self._value + + +class VarArg(Type): + """ + Special type representing a variable number of arguments at the + end of a function's signature. Only used for signature matching, + not for actual values. + """ + + def __init__(self, dtype): + self.dtype = dtype + super(VarArg, self).__init__("*%s" % dtype) + + @property + def key(self): + return self.dtype + + +class Module(Dummy): + def __init__(self, pymod): + self.pymod = pymod + super(Module, self).__init__("Module(%s)" % pymod) + + @property + def key(self): + return self.pymod + + +class MemInfoPointer(Type): + """ + Pointer to a Numba "meminfo" (i.e. the information for a managed + piece of memory). + """ + + mutable = True + + def __init__(self, dtype): + self.dtype = dtype + name = "memory-managed *%s" % dtype + super(MemInfoPointer, self).__init__(name) + + @property + def key(self): + return self.dtype + + +class CPointer(Type): + """ + Type class for pointers to other types. + + Attributes + ---------- + dtype : The pointee type + addrspace : int + The address space pointee belongs to. + """ + + mutable = True + + def __init__(self, dtype, addrspace=None): + self.dtype = dtype + self.addrspace = addrspace + if addrspace is not None: + name = "%s_%s*" % (dtype, addrspace) + else: + name = "%s*" % dtype + super(CPointer, self).__init__(name) + + @property + def key(self): + return self.dtype, self.addrspace + + +class EphemeralPointer(CPointer): + """ + Type class for pointers which aren't guaranteed to last long - e.g. + stack-allocated slots. The data model serializes such pointers + by copying the data pointed to. + """ + + +class EphemeralArray(Type): + """ + Similar to EphemeralPointer, but pointing to an array of elements, + rather than a single one. The array size must be known at compile-time. + """ + + def __init__(self, dtype, count): + self.dtype = dtype + self.count = count + name = "*%s[%d]" % (dtype, count) + super(EphemeralArray, self).__init__(name) + + @property + def key(self): + return self.dtype, self.count + + +class Object(Type): + # XXX unused? + mutable = True + + def __init__(self, clsobj): + self.cls = clsobj + name = "Object(%s)" % clsobj.__name__ + super(Object, self).__init__(name) + + @property + def key(self): + return self.cls + + +class Optional(Type): + """ + Type class for optional types, i.e. union { some type, None } + """ + + def __init__(self, typ): + assert not isinstance(typ, (Optional, NoneType)) + typ = unliteral(typ) + self.type = typ + name = "OptionalType(%s)" % self.type + super(Optional, self).__init__(name) + + @property + def key(self): + return self.type + + def can_convert_to(self, typingctx, other): + if isinstance(other, Optional): + return typingctx.can_convert(self.type, other.type) + else: + conv = typingctx.can_convert(self.type, other) + if conv is not None: + return max(conv, Conversion.safe) + + def can_convert_from(self, typingctx, other): + if isinstance(other, NoneType): + return Conversion.promote + elif isinstance(other, Optional): + return typingctx.can_convert(other.type, self.type) + else: + conv = typingctx.can_convert(other, self.type) + if conv is not None: + return max(conv, Conversion.promote) + + def unify(self, typingctx, other): + if isinstance(other, Optional): + unified = typingctx.unify_pairs(self.type, other.type) + else: + unified = typingctx.unify_pairs(self.type, other) + + if unified is not None: + if isinstance(unified, Optional): + return unified + else: + return Optional(unified) + + +class NoneType(Opaque): + """ + The type for None. + """ + + def unify(self, typingctx, other): + """ + Turn anything to a Optional type; + """ + if isinstance(other, (Optional, NoneType)): + return other + return Optional(other) + + +class EllipsisType(Opaque): + """ + The type for the Ellipsis singleton. + """ + + +class ExceptionClass(Callable, Phantom): + """ + The type of exception classes (not instances). + """ + + def __init__(self, exc_class): + assert issubclass(exc_class, BaseException) + name = "%s" % (exc_class.__name__) + self.exc_class = exc_class + super(ExceptionClass, self).__init__(name) + + def get_call_type(self, context, args, kws): + return self.get_call_signatures()[0][0] + + def get_call_signatures(self): + from numba.cuda import typing + + return_type = ExceptionInstance(self.exc_class) + return [typing.signature(return_type)], False + + def get_impl_key(self, sig): + return type(self) + + @property + def key(self): + return self.exc_class + + +class ExceptionInstance(Phantom): + """ + The type of exception instances. *exc_class* should be the + exception class. + """ + + def __init__(self, exc_class): + assert issubclass(exc_class, BaseException) + name = "%s(...)" % (exc_class.__name__,) + self.exc_class = exc_class + super(ExceptionInstance, self).__init__(name) + + @property + def key(self): + return self.exc_class + + +class SliceType(Type): + def __init__(self, name, members): + assert members in (2, 3) + self.members = members + self.has_step = members >= 3 + super(SliceType, self).__init__(name) + + @property + def key(self): + return self.members + + +class SliceLiteral(Literal, SliceType): + def __init__(self, value): + self._literal_init(value) + name = "Literal[slice]({})".format(value) + members = 2 if value.step is None else 3 + SliceType.__init__(self, name=name, members=members) + + @property + def key(self): + sl = self.literal_value + return sl.start, sl.stop, sl.step + + +Literal.ctor_map[slice] = SliceLiteral + + +class ClassInstanceType(Type): + """ + The type of a jitted class *instance*. It will be the return-type + of the constructor of the class. + """ + + mutable = True + name_prefix = "instance" + + def __init__(self, class_type): + self.class_type = class_type + name = "{0}.{1}".format(self.name_prefix, self.class_type.name) + super(ClassInstanceType, self).__init__(name) + + def get_data_type(self): + return ClassDataType(self) + + def get_reference_type(self): + return self + + @property + def key(self): + return self.class_type.key + + @property + def classname(self): + return self.class_type.class_name + + @property + def jit_props(self): + return self.class_type.jit_props + + @property + def jit_static_methods(self): + return self.class_type.jit_static_methods + + @property + def jit_methods(self): + return self.class_type.jit_methods + + @property + def struct(self): + return self.class_type.struct + + @property + def methods(self): + return self.class_type.methods + + @property + def static_methods(self): + return self.class_type.static_methods + + +class ClassType(Callable, Opaque): + """ + The type of the jitted class (not instance). When the type of a class + is called, its constructor is invoked. + """ + + mutable = True + name_prefix = "jitclass" + instance_type_class = ClassInstanceType + + def __init__( + self, + class_def, + ctor_template_cls, + struct, + jit_methods, + jit_props, + jit_static_methods, + ): + self.class_name = class_def.__name__ + self.class_doc = class_def.__doc__ + self._ctor_template_class = ctor_template_cls + self.jit_methods = jit_methods + self.jit_props = jit_props + self.jit_static_methods = jit_static_methods + self.struct = struct + fielddesc = ",".join("{0}:{1}".format(k, v) for k, v in struct.items()) + name = "{0}.{1}#{2:x}<{3}>".format( + self.name_prefix, self.class_name, id(self), fielddesc + ) + super(ClassType, self).__init__(name) + + def get_call_type(self, context, args, kws): + return self.ctor_template(context).apply(args, kws) + + def get_call_signatures(self): + return (), True + + def get_impl_key(self, sig): + return type(self) + + @property + def methods(self): + return {k: v.py_func for k, v in self.jit_methods.items()} + + @property + def static_methods(self): + return {k: v.py_func for k, v in self.jit_static_methods.items()} + + @property + def instance_type(self): + return ClassInstanceType(self) + + @property + def ctor_template(self): + return self._specialize_template(self._ctor_template_class) + + def _specialize_template(self, basecls): + return type(basecls.__name__, (basecls,), dict(key=self)) + + +class DeferredType(Type): + """ + Represents a type that will be defined later. It must be defined + before it is materialized (used in the compiler). Once defined, it + behaves exactly as the type it is defining. + """ + + def __init__(self): + self._define = None + name = "{0}#{1}".format(type(self).__name__, id(self)) + super(DeferredType, self).__init__(name) + + def get(self): + if self._define is None: + raise RuntimeError("deferred type not defined") + return self._define + + def define(self, typ): + if self._define is not None: + raise TypeError("deferred type already defined") + if not isinstance(typ, Type): + raise TypeError("arg is not a Type; got: {0}".format(type(typ))) + self._define = typ + + def unify(self, typingctx, other): + return typingctx.unify_pairs(self.get(), other) + + +class ClassDataType(Type): + """ + Internal only. + Represents the data of the instance. The representation of + ClassInstanceType contains a pointer to a ClassDataType which represents + a C structure that contains all the data fields of the class instance. + """ + + def __init__(self, classtyp): + self.class_type = classtyp + name = "data.{0}".format(self.class_type.name) + super(ClassDataType, self).__init__(name) + + +class ContextManager(Callable, Phantom): + """ + An overly-simple ContextManager type that cannot be materialized. + """ + + def __init__(self, cm): + self.cm = cm + super(ContextManager, self).__init__("ContextManager({})".format(cm)) + + def get_call_signatures(self): + if not self.cm.is_callable: + msg = "contextmanager {} is not callable".format(self.cm) + raise TypingError(msg) + + return (), False + + def get_call_type(self, context, args, kws): + from numba.cuda import typing + + if not self.cm.is_callable: + msg = "contextmanager {} is not callable".format(self.cm) + raise TypingError(msg) + + posargs = list(args) + [v for k, v in sorted(kws.items())] + return typing.signature(self, *posargs) + + def get_impl_key(self, sig): + return type(self) + + +class UnicodeType(IterableType, Hashable): + def __init__(self, name): + super(UnicodeType, self).__init__(name) + + @property + def iterator_type(self): + return UnicodeIteratorType(self) + + +class UnicodeIteratorType(SimpleIteratorType): + def __init__(self, dtype): + name = "iter_unicode" + self.data = dtype + super(UnicodeIteratorType, self).__init__(name, dtype) diff --git a/numba_cuda/numba/cuda/types/cuda_npytypes.py b/numba_cuda/numba/cuda/types/cuda_npytypes.py new file mode 100644 index 000000000..7835e50ad --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_npytypes.py @@ -0,0 +1,690 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import collections +import warnings +from functools import cached_property + +from llvmlite import ir + +from .abstract import DTypeSpec, IteratorType, MutableSequence, Number, Type +from .common import Buffer, Opaque, SimpleIteratorType +from numba.cuda.typeconv import Conversion +from .misc import UnicodeType +from .containers import Bytes +import numpy as np + + +class CharSeq(Type): + """ + A fixed-length 8-bit character sequence. + """ + + mutable = True + + def __init__(self, count): + self.count = count + name = "[char x %d]" % count + super(CharSeq, self).__init__(name) + + @property + def key(self): + return self.count + + def can_convert_from(self, typingctx, other): + if isinstance(other, Bytes): + return Conversion.safe + + +class UnicodeCharSeq(Type): + """ + A fixed-length unicode character sequence. + """ + + mutable = True + + def __init__(self, count): + self.count = count + name = "[unichr x %d]" % count + super(UnicodeCharSeq, self).__init__(name) + + @property + def key(self): + return self.count + + def can_convert_to(self, typingctx, other): + if isinstance(other, UnicodeCharSeq): + return Conversion.safe + + def can_convert_from(self, typingctx, other): + if isinstance(other, UnicodeType): + # Assuming that unicode_type itemsize is not greater than + # numpy.dtype('U1').itemsize that UnicodeCharSeq is based + # on. + return Conversion.safe + + def __repr__(self): + return f"UnicodeCharSeq({self.count})" + + +_RecordField = collections.namedtuple( + "_RecordField", + "type,offset,alignment,title", +) + + +class Record(Type): + """ + A Record datatype can be mapped to a NumPy structured dtype. + A record is very flexible since it is laid out as a list of bytes. + Fields can be mapped to arbitrary points inside it, even if they overlap. + + *fields* is a list of `(name:str, data:dict)`. + Where `data` is `{ type: Type, offset: int }` + *size* is an int; the record size + *aligned* is a boolean; whether the record is ABI aligned. + """ + + mutable = True + + @classmethod + def make_c_struct(cls, name_types): + """Construct a Record type from a list of (name:str, type:Types). + The layout of the structure will follow C. + + Note: only scalar types are supported currently. + """ + from numba.cuda.core.registry import cpu_target + + ctx = cpu_target.target_context + offset = 0 + fields = [] + lltypes = [] + for k, ty in name_types: + if not isinstance(ty, (Number, NestedArray)): + msg = "Only Number and NestedArray types are supported, found: {}. " + raise TypeError(msg.format(ty)) + if isinstance(ty, NestedArray): + datatype = ctx.data_model_manager[ty].as_storage_type() + else: + datatype = ctx.get_data_type(ty) + lltypes.append(datatype) + size = ctx.get_abi_sizeof(datatype) + align = ctx.get_abi_alignment(datatype) + # align + misaligned = offset % align + if misaligned: + offset += align - misaligned + fields.append( + ( + k, + { + "type": ty, + "offset": offset, + "alignment": align, + }, + ) + ) + offset += size + # Adjust sizeof structure + abi_size = ctx.get_abi_sizeof(ir.LiteralStructType(lltypes)) + return Record(fields, size=abi_size, aligned=True) + + def __init__(self, fields, size, aligned): + fields = self._normalize_fields(fields) + self.fields = dict(fields) + self.size = size + self.aligned = aligned + + # Create description + descbuf = [] + fmt = "{}[type={};offset={}{}]" + for k, infos in fields: + extra = "" + if infos.alignment is not None: + extra += ";alignment={}".format(infos.alignment) + elif infos.title is not None: + extra += ";title={}".format(infos.title) + descbuf.append(fmt.format(k, infos.type, infos.offset, extra)) + + desc = ",".join(descbuf) + name = "Record({};{};{})".format(desc, self.size, self.aligned) + super(Record, self).__init__(name) + + self.bitwidth = self.dtype.itemsize * 8 + + @classmethod + def _normalize_fields(cls, fields): + """ + fields: + [name: str, + value: { + type: Type, + offset: int, + [ alignment: int ], + [ title : str], + }] + """ + res = [] + for name, infos in sorted(fields, key=lambda x: (x[1]["offset"], x[0])): + fd = _RecordField( + type=infos["type"], + offset=infos["offset"], + alignment=infos.get("alignment"), + title=infos.get("title"), + ) + res.append((name, fd)) + return res + + @property + def key(self): + # Numpy dtype equality doesn't always succeed, use the name instead + # (https://github.com/numpy/numpy/issues/5715) + return self.name + + @property + def mangling_args(self): + return self.__class__.__name__, (self._code,) + + def __len__(self): + """Returns the number of fields""" + return len(self.fields) + + def offset(self, key): + """Get the byte offset of a field from the start of the structure.""" + return self.fields[key].offset + + def typeof(self, key): + """Get the type of a field.""" + return self.fields[key].type + + def alignof(self, key): + """Get the specified alignment of the field. + + Since field alignment is optional, this may return None. + """ + return self.fields[key].alignment + + def has_titles(self): + """Returns True the record uses titles.""" + return any(fd.title is not None for fd in self.fields.values()) + + def is_title(self, key): + """Returns True if the field named *key* is a title.""" + return self.fields[key].title == key + + @property + def members(self): + """An ordered list of (name, type) for the fields.""" + ordered = sorted(self.fields.items(), key=lambda x: x[1].offset) + return [(k, v.type) for k, v in ordered] + + @property + def dtype(self): + from numba.cuda.np.numpy_support import as_struct_dtype + + return as_struct_dtype(self) + + def can_convert_to(self, typingctx, other): + """ + Convert this Record to the *other*. + + This method only implements width subtyping for records. + """ + from numba.core.errors import NumbaExperimentalFeatureWarning + + if isinstance(other, Record): + if len(other.fields) > len(self.fields): + return + for other_fd, self_fd in zip( + other.fields.items(), self.fields.items() + ): + if not other_fd == self_fd: + return + warnings.warn( + f"{self} has been considered a subtype of {other} " + f" This is an experimental feature.", + category=NumbaExperimentalFeatureWarning, + ) + return Conversion.safe + + def __repr__(self): + fields = [ + f"('{f_name}', " + + f"{{'type': {repr(f_info.type)}, " + + f"'offset': {f_info.offset}, " + + f"'alignment': {f_info.alignment}, " + + f"'title': {f_info.title}, " + + "}" + + ")" + for f_name, f_info in self.fields.items() + ] + fields = "[" + ", ".join(fields) + "]" + return f"Record({fields}, {self.size}, {self.aligned})" + + +class DType(DTypeSpec, Opaque): + """ + Type class associated with the `np.dtype`. + + i.e. :code:`assert type(np.dtype('int32')) == np.dtype` + + np.dtype('int32') + """ + + def __init__(self, dtype): + assert isinstance(dtype, Type) + self._dtype = dtype + name = "dtype(%s)" % (dtype,) + super(DTypeSpec, self).__init__(name) + + @property + def key(self): + return self.dtype + + @property + def dtype(self): + return self._dtype + + def __getitem__(self, arg): + res = super(DType, self).__getitem__(arg) + return res.copy(dtype=self.dtype) + + +class NumpyFlatType(SimpleIteratorType, MutableSequence): + """ + Type class for `ndarray.flat()` objects. + """ + + def __init__(self, arrty): + self.array_type = arrty + yield_type = arrty.dtype + self.dtype = yield_type + name = "array.flat({arrayty})".format(arrayty=arrty) + super(NumpyFlatType, self).__init__(name, yield_type) + + @property + def key(self): + return self.array_type + + +class NumpyNdEnumerateType(SimpleIteratorType): + """ + Type class for `np.ndenumerate()` objects. + """ + + def __init__(self, arrty): + from . import Tuple, UniTuple, intp + + self.array_type = arrty + yield_type = Tuple((UniTuple(intp, arrty.ndim), arrty.dtype)) + name = "ndenumerate({arrayty})".format(arrayty=arrty) + super(NumpyNdEnumerateType, self).__init__(name, yield_type) + + @property + def key(self): + return self.array_type + + +class NumpyNdIterType(IteratorType): + """ + Type class for `np.nditer()` objects. + + The layout denotes in which order the logical shape is iterated on. + "C" means logical order (corresponding to in-memory order in C arrays), + "F" means reverse logical order (corresponding to in-memory order in + F arrays). + """ + + def __init__(self, arrays): + # Note inputs arrays can also be scalars, in which case they are + # broadcast. + self.arrays = tuple(arrays) + self.layout = self._compute_layout(self.arrays) + self.dtypes = tuple(getattr(a, "dtype", a) for a in self.arrays) + self.ndim = max(getattr(a, "ndim", 0) for a in self.arrays) + name = "nditer(ndim={ndim}, layout={layout}, inputs={arrays})".format( + ndim=self.ndim, layout=self.layout, arrays=self.arrays + ) + super(NumpyNdIterType, self).__init__(name) + + @classmethod + def _compute_layout(cls, arrays): + c = collections.Counter() + for a in arrays: + if not isinstance(a, Array): + continue + if a.layout in "CF" and a.ndim == 1: + c["C"] += 1 + c["F"] += 1 + elif a.ndim >= 1: + c[a.layout] += 1 + return "F" if c["F"] > c["C"] else "C" + + @property + def key(self): + return self.arrays + + @property + def views(self): + """ + The views yielded by the iterator. + """ + return [Array(dtype, 0, "C") for dtype in self.dtypes] + + @property + def yield_type(self): + from . import BaseTuple + + views = self.views + if len(views) > 1: + return BaseTuple.from_types(views) + else: + return views[0] + + @cached_property + def indexers(self): + """ + A list of (kind, start_dim, end_dim, indices) where: + - `kind` is either "flat", "indexed", "0d" or "scalar" + - `start_dim` and `end_dim` are the dimension numbers at which + this indexing takes place + - `indices` is the indices of the indexed arrays in self.arrays + """ + d = collections.OrderedDict() + layout = self.layout + ndim = self.ndim + assert layout in "CF" + for i, a in enumerate(self.arrays): + if not isinstance(a, Array): + indexer = ("scalar", 0, 0) + elif a.ndim == 0: + indexer = ("0d", 0, 0) + else: + if a.layout == layout or (a.ndim == 1 and a.layout in "CF"): + kind = "flat" + else: + kind = "indexed" + if layout == "C": + # If iterating in C order, broadcasting is done on the outer indices + indexer = (kind, ndim - a.ndim, ndim) + else: + indexer = (kind, 0, a.ndim) + d.setdefault(indexer, []).append(i) + return list(k + (v,) for k, v in d.items()) + + @cached_property + def need_shaped_indexing(self): + """ + Whether iterating on this iterator requires keeping track of + individual indices inside the shape. If False, only a single index + over the equivalent flat shape is required, which can make the + iterator more efficient. + """ + for kind, start_dim, end_dim, _ in self.indexers: + if kind in ("0d", "scalar"): + pass + elif kind == "flat": + if (start_dim, end_dim) != (0, self.ndim): + # Broadcast flat iteration needs shaped indexing + # to know when to restart iteration. + return True + else: + return True + return False + + +class NumpyNdIndexType(SimpleIteratorType): + """ + Type class for `np.ndindex()` objects. + """ + + def __init__(self, ndim): + from . import UniTuple, intp + + self.ndim = ndim + yield_type = UniTuple(intp, self.ndim) + name = "ndindex(ndim={ndim})".format(ndim=ndim) + super(NumpyNdIndexType, self).__init__(name, yield_type) + + @property + def key(self): + return self.ndim + + +class Array(Buffer): + """ + Type class for Numpy arrays. + """ + + def __init__( + self, dtype, ndim, layout, readonly=False, name=None, aligned=True + ): + if readonly: + self.mutable = False + if not aligned or (isinstance(dtype, Record) and not dtype.aligned): + self.aligned = False + if isinstance(dtype, NestedArray): + ndim += dtype.ndim + dtype = dtype.dtype + if name is None: + type_name = "array" + if not self.mutable: + type_name = "readonly " + type_name + if not self.aligned: + type_name = "unaligned " + type_name + name = "%s(%s, %sd, %s)" % (type_name, dtype, ndim, layout) + super(Array, self).__init__(dtype, ndim, layout, name=name) + + @property + def mangling_args(self): + args = [ + self.dtype, + self.ndim, + self.layout, + "mutable" if self.mutable else "readonly", + "aligned" if self.aligned else "unaligned", + ] + return self.__class__.__name__, args + + def copy(self, dtype=None, ndim=None, layout=None, readonly=None): + if dtype is None: + dtype = self.dtype + if ndim is None: + ndim = self.ndim + if layout is None: + layout = self.layout + if readonly is None: + readonly = not self.mutable + return Array( + dtype=dtype, + ndim=ndim, + layout=layout, + readonly=readonly, + aligned=self.aligned, + ) + + @property + def key(self): + return self.dtype, self.ndim, self.layout, self.mutable, self.aligned + + def unify(self, typingctx, other): + """ + Unify this with the *other* Array. + """ + # If other is array and the ndim matches + if isinstance(other, Array) and other.ndim == self.ndim: + # If dtype matches or other.dtype is undefined (inferred) + if other.dtype == self.dtype or not other.dtype.is_precise(): + if self.layout == other.layout: + layout = self.layout + else: + layout = "A" + readonly = not (self.mutable and other.mutable) + aligned = self.aligned and other.aligned + return Array( + dtype=self.dtype, + ndim=self.ndim, + layout=layout, + readonly=readonly, + aligned=aligned, + ) + + def can_convert_to(self, typingctx, other): + """ + Convert this Array to the *other*. + """ + if ( + isinstance(other, Array) + and other.ndim == self.ndim + and other.dtype == self.dtype + ): + if ( + other.layout in ("A", self.layout) + and (self.mutable or not other.mutable) + and (self.aligned or not other.aligned) + ): + return Conversion.safe + + def is_precise(self): + return self.dtype.is_precise() + + @property + def box_type(self): + """Returns the Python type to box to.""" + return np.ndarray + + def __repr__(self): + return ( + f"Array({repr(self.dtype)}, {self.ndim}, '{self.layout}', " + f"{not self.mutable}, aligned={self.aligned})" + ) + + +class ArrayCTypes(Type): + """ + This is the type for `np.ndarray.ctypes`. + """ + + def __init__(self, arytype): + # This depends on the ndim for the shape and strides attributes, + # even though they are not implemented, yet. + self.dtype = arytype.dtype + self.ndim = arytype.ndim + name = "ArrayCTypes(dtype={0}, ndim={1})".format(self.dtype, self.ndim) + super(ArrayCTypes, self).__init__(name) + + @property + def key(self): + return self.dtype, self.ndim + + def can_convert_to(self, typingctx, other): + """ + Convert this type to the corresponding pointer type. + This allows passing a array.ctypes object to a C function taking + a raw pointer. + + Note that in pure Python, the array.ctypes object can only be + passed to a ctypes function accepting a c_void_p, not a typed + pointer. + """ + from . import CPointer, voidptr + + # XXX what about readonly + if isinstance(other, CPointer) and other.dtype == self.dtype: + return Conversion.safe + elif other == voidptr: + return Conversion.safe + + +class ArrayFlags(Type): + """ + This is the type for `np.ndarray.flags`. + """ + + def __init__(self, arytype): + self.array_type = arytype + name = "ArrayFlags({0})".format(self.array_type) + super(ArrayFlags, self).__init__(name) + + @property + def key(self): + return self.array_type + + +class NestedArray(Array): + """ + A NestedArray is an array nested within a structured type (which are "void" + type in NumPy parlance). Unlike an Array, the shape, and not just the number + of dimensions is part of the type of a NestedArray. + """ + + def __init__(self, dtype, shape): + if isinstance(dtype, NestedArray): + tmp = Array(dtype.dtype, dtype.ndim, "C") + shape += dtype.shape + dtype = tmp.dtype + assert dtype.bitwidth % 8 == 0, ( + "Dtype bitwidth must be a multiple of bytes" + ) + self._shape = shape + name = "nestedarray(%s, %s)" % (dtype, shape) + ndim = len(shape) + super(NestedArray, self).__init__(dtype, ndim, "C", name=name) + + @property + def shape(self): + return self._shape + + @property + def nitems(self): + l = 1 + for s in self.shape: + l = l * s + return l + + @property + def size(self): + return self.dtype.bitwidth // 8 + + @property + def strides(self): + stride = self.size + strides = [] + for i in reversed(self._shape): + strides.append(stride) + stride *= i + return tuple(reversed(strides)) + + @property + def key(self): + return self.dtype, self.shape + + def __repr__(self): + return f"NestedArray({repr(self.dtype)}, {self.shape})" + + +class NumPyRandomBitGeneratorType(Type): + def __init__(self, *args, **kwargs): + super(NumPyRandomBitGeneratorType, self).__init__(*args, **kwargs) + self.name = "NumPyRandomBitGeneratorType" + + +class NumPyRandomGeneratorType(Type): + def __init__(self, *args, **kwargs): + super(NumPyRandomGeneratorType, self).__init__(*args, **kwargs) + self.name = "NumPyRandomGeneratorType" + + +class PolynomialType(Type): + def __init__(self, coef, domain=None, window=None, n_args=1): + super(PolynomialType, self).__init__( + name=f"PolynomialType({coef}, {domain}, {domain}, {n_args})" + ) + self.coef = coef + self.domain = domain + self.window = window + # We use n_args to keep track of the number of arguments in the + # constructor, since the types of domain and window arguments depend on + # that and we need that information when boxing + self.n_args = n_args diff --git a/numba_cuda/numba/cuda/types/cuda_scalars.py b/numba_cuda/numba/cuda/types/cuda_scalars.py new file mode 100644 index 000000000..77bba9908 --- /dev/null +++ b/numba_cuda/numba/cuda/types/cuda_scalars.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import enum + +import numpy as np + +from .abstract import Dummy, Hashable, Literal, Number, Type +from functools import total_ordering, cached_property +from numba.cuda.typeconv import Conversion +from numba.cuda.np import npdatetime_helpers + + +class Boolean(Hashable): + def cast_python_value(self, value): + return bool(value) + + +def parse_integer_bitwidth(name): + for prefix in ("int", "uint"): + if name.startswith(prefix): + bitwidth = int(name[len(prefix) :]) + return bitwidth + + +def parse_integer_signed(name): + signed = name.startswith("int") + return signed + + +@total_ordering +class Integer(Number): + def __init__(self, name, bitwidth=None, signed=None): + super(Integer, self).__init__(name) + if bitwidth is None: + bitwidth = parse_integer_bitwidth(name) + if signed is None: + signed = parse_integer_signed(name) + self.bitwidth = bitwidth + self.signed = signed + + @classmethod + def from_bitwidth(cls, bitwidth, signed=True): + name = ("int%d" if signed else "uint%d") % bitwidth + return cls(name) + + def cast_python_value(self, value): + return getattr(np, self.name)(value) + + def __lt__(self, other): + if self.__class__ is not other.__class__: + return NotImplemented + if self.signed != other.signed: + return NotImplemented + return self.bitwidth < other.bitwidth + + @property + def maxval(self): + """ + The maximum value representable by this type. + """ + if self.signed: + return (1 << (self.bitwidth - 1)) - 1 + else: + return (1 << self.bitwidth) - 1 + + @property + def minval(self): + """ + The minimal value representable by this type. + """ + if self.signed: + return -(1 << (self.bitwidth - 1)) + else: + return 0 + + +class IntegerLiteral(Literal, Integer): + def __init__(self, value): + self._literal_init(value) + name = "Literal[int]({})".format(value) + basetype = self.literal_type + Integer.__init__( + self, + name=name, + bitwidth=basetype.bitwidth, + signed=basetype.signed, + ) + + def can_convert_to(self, typingctx, other): + conv = typingctx.can_convert(self.literal_type, other) + if conv is not None: + return max(conv, Conversion.promote) + + +Literal.ctor_map[int] = IntegerLiteral + + +class BooleanLiteral(Literal, Boolean): + def __init__(self, value): + self._literal_init(value) + name = "Literal[bool]({})".format(value) + Boolean.__init__(self, name=name) + + def can_convert_to(self, typingctx, other): + conv = typingctx.can_convert(self.literal_type, other) + if conv is not None: + return max(conv, Conversion.promote) + + +Literal.ctor_map[bool] = BooleanLiteral + + +@total_ordering +class Float(Number): + def __init__(self, *args, **kws): + super(Float, self).__init__(*args, **kws) + # Determine bitwidth + assert self.name.startswith("float") + bitwidth = int(self.name[5:]) + self.bitwidth = bitwidth + + def cast_python_value(self, value): + return getattr(np, self.name)(value) + + def __lt__(self, other): + if self.__class__ is not other.__class__: + return NotImplemented + return self.bitwidth < other.bitwidth + + +@total_ordering +class Complex(Number): + def __init__(self, name, underlying_float, **kwargs): + super(Complex, self).__init__(name, **kwargs) + self.underlying_float = underlying_float + # Determine bitwidth + assert self.name.startswith("complex") + bitwidth = int(self.name[7:]) + self.bitwidth = bitwidth + + def cast_python_value(self, value): + return getattr(np, self.name)(value) + + def __lt__(self, other): + if self.__class__ is not other.__class__: + return NotImplemented + return self.bitwidth < other.bitwidth + + +class _NPDatetimeBase(Type): + """ + Common base class for np.datetime64 and np.timedelta64. + """ + + def __init__(self, unit, *args, **kws): + name = "%s[%s]" % (self.type_name, unit) + self.unit = unit + self.unit_code = npdatetime_helpers.DATETIME_UNITS[self.unit] + super(_NPDatetimeBase, self).__init__(name, *args, **kws) + + def __lt__(self, other): + if self.__class__ is not other.__class__: + return NotImplemented + # A coarser-grained unit is "smaller", i.e. less precise values + # can be represented (but the magnitude of representable values is + # also greater...). + return self.unit_code < other.unit_code + + def cast_python_value(self, value): + cls = getattr(np, self.type_name) + if self.unit: + return cls(value, self.unit) + else: + return cls(value) + + +@total_ordering +class NPTimedelta(_NPDatetimeBase): + type_name = "timedelta64" + + +@total_ordering +class NPDatetime(_NPDatetimeBase): + type_name = "datetime64" + + +class EnumClass(Dummy): + """ + Type class for Enum classes. + """ + + basename = "Enum class" + + def __init__(self, cls, dtype): + assert isinstance(cls, type) + assert isinstance(dtype, Type) + self.instance_class = cls + self.dtype = dtype + name = "%s<%s>(%s)" % ( + self.basename, + self.dtype, + self.instance_class.__name__, + ) + super(EnumClass, self).__init__(name) + + @property + def key(self): + return self.instance_class, self.dtype + + @cached_property + def member_type(self): + """ + The type of this class' members. + """ + return EnumMember(self.instance_class, self.dtype) + + +class IntEnumClass(EnumClass): + """ + Type class for IntEnum classes. + """ + + basename = "IntEnum class" + + @cached_property + def member_type(self): + """ + The type of this class' members. + """ + return IntEnumMember(self.instance_class, self.dtype) + + +class EnumMember(Type): + """ + Type class for Enum members. + """ + + basename = "Enum" + class_type_class = EnumClass + + def __init__(self, cls, dtype): + assert isinstance(cls, type) + assert isinstance(dtype, Type) + self.instance_class = cls + self.dtype = dtype + name = "%s<%s>(%s)" % ( + self.basename, + self.dtype, + self.instance_class.__name__, + ) + super(EnumMember, self).__init__(name) + + @property + def key(self): + return self.instance_class, self.dtype + + @property + def class_type(self): + """ + The type of this member's class. + """ + return self.class_type_class(self.instance_class, self.dtype) + + +class IntEnumMember(EnumMember): + """ + Type class for IntEnum members. + """ + + basename = "IntEnum" + class_type_class = IntEnumClass + + def can_convert_to(self, typingctx, other): + """ + Convert IntEnum members to plain integers. + """ + if issubclass(self.instance_class, enum.IntEnum): + conv = typingctx.can_convert(self.dtype, other) + return max(conv, Conversion.safe) diff --git a/numba_cuda/numba/cuda/types/function_type.py b/numba_cuda/numba/cuda/types/function_type.py new file mode 100644 index 000000000..a5d7003b5 --- /dev/null +++ b/numba_cuda/numba/cuda/types/function_type.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), + "numba.core.types.function_type", + "numba.cuda.types.cuda_function_type", +) diff --git a/numba_cuda/numba/cuda/types/functions.py b/numba_cuda/numba/cuda/types/functions.py new file mode 100644 index 000000000..5b78684ec --- /dev/null +++ b/numba_cuda/numba/cuda/types/functions.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.types.functions", "numba.cuda.types.cuda_functions" +) diff --git a/numba_cuda/numba/cuda/types/iterators.py b/numba_cuda/numba/cuda/types/iterators.py new file mode 100644 index 000000000..8801f3022 --- /dev/null +++ b/numba_cuda/numba/cuda/types/iterators.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.types.iterators", "numba.cuda.types.cuda_iterators" +) diff --git a/numba_cuda/numba/cuda/types/misc.py b/numba_cuda/numba/cuda/types/misc.py new file mode 100644 index 000000000..e8a0e7014 --- /dev/null +++ b/numba_cuda/numba/cuda/types/misc.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.types.misc", "numba.cuda.types.cuda_misc" +) diff --git a/numba_cuda/numba/cuda/types/npytypes.py b/numba_cuda/numba/cuda/types/npytypes.py new file mode 100644 index 000000000..1c25cac53 --- /dev/null +++ b/numba_cuda/numba/cuda/types/npytypes.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.types.npytypes", "numba.cuda.types.cuda_npytypes" +) diff --git a/numba_cuda/numba/cuda/types/scalars.py b/numba_cuda/numba/cuda/types/scalars.py new file mode 100644 index 000000000..09297137a --- /dev/null +++ b/numba_cuda/numba/cuda/types/scalars.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +from numba.cuda.utils import redirect_numba_module + +sys.modules[__name__] = redirect_numba_module( + locals(), "numba.core.types.scalars", "numba.cuda.types.cuda_scalars" +) diff --git a/numba_cuda/numba/cuda/typing/arraydecl.py b/numba_cuda/numba/cuda/typing/arraydecl.py index 17eeeebbb..abaed2f88 100644 --- a/numba_cuda/numba/cuda/typing/arraydecl.py +++ b/numba_cuda/numba/cuda/typing/arraydecl.py @@ -5,7 +5,7 @@ import operator from collections import namedtuple -from numba.core import types +from numba.cuda import types from numba.cuda import utils from numba.cuda.typing.templates import ( AttributeTemplate, @@ -30,7 +30,6 @@ numpy_version = tuple(map(int, np.__version__.split(".")[:2])) - registry = Registry() infer = registry.register infer_global = registry.register_global diff --git a/numba_cuda/numba/cuda/typing/asnumbatype.py b/numba_cuda/numba/cuda/typing/asnumbatype.py index d98d5eec6..50c3f9270 100644 --- a/numba_cuda/numba/cuda/typing/asnumbatype.py +++ b/numba_cuda/numba/cuda/typing/asnumbatype.py @@ -6,7 +6,7 @@ from numba.cuda.typing.typeof import typeof from numba.core import errors -from numba.core import types +from numba.cuda import types class AsNumbaTypeRegistry: diff --git a/numba_cuda/numba/cuda/typing/bufproto.py b/numba_cuda/numba/cuda/typing/bufproto.py index 00ee87634..c55f58339 100644 --- a/numba_cuda/numba/cuda/typing/bufproto.py +++ b/numba_cuda/numba/cuda/typing/bufproto.py @@ -7,7 +7,7 @@ import array -from numba.core import types +from numba.cuda import types from numba.core.errors import NumbaValueError diff --git a/numba_cuda/numba/cuda/typing/builtins.py b/numba_cuda/numba/cuda/typing/builtins.py index 703f5e51c..d8adc0781 100644 --- a/numba_cuda/numba/cuda/typing/builtins.py +++ b/numba_cuda/numba/cuda/typing/builtins.py @@ -6,7 +6,8 @@ import numpy as np import operator -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba import prange from numba.parfors.parfor import internal_prange diff --git a/numba_cuda/numba/cuda/typing/cffi_utils.py b/numba_cuda/numba/cuda/typing/cffi_utils.py index 9d9d0da11..d20edda78 100644 --- a/numba_cuda/numba/cuda/typing/cffi_utils.py +++ b/numba_cuda/numba/cuda/typing/cffi_utils.py @@ -9,7 +9,7 @@ from functools import partial import numpy as np -from numba.core import types +from numba.cuda import types from numba.core.errors import TypingError from numba.cuda.typing import templates from numba.cuda.np import numpy_support diff --git a/numba_cuda/numba/cuda/typing/cmathdecl.py b/numba_cuda/numba/cuda/typing/cmathdecl.py index 24eae3a4f..0812fc2a3 100644 --- a/numba_cuda/numba/cuda/typing/cmathdecl.py +++ b/numba_cuda/numba/cuda/typing/cmathdecl.py @@ -3,7 +3,7 @@ import cmath -from numba.core import types +from numba.cuda import types from numba.cuda.typing.templates import ConcreteTemplate, signature, Registry registry = Registry() diff --git a/numba_cuda/numba/cuda/typing/collections.py b/numba_cuda/numba/cuda/typing/collections.py index 1eaf1fa9f..a4e0dde11 100644 --- a/numba_cuda/numba/cuda/typing/collections.py +++ b/numba_cuda/numba/cuda/typing/collections.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda import utils import operator from .templates import ( diff --git a/numba_cuda/numba/cuda/typing/context.py b/numba_cuda/numba/cuda/typing/context.py index 62c6cf938..e6ac0ad83 100644 --- a/numba_cuda/numba/cuda/typing/context.py +++ b/numba_cuda/numba/cuda/typing/context.py @@ -10,11 +10,11 @@ import operator from importlib.util import find_spec -from numba.core import types, errors +from numba.core import errors from numba.cuda.typeconv import Conversion, rules from numba.cuda.typing.typeof import typeof, Purpose from numba.cuda.typing import templates -from numba.cuda import utils +from numba.cuda import types, utils from numba.cuda.utils import order_by_target_specificity @@ -291,7 +291,7 @@ def find_matching_getattr_template(self, typ, attr): templates = list(self._get_attribute_templates(typ)) # get the order in which to try templates - from numba.core.target_extension import get_local_target # circular + from numba.core.target_extension import get_local_target target_hw = get_local_target(self) order = order_by_target_specificity(target_hw, templates, fnkey=attr) @@ -512,9 +512,9 @@ def is_external(obj): continue self.insert_attributes(ftcls(self)) for gv, gty in loader.new_registrations("globals"): - # If external_defs_only, check the global type's module + # If external_defs_only, check the global value's module if external_defs_only: - if hasattr(gty, "__module__") and is_external(gty): + if hasattr(gv, "__module__") and not is_external(gv): continue existing = self._lookup_global(gv) if existing is None: diff --git a/numba_cuda/numba/cuda/typing/ctypes_utils.py b/numba_cuda/numba/cuda/typing/ctypes_utils.py index af8ecbfd7..e49772765 100644 --- a/numba_cuda/numba/cuda/typing/ctypes_utils.py +++ b/numba_cuda/numba/cuda/typing/ctypes_utils.py @@ -8,7 +8,7 @@ import ctypes import sys -from numba.core import types +from numba.cuda import types from numba.cuda.typing import templates _FROM_CTYPES = { diff --git a/numba_cuda/numba/cuda/typing/dictdecl.py b/numba_cuda/numba/cuda/typing/dictdecl.py index 42e08258e..54ccc0cc0 100644 --- a/numba_cuda/numba/cuda/typing/dictdecl.py +++ b/numba_cuda/numba/cuda/typing/dictdecl.py @@ -5,7 +5,8 @@ This implements the typing template for `dict()`. """ -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from .templates import ( AbstractTemplate, Registry, diff --git a/numba_cuda/numba/cuda/typing/enumdecl.py b/numba_cuda/numba/cuda/typing/enumdecl.py index 10a899262..bb508d880 100644 --- a/numba_cuda/numba/cuda/typing/enumdecl.py +++ b/numba_cuda/numba/cuda/typing/enumdecl.py @@ -6,7 +6,7 @@ """ import operator -from numba.core import types +from numba.cuda import types from numba.cuda.typing.templates import ( AbstractTemplate, AttributeTemplate, diff --git a/numba_cuda/numba/cuda/typing/listdecl.py b/numba_cuda/numba/cuda/typing/listdecl.py index 8754197fd..40de72783 100644 --- a/numba_cuda/numba/cuda/typing/listdecl.py +++ b/numba_cuda/numba/cuda/typing/listdecl.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause import operator -from numba.core import types +from numba.cuda import types from .templates import ( AbstractTemplate, AttributeTemplate, diff --git a/numba_cuda/numba/cuda/typing/mathdecl.py b/numba_cuda/numba/cuda/typing/mathdecl.py index d9a1063a7..e62b1f1f5 100644 --- a/numba_cuda/numba/cuda/typing/mathdecl.py +++ b/numba_cuda/numba/cuda/typing/mathdecl.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause import math -from numba.core import types +from numba.cuda import types from numba.cuda.typing.templates import ConcreteTemplate, signature, Registry registry = Registry() diff --git a/numba_cuda/numba/cuda/typing/npdatetime.py b/numba_cuda/numba/cuda/typing/npdatetime.py index 1a6811e4d..3a72e32b5 100644 --- a/numba_cuda/numba/cuda/typing/npdatetime.py +++ b/numba_cuda/numba/cuda/typing/npdatetime.py @@ -3,7 +3,8 @@ import operator -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda.typing.templates import ( AbstractTemplate, Registry, diff --git a/numba_cuda/numba/cuda/typing/npydecl.py b/numba_cuda/numba/cuda/typing/npydecl.py index b37e5163e..1c20849d0 100644 --- a/numba_cuda/numba/cuda/typing/npydecl.py +++ b/numba_cuda/numba/cuda/typing/npydecl.py @@ -5,7 +5,7 @@ import operator from numba.cuda.typing.templates import AbstractTemplate, Registry, signature -from numba.core import types +from numba.cuda import types from numba.cuda import utils from numba.core.errors import TypingError, NumbaTypeError from numba.cuda.np.numpy_support import ( diff --git a/numba_cuda/numba/cuda/typing/setdecl.py b/numba_cuda/numba/cuda/typing/setdecl.py index 03bc6761f..339941ff3 100644 --- a/numba_cuda/numba/cuda/typing/setdecl.py +++ b/numba_cuda/numba/cuda/typing/setdecl.py @@ -3,7 +3,7 @@ import operator -from numba.core import types +from numba.cuda import types from .templates import ( AbstractTemplate, AttributeTemplate, diff --git a/numba_cuda/numba/cuda/typing/templates.py b/numba_cuda/numba/cuda/typing/templates.py index 987624b21..e5b7cff84 100644 --- a/numba_cuda/numba/cuda/typing/templates.py +++ b/numba_cuda/numba/cuda/typing/templates.py @@ -15,12 +15,13 @@ from types import MethodType, FunctionType, MappingProxyType import numba -from numba.core import types +from numba.cuda import types from numba.core.errors import ( TypingError, InternalError, ) from numba.cuda.core.options import InlineOptions + from numba.cuda import utils from numba.cuda.core import targetconfig @@ -31,7 +32,6 @@ except ImportError: numba_sig_present = False - # info store for inliner callback functions e.g. cost model _inline_info = namedtuple("inline_info", "func_ir typemap calltypes signature") diff --git a/numba_cuda/numba/cuda/typing/typeof.py b/numba_cuda/numba/cuda/typing/typeof.py index a9fe8c44c..f199c5af4 100644 --- a/numba_cuda/numba/cuda/typing/typeof.py +++ b/numba_cuda/numba/cuda/typing/typeof.py @@ -9,7 +9,8 @@ import numpy as np from numpy.random.bit_generator import BitGenerator -from numba.core import types, errors +from numba.cuda import types +from numba.core import errors from numba.cuda import utils from numba.cuda.np import numpy_support diff --git a/numba_cuda/numba/cuda/utils.py b/numba_cuda/numba/cuda/utils.py index 67a86ef6f..13203103b 100644 --- a/numba_cuda/numba/cuda/utils.py +++ b/numba_cuda/numba/cuda/utils.py @@ -21,6 +21,7 @@ from types import ModuleType from importlib import import_module +from importlib.util import find_spec import numpy as np from inspect import signature as pysignature # noqa: F401 @@ -34,6 +35,7 @@ from numba.cuda.core import config +from collections.abc import Sequence PYVERSION = config.PYVERSION @@ -477,6 +479,96 @@ def get_nargs_range(pyfunc): return min_nargs, max_nargs +def unified_function_type(numba_types, require_precise=True): + """Returns a unified Numba function type if possible. + + Parameters + ---------- + numba_types : Sequence of numba Type instances. + require_precise : bool + If True, the returned Numba function type must be precise. + + Returns + ------- + typ : {numba.cuda.types.Type, None} + A unified Numba function type. Or ``None`` when the Numba types + cannot be unified, e.g. when the ``numba_types`` contains at + least two different Numba function type instances. + + If ``numba_types`` contains a Numba dispatcher type, the unified + Numba function type will be an imprecise ``UndefinedFunctionType`` + instance, or None when ``require_precise=True`` is specified. + + Specifying ``require_precise=False`` enables unifying imprecise + Numba dispatcher instances when used in tuples or if-then branches + when the precise Numba function cannot be determined on the first + occurrence that is not a call expression. + """ + from numba.core.errors import NumbaExperimentalFeatureWarning + from numba.cuda import types + + if not ( + isinstance(numba_types, Sequence) + and len(numba_types) > 0 + and isinstance(numba_types[0], (types.Dispatcher, types.FunctionType)) + ): + return + + warnings.warn( + "First-class function type feature is experimental", + category=NumbaExperimentalFeatureWarning, + ) + + mnargs, mxargs = None, None + dispatchers = set() + function = None + undefined_function = None + + for t in numba_types: + if isinstance(t, types.Dispatcher): + mnargs1, mxargs1 = get_nargs_range(t.dispatcher.py_func) + if mnargs is None: + mnargs, mxargs = mnargs1, mxargs1 + elif not (mnargs, mxargs) == (mnargs1, mxargs1): + return + dispatchers.add(t.dispatcher) + t = t.dispatcher.get_function_type() + if t is None: + continue + if isinstance(t, types.FunctionType): + if mnargs is None: + mnargs = mxargs = t.nargs + elif not (mnargs == mxargs == t.nargs): + return + if isinstance(t, types.UndefinedFunctionType): + if undefined_function is None: + undefined_function = t + else: + # Refuse to unify using function type + return + dispatchers.update(t.dispatchers) + else: + if function is None: + function = t + else: + assert function == t + else: + return + if require_precise and (function is None or undefined_function is not None): + return + if function is not None: + if undefined_function is not None: + assert function.nargs == undefined_function.nargs + function = undefined_function + elif undefined_function is not None: + undefined_function.dispatchers.update(dispatchers) + function = undefined_function + else: + function = types.UndefinedFunctionType(mnargs, dispatchers) + + return function + + class _RedirectSubpackage(ModuleType): """Redirect a subpackage to a subpackage. @@ -522,6 +614,13 @@ def __reduce__(self): return _RedirectSubpackage, args +def redirect_numba_module(old_module_locals, numba_module, numba_cuda_module): + if find_spec("numba"): + return _RedirectSubpackage(old_module_locals, numba_module) + else: + return _RedirectSubpackage(old_module_locals, numba_cuda_module) + + def get_hashable_key(value): """ Given a value, returns a key that can be used diff --git a/numba_cuda/numba/cuda/vector_types.py b/numba_cuda/numba/cuda/vector_types.py index 10bf5e188..6e46206bd 100644 --- a/numba_cuda/numba/cuda/vector_types.py +++ b/numba_cuda/numba/cuda/vector_types.py @@ -6,9 +6,9 @@ from typing import List, Tuple, Dict -from numba import types +from numba.cuda import types from numba.cuda import cgutils -from numba.core.datamodel import models +from numba.cuda.datamodel import models from numba.cuda.core.imputils import Registry as ImplRegistry from numba.cuda.typing.templates import ConcreteTemplate from numba.cuda.typing.templates import Registry as TypingRegistry @@ -63,7 +63,7 @@ def make_vector_type( ---------- name: str The name of the type. - base_type: numba.types.Type + base_type: numba.cuda.types.Type The primitive type for each element in the vector. attr_names: tuple of str Name for each attribute. diff --git a/pyproject.toml b/pyproject.toml index ab8a8920b..a378dbda4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,12 @@ exclude = [ # errors in device_init because its purpose is to bring together a lot of # the public API to be star-imported in numba.cuda.__init__ "numba_cuda/numba/cuda/device_init.py" = ["F401", "F403", "F405"] +# Ignore star imports", " unused imports", " and "may be defined by star imports" +# errors in types init files. +"numba_cuda/numba/cuda/types/__init__.py" = ["F401", "F403", "F405"] +"numba_cuda/numba/cuda/types/__init__.pyi" = ["F401", "F403", "F405"] +# Ignore "unused" imports in datamodel init file. +"numba_cuda/numba/cuda/datamodel/__init__.py" = ["F401"] # libdevice.py is an autogenerated file containing stubs for all the device # functions. Some of the lines in docstrings are a little over-long", " as they # contain the URLs of the reference pages in the online libdevice