diff --git a/numba_cuda/numba/cuda/typing/templates.py b/numba_cuda/numba/cuda/typing/templates.py index 5f1834f61..ae7538758 100644 --- a/numba_cuda/numba/cuda/typing/templates.py +++ b/numba_cuda/numba/cuda/typing/templates.py @@ -1355,12 +1355,18 @@ def __init__(self): self.globals = [] def register(self, item): - assert issubclass(item, FunctionTemplate) + assert issubclass( + item, + (FunctionTemplate, numba.core.typing.templates.FunctionTemplate), + ) self.functions.append(item) return item def register_attr(self, item): - assert issubclass(item, AttributeTemplate) + assert issubclass( + item, + (AttributeTemplate, numba.core.typing.templates.AttributeTemplate), + ) self.attributes.append(item) return item