diff --git a/numba_cuda/numba/cuda/_internal/cuda_fp16.py b/numba_cuda/numba/cuda/_internal/cuda_fp16.py index 6bf2b0159..96456d9ba 100644 --- a/numba_cuda/numba/cuda/_internal/cuda_fp16.py +++ b/numba_cuda/numba/cuda/_internal/cuda_fp16.py @@ -124,7 +124,7 @@ def __init__(self): self.bitwidth = 2 * 8 def can_convert_from(self, typingctx, other): - from numba.core.typeconv import Conversion + from numba.cuda.typeconv import Conversion if other in []: return Conversion.safe @@ -174,7 +174,7 @@ def __init__(self): self.bitwidth = 4 * 8 def can_convert_from(self, typingctx, other): - from numba.core.typeconv import Conversion + from numba.cuda.typeconv import Conversion if other in []: return Conversion.safe @@ -7903,9 +7903,9 @@ def generic(self, args, kws): # - Conversion.safe if ( - (convertible == numba.core.typeconv.Conversion.exact) - or (convertible == numba.core.typeconv.Conversion.promote) - or (convertible == numba.core.typeconv.Conversion.safe) + (convertible == numba.cuda.typeconv.Conversion.exact) + or (convertible == numba.cuda.typeconv.Conversion.promote) + or (convertible == numba.cuda.typeconv.Conversion.safe) ): return signature(retty, types.float16, types.float16) diff --git a/numba_cuda/numba/cuda/cext/_typeconv.cpp b/numba_cuda/numba/cuda/cext/_typeconv.cpp new file mode 100644 index 000000000..ca414e08d --- /dev/null +++ b/numba_cuda/numba/cuda/cext/_typeconv.cpp @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-2-Clause + +#include "_pymodule.h" +#include "capsulethunk.h" +#include "typeconv.hpp" + +extern "C" { + + +static PyObject* +new_type_manager(PyObject* self, PyObject* args); + +static void +del_type_manager(PyObject *); + +static PyObject* +select_overload(PyObject* self, PyObject* args); + +static PyObject* +check_compatible(PyObject* self, PyObject* args); + +static PyObject* +set_compatible(PyObject* self, PyObject* args); + +static PyObject* +get_pointer(PyObject* self, PyObject* args); + + +static PyMethodDef ext_methods[] = { +#define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL } + declmethod(new_type_manager), + declmethod(select_overload), + declmethod(check_compatible), + declmethod(set_compatible), + declmethod(get_pointer), + { NULL }, +#undef declmethod +}; + + +MOD_INIT(_typeconv) { + PyObject *m; + MOD_DEF(m, "_typeconv", "No docs", ext_methods) + if (m == NULL) + return MOD_ERROR_VAL; + + return MOD_SUCCESS_VAL(m); +} + +} // end extern C + +/////////////////////////////////////////////////////////////////////////////// + +const char PY_CAPSULE_TM_NAME[] = "*tm"; +#define BAD_TM_ARGUMENT PyErr_SetString(PyExc_TypeError, \ + "1st argument not TypeManager") + +static +TypeManager* unwrap_TypeManager(PyObject *tm) { + void* p = PyCapsule_GetPointer(tm, PY_CAPSULE_TM_NAME); + return reinterpret_cast(p); +} + +PyObject* +new_type_manager(PyObject* self, PyObject* args) +{ + TypeManager* tm = new TypeManager(); + return PyCapsule_New(tm, PY_CAPSULE_TM_NAME, &del_type_manager); +} + +void +del_type_manager(PyObject *tm) +{ + delete unwrap_TypeManager(tm); +} + +PyObject* +select_overload(PyObject* self, PyObject* args) +{ + PyObject *tmcap, *sigtup, *ovsigstup; + int allow_unsafe; + int exact_match_required; + + if (!PyArg_ParseTuple(args, "OOOii", &tmcap, &sigtup, &ovsigstup, + &allow_unsafe, &exact_match_required)) { + return NULL; + } + + TypeManager *tm = unwrap_TypeManager(tmcap); + if (!tm) { + BAD_TM_ARGUMENT; + } + + Py_ssize_t sigsz = PySequence_Size(sigtup); + Py_ssize_t ovsz = PySequence_Size(ovsigstup); + + Type *sig = new Type[sigsz]; + Type *ovsigs = new Type[ovsz * sigsz]; + + for (int i = 0; i < sigsz; ++i) { + sig[i] = Type(PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(sigtup, + i), NULL)); + } + + for (int i = 0; i < ovsz; ++i) { + PyObject *cursig = PySequence_Fast_GET_ITEM(ovsigstup, i); + for (int j = 0; j < sigsz; ++j) { + long tid = PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(cursig, + j), NULL); + ovsigs[i * sigsz + j] = Type(tid); + } + } + + int selected = -42; + int matches = tm->selectOverload(sig, ovsigs, selected, sigsz, ovsz, + (bool) allow_unsafe, + (bool) exact_match_required); + + delete [] sig; + delete [] ovsigs; + + if (matches > 1) { + PyErr_SetString(PyExc_TypeError, "Ambiguous overloading"); + return NULL; + } else if (matches == 0) { + PyErr_SetString(PyExc_TypeError, "No compatible overload"); + return NULL; + } + + return PyLong_FromLong(selected); +} + +PyObject* +check_compatible(PyObject* self, PyObject* args) +{ + PyObject *tmcap; + int from, to; + if (!PyArg_ParseTuple(args, "Oii", &tmcap, &from, &to)) { + return NULL; + } + + TypeManager *tm = unwrap_TypeManager(tmcap); + if(!tm) { + BAD_TM_ARGUMENT; + return NULL; + } + + switch(tm->isCompatible(Type(from), Type(to))){ + case TCC_EXACT: + return PyString_FromString("exact"); + case TCC_PROMOTE: + return PyString_FromString("promote"); + case TCC_CONVERT_SAFE: + return PyString_FromString("safe"); + case TCC_CONVERT_UNSAFE: + return PyString_FromString("unsafe"); + default: + Py_RETURN_NONE; + } +} + +PyObject* +set_compatible(PyObject* self, PyObject* args) +{ + PyObject *tmcap; + int from, to, by; + if (!PyArg_ParseTuple(args, "Oiii", &tmcap, &from, &to, &by)) { + return NULL; + } + + TypeManager *tm = unwrap_TypeManager(tmcap); + if (!tm) { + BAD_TM_ARGUMENT; + return NULL; + } + TypeCompatibleCode tcc; + switch (by) { + case 'p': // promote + tcc = TCC_PROMOTE; + break; + case 's': // safe convert + tcc = TCC_CONVERT_SAFE; + break; + case 'u': // unsafe convert + tcc = TCC_CONVERT_UNSAFE; + break; + default: + PyErr_SetString(PyExc_ValueError, "Unknown TCC"); + return NULL; + } + + tm->addCompatibility(Type(from), Type(to), tcc); + Py_RETURN_NONE; +} + + +PyObject* +get_pointer(PyObject* self, PyObject* args) +{ + PyObject *tmcap; + if (!PyArg_ParseTuple(args, "O", &tmcap)) { + return NULL; + } + return PyLong_FromVoidPtr(unwrap_TypeManager(tmcap)); +} diff --git a/numba_cuda/numba/cuda/cext/capsulethunk.h b/numba_cuda/numba/cuda/cext/capsulethunk.h new file mode 100644 index 000000000..bbc125d48 --- /dev/null +++ b/numba_cuda/numba/cuda/cext/capsulethunk.h @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-2-Clause + +/** + + This is a modified version of capsulethunk.h for use in llvmpy + +**/ + +#ifndef __CAPSULETHUNK_H +#define __CAPSULETHUNK_H + +#if ( (PY_VERSION_HEX < 0x02070000) \ + || ((PY_VERSION_HEX >= 0x03000000) \ + && (PY_VERSION_HEX < 0x03010000)) ) + +//#define Assert(X) do_assert(!!(X), #X, __FILE__, __LINE__) +#define Assert(X) + +static +void do_assert(int cond, const char * msg, const char *file, unsigned line){ + if (!cond) { + fprintf(stderr, "Assertion failed %s:%d\n%s\n", file, line, msg); + exit(1); + } +} + +typedef void (*PyCapsule_Destructor)(PyObject *); + +struct FakePyCapsule_Desc { + const char *name; + void *context; + PyCapsule_Destructor dtor; + PyObject *parent; + + FakePyCapsule_Desc() : name(0), context(0), dtor(0) {} +}; + +static +FakePyCapsule_Desc* get_pycobj_desc(PyObject *p){ + void *desc = ((PyCObject*)p)->desc; + Assert(desc && "No desc in PyCObject"); + return static_cast(desc); +} + +static +void pycobject_pycapsule_dtor(void *p, void *desc){ + Assert(desc); + Assert(p); + FakePyCapsule_Desc *fpc_desc = static_cast(desc); + Assert(fpc_desc->parent); + Assert(PyCObject_Check(fpc_desc->parent)); + fpc_desc->dtor(static_cast(fpc_desc->parent)); + delete fpc_desc; +} + +static +PyObject* PyCapsule_New(void* ptr, const char *name, PyCapsule_Destructor dtor) +{ + FakePyCapsule_Desc *desc = new FakePyCapsule_Desc; + desc->name = name; + desc->context = NULL; + desc->dtor = dtor; + PyObject *p = PyCObject_FromVoidPtrAndDesc(ptr, desc, + pycobject_pycapsule_dtor); + desc->parent = p; + return p; +} + +static +int PyCapsule_CheckExact(PyObject *p) +{ + return PyCObject_Check(p); +} + +static +void* PyCapsule_GetPointer(PyObject *p, const char *name) +{ + Assert(PyCapsule_CheckExact(p)); + if (strcmp(get_pycobj_desc(p)->name, name) != 0) { + PyErr_SetString(PyExc_ValueError, "Invalid PyCapsule object"); + } + return PyCObject_AsVoidPtr(p); +} + +static +void* PyCapsule_GetContext(PyObject *p) +{ + Assert(p); + Assert(PyCapsule_CheckExact(p)); + return get_pycobj_desc(p)->context; +} + +static +int PyCapsule_SetContext(PyObject *p, void *context) +{ + Assert(PyCapsule_CheckExact(p)); + get_pycobj_desc(p)->context = context; + return 0; +} + +static +const char * PyCapsule_GetName(PyObject *p) +{ +// Assert(PyCapsule_CheckExact(p)); + return get_pycobj_desc(p)->name; +} + +#endif /* #if PY_VERSION_HEX < 0x02070000 */ + +#endif /* __CAPSULETHUNK_H */ diff --git a/numba_cuda/numba/cuda/core/typeinfer.py b/numba_cuda/numba/cuda/core/typeinfer.py index 38aea47e8..2103d602c 100644 --- a/numba_cuda/numba/cuda/core/typeinfer.py +++ b/numba_cuda/numba/cuda/core/typeinfer.py @@ -48,7 +48,7 @@ NumbaValueError, ) from numba.cuda.core.funcdesc import qualifying_prefix -from numba.core.typeconv import Conversion +from numba.cuda.typeconv import Conversion _logger = logging.getLogger(__name__) diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index fea7313f1..c2db1a442 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -88,7 +88,7 @@ def resolve_value_type(self, val): def can_convert(self, fromty, toty): """ Check whether conversion is possible from *fromty* to *toty*. - If successful, return a numba.typeconv.Conversion instance; + If successful, return a numba.cuda.typeconv.Conversion instance; otherwise None is returned. """ diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py b/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py new file mode 100644 index 000000000..16aa2b4d7 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import itertools + +from numba.core import types +from numba.cuda.typeconv.typeconv import TypeManager, TypeCastingRules +from numba.cuda.typeconv import rules +from numba.cuda.typeconv import castgraph, Conversion +import unittest + + +class CompatibilityTestMixin(unittest.TestCase): + def check_number_compatibility(self, check_compatible): + b = types.boolean + i8 = types.int8 + i16 = types.int16 + i32 = types.int32 + i64 = types.int64 + u8 = types.uint8 + u32 = types.uint32 + u64 = types.uint64 + f16 = types.float16 + f32 = types.float32 + f64 = types.float64 + c64 = types.complex64 + c128 = types.complex128 + + self.assertEqual(check_compatible(i32, i32), Conversion.exact) + + self.assertEqual(check_compatible(b, i8), Conversion.safe) + self.assertEqual(check_compatible(b, u8), Conversion.safe) + self.assertEqual(check_compatible(i8, b), Conversion.unsafe) + self.assertEqual(check_compatible(u8, b), Conversion.unsafe) + + self.assertEqual(check_compatible(i32, i64), Conversion.promote) + self.assertEqual(check_compatible(i32, u32), Conversion.unsafe) + self.assertEqual(check_compatible(u32, i32), Conversion.unsafe) + self.assertEqual(check_compatible(u32, i64), Conversion.safe) + + self.assertEqual(check_compatible(i16, f16), Conversion.unsafe) + self.assertEqual(check_compatible(i32, f32), Conversion.unsafe) + self.assertEqual(check_compatible(u32, f32), Conversion.unsafe) + self.assertEqual(check_compatible(i32, f64), Conversion.safe) + self.assertEqual(check_compatible(u32, f64), Conversion.safe) + # Note this is inconsistent with i32 -> f32... + self.assertEqual(check_compatible(i64, f64), Conversion.safe) + self.assertEqual(check_compatible(u64, f64), Conversion.safe) + + self.assertEqual(check_compatible(f32, c64), Conversion.safe) + self.assertEqual(check_compatible(f64, c128), Conversion.safe) + self.assertEqual(check_compatible(f64, c64), Conversion.unsafe) + + # Propagated compatibility relationships + self.assertEqual(check_compatible(i16, f64), Conversion.safe) + self.assertEqual(check_compatible(i16, i64), Conversion.promote) + self.assertEqual(check_compatible(i32, c64), Conversion.unsafe) + self.assertEqual(check_compatible(i32, c128), Conversion.safe) + self.assertEqual(check_compatible(i32, u64), Conversion.unsafe) + + for ta, tb in itertools.product( + types.number_domain, types.number_domain + ): + if ta in types.complex_domain and tb not in types.complex_domain: + continue + self.assertTrue( + check_compatible(ta, tb) is not None, + msg="No cast from %s to %s" % (ta, tb), + ) + + +class TestTypeConv(CompatibilityTestMixin, unittest.TestCase): + def test_typeconv(self): + tm = TypeManager() + + i32 = types.int32 + i64 = types.int64 + f32 = types.float32 + + tm.set_promote(i32, i64) + tm.set_unsafe_convert(i32, f32) + + sig = (i32, f32) + ovs = [ + (i32, i32), + (f32, f32), + (i64, i64), + ] + + # allow_unsafe = True => a conversion from i32 to f32 is chosen + sel = tm.select_overload(sig, ovs, True, False) + self.assertEqual(sel, 1) + # allow_unsafe = False => no overload available + with self.assertRaises(TypeError): + sel = tm.select_overload(sig, ovs, False, False) + + def test_default_rules(self): + tm = rules.default_type_manager + self.check_number_compatibility(tm.check_compatible) + + def test_overload1(self): + tm = rules.default_type_manager + + i32 = types.int32 + i64 = types.int64 + + sig = (i64, i32, i32) + ovs = [ + (i32, i32, i32), + (i64, i64, i64), + ] + # The first overload is unsafe, the second is safe => the second + # is always chosen, regardless of allow_unsafe. + self.assertEqual(tm.select_overload(sig, ovs, True, False), 1) + self.assertEqual(tm.select_overload(sig, ovs, False, False), 1) + + def test_overload2(self): + tm = rules.default_type_manager + + i16 = types.int16 + i32 = types.int32 + i64 = types.int64 + + sig = (i32, i16, i32) + ovs = [ + # Three promotes + (i64, i64, i64), + # One promotes, two exact types + (i32, i32, i32), + # Two unsafe converts, one exact type + (i16, i16, i16), + ] + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=False, exact_match_required=False + ), + 1, + ) + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 1, + ) + + # The same in reverse order + ovs.reverse() + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=False, exact_match_required=False + ), + 1, + ) + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 1, + ) + + def test_overload3(self): + # Promotes should be preferred over safe converts + tm = rules.default_type_manager + + i32 = types.int32 + i64 = types.int64 + f64 = types.float64 + + sig = (i32, i32) + ovs = [ + # Two promotes + (i64, i64), + # Two safe converts + (f64, f64), + ] + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=False, exact_match_required=False + ), + 0, + ) + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 0, + ) + + # The same in reverse order + ovs.reverse() + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=False, exact_match_required=False + ), + 1, + ) + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 1, + ) + + def test_overload4(self): + tm = rules.default_type_manager + + i16 = types.int16 + i32 = types.int32 + f16 = types.float16 + f32 = types.float32 + + sig = (i16, f16, f16) + ovs = [ + # One unsafe, one promote, one exact + (f16, f32, f16), + # Two unsafe, one exact types + (f32, i32, f16), + ] + + self.assertEqual( + tm.select_overload( + sig, ovs, allow_unsafe=True, exact_match_required=False + ), + 0, + ) + + def test_type_casting_rules(self): + tm = TypeManager() + tcr = TypeCastingRules(tm) + + i16 = types.int16 + i32 = types.int32 + i64 = types.int64 + f64 = types.float64 + f32 = types.float32 + f16 = types.float16 + made_up = types.Dummy("made_up") + + tcr.promote_unsafe(i32, i64) + tcr.safe_unsafe(i32, f64) + tcr.promote_unsafe(f32, f64) + tcr.promote_unsafe(f16, f32) + tcr.unsafe_unsafe(i16, f16) + + def base_test(): + # As declared + self.assertEqual(tm.check_compatible(i32, i64), Conversion.promote) + self.assertEqual(tm.check_compatible(i32, f64), Conversion.safe) + self.assertEqual(tm.check_compatible(f16, f32), Conversion.promote) + self.assertEqual(tm.check_compatible(f32, f64), Conversion.promote) + self.assertEqual(tm.check_compatible(i64, i32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f64, i32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f64, f32), Conversion.unsafe) + + # Propagated + self.assertEqual(tm.check_compatible(i64, f64), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f64, i64), Conversion.unsafe) + self.assertEqual(tm.check_compatible(i64, f32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(i32, f32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f32, i32), Conversion.unsafe) + self.assertEqual(tm.check_compatible(i16, f16), Conversion.unsafe) + self.assertEqual(tm.check_compatible(f16, i16), Conversion.unsafe) + + # Test base graph + base_test() + + self.assertIsNone(tm.check_compatible(i64, made_up)) + self.assertIsNone(tm.check_compatible(i32, made_up)) + self.assertIsNone(tm.check_compatible(f32, made_up)) + self.assertIsNone(tm.check_compatible(made_up, f64)) + self.assertIsNone(tm.check_compatible(made_up, i64)) + + # Add new test + tcr.promote(f64, made_up) + tcr.unsafe(made_up, i32) + + # Ensure the graph did not change by adding the new type + base_test() + + # To "made up" type + self.assertEqual(tm.check_compatible(i64, made_up), Conversion.unsafe) + self.assertEqual(tm.check_compatible(i32, made_up), Conversion.safe) + self.assertEqual(tm.check_compatible(f32, made_up), Conversion.promote) + self.assertEqual(tm.check_compatible(made_up, f64), Conversion.unsafe) + self.assertEqual(tm.check_compatible(made_up, i64), Conversion.unsafe) + + def test_castgraph_propagate(self): + saved = [] + + def callback(src, dst, rel): + saved.append((src, dst, rel)) + + tg = castgraph.TypeGraph(callback) + + i32 = types.int32 + i64 = types.int64 + f64 = types.float64 + f32 = types.float32 + + tg.insert_rule(i32, i64, Conversion.promote) + tg.insert_rule(i64, i32, Conversion.unsafe) + + saved.append(None) + + tg.insert_rule(i32, f64, Conversion.safe) + tg.insert_rule(f64, i32, Conversion.unsafe) + + saved.append(None) + + tg.insert_rule(f32, f64, Conversion.promote) + tg.insert_rule(f64, f32, Conversion.unsafe) + + self.assertIn((i32, i64, Conversion.promote), saved[0:2]) + self.assertIn((i64, i32, Conversion.unsafe), saved[0:2]) + self.assertIs(saved[2], None) + + self.assertIn((i32, f64, Conversion.safe), saved[3:7]) + self.assertIn((f64, i32, Conversion.unsafe), saved[3:7]) + self.assertIn((i64, f64, Conversion.unsafe), saved[3:7]) + self.assertIn((i64, f64, Conversion.unsafe), saved[3:7]) + self.assertIs(saved[7], None) + + self.assertIn((f32, f64, Conversion.promote), saved[8:14]) + self.assertIn((f64, f32, Conversion.unsafe), saved[8:14]) + self.assertIn((f32, i32, Conversion.unsafe), saved[8:14]) + self.assertIn((i32, f32, Conversion.unsafe), saved[8:14]) + self.assertIn((f32, i64, Conversion.unsafe), saved[8:14]) + self.assertIn((i64, f32, Conversion.unsafe), saved[8:14]) + self.assertEqual(len(saved[14:]), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py b/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py index 09adb20ae..d8190a06d 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py @@ -4,10 +4,10 @@ import itertools from numba.core import errors, types, typing -from numba.core.typeconv import Conversion +from numba.cuda.typeconv import Conversion from numba.cuda.testing import CUDATestCase, skip_on_cudasim -from numba.tests.test_typeconv import CompatibilityTestMixin +from numba.cuda.tests.cudapy.test_typeconv import CompatibilityTestMixin from numba.cuda.core.untyped_passes import TranslateByteCode, IRProcessing from numba.cuda.core.typed_passes import PartialTypeInference from numba.cuda.core.compiler_machinery import FunctionPass, register_pass diff --git a/numba_cuda/numba/cuda/typeconv/__init__.py b/numba_cuda/numba/cuda/typeconv/__init__.py new file mode 100644 index 000000000..55011b82a --- /dev/null +++ b/numba_cuda/numba/cuda/typeconv/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from .castgraph import Conversion # noqa F401 diff --git a/numba_cuda/numba/cuda/typeconv/castgraph.py b/numba_cuda/numba/cuda/typeconv/castgraph.py new file mode 100644 index 000000000..aa9f20db7 --- /dev/null +++ b/numba_cuda/numba/cuda/typeconv/castgraph.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from collections import defaultdict +import enum + + +class Conversion(enum.IntEnum): + """ + A conversion kind from one type to the other. The enum members + are ordered from stricter to looser. + """ + + # The two types are identical + exact = 1 + # The two types are of the same kind, the destination type has more + # extension or precision than the source type (e.g. float32 -> float64, + # or int32 -> int64) + promote = 2 + # The source type can be converted to the destination type without loss + # of information (e.g. int32 -> int64). Note that the conversion may + # still fail explicitly at runtime (e.g. Optional(int32) -> int32) + safe = 3 + # The conversion may appear to succeed at runtime while losing information + # or precision (e.g. int32 -> uint32, float64 -> float32, int64 -> int32, + # etc.) + unsafe = 4 + + # This value is only used internally + nil = 99 + + +class CastSet(object): + """A set of casting rules. + + There is at most one rule per target type. + """ + + def __init__(self): + self._rels = {} + + def insert(self, to, rel): + old = self.get(to) + setrel = min(rel, old) + self._rels[to] = setrel + return old != setrel + + def items(self): + return self._rels.items() + + def get(self, item): + return self._rels.get(item, Conversion.nil) + + def __len__(self): + return len(self._rels) + + def __repr__(self): + body = [ + "{rel}({ty})".format(rel=rel, ty=ty) + for ty, rel in self._rels.items() + ] + return "{" + ", ".join(body) + "}" + + def __contains__(self, item): + return item in self._rels + + def __iter__(self): + return iter(self._rels.keys()) + + def __getitem__(self, item): + return self._rels[item] + + +class TypeGraph(object): + """A graph that maintains the casting relationship of all types. + + This simplifies the definition of casting rules by automatically + propagating the rules. + """ + + def __init__(self, callback=None): + """ + Args + ---- + - callback: callable or None + It is called for each new casting rule with + (from_type, to_type, castrel). + """ + assert callback is None or callable(callback) + self._forwards = defaultdict(CastSet) + self._backwards = defaultdict(set) + self._callback = callback + + def get(self, ty): + return self._forwards[ty] + + def propagate(self, a, b, baserel): + backset = self._backwards[a] + + # Forward propagate the relationship to all nodes that b leads to + for child in self._forwards[b]: + rel = max(baserel, self._forwards[b][child]) + if a != child: + if self._forwards[a].insert(child, rel): + self._callback(a, child, rel) + self._backwards[child].add(a) + + # Propagate the relationship from nodes that connects to a + for backnode in backset: + if backnode != child: + backrel = max(rel, self._forwards[backnode][a]) + if self._forwards[backnode].insert(child, backrel): + self._callback(backnode, child, backrel) + self._backwards[child].add(backnode) + + # Every node that leads to a connects to b + for child in self._backwards[a]: + rel = max(baserel, self._forwards[child][a]) + if b != child: + if self._forwards[child].insert(b, rel): + self._callback(child, b, rel) + self._backwards[b].add(child) + + def insert_rule(self, a, b, rel): + self._forwards[a].insert(b, rel) + self._callback(a, b, rel) + self._backwards[b].add(a) + self.propagate(a, b, rel) + + def promote(self, a, b): + self.insert_rule(a, b, Conversion.promote) + + def safe(self, a, b): + self.insert_rule(a, b, Conversion.safe) + + def unsafe(self, a, b): + self.insert_rule(a, b, Conversion.unsafe) diff --git a/numba_cuda/numba/cuda/typeconv/rules.py b/numba_cuda/numba/cuda/typeconv/rules.py new file mode 100644 index 000000000..2fa5c1158 --- /dev/null +++ b/numba_cuda/numba/cuda/typeconv/rules.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import itertools +from .typeconv import TypeManager, TypeCastingRules +from numba.core import types +from numba.cuda import config + + +default_type_manager = TypeManager() + + +def dump_number_rules(): + tm = default_type_manager + for a, b in itertools.product(types.number_domain, types.number_domain): + print(a, "->", b, tm.check_compatible(a, b)) + + +if config.USE_LEGACY_TYPE_SYSTEM: # Old type system + + def _init_casting_rules(tm): + tcr = TypeCastingRules(tm) + tcr.safe_unsafe(types.boolean, types.int8) + tcr.safe_unsafe(types.boolean, types.uint8) + + tcr.promote_unsafe(types.int8, types.int16) + tcr.promote_unsafe(types.uint8, types.uint16) + + tcr.promote_unsafe(types.int16, types.int32) + tcr.promote_unsafe(types.uint16, types.uint32) + + tcr.promote_unsafe(types.int32, types.int64) + tcr.promote_unsafe(types.uint32, types.uint64) + + tcr.safe_unsafe(types.uint8, types.int16) + tcr.safe_unsafe(types.uint16, types.int32) + tcr.safe_unsafe(types.uint32, types.int64) + + tcr.safe_unsafe(types.int8, types.float16) + tcr.safe_unsafe(types.int16, types.float32) + tcr.safe_unsafe(types.int32, types.float64) + + tcr.unsafe_unsafe(types.int16, types.float16) + tcr.unsafe_unsafe(types.int32, types.float32) + # XXX this is inconsistent with the above; but we want to prefer + # float64 over int64 when typing a heterogeneous operation, + # e.g. `float64 + int64`. Perhaps we need more granularity in the + # conversion kinds. + tcr.safe_unsafe(types.int64, types.float64) + tcr.safe_unsafe(types.uint64, types.float64) + + tcr.promote_unsafe(types.float16, types.float32) + tcr.promote_unsafe(types.float32, types.float64) + + tcr.safe(types.float32, types.complex64) + tcr.safe(types.float64, types.complex128) + + tcr.promote_unsafe(types.complex64, types.complex128) + + # Allow integers to cast ot void* + tcr.unsafe_unsafe(types.uintp, types.voidptr) + + return tcr +else: # New type system + # Currently left as empty + # If no casting rules are required we may opt to remove + # this framework upon deprecation + def _init_casting_rules(tm): + tcr = TypeCastingRules(tm) + return tcr + + +default_casting_rules = _init_casting_rules(default_type_manager) diff --git a/numba_cuda/numba/cuda/typeconv/typeconv.py b/numba_cuda/numba/cuda/typeconv/typeconv.py new file mode 100644 index 000000000..a87ce5043 --- /dev/null +++ b/numba_cuda/numba/cuda/typeconv/typeconv.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from numba.cuda.cext import _typeconv +from numba.cuda.typeconv import castgraph, Conversion +from numba.core import types + + +class TypeManager(object): + # The character codes used by the C/C++ API (_typeconv.cpp) + _conversion_codes = { + Conversion.safe: ord("s"), + Conversion.unsafe: ord("u"), + Conversion.promote: ord("p"), + } + + def __init__(self): + self._ptr = _typeconv.new_type_manager() + self._types = set() + + def select_overload( + self, sig, overloads, allow_unsafe, exact_match_required + ): + sig = [t._code for t in sig] + overloads = [[t._code for t in s] for s in overloads] + return _typeconv.select_overload( + self._ptr, sig, overloads, allow_unsafe, exact_match_required + ) + + def check_compatible(self, fromty, toty): + if not isinstance(toty, types.Type): + raise ValueError( + "Specified type '%s' (%s) is not a Numba type" + % (toty, type(toty)) + ) + name = _typeconv.check_compatible(self._ptr, fromty._code, toty._code) + conv = Conversion[name] if name is not None else None + assert conv is not Conversion.nil + return conv + + def set_compatible(self, fromty, toty, by): + code = self._conversion_codes[by] + _typeconv.set_compatible(self._ptr, fromty._code, toty._code, code) + # Ensure the types don't die, otherwise they may be recreated with + # other type codes and pollute the hash table. + self._types.add(fromty) + self._types.add(toty) + + def set_promote(self, fromty, toty): + self.set_compatible(fromty, toty, Conversion.promote) + + def set_unsafe_convert(self, fromty, toty): + self.set_compatible(fromty, toty, Conversion.unsafe) + + def set_safe_convert(self, fromty, toty): + self.set_compatible(fromty, toty, Conversion.safe) + + def get_pointer(self): + return _typeconv.get_pointer(self._ptr) + + +class TypeCastingRules(object): + """ + A helper for establishing type casting rules. + """ + + def __init__(self, tm): + self._tm = tm + self._tg = castgraph.TypeGraph(self._cb_update) + + def promote(self, a, b): + """ + Set `a` can promote to `b` + """ + self._tg.promote(a, b) + + def unsafe(self, a, b): + """ + Set `a` can unsafe convert to `b` + """ + self._tg.unsafe(a, b) + + def safe(self, a, b): + """ + Set `a` can safe convert to `b` + """ + self._tg.safe(a, b) + + def promote_unsafe(self, a, b): + """ + Set `a` can promote to `b` and `b` can unsafe convert to `a` + """ + self.promote(a, b) + self.unsafe(b, a) + + def safe_unsafe(self, a, b): + """ + Set `a` can safe convert to `b` and `b` can unsafe convert to `a` + """ + self._tg.safe(a, b) + self._tg.unsafe(b, a) + + def unsafe_unsafe(self, a, b): + """ + Set `a` can unsafe convert to `b` and `b` can unsafe convert to `a` + """ + self._tg.unsafe(a, b) + self._tg.unsafe(b, a) + + def _cb_update(self, a, b, rel): + """ + Callback for updating. + """ + if rel == Conversion.promote: + self._tm.set_promote(a, b) + elif rel == Conversion.safe: + self._tm.set_safe_convert(a, b) + elif rel == Conversion.unsafe: + self._tm.set_unsafe_convert(a, b) + else: + raise AssertionError(rel) diff --git a/numba_cuda/numba/cuda/types.py b/numba_cuda/numba/cuda/types.py index d1ec8c28d..7e407ac81 100644 --- a/numba_cuda/numba/cuda/types.py +++ b/numba_cuda/numba/cuda/types.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause from numba.core import types -from numba.core.typeconv import Conversion +from numba.cuda.typeconv import Conversion class Dim3(types.Type): diff --git a/numba_cuda/numba/cuda/typing/context.py b/numba_cuda/numba/cuda/typing/context.py index 386db0511..e511de2cb 100644 --- a/numba_cuda/numba/cuda/typing/context.py +++ b/numba_cuda/numba/cuda/typing/context.py @@ -10,7 +10,7 @@ import operator from numba.core import types, errors -from numba.core.typeconv import Conversion, rules +from numba.cuda.typeconv import Conversion, rules from numba.core.typing.typeof import typeof, Purpose from numba.core.typing import templates from numba.cuda import utils diff --git a/setup.py b/setup.py index cb5d4b4db..f87cb6b79 100644 --- a/setup.py +++ b/setup.py @@ -78,10 +78,20 @@ def get_ext_modules(): **np_compile_args, ) + ext_typeconv = Extension( + name="numba_cuda.numba.cuda.cext._typeconv", + sources=[ + "numba_cuda/numba/cuda/cext/typeconv.cpp", + "numba_cuda/numba/cuda/cext/_typeconv.cpp", + ], + depends=["numba_cuda/numba/cuda/cext/_pymodule.h"], + extra_compile_args=["-std=c++11"], + ) + # Append our cext dir to include_dirs ext_dispatcher.include_dirs.append("numba_cuda/numba/cuda/cext") - return [ext_dispatcher, ext_mviewbuf, ext_devicearray] + return [ext_dispatcher, ext_typeconv, ext_mviewbuf, ext_devicearray] def is_building():