Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions numba_cuda/numba/cuda/_internal/cuda_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self):
self.bitwidth = 2 * 8

def can_convert_from(self, typingctx, other):
from numba.core.typeconv import Conversion
from numba.cuda.typeconv import Conversion

if other in []:
return Conversion.safe
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(self):
self.bitwidth = 4 * 8

def can_convert_from(self, typingctx, other):
from numba.core.typeconv import Conversion
from numba.cuda.typeconv import Conversion

if other in []:
return Conversion.safe
Expand Down Expand Up @@ -7903,9 +7903,9 @@ def generic(self, args, kws):
# - Conversion.safe

if (
(convertible == numba.core.typeconv.Conversion.exact)
or (convertible == numba.core.typeconv.Conversion.promote)
or (convertible == numba.core.typeconv.Conversion.safe)
(convertible == numba.cuda.typeconv.Conversion.exact)
or (convertible == numba.cuda.typeconv.Conversion.promote)
or (convertible == numba.cuda.typeconv.Conversion.safe)
):
return signature(retty, types.float16, types.float16)

Expand Down
206 changes: 206 additions & 0 deletions numba_cuda/numba/cuda/cext/_typeconv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-2-Clause

#include "_pymodule.h"
#include "capsulethunk.h"
#include "typeconv.hpp"

extern "C" {


static PyObject*
new_type_manager(PyObject* self, PyObject* args);

static void
del_type_manager(PyObject *);

static PyObject*
select_overload(PyObject* self, PyObject* args);

static PyObject*
check_compatible(PyObject* self, PyObject* args);

static PyObject*
set_compatible(PyObject* self, PyObject* args);

static PyObject*
get_pointer(PyObject* self, PyObject* args);


static PyMethodDef ext_methods[] = {
#define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL }
declmethod(new_type_manager),
declmethod(select_overload),
declmethod(check_compatible),
declmethod(set_compatible),
declmethod(get_pointer),
{ NULL },
#undef declmethod
};


MOD_INIT(_typeconv) {
PyObject *m;
MOD_DEF(m, "_typeconv", "No docs", ext_methods)
if (m == NULL)
return MOD_ERROR_VAL;

return MOD_SUCCESS_VAL(m);
}

} // end extern C

///////////////////////////////////////////////////////////////////////////////

const char PY_CAPSULE_TM_NAME[] = "*tm";
#define BAD_TM_ARGUMENT PyErr_SetString(PyExc_TypeError, \
"1st argument not TypeManager")

static
TypeManager* unwrap_TypeManager(PyObject *tm) {
void* p = PyCapsule_GetPointer(tm, PY_CAPSULE_TM_NAME);
return reinterpret_cast<TypeManager*>(p);
}

PyObject*
new_type_manager(PyObject* self, PyObject* args)
{
TypeManager* tm = new TypeManager();
return PyCapsule_New(tm, PY_CAPSULE_TM_NAME, &del_type_manager);
}

void
del_type_manager(PyObject *tm)
{
delete unwrap_TypeManager(tm);
}

PyObject*
select_overload(PyObject* self, PyObject* args)
{
PyObject *tmcap, *sigtup, *ovsigstup;
int allow_unsafe;
int exact_match_required;

if (!PyArg_ParseTuple(args, "OOOii", &tmcap, &sigtup, &ovsigstup,
&allow_unsafe, &exact_match_required)) {
return NULL;
}

TypeManager *tm = unwrap_TypeManager(tmcap);
if (!tm) {
BAD_TM_ARGUMENT;
}

Py_ssize_t sigsz = PySequence_Size(sigtup);
Py_ssize_t ovsz = PySequence_Size(ovsigstup);

Type *sig = new Type[sigsz];
Type *ovsigs = new Type[ovsz * sigsz];

for (int i = 0; i < sigsz; ++i) {
sig[i] = Type(PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(sigtup,
i), NULL));
}

for (int i = 0; i < ovsz; ++i) {
PyObject *cursig = PySequence_Fast_GET_ITEM(ovsigstup, i);
for (int j = 0; j < sigsz; ++j) {
long tid = PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(cursig,
j), NULL);
ovsigs[i * sigsz + j] = Type(tid);
}
}

int selected = -42;
int matches = tm->selectOverload(sig, ovsigs, selected, sigsz, ovsz,
(bool) allow_unsafe,
(bool) exact_match_required);

delete [] sig;
delete [] ovsigs;

if (matches > 1) {
PyErr_SetString(PyExc_TypeError, "Ambiguous overloading");
return NULL;
} else if (matches == 0) {
PyErr_SetString(PyExc_TypeError, "No compatible overload");
return NULL;
}

return PyLong_FromLong(selected);
}

PyObject*
check_compatible(PyObject* self, PyObject* args)
{
PyObject *tmcap;
int from, to;
if (!PyArg_ParseTuple(args, "Oii", &tmcap, &from, &to)) {
return NULL;
}

TypeManager *tm = unwrap_TypeManager(tmcap);
if(!tm) {
BAD_TM_ARGUMENT;
return NULL;
}

switch(tm->isCompatible(Type(from), Type(to))){
case TCC_EXACT:
return PyString_FromString("exact");
case TCC_PROMOTE:
return PyString_FromString("promote");
case TCC_CONVERT_SAFE:
return PyString_FromString("safe");
case TCC_CONVERT_UNSAFE:
return PyString_FromString("unsafe");
default:
Py_RETURN_NONE;
}
}

PyObject*
set_compatible(PyObject* self, PyObject* args)
{
PyObject *tmcap;
int from, to, by;
if (!PyArg_ParseTuple(args, "Oiii", &tmcap, &from, &to, &by)) {
return NULL;
}

TypeManager *tm = unwrap_TypeManager(tmcap);
if (!tm) {
BAD_TM_ARGUMENT;
return NULL;
}
TypeCompatibleCode tcc;
switch (by) {
case 'p': // promote
tcc = TCC_PROMOTE;
break;
case 's': // safe convert
tcc = TCC_CONVERT_SAFE;
break;
case 'u': // unsafe convert
tcc = TCC_CONVERT_UNSAFE;
break;
default:
PyErr_SetString(PyExc_ValueError, "Unknown TCC");
return NULL;
}

tm->addCompatibility(Type(from), Type(to), tcc);
Py_RETURN_NONE;
}


PyObject*
get_pointer(PyObject* self, PyObject* args)
{
PyObject *tmcap;
if (!PyArg_ParseTuple(args, "O", &tmcap)) {
return NULL;
}
return PyLong_FromVoidPtr(unwrap_TypeManager(tmcap));
}
111 changes: 111 additions & 0 deletions numba_cuda/numba/cuda/cext/capsulethunk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-2-Clause

/**

This is a modified version of capsulethunk.h for use in llvmpy

**/

#ifndef __CAPSULETHUNK_H
#define __CAPSULETHUNK_H

#if ( (PY_VERSION_HEX < 0x02070000) \
|| ((PY_VERSION_HEX >= 0x03000000) \
&& (PY_VERSION_HEX < 0x03010000)) )

//#define Assert(X) do_assert(!!(X), #X, __FILE__, __LINE__)
#define Assert(X)

static
void do_assert(int cond, const char * msg, const char *file, unsigned line){
if (!cond) {
fprintf(stderr, "Assertion failed %s:%d\n%s\n", file, line, msg);
exit(1);
}
}

typedef void (*PyCapsule_Destructor)(PyObject *);

struct FakePyCapsule_Desc {
const char *name;
void *context;
PyCapsule_Destructor dtor;
PyObject *parent;

FakePyCapsule_Desc() : name(0), context(0), dtor(0) {}
};

static
FakePyCapsule_Desc* get_pycobj_desc(PyObject *p){
void *desc = ((PyCObject*)p)->desc;
Assert(desc && "No desc in PyCObject");
return static_cast<FakePyCapsule_Desc*>(desc);
}

static
void pycobject_pycapsule_dtor(void *p, void *desc){
Assert(desc);
Assert(p);
FakePyCapsule_Desc *fpc_desc = static_cast<FakePyCapsule_Desc*>(desc);
Assert(fpc_desc->parent);
Assert(PyCObject_Check(fpc_desc->parent));
fpc_desc->dtor(static_cast<PyObject*>(fpc_desc->parent));
delete fpc_desc;
}

static
PyObject* PyCapsule_New(void* ptr, const char *name, PyCapsule_Destructor dtor)
{
FakePyCapsule_Desc *desc = new FakePyCapsule_Desc;
desc->name = name;
desc->context = NULL;
desc->dtor = dtor;
PyObject *p = PyCObject_FromVoidPtrAndDesc(ptr, desc,
pycobject_pycapsule_dtor);
desc->parent = p;
return p;
}

static
int PyCapsule_CheckExact(PyObject *p)
{
return PyCObject_Check(p);
}

static
void* PyCapsule_GetPointer(PyObject *p, const char *name)
{
Assert(PyCapsule_CheckExact(p));
if (strcmp(get_pycobj_desc(p)->name, name) != 0) {
PyErr_SetString(PyExc_ValueError, "Invalid PyCapsule object");
}
return PyCObject_AsVoidPtr(p);
}

static
void* PyCapsule_GetContext(PyObject *p)
{
Assert(p);
Assert(PyCapsule_CheckExact(p));
return get_pycobj_desc(p)->context;
}

static
int PyCapsule_SetContext(PyObject *p, void *context)
{
Assert(PyCapsule_CheckExact(p));
get_pycobj_desc(p)->context = context;
return 0;
}

static
const char * PyCapsule_GetName(PyObject *p)
{
// Assert(PyCapsule_CheckExact(p));
return get_pycobj_desc(p)->name;
}

#endif /* #if PY_VERSION_HEX < 0x02070000 */

#endif /* __CAPSULETHUNK_H */
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/core/typeinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
NumbaValueError,
)
from numba.cuda.core.funcdesc import qualifying_prefix
from numba.core.typeconv import Conversion
from numba.cuda.typeconv import Conversion

_logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def resolve_value_type(self, val):
def can_convert(self, fromty, toty):
"""
Check whether conversion is possible from *fromty* to *toty*.
If successful, return a numba.typeconv.Conversion instance;
If successful, return a numba.cuda.typeconv.Conversion instance;
otherwise None is returned.
"""

Expand Down
Loading