From e1984893e1b68e86e9504047f19760e3a9cb6bdb Mon Sep 17 00:00:00 2001 From: Vijay Kandiah Date: Wed, 1 Oct 2025 21:44:38 -0700 Subject: [PATCH 1/4] Vendor in typeconv for future CUDA-specific changes --- numba_cuda/numba/cuda/_internal/cuda_fp16.py | 10 +- numba_cuda/numba/cuda/core/typeinfer.py | 2 +- numba_cuda/numba/cuda/target.py | 2 +- .../numba/cuda/tests/cudapy/test_typeinfer.py | 2 +- numba_cuda/numba/cuda/typeconv/__init__.py | 4 + numba_cuda/numba/cuda/typeconv/castgraph.py | 137 +++++++++++++++++ numba_cuda/numba/cuda/typeconv/rules.py | 73 +++++++++ numba_cuda/numba/cuda/typeconv/typeconv.py | 139 ++++++++++++++++++ numba_cuda/numba/cuda/types.py | 2 +- numba_cuda/numba/cuda/typing/context.py | 2 +- 10 files changed, 363 insertions(+), 10 deletions(-) create mode 100644 numba_cuda/numba/cuda/typeconv/__init__.py create mode 100644 numba_cuda/numba/cuda/typeconv/castgraph.py create mode 100644 numba_cuda/numba/cuda/typeconv/rules.py create mode 100644 numba_cuda/numba/cuda/typeconv/typeconv.py diff --git a/numba_cuda/numba/cuda/_internal/cuda_fp16.py b/numba_cuda/numba/cuda/_internal/cuda_fp16.py index 6bf2b0159..96456d9ba 100644 --- a/numba_cuda/numba/cuda/_internal/cuda_fp16.py +++ b/numba_cuda/numba/cuda/_internal/cuda_fp16.py @@ -124,7 +124,7 @@ def __init__(self): self.bitwidth = 2 * 8 def can_convert_from(self, typingctx, other): - from numba.core.typeconv import Conversion + from numba.cuda.typeconv import Conversion if other in []: return Conversion.safe @@ -174,7 +174,7 @@ def __init__(self): self.bitwidth = 4 * 8 def can_convert_from(self, typingctx, other): - from numba.core.typeconv import Conversion + from numba.cuda.typeconv import Conversion if other in []: return Conversion.safe @@ -7903,9 +7903,9 @@ def generic(self, args, kws): # - Conversion.safe if ( - (convertible == numba.core.typeconv.Conversion.exact) - or (convertible == numba.core.typeconv.Conversion.promote) - or (convertible == numba.core.typeconv.Conversion.safe) + (convertible == numba.cuda.typeconv.Conversion.exact) + or (convertible == numba.cuda.typeconv.Conversion.promote) + or (convertible == numba.cuda.typeconv.Conversion.safe) ): return signature(retty, types.float16, types.float16) diff --git a/numba_cuda/numba/cuda/core/typeinfer.py b/numba_cuda/numba/cuda/core/typeinfer.py index 38aea47e8..2103d602c 100644 --- a/numba_cuda/numba/cuda/core/typeinfer.py +++ b/numba_cuda/numba/cuda/core/typeinfer.py @@ -48,7 +48,7 @@ NumbaValueError, ) from numba.cuda.core.funcdesc import qualifying_prefix -from numba.core.typeconv import Conversion +from numba.cuda.typeconv import Conversion _logger = logging.getLogger(__name__) diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 201e4beb6..f8e129270 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -88,7 +88,7 @@ def resolve_value_type(self, val): def can_convert(self, fromty, toty): """ Check whether conversion is possible from *fromty* to *toty*. - If successful, return a numba.typeconv.Conversion instance; + If successful, return a numba.cuda.typeconv.Conversion instance; otherwise None is returned. """ diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py b/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py index 09adb20ae..0c8ae598f 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py @@ -4,7 +4,7 @@ import itertools from numba.core import errors, types, typing -from numba.core.typeconv import Conversion +from numba.cuda.typeconv import Conversion from numba.cuda.testing import CUDATestCase, skip_on_cudasim from numba.tests.test_typeconv import CompatibilityTestMixin diff --git a/numba_cuda/numba/cuda/typeconv/__init__.py b/numba_cuda/numba/cuda/typeconv/__init__.py new file mode 100644 index 000000000..55011b82a --- /dev/null +++ b/numba_cuda/numba/cuda/typeconv/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from .castgraph import Conversion # noqa F401 diff --git a/numba_cuda/numba/cuda/typeconv/castgraph.py b/numba_cuda/numba/cuda/typeconv/castgraph.py new file mode 100644 index 000000000..aa9f20db7 --- /dev/null +++ b/numba_cuda/numba/cuda/typeconv/castgraph.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from collections import defaultdict +import enum + + +class Conversion(enum.IntEnum): + """ + A conversion kind from one type to the other. The enum members + are ordered from stricter to looser. + """ + + # The two types are identical + exact = 1 + # The two types are of the same kind, the destination type has more + # extension or precision than the source type (e.g. float32 -> float64, + # or int32 -> int64) + promote = 2 + # The source type can be converted to the destination type without loss + # of information (e.g. int32 -> int64). Note that the conversion may + # still fail explicitly at runtime (e.g. Optional(int32) -> int32) + safe = 3 + # The conversion may appear to succeed at runtime while losing information + # or precision (e.g. int32 -> uint32, float64 -> float32, int64 -> int32, + # etc.) + unsafe = 4 + + # This value is only used internally + nil = 99 + + +class CastSet(object): + """A set of casting rules. + + There is at most one rule per target type. + """ + + def __init__(self): + self._rels = {} + + def insert(self, to, rel): + old = self.get(to) + setrel = min(rel, old) + self._rels[to] = setrel + return old != setrel + + def items(self): + return self._rels.items() + + def get(self, item): + return self._rels.get(item, Conversion.nil) + + def __len__(self): + return len(self._rels) + + def __repr__(self): + body = [ + "{rel}({ty})".format(rel=rel, ty=ty) + for ty, rel in self._rels.items() + ] + return "{" + ", ".join(body) + "}" + + def __contains__(self, item): + return item in self._rels + + def __iter__(self): + return iter(self._rels.keys()) + + def __getitem__(self, item): + return self._rels[item] + + +class TypeGraph(object): + """A graph that maintains the casting relationship of all types. + + This simplifies the definition of casting rules by automatically + propagating the rules. + """ + + def __init__(self, callback=None): + """ + Args + ---- + - callback: callable or None + It is called for each new casting rule with + (from_type, to_type, castrel). + """ + assert callback is None or callable(callback) + self._forwards = defaultdict(CastSet) + self._backwards = defaultdict(set) + self._callback = callback + + def get(self, ty): + return self._forwards[ty] + + def propagate(self, a, b, baserel): + backset = self._backwards[a] + + # Forward propagate the relationship to all nodes that b leads to + for child in self._forwards[b]: + rel = max(baserel, self._forwards[b][child]) + if a != child: + if self._forwards[a].insert(child, rel): + self._callback(a, child, rel) + self._backwards[child].add(a) + + # Propagate the relationship from nodes that connects to a + for backnode in backset: + if backnode != child: + backrel = max(rel, self._forwards[backnode][a]) + if self._forwards[backnode].insert(child, backrel): + self._callback(backnode, child, backrel) + self._backwards[child].add(backnode) + + # Every node that leads to a connects to b + for child in self._backwards[a]: + rel = max(baserel, self._forwards[child][a]) + if b != child: + if self._forwards[child].insert(b, rel): + self._callback(child, b, rel) + self._backwards[b].add(child) + + def insert_rule(self, a, b, rel): + self._forwards[a].insert(b, rel) + self._callback(a, b, rel) + self._backwards[b].add(a) + self.propagate(a, b, rel) + + def promote(self, a, b): + self.insert_rule(a, b, Conversion.promote) + + def safe(self, a, b): + self.insert_rule(a, b, Conversion.safe) + + def unsafe(self, a, b): + self.insert_rule(a, b, Conversion.unsafe) diff --git a/numba_cuda/numba/cuda/typeconv/rules.py b/numba_cuda/numba/cuda/typeconv/rules.py new file mode 100644 index 000000000..2fa5c1158 --- /dev/null +++ b/numba_cuda/numba/cuda/typeconv/rules.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import itertools +from .typeconv import TypeManager, TypeCastingRules +from numba.core import types +from numba.cuda import config + + +default_type_manager = TypeManager() + + +def dump_number_rules(): + tm = default_type_manager + for a, b in itertools.product(types.number_domain, types.number_domain): + print(a, "->", b, tm.check_compatible(a, b)) + + +if config.USE_LEGACY_TYPE_SYSTEM: # Old type system + + def _init_casting_rules(tm): + tcr = TypeCastingRules(tm) + tcr.safe_unsafe(types.boolean, types.int8) + tcr.safe_unsafe(types.boolean, types.uint8) + + tcr.promote_unsafe(types.int8, types.int16) + tcr.promote_unsafe(types.uint8, types.uint16) + + tcr.promote_unsafe(types.int16, types.int32) + tcr.promote_unsafe(types.uint16, types.uint32) + + tcr.promote_unsafe(types.int32, types.int64) + tcr.promote_unsafe(types.uint32, types.uint64) + + tcr.safe_unsafe(types.uint8, types.int16) + tcr.safe_unsafe(types.uint16, types.int32) + tcr.safe_unsafe(types.uint32, types.int64) + + tcr.safe_unsafe(types.int8, types.float16) + tcr.safe_unsafe(types.int16, types.float32) + tcr.safe_unsafe(types.int32, types.float64) + + tcr.unsafe_unsafe(types.int16, types.float16) + tcr.unsafe_unsafe(types.int32, types.float32) + # XXX this is inconsistent with the above; but we want to prefer + # float64 over int64 when typing a heterogeneous operation, + # e.g. `float64 + int64`. Perhaps we need more granularity in the + # conversion kinds. + tcr.safe_unsafe(types.int64, types.float64) + tcr.safe_unsafe(types.uint64, types.float64) + + tcr.promote_unsafe(types.float16, types.float32) + tcr.promote_unsafe(types.float32, types.float64) + + tcr.safe(types.float32, types.complex64) + tcr.safe(types.float64, types.complex128) + + tcr.promote_unsafe(types.complex64, types.complex128) + + # Allow integers to cast ot void* + tcr.unsafe_unsafe(types.uintp, types.voidptr) + + return tcr +else: # New type system + # Currently left as empty + # If no casting rules are required we may opt to remove + # this framework upon deprecation + def _init_casting_rules(tm): + tcr = TypeCastingRules(tm) + return tcr + + +default_casting_rules = _init_casting_rules(default_type_manager) diff --git a/numba_cuda/numba/cuda/typeconv/typeconv.py b/numba_cuda/numba/cuda/typeconv/typeconv.py new file mode 100644 index 000000000..117afe253 --- /dev/null +++ b/numba_cuda/numba/cuda/typeconv/typeconv.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +try: + # This is usually the the first C extension import performed when importing + # Numba, if it fails to import, provide some feedback + from numba.core.typeconv import _typeconv +except ImportError as e: + base_url = "https://numba.readthedocs.io/en/stable" + dev_url = f"{base_url}/developer/contributing.html" + user_url = f"{base_url}/user/faq.html#numba-could-not-be-imported" + dashes = "-" * 80 + msg = ( + f"Numba could not be imported.\n{dashes}\nIf you are seeing this " + "message and are undertaking Numba development work, you may need " + "to rebuild Numba.\nPlease see the development set up guide:\n\n" + f"{dev_url}.\n\n{dashes}\nIf you are not working on Numba " + f"development, the original error was: '{str(e)}'.\nFor help, " + f"please visit:\n\n{user_url}\n" + ) + raise ImportError(msg) + +from numba.cuda.typeconv import castgraph, Conversion +from numba.core import types + + +class TypeManager(object): + # The character codes used by the C/C++ API (_typeconv.cpp) + _conversion_codes = { + Conversion.safe: ord("s"), + Conversion.unsafe: ord("u"), + Conversion.promote: ord("p"), + } + + def __init__(self): + self._ptr = _typeconv.new_type_manager() + self._types = set() + + def select_overload( + self, sig, overloads, allow_unsafe, exact_match_required + ): + sig = [t._code for t in sig] + overloads = [[t._code for t in s] for s in overloads] + return _typeconv.select_overload( + self._ptr, sig, overloads, allow_unsafe, exact_match_required + ) + + def check_compatible(self, fromty, toty): + if not isinstance(toty, types.Type): + raise ValueError( + "Specified type '%s' (%s) is not a Numba type" + % (toty, type(toty)) + ) + name = _typeconv.check_compatible(self._ptr, fromty._code, toty._code) + conv = Conversion[name] if name is not None else None + assert conv is not Conversion.nil + return conv + + def set_compatible(self, fromty, toty, by): + code = self._conversion_codes[by] + _typeconv.set_compatible(self._ptr, fromty._code, toty._code, code) + # Ensure the types don't die, otherwise they may be recreated with + # other type codes and pollute the hash table. + self._types.add(fromty) + self._types.add(toty) + + def set_promote(self, fromty, toty): + self.set_compatible(fromty, toty, Conversion.promote) + + def set_unsafe_convert(self, fromty, toty): + self.set_compatible(fromty, toty, Conversion.unsafe) + + def set_safe_convert(self, fromty, toty): + self.set_compatible(fromty, toty, Conversion.safe) + + def get_pointer(self): + return _typeconv.get_pointer(self._ptr) + + +class TypeCastingRules(object): + """ + A helper for establishing type casting rules. + """ + + def __init__(self, tm): + self._tm = tm + self._tg = castgraph.TypeGraph(self._cb_update) + + def promote(self, a, b): + """ + Set `a` can promote to `b` + """ + self._tg.promote(a, b) + + def unsafe(self, a, b): + """ + Set `a` can unsafe convert to `b` + """ + self._tg.unsafe(a, b) + + def safe(self, a, b): + """ + Set `a` can safe convert to `b` + """ + self._tg.safe(a, b) + + def promote_unsafe(self, a, b): + """ + Set `a` can promote to `b` and `b` can unsafe convert to `a` + """ + self.promote(a, b) + self.unsafe(b, a) + + def safe_unsafe(self, a, b): + """ + Set `a` can safe convert to `b` and `b` can unsafe convert to `a` + """ + self._tg.safe(a, b) + self._tg.unsafe(b, a) + + def unsafe_unsafe(self, a, b): + """ + Set `a` can unsafe convert to `b` and `b` can unsafe convert to `a` + """ + self._tg.unsafe(a, b) + self._tg.unsafe(b, a) + + def _cb_update(self, a, b, rel): + """ + Callback for updating. + """ + if rel == Conversion.promote: + self._tm.set_promote(a, b) + elif rel == Conversion.safe: + self._tm.set_safe_convert(a, b) + elif rel == Conversion.unsafe: + self._tm.set_unsafe_convert(a, b) + else: + raise AssertionError(rel) diff --git a/numba_cuda/numba/cuda/types.py b/numba_cuda/numba/cuda/types.py index d1ec8c28d..7e407ac81 100644 --- a/numba_cuda/numba/cuda/types.py +++ b/numba_cuda/numba/cuda/types.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause from numba.core import types -from numba.core.typeconv import Conversion +from numba.cuda.typeconv import Conversion class Dim3(types.Type): diff --git a/numba_cuda/numba/cuda/typing/context.py b/numba_cuda/numba/cuda/typing/context.py index 386db0511..e511de2cb 100644 --- a/numba_cuda/numba/cuda/typing/context.py +++ b/numba_cuda/numba/cuda/typing/context.py @@ -10,7 +10,7 @@ import operator from numba.core import types, errors -from numba.core.typeconv import Conversion, rules +from numba.cuda.typeconv import Conversion, rules from numba.core.typing.typeof import typeof, Purpose from numba.core.typing import templates from numba.cuda import utils From 237796ea10f49141f065c1f97d5f0661dbc1890d Mon Sep 17 00:00:00 2001 From: Vijay Kandiah Date: Wed, 1 Oct 2025 22:02:51 -0700 Subject: [PATCH 2/4] Add test_typeconv to numba-cuda testing suite --- .../numba/cuda/tests/cudapy/test_typeconv.py | 333 ++++++++++++++++++ .../numba/cuda/tests/cudapy/test_typeinfer.py | 2 +- 2 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py b/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py new file mode 100644 index 000000000..16aa2b4d7 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import itertools + +from numba.core import types +from numba.cuda.typeconv.typeconv import TypeManager, TypeCastingRules +from numba.cuda.typeconv import rules +from numba.cuda.typeconv import castgraph, Conversion +import unittest + + +class CompatibilityTestMixin(unittest.TestCase): + def check_number_compatibility(self, check_compatible): + b = types.boolean + i8 = types.int8 + i16 = types.int16 + i32 = types.int32 + i64 = types.int64 + u8 = types.uint8 + u32 = types.uint32 + u64 = types.uint64 + f16 = types.float16 + f32 = types.float32 + f64 = types.float64 + c64 = types.complex64 + c128 = types.complex128 + + self.assertEqual(check_compatible(i32, i32), Conversion.exact) + + self.assertEqual(check_compatible(b, i8), Conversion.safe) + self.assertEqual(check_compatible(b, u8), Conversion.safe) + self.assertEqual(check_compatible(i8, b), Conversion.unsafe) + self.assertEqual(check_compatible(u8, b), Conversion.unsafe) + + self.assertEqual(check_compatible(i32, i64), Conversion.promote) + self.assertEqual(check_compatible(i32, u32), Conversion.unsafe) + self.assertEqual(check_compatible(u32, i32), Conversion.unsafe) + self.assertEqual(check_compatible(u32, i64), Conversion.safe) + + self.assertEqual(check_compatible(i16, f16), Conversion.unsafe) + self.assertEqual(check_compatible(i32, f32), Conversion.unsafe) + self.assertEqual(check_compatible(u32, f32), Conversion.unsafe) + self.assertEqual(check_compatible(i32, f64), Conversion.safe) + self.assertEqual(check_compatible(u32, f64), Conversion.safe) + # Note this is inconsistent with i32 -> f32... + self.assertEqual(check_compatible(i64, f64), Conversion.safe) + self.assertEqual(check_compatible(u64, f64), Conversion.safe) + + self.assertEqual(check_compatible(f32, c64), Conversion.safe) + self.assertEqual(check_compatible(f64, c128), Conversion.safe) + self.assertEqual(check_compatible(f64, c64), Conversion.unsafe) + + # Propagated compatibility relationships + self.assertEqual(check_compatible(i16, f64), Conversion.safe) + self.assertEqual(check_compatible(i16, i64), Conversion.promote) + self.assertEqual(check_compatible(i32, c64), Conversion.unsafe) + self.assertEqual(check_compatible(i32, c128), Conversion.safe) + self.assertEqual(check_compatible(i32, u64), Conversion.unsafe) + + for ta, tb in itertools.product( + types.number_domain, types.number_domain + ): + if ta in types.complex_domain and tb not in types.complex_domain: + continue + self.assertTrue( + check_compatible(ta, tb) is not None, + msg="No cast from %s to %s" % (ta, tb), + ) + + +class TestTypeConv(CompatibilityTestMixin, unittest.TestCase): + def test_typeconv(self): + tm = TypeManager() + + i32 = types.int32 + i64 = types.int64 + f32 = types.float32 + + tm.set_promote(i32, i64) + tm.set_unsafe_convert(i32, f32) + + sig = (i32, f32) + ovs = [ + (i32, i32), + (f32, f32), + (i64, i64), + ] + + # allow_unsafe = True => a conversion from i32 to f32 is chosen + sel = tm.select_overload(sig, ovs, True, False) + self.assertEqual(sel, 1) + # allow_unsafe = False => no overload available + with self.assertRaises(TypeError): + sel = tm.select_overload(sig, ovs, False, False) + + def test_default_rules(self): + tm = rules.default_type_manager + self.check_number_compatibility(tm.check_compatible) + + def test_overload1(self): + tm = rules.default_type_manager + + i32 = types.int32 + i64 = types.int64 + + sig = (i64, i32, i32) + ovs = [ + (i32, i32, i32), + (i64, i64, i64), + ] + # The first overload is unsafe, the second is safe => the second + # is always chosen, regardless of allow_unsafe. + self.assertEqual(tm.select_overload(sig, ovs, True, False), 1) + self.assertEqual(tm.select_overload(sig, ovs, False, False), 1) + + def test_overload2(self): + tm = rules.default_type_manager + + i16 = types.int16 + i32 = types.int32 + i64 = types.int64 + + sig = (i32, i16, i32) + ovs = [ + # Three promotes + (i64, i64, i64), + # One promotes, two exact types + (i32, i32, i32), + # Two unsafe converts, one exact type + (i16, i16, i16), + ] + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=False, exact_match_required=False + ), + 1, + ) + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 1, + ) + + # The same in reverse order + ovs.reverse() + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=False, exact_match_required=False + ), + 1, + ) + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 1, + ) + + def test_overload3(self): + # Promotes should be preferred over safe converts + tm = rules.default_type_manager + + i32 = types.int32 + i64 = types.int64 + f64 = types.float64 + + sig = (i32, i32) + ovs = [ + # Two promotes + (i64, i64), + # Two safe converts + (f64, f64), + ] + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=False, exact_match_required=False + ), + 0, + ) + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 0, + ) + + # The same in reverse order + ovs.reverse() + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=False, exact_match_required=False + ), + 1, + ) + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 1, + ) + + def test_overload4(self): + tm = rules.default_type_manager + + i16 = types.int16 + i32 = types.int32 + f16 = types.float16 + f32 = types.float32 + + sig = (i16, f16, f16) + ovs = [ + # One unsafe, one promote, one exact + (f16, f32, f16), + # Two unsafe, one exact types + (f32, i32, f16), + ] + + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 0, + ) + + def test_type_casting_rules(self): + tm = TypeManager() + tcr = TypeCastingRules(tm) + + i16 = types.int16 + i32 = types.int32 + i64 = types.int64 + f64 = types.float64 + f32 = types.float32 + f16 = types.float16 + made_up = types.Dummy("made_up") + + tcr.promote_unsafe(i32, i64) + tcr.safe_unsafe(i32, f64) + tcr.promote_unsafe(f32, f64) + tcr.promote_unsafe(f16, f32) + tcr.unsafe_unsafe(i16, f16) + + def base_test(): + # As declared + self.assertEqual(tm.check_compatible(i32, i64), Conversion.promote) + self.assertEqual(tm.check_compatible(i32, f64), Conversion.safe) + self.assertEqual(tm.check_compatible(f16, f32), Conversion.promote) + self.assertEqual(tm.check_compatible(f32, f64), Conversion.promote) + self.assertEqual(tm.check_compatible(i64, i32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f64, i32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f64, f32), Conversion.unsafe) + + # Propagated + self.assertEqual(tm.check_compatible(i64, f64), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f64, i64), Conversion.unsafe) + self.assertEqual(tm.check_compatible(i64, f32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(i32, f32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f32, i32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(i16, f16), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f16, i16), Conversion.unsafe) + + # Test base graph + base_test() + + self.assertIsNone(tm.check_compatible(i64, made_up)) + self.assertIsNone(tm.check_compatible(i32, made_up)) + self.assertIsNone(tm.check_compatible(f32, made_up)) + self.assertIsNone(tm.check_compatible(made_up, f64)) + self.assertIsNone(tm.check_compatible(made_up, i64)) + + # Add new test + tcr.promote(f64, made_up) + tcr.unsafe(made_up, i32) + + # Ensure the graph did not change by adding the new type + base_test() + + # To "made up" type + self.assertEqual(tm.check_compatible(i64, made_up), Conversion.unsafe) + self.assertEqual(tm.check_compatible(i32, made_up), Conversion.safe) + self.assertEqual(tm.check_compatible(f32, made_up), Conversion.promote) + self.assertEqual(tm.check_compatible(made_up, f64), Conversion.unsafe) + self.assertEqual(tm.check_compatible(made_up, i64), Conversion.unsafe) + + def test_castgraph_propagate(self): + saved = [] + + def callback(src, dst, rel): + saved.append((src, dst, rel)) + + tg = castgraph.TypeGraph(callback) + + i32 = types.int32 + i64 = types.int64 + f64 = types.float64 + f32 = types.float32 + + tg.insert_rule(i32, i64, Conversion.promote) + tg.insert_rule(i64, i32, Conversion.unsafe) + + saved.append(None) + + tg.insert_rule(i32, f64, Conversion.safe) + tg.insert_rule(f64, i32, Conversion.unsafe) + + saved.append(None) + + tg.insert_rule(f32, f64, Conversion.promote) + tg.insert_rule(f64, f32, Conversion.unsafe) + + self.assertIn((i32, i64, Conversion.promote), saved[0:2]) + self.assertIn((i64, i32, Conversion.unsafe), saved[0:2]) + self.assertIs(saved[2], None) + + self.assertIn((i32, f64, Conversion.safe), saved[3:7]) + self.assertIn((f64, i32, Conversion.unsafe), saved[3:7]) + self.assertIn((i64, f64, Conversion.unsafe), saved[3:7]) + self.assertIn((i64, f64, Conversion.unsafe), saved[3:7]) + self.assertIs(saved[7], None) + + self.assertIn((f32, f64, Conversion.promote), saved[8:14]) + self.assertIn((f64, f32, Conversion.unsafe), saved[8:14]) + self.assertIn((f32, i32, Conversion.unsafe), saved[8:14]) + self.assertIn((i32, f32, Conversion.unsafe), saved[8:14]) + self.assertIn((f32, i64, Conversion.unsafe), saved[8:14]) + self.assertIn((i64, f32, Conversion.unsafe), saved[8:14]) + self.assertEqual(len(saved[14:]), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py b/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py index 0c8ae598f..d8190a06d 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py @@ -7,7 +7,7 @@ from numba.cuda.typeconv import Conversion from numba.cuda.testing import CUDATestCase, skip_on_cudasim -from numba.tests.test_typeconv import CompatibilityTestMixin +from numba.cuda.tests.cudapy.test_typeconv import CompatibilityTestMixin from numba.cuda.core.untyped_passes import TranslateByteCode, IRProcessing from numba.cuda.core.typed_passes import PartialTypeInference from numba.cuda.core.compiler_machinery import FunctionPass, register_pass From e51ef22ef2f8197b6f3f7b6e73c552cfa77364d0 Mon Sep 17 00:00:00 2001 From: Vijay Kandiah Date: Fri, 3 Oct 2025 10:18:33 -0700 Subject: [PATCH 3/4] Vendor in _typeconv cext for CUDA-specific changes --- numba_cuda/numba/cuda/cext/__init__.py | 4 +- numba_cuda/numba/cuda/cext/_typeconv.cpp | 206 +++++++++++++++++++++ numba_cuda/numba/cuda/cext/capsulethunk.h | 111 +++++++++++ numba_cuda/numba/cuda/typeconv/typeconv.py | 2 +- setup.py | 12 +- 5 files changed, 331 insertions(+), 4 deletions(-) create mode 100644 numba_cuda/numba/cuda/cext/_typeconv.cpp create mode 100644 numba_cuda/numba/cuda/cext/capsulethunk.h diff --git a/numba_cuda/numba/cuda/cext/__init__.py b/numba_cuda/numba/cuda/cext/__init__.py index c54155bae..3ad16af1b 100644 --- a/numba_cuda/numba/cuda/cext/__init__.py +++ b/numba_cuda/numba/cuda/cext/__init__.py @@ -90,5 +90,5 @@ def _load_cext_module( _devicearray = _load_cext_module("_devicearray", required=True) _dispatcher = _load_cext_module("_dispatcher", required=True) mviewbuf = _load_cext_module("mviewbuf", required=True) - -__all__ = ["mviewbuf", "_dispatcher", "_devicearray"] +_typeconv = _load_cext_module("_typeconv", required=True) +__all__ = ["mviewbuf", "_dispatcher", "_devicearray", "_typeconv"] diff --git a/numba_cuda/numba/cuda/cext/_typeconv.cpp b/numba_cuda/numba/cuda/cext/_typeconv.cpp new file mode 100644 index 000000000..ca414e08d --- /dev/null +++ b/numba_cuda/numba/cuda/cext/_typeconv.cpp @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-2-Clause + +#include "_pymodule.h" +#include "capsulethunk.h" +#include "typeconv.hpp" + +extern "C" { + + +static PyObject* +new_type_manager(PyObject* self, PyObject* args); + +static void +del_type_manager(PyObject *); + +static PyObject* +select_overload(PyObject* self, PyObject* args); + +static PyObject* +check_compatible(PyObject* self, PyObject* args); + +static PyObject* +set_compatible(PyObject* self, PyObject* args); + +static PyObject* +get_pointer(PyObject* self, PyObject* args); + + +static PyMethodDef ext_methods[] = { +#define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL } + declmethod(new_type_manager), + declmethod(select_overload), + declmethod(check_compatible), + declmethod(set_compatible), + declmethod(get_pointer), + { NULL }, +#undef declmethod +}; + + +MOD_INIT(_typeconv) { + PyObject *m; + MOD_DEF(m, "_typeconv", "No docs", ext_methods) + if (m == NULL) + return MOD_ERROR_VAL; + + return MOD_SUCCESS_VAL(m); +} + +} // end extern C + +/////////////////////////////////////////////////////////////////////////////// + +const char PY_CAPSULE_TM_NAME[] = "*tm"; +#define BAD_TM_ARGUMENT PyErr_SetString(PyExc_TypeError, \ + "1st argument not TypeManager") + +static +TypeManager* unwrap_TypeManager(PyObject *tm) { + void* p = PyCapsule_GetPointer(tm, PY_CAPSULE_TM_NAME); + return reinterpret_cast(p); +} + +PyObject* +new_type_manager(PyObject* self, PyObject* args) +{ + TypeManager* tm = new TypeManager(); + return PyCapsule_New(tm, PY_CAPSULE_TM_NAME, &del_type_manager); +} + +void +del_type_manager(PyObject *tm) +{ + delete unwrap_TypeManager(tm); +} + +PyObject* +select_overload(PyObject* self, PyObject* args) +{ + PyObject *tmcap, *sigtup, *ovsigstup; + int allow_unsafe; + int exact_match_required; + + if (!PyArg_ParseTuple(args, "OOOii", &tmcap, &sigtup, &ovsigstup, + &allow_unsafe, &exact_match_required)) { + return NULL; + } + + TypeManager *tm = unwrap_TypeManager(tmcap); + if (!tm) { + BAD_TM_ARGUMENT; + } + + Py_ssize_t sigsz = PySequence_Size(sigtup); + Py_ssize_t ovsz = PySequence_Size(ovsigstup); + + Type *sig = new Type[sigsz]; + Type *ovsigs = new Type[ovsz * sigsz]; + + for (int i = 0; i < sigsz; ++i) { + sig[i] = Type(PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(sigtup, + i), NULL)); + } + + for (int i = 0; i < ovsz; ++i) { + PyObject *cursig = PySequence_Fast_GET_ITEM(ovsigstup, i); + for (int j = 0; j < sigsz; ++j) { + long tid = PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(cursig, + j), NULL); + ovsigs[i * sigsz + j] = Type(tid); + } + } + + int selected = -42; + int matches = tm->selectOverload(sig, ovsigs, selected, sigsz, ovsz, + (bool) allow_unsafe, + (bool) exact_match_required); + + delete [] sig; + delete [] ovsigs; + + if (matches > 1) { + PyErr_SetString(PyExc_TypeError, "Ambiguous overloading"); + return NULL; + } else if (matches == 0) { + PyErr_SetString(PyExc_TypeError, "No compatible overload"); + return NULL; + } + + return PyLong_FromLong(selected); +} + +PyObject* +check_compatible(PyObject* self, PyObject* args) +{ + PyObject *tmcap; + int from, to; + if (!PyArg_ParseTuple(args, "Oii", &tmcap, &from, &to)) { + return NULL; + } + + TypeManager *tm = unwrap_TypeManager(tmcap); + if(!tm) { + BAD_TM_ARGUMENT; + return NULL; + } + + switch(tm->isCompatible(Type(from), Type(to))){ + case TCC_EXACT: + return PyString_FromString("exact"); + case TCC_PROMOTE: + return PyString_FromString("promote"); + case TCC_CONVERT_SAFE: + return PyString_FromString("safe"); + case TCC_CONVERT_UNSAFE: + return PyString_FromString("unsafe"); + default: + Py_RETURN_NONE; + } +} + +PyObject* +set_compatible(PyObject* self, PyObject* args) +{ + PyObject *tmcap; + int from, to, by; + if (!PyArg_ParseTuple(args, "Oiii", &tmcap, &from, &to, &by)) { + return NULL; + } + + TypeManager *tm = unwrap_TypeManager(tmcap); + if (!tm) { + BAD_TM_ARGUMENT; + return NULL; + } + TypeCompatibleCode tcc; + switch (by) { + case 'p': // promote + tcc = TCC_PROMOTE; + break; + case 's': // safe convert + tcc = TCC_CONVERT_SAFE; + break; + case 'u': // unsafe convert + tcc = TCC_CONVERT_UNSAFE; + break; + default: + PyErr_SetString(PyExc_ValueError, "Unknown TCC"); + return NULL; + } + + tm->addCompatibility(Type(from), Type(to), tcc); + Py_RETURN_NONE; +} + + +PyObject* +get_pointer(PyObject* self, PyObject* args) +{ + PyObject *tmcap; + if (!PyArg_ParseTuple(args, "O", &tmcap)) { + return NULL; + } + return PyLong_FromVoidPtr(unwrap_TypeManager(tmcap)); +} diff --git a/numba_cuda/numba/cuda/cext/capsulethunk.h b/numba_cuda/numba/cuda/cext/capsulethunk.h new file mode 100644 index 000000000..bbc125d48 --- /dev/null +++ b/numba_cuda/numba/cuda/cext/capsulethunk.h @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-2-Clause + +/** + + This is a modified version of capsulethunk.h for use in llvmpy + +**/ + +#ifndef __CAPSULETHUNK_H +#define __CAPSULETHUNK_H + +#if ( (PY_VERSION_HEX < 0x02070000) \ + || ((PY_VERSION_HEX >= 0x03000000) \ + && (PY_VERSION_HEX < 0x03010000)) ) + +//#define Assert(X) do_assert(!!(X), #X, __FILE__, __LINE__) +#define Assert(X) + +static +void do_assert(int cond, const char * msg, const char *file, unsigned line){ + if (!cond) { + fprintf(stderr, "Assertion failed %s:%d\n%s\n", file, line, msg); + exit(1); + } +} + +typedef void (*PyCapsule_Destructor)(PyObject *); + +struct FakePyCapsule_Desc { + const char *name; + void *context; + PyCapsule_Destructor dtor; + PyObject *parent; + + FakePyCapsule_Desc() : name(0), context(0), dtor(0) {} +}; + +static +FakePyCapsule_Desc* get_pycobj_desc(PyObject *p){ + void *desc = ((PyCObject*)p)->desc; + Assert(desc && "No desc in PyCObject"); + return static_cast(desc); +} + +static +void pycobject_pycapsule_dtor(void *p, void *desc){ + Assert(desc); + Assert(p); + FakePyCapsule_Desc *fpc_desc = static_cast(desc); + Assert(fpc_desc->parent); + Assert(PyCObject_Check(fpc_desc->parent)); + fpc_desc->dtor(static_cast(fpc_desc->parent)); + delete fpc_desc; +} + +static +PyObject* PyCapsule_New(void* ptr, const char *name, PyCapsule_Destructor dtor) +{ + FakePyCapsule_Desc *desc = new FakePyCapsule_Desc; + desc->name = name; + desc->context = NULL; + desc->dtor = dtor; + PyObject *p = PyCObject_FromVoidPtrAndDesc(ptr, desc, + pycobject_pycapsule_dtor); + desc->parent = p; + return p; +} + +static +int PyCapsule_CheckExact(PyObject *p) +{ + return PyCObject_Check(p); +} + +static +void* PyCapsule_GetPointer(PyObject *p, const char *name) +{ + Assert(PyCapsule_CheckExact(p)); + if (strcmp(get_pycobj_desc(p)->name, name) != 0) { + PyErr_SetString(PyExc_ValueError, "Invalid PyCapsule object"); + } + return PyCObject_AsVoidPtr(p); +} + +static +void* PyCapsule_GetContext(PyObject *p) +{ + Assert(p); + Assert(PyCapsule_CheckExact(p)); + return get_pycobj_desc(p)->context; +} + +static +int PyCapsule_SetContext(PyObject *p, void *context) +{ + Assert(PyCapsule_CheckExact(p)); + get_pycobj_desc(p)->context = context; + return 0; +} + +static +const char * PyCapsule_GetName(PyObject *p) +{ +// Assert(PyCapsule_CheckExact(p)); + return get_pycobj_desc(p)->name; +} + +#endif /* #if PY_VERSION_HEX < 0x02070000 */ + +#endif /* __CAPSULETHUNK_H */ diff --git a/numba_cuda/numba/cuda/typeconv/typeconv.py b/numba_cuda/numba/cuda/typeconv/typeconv.py index 117afe253..c414a01d5 100644 --- a/numba_cuda/numba/cuda/typeconv/typeconv.py +++ b/numba_cuda/numba/cuda/typeconv/typeconv.py @@ -4,7 +4,7 @@ try: # This is usually the the first C extension import performed when importing # Numba, if it fails to import, provide some feedback - from numba.core.typeconv import _typeconv + from numba.cuda.cext import _typeconv except ImportError as e: base_url = "https://numba.readthedocs.io/en/stable" dev_url = f"{base_url}/developer/contributing.html" diff --git a/setup.py b/setup.py index cb5d4b4db..f87cb6b79 100644 --- a/setup.py +++ b/setup.py @@ -78,10 +78,20 @@ def get_ext_modules(): **np_compile_args, ) + ext_typeconv = Extension( + name="numba_cuda.numba.cuda.cext._typeconv", + sources=[ + "numba_cuda/numba/cuda/cext/typeconv.cpp", + "numba_cuda/numba/cuda/cext/_typeconv.cpp", + ], + depends=["numba_cuda/numba/cuda/cext/_pymodule.h"], + extra_compile_args=["-std=c++11"], + ) + # Append our cext dir to include_dirs ext_dispatcher.include_dirs.append("numba_cuda/numba/cuda/cext") - return [ext_dispatcher, ext_mviewbuf, ext_devicearray] + return [ext_dispatcher, ext_typeconv, ext_mviewbuf, ext_devicearray] def is_building(): From 0dffe8c6fdf5f4097a76bf3815f35649bf2ce687 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Fri, 10 Oct 2025 14:23:15 +0100 Subject: [PATCH 4/4] Remove explanation of ImportError from typeconv The error message doesn't suit Numba-CUDA, is unlikely to be triggered in common scenarios, and even if the import would have failed due to the extensions not being built, this may not be the first imported module in numba-cuda that fails. --- numba_cuda/numba/cuda/typeconv/typeconv.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/numba_cuda/numba/cuda/typeconv/typeconv.py b/numba_cuda/numba/cuda/typeconv/typeconv.py index c414a01d5..a87ce5043 100644 --- a/numba_cuda/numba/cuda/typeconv/typeconv.py +++ b/numba_cuda/numba/cuda/typeconv/typeconv.py @@ -1,25 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -try: - # This is usually the the first C extension import performed when importing - # Numba, if it fails to import, provide some feedback - from numba.cuda.cext import _typeconv -except ImportError as e: - base_url = "https://numba.readthedocs.io/en/stable" - dev_url = f"{base_url}/developer/contributing.html" - user_url = f"{base_url}/user/faq.html#numba-could-not-be-imported" - dashes = "-" * 80 - msg = ( - f"Numba could not be imported.\n{dashes}\nIf you are seeing this " - "message and are undertaking Numba development work, you may need " - "to rebuild Numba.\nPlease see the development set up guide:\n\n" - f"{dev_url}.\n\n{dashes}\nIf you are not working on Numba " - f"development, the original error was: '{str(e)}'.\nFor help, " - f"please visit:\n\n{user_url}\n" - ) - raise ImportError(msg) - +from numba.cuda.cext import _typeconv from numba.cuda.typeconv import castgraph, Conversion from numba.core import types