Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions numba_cuda/numba/cuda/extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def overload(
strict=True,
inline="never",
prefer_literal=False,
target="cuda",
**kwargs,
):
"""
Expand Down Expand Up @@ -205,6 +206,8 @@ def len_impl(seq):
# TODO: abort now if the kwarg 'target' relates to an unregistered target,
# this requires sorting out the circular imports first.

kwargs["target"] = target

def decorate(overload_func):
template = make_overload_template(
func, overload_func, opts, strict, inline, prefer_literal, **kwargs
Expand Down Expand Up @@ -251,7 +254,7 @@ def ov_wrap(*args, **kwargs):
return wrap(*args)


def overload_attribute(typ, attr, **kwargs):
def overload_attribute(typ, attr, target="cuda", **kwargs):
"""
A decorator marking the decorated function as typing and implementing
attribute *attr* for the given Numba type in nopython mode.
Expand All @@ -270,6 +273,8 @@ def get(arr):
# TODO implement setters
from numba.core.typing.templates import make_overload_attribute_template

kwargs["target"] = target

def decorate(overload_func):
template = make_overload_attribute_template(
typ, attr, overload_func, **kwargs
Expand Down Expand Up @@ -302,7 +307,7 @@ def decorate(overload_func):
return decorate


def overload_method(typ, attr, **kwargs):
def overload_method(typ, attr, target="cuda", **kwargs):
"""
A decorator marking the decorated function as typing and implementing
method *attr* for the given Numba type in nopython mode.
Expand All @@ -324,10 +329,13 @@ def take_impl(arr, indices):

return take_impl
"""

kwargs["target"] = target

return _overload_method_common(typ, attr, **kwargs)


def overload_classmethod(typ, attr, **kwargs):
def overload_classmethod(typ, attr, target="cuda", **kwargs):
"""
A decorator marking the decorated function as typing and implementing
classmethod *attr* for the given Numba type in nopython mode.
Expand All @@ -352,6 +360,9 @@ def impl(cls, nitems):
def foo(n):
return types.Array.make(n)
"""

kwargs["target"] = target

return _overload_method_common(types.TypeRef(typ), attr, **kwargs)


Expand Down