diff --git a/numba_cuda/numba/cuda/extending.py b/numba_cuda/numba/cuda/extending.py index 90b2e8f57..1c560c43a 100644 --- a/numba_cuda/numba/cuda/extending.py +++ b/numba_cuda/numba/cuda/extending.py @@ -136,6 +136,7 @@ def overload( strict=True, inline="never", prefer_literal=False, + target="cuda", **kwargs, ): """ @@ -205,6 +206,8 @@ def len_impl(seq): # TODO: abort now if the kwarg 'target' relates to an unregistered target, # this requires sorting out the circular imports first. + kwargs["target"] = target + def decorate(overload_func): template = make_overload_template( func, overload_func, opts, strict, inline, prefer_literal, **kwargs @@ -251,7 +254,7 @@ def ov_wrap(*args, **kwargs): return wrap(*args) -def overload_attribute(typ, attr, **kwargs): +def overload_attribute(typ, attr, target="cuda", **kwargs): """ A decorator marking the decorated function as typing and implementing attribute *attr* for the given Numba type in nopython mode. @@ -270,6 +273,8 @@ def get(arr): # TODO implement setters from numba.core.typing.templates import make_overload_attribute_template + kwargs["target"] = target + def decorate(overload_func): template = make_overload_attribute_template( typ, attr, overload_func, **kwargs @@ -302,7 +307,7 @@ def decorate(overload_func): return decorate -def overload_method(typ, attr, **kwargs): +def overload_method(typ, attr, target="cuda", **kwargs): """ A decorator marking the decorated function as typing and implementing method *attr* for the given Numba type in nopython mode. @@ -324,10 +329,13 @@ def take_impl(arr, indices): return take_impl """ + + kwargs["target"] = target + return _overload_method_common(typ, attr, **kwargs) -def overload_classmethod(typ, attr, **kwargs): +def overload_classmethod(typ, attr, target="cuda", **kwargs): """ A decorator marking the decorated function as typing and implementing classmethod *attr* for the given Numba type in nopython mode. @@ -352,6 +360,9 @@ def impl(cls, nitems): def foo(n): return types.Array.make(n) """ + + kwargs["target"] = target + return _overload_method_common(types.TypeRef(typ), attr, **kwargs)