From 96ffc6743b14a2740b370e51118c8233b93dda53 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 6 Nov 2025 13:45:53 -0800 Subject: [PATCH 01/15] eagerly use all vended numba.cuda modules --- numbast/benchmarks/test_arithmetic.py | 2 +- numbast/src/numbast/function.py | 2 +- numbast/src/numbast/functor.py | 4 +-- numbast/src/numbast/static/function.py | 8 ++--- numbast/src/numbast/static/renderer.py | 6 ++-- numbast/src/numbast/static/struct.py | 30 +++++++++---------- .../static/tests/test_conflict_shim_names.py | 4 +-- .../tests/test_function_static_bindings.py | 2 +- .../static/tests/test_link_two_files.py | 4 +-- .../static/tests/test_operator_bindings.py | 4 +-- .../numbast/static/tests/test_static_demo.py | 4 +-- .../tests/test_struct_static_bindings.py | 4 +-- numbast/src/numbast/struct.py | 8 ++--- numbast/src/numbast/types.py | 2 +- numbast/tests/data/sample_function.cuh | 8 +++++ numbast/tests/demo/demo.py | 2 +- numbast/tests/test_struct.py | 2 +- 17 files changed, 52 insertions(+), 44 deletions(-) create mode 100644 numbast/tests/data/sample_function.cuh diff --git a/numbast/benchmarks/test_arithmetic.py b/numbast/benchmarks/test_arithmetic.py index e41d8d98..06d7e6dc 100644 --- a/numbast/benchmarks/test_arithmetic.py +++ b/numbast/benchmarks/test_arithmetic.py @@ -1,6 +1,6 @@ import numba.cuda as cuda +from numba.cuda.types import float32 import numpy as np -from numba import float32 import pytest diff --git a/numbast/src/numbast/function.py b/numbast/src/numbast/function.py index b1b4df3d..5e3d3a83 100644 --- a/numbast/src/numbast/function.py +++ b/numbast/src/numbast/function.py @@ -5,7 +5,7 @@ from collections import defaultdict from numba import types as nbtypes -from numba.core.typing import signature as nb_signature, Signature +from numba.cuda.typing import signature as nb_signature, Signature from numba.cuda.typing.templates import ConcreteTemplate from numba.cuda import declare_device from numba.cuda.cudadecl import register_global, register diff --git a/numbast/src/numbast/functor.py b/numbast/src/numbast/functor.py index 552a3d13..c2808cf9 100644 --- a/numbast/src/numbast/functor.py +++ b/numbast/src/numbast/functor.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from numba import types -from numba.core.extending import typeof_impl -from numba.core.imputils import lower_constant +from numba.cuda.extending import typeof_impl +from numba.cuda.core.imputils import lower_constant from numba import cuda from numbast.types import NUMBA_TO_CTYPE_MAPS as N2C, FunctorType diff --git a/numbast/src/numbast/static/function.py b/numbast/src/numbast/static/function.py index 6c8ef057..35605f06 100644 --- a/numbast/src/numbast/static/function.py +++ b/numbast/src/numbast/static/function.py @@ -227,11 +227,11 @@ def _render_decl_device(self): """Render codes that declares a foreign function for this function in Numba.""" self.Imports.add("from numba.cuda import declare_device") - self.Imports.add("from numba.core.typing import signature") + self.Imports.add("from numba.cuda.typing import signature") # All arguments are passed by pointers in C-CPP shim interop - self.Imports.add("from numba.types import CPointer") + self.Imports.add("from numba.cuda.types import CPointer") # Numba ABI returns int32 for exception codes - self.Imports.add("from numba.types import int32") + self.Imports.add("from numba.cuda.types import int32") decl_device_rendered = self.decl_device_template.format( decl_name=self._deduplicated_shim_name, @@ -273,7 +273,7 @@ def _render_shim_function(self): def _render_lowering(self): """Render lowering codes for this struct constructor.""" - self.Imports.add("from numba.core.typing import signature") + self.Imports.add("from numba.cuda.typing import signature") use_cooperative = "" if self._use_cooperative: diff --git a/numbast/src/numbast/static/renderer.py b/numbast/src/numbast/static/renderer.py index 1fbcbc4b..1be50e89 100644 --- a/numbast/src/numbast/static/renderer.py +++ b/numbast/src/numbast/static/renderer.py @@ -104,7 +104,7 @@ def _try_import_numba_type(cls, typ: str): cls._imported_numba_types.add(typ) elif typ in numba.types.__dict__: - cls.Imports.add(f"from numba.types import {typ}") + cls.Imports.add(f"from numba.cuda.types import {typ}") cls._imported_numba_types.add(typ) else: @@ -283,7 +283,7 @@ def registry_setup(use_separate_registry: bool) -> str: "from numba.cuda.typing.templates import Registry as TypingRegistry" ) BaseRenderer.Imports.add( - "from numba.core.imputils import Registry as TargetRegistry, lower_cast" + "from numba.cuda.core.imputils import Registry as TargetRegistry, lower_cast" ) return BaseRenderer.SeparateRegistrySetup else: @@ -299,5 +299,5 @@ def registry_setup(use_separate_registry: bool) -> str: BaseRenderer.Imports.add( "from numba.cuda.cudaimpl import lower_constant" ) - BaseRenderer.Imports.add("from numba.core.extending import lower_cast") + BaseRenderer.Imports.add("from numba.cuda.extending import lower_cast") return "" diff --git a/numbast/src/numbast/static/struct.py b/numbast/src/numbast/static/struct.py index 18d70edf..4a7ae9ba 100644 --- a/numbast/src/numbast/static/struct.py +++ b/numbast/src/numbast/static/struct.py @@ -7,8 +7,8 @@ import tempfile import warnings -from numba.types import Type -from numba.core.datamodel.models import StructModel, PrimitiveModel +from numba.cuda.types import Type +from numba.cuda.datamodel.models import StructModel, PrimitiveModel from ast_canopy.pylibastcanopy import access_kind, method_kind from ast_canopy.decl import Struct, StructMethod @@ -201,11 +201,11 @@ def _render_decl_device(self): """Render codes that declares a foreign function for this constructor in Numba.""" self.Imports.add("from numba.cuda import declare_device") - self.Imports.add("from numba.core.typing import signature") + self.Imports.add("from numba.cuda.typing import signature") # All arguments are passed by pointers in C-CPP shim interop - self.Imports.add("from numba.types import CPointer") + self.Imports.add("from numba.cuda.types import CPointer") # Numba ABI returns int32 for exception codes - self.Imports.add("from numba.types import int32") + self.Imports.add("from numba.cuda.types import int32") decl_device_rendered = self.struct_ctor_decl_device_template.format( struct_ctor_device_decl_str=self._struct_ctor_device_decl_str, @@ -363,7 +363,7 @@ def _render_typing(self, signature_strs: list[str]): self.Imports.add( "from numba.cuda.typing.templates import ConcreteTemplate" ) - self.Imports.add("from numba.types import Function") + self.Imports.add("from numba.cuda.types import Function") signatures_str = ", ".join(signature_strs) @@ -522,9 +522,9 @@ def _render_decl_device(self): """Render codes that declares a foreign function for this constructor in Numba.""" self.Imports.add("from numba.cuda import declare_device") - self.Imports.add("from numba.core.typing import signature") + self.Imports.add("from numba.cuda.typing import signature") # All arguments are passed by pointers in C-CPP shim interop - self.Imports.add("from numba.types import CPointer") + self.Imports.add("from numba.cuda.types import CPointer") decl_device_rendered = ( self.struct_conversion_op_decl_device_template.format( @@ -565,7 +565,7 @@ def _render_shim_function(self): def _render_lowering(self): """Render lowering codes for this struct constructor.""" - self.Imports.add("from numba.core.extending import lower_cast") + self.Imports.add("from numba.cuda.extending import lower_cast") self._lowering_rendered = ( self.struct_conversion_op_lowering_template.format( @@ -689,7 +689,7 @@ def __init__(self): self.bitwidth = {struct_sizeof} * 8 def can_convert_from(self, typingctx, other): - from numba.core.typeconv import Conversion + from numba.cuda.typeconv import Conversion if other in [{implicit_conversion_types}]: return Conversion.safe @@ -764,12 +764,12 @@ def __init__( self._data_model = data_model self.Imports.add( - f"from numba.types import {self._parent_type.__qualname__}" + f"from numba.cuda.types import {self._parent_type.__qualname__}" ) self._parent_type_str = self._parent_type.__qualname__ self.Imports.add( - f"from numba.core.datamodel import {self._data_model.__qualname__}" + f"from numba.cuda.datamodel import {self._data_model.__qualname__}" ) self._data_model_str = self._data_model.__qualname__ @@ -816,7 +816,7 @@ def _render_python_api(self): This is the python handle to use it in Numba kernels. """ - self.Imports.add("from numba.extending import as_numba_type") + self.Imports.add("from numba.cuda.extending import as_numba_type") self._python_api_rendered = self.python_api_template.format( struct_type_name=self._struct_type_name, @@ -826,7 +826,7 @@ def _render_python_api(self): def _render_data_model(self): """Renders the data model of the struct.""" - self.Imports.add("from numba.core.extending import register_model") + self.Imports.add("from numba.cuda.extending import register_model") if self._data_model == PrimitiveModel: self.Imports.add("from llvmlite import ir") @@ -868,7 +868,7 @@ def _render_struct_attr(self): "from numba.cuda.typing.templates import AttributeTemplate" ) self.Imports.add( - "from numba.core.extending import make_attribute_wrapper" + "from numba.cuda.extending import make_attribute_wrapper" ) public_fields = [ diff --git a/numbast/src/numbast/static/tests/test_conflict_shim_names.py b/numbast/src/numbast/static/tests/test_conflict_shim_names.py index b621556d..91e7dce1 100644 --- a/numbast/src/numbast/static/tests/test_conflict_shim_names.py +++ b/numbast/src/numbast/static/tests/test_conflict_shim_names.py @@ -3,8 +3,8 @@ from numba import cuda -from numba.types import Type -from numba.core.datamodel import StructModel +from numba.cuda.types import Type +from numba.cuda.datamodel import StructModel def _check_shim_name_contains_mangled_name(src: str, mangled_name: str): diff --git a/numbast/src/numbast/static/tests/test_function_static_bindings.py b/numbast/src/numbast/static/tests/test_function_static_bindings.py index 750e920f..be1ad953 100644 --- a/numbast/src/numbast/static/tests/test_function_static_bindings.py +++ b/numbast/src/numbast/static/tests/test_function_static_bindings.py @@ -6,7 +6,7 @@ import numpy as np import cffi -from numba.types import int32, float32 +from numba.cuda.types import int32, float32 from numba import cuda from numba.cuda import device_array diff --git a/numbast/src/numbast/static/tests/test_link_two_files.py b/numbast/src/numbast/static/tests/test_link_two_files.py index 13bc5183..317145b9 100644 --- a/numbast/src/numbast/static/tests/test_link_two_files.py +++ b/numbast/src/numbast/static/tests/test_link_two_files.py @@ -4,8 +4,8 @@ import pytest from numba import cuda -from numba.types import Type, Number -from numba.core.datamodel import StructModel, PrimitiveModel +from numba.cuda.types import Type, Number +from numba.cuda.datamodel import StructModel, PrimitiveModel from ast_canopy import parse_declarations_from_source from numbast.static.renderer import clear_base_renderer_cache, registry_setup diff --git a/numbast/src/numbast/static/tests/test_operator_bindings.py b/numbast/src/numbast/static/tests/test_operator_bindings.py index ff22a087..27eddc46 100644 --- a/numbast/src/numbast/static/tests/test_operator_bindings.py +++ b/numbast/src/numbast/static/tests/test_operator_bindings.py @@ -4,8 +4,8 @@ import pytest from numba import cuda -from numba.types import Type -from numba.core.datamodel import StructModel +from numba.cuda.types import Type +from numba.cuda.datamodel import StructModel from numba.cuda import device_array from ast_canopy import parse_declarations_from_source diff --git a/numbast/src/numbast/static/tests/test_static_demo.py b/numbast/src/numbast/static/tests/test_static_demo.py index 058a1a21..99250b57 100644 --- a/numbast/src/numbast/static/tests/test_static_demo.py +++ b/numbast/src/numbast/static/tests/test_static_demo.py @@ -4,8 +4,8 @@ import pytest from numba import cuda -from numba.types import Number, float64 -from numba.core.datamodel import PrimitiveModel +from numba.cuda.types import Number, float64 +from numba.cuda.datamodel import PrimitiveModel from numba.cuda import device_array from ast_canopy import parse_declarations_from_source diff --git a/numbast/src/numbast/static/tests/test_struct_static_bindings.py b/numbast/src/numbast/static/tests/test_struct_static_bindings.py index ea93988a..a538b98a 100644 --- a/numbast/src/numbast/static/tests/test_struct_static_bindings.py +++ b/numbast/src/numbast/static/tests/test_struct_static_bindings.py @@ -4,8 +4,8 @@ import pytest from numba import cuda -from numba.types import Type, Number -from numba.core.datamodel import StructModel, PrimitiveModel +from numba.cuda.types import Type, Number +from numba.cuda.datamodel import StructModel, PrimitiveModel from numba.cuda import device_array from ast_canopy import parse_declarations_from_source diff --git a/numbast/src/numbast/struct.py b/numbast/src/numbast/struct.py index c61dc182..899d83be 100644 --- a/numbast/src/numbast/struct.py +++ b/numbast/src/numbast/struct.py @@ -5,15 +5,15 @@ from llvmlite import ir -from numba import types as nbtypes -from numba.core.extending import ( +from numba.cuda import types as nbtypes +from numba.cuda.extending import ( register_model, lower_cast, make_attribute_wrapper, ) -from numba.core.typing import signature as nb_signature +from numba.cuda.typing import signature as nb_signature from numba.cuda.typing.templates import ConcreteTemplate, AttributeTemplate -from numba.core.datamodel.models import StructModel, PrimitiveModel +from numba.cuda.datamodel.models import StructModel, PrimitiveModel from numba.cuda import declare_device from numba.cuda.cudadecl import register_global, register, register_attr from numba.cuda.cudaimpl import lower diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py index c0f16faa..f75a26fc 100644 --- a/numbast/src/numbast/types.py +++ b/numbast/src/numbast/types.py @@ -6,8 +6,8 @@ from numba import types as nbtypes from numba.cuda.types import bfloat16 +from numba.cuda.typing.typeof import typeof from numba.cuda.vector_types import vector_types -from numba.misc.special import typeof class FunctorType(nbtypes.Type): diff --git a/numbast/tests/data/sample_function.cuh b/numbast/tests/data/sample_function.cuh new file mode 100644 index 00000000..9083cc4c --- /dev/null +++ b/numbast/tests/data/sample_function.cuh @@ -0,0 +1,8 @@ +// clang-format off +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// clang-format on + +#pragma once + +void __device__ func_with_void_ptr_arg(void *ptr) {} diff --git a/numbast/tests/demo/demo.py b/numbast/tests/demo/demo.py index 5a1e017c..5bdb03de 100644 --- a/numbast/tests/demo/demo.py +++ b/numbast/tests/demo/demo.py @@ -3,7 +3,7 @@ from numbast import bind_cxx_struct, bind_cxx_function, MemoryShimWriter from numba import types, cuda -from numba.core.datamodel.models import PrimitiveModel +from numba.cuda.datamodel.models import PrimitiveModel import numpy as np diff --git a/numbast/tests/test_struct.py b/numbast/tests/test_struct.py index 66ede4a5..f84baf26 100644 --- a/numbast/tests/test_struct.py +++ b/numbast/tests/test_struct.py @@ -4,7 +4,7 @@ import os from numba import types -from numba.core.datamodel import StructModel +from numba.cuda.datamodel import StructModel from llvmlite import ir From c93c98ae9d6452627f1ebda4fb2439195fb770c1 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 20 Nov 2025 11:14:39 -0800 Subject: [PATCH 02/15] skip unknown types in struct conversion operator gen --- numbast/src/numbast/static/struct.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/numbast/src/numbast/static/struct.py b/numbast/src/numbast/static/struct.py index 550f5483..5e770223 100644 --- a/numbast/src/numbast/static/struct.py +++ b/numbast/src/numbast/static/struct.py @@ -640,13 +640,20 @@ def _render(self): """Render all struct constructors.""" for convop_decl in self._convop_decls: - renderer = StaticStructConversionOperatorRenderer( - struct_name=self._struct_name, - struct_type_class=self._struct_type_class, - struct_type_name=self._struct_type_name, - header_path=self._header_path, - convop_decl=convop_decl, - ) + try: + renderer = StaticStructConversionOperatorRenderer( + struct_name=self._struct_name, + struct_type_class=self._struct_type_class, + struct_type_name=self._struct_type_name, + header_path=self._header_path, + convop_decl=convop_decl, + ) + except TypeNotFoundError as e: + warnings.warn( + f"{e._type_name} is not known to Numbast. Skipping " + f"binding for {str(convop_decl)}" + ) + continue renderer._render() self._python_rendered += renderer._python_rendered @@ -1359,11 +1366,13 @@ def __init__( decls: list[Struct], specs: dict[str, tuple[type | None, type | None, os.PathLike]], default_header: os.PathLike | str | None = None, + struct_prefix_removal: list[str] = [], excludes: list[str] = [], ): self._decls = decls self._specs = specs self._default_header = default_header + self._struct_prefix_removal = struct_prefix_removal self._python_rendered = [] self._c_rendered = [] From b72a05c4c12cff9d5f4f0a5cb999bd9f1c8b076c Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 20 Nov 2025 11:19:49 -0800 Subject: [PATCH 03/15] factor out prefix removal helper --- numbast/src/numbast/static/function.py | 26 ++++---------------------- numbast/src/numbast/utils.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/numbast/src/numbast/static/function.py b/numbast/src/numbast/static/function.py index 6c8ef057..217eabb5 100644 --- a/numbast/src/numbast/static/function.py +++ b/numbast/src/numbast/static/function.py @@ -16,7 +16,7 @@ get_shim, ) from numbast.static.types import to_numba_type_str -from numbast.utils import make_function_shim +from numbast.utils import make_function_shim, _apply_prefix_removal from numbast.errors import TypeNotFoundError, MangledFunctionNameConflictError from ast_canopy.decl import Function @@ -413,8 +413,9 @@ def __init__( function_prefix_removal: list[str] = [], ): super().__init__(decl, header_path, use_cooperative) - self._function_prefix_removal = function_prefix_removal - self._python_func_name = self._apply_prefix_removal(self._decl.name) + self._python_func_name = _apply_prefix_removal( + decl.name, function_prefix_removal + ) # Override the base class symbol tracking to use the Python function name # Remove the original name that was added by the base class @@ -423,25 +424,6 @@ def __init__( # Add the Python function name (with prefix removal applied) self._function_symbols.append(self._python_func_name) - def _apply_prefix_removal(self, name: str) -> str: - """Apply prefix removal to a function name based on the configuration. - - Parameters - ---------- - name : str - The original function name - - Returns - ------- - str - The function name with prefixes removed - """ - for prefix in self._function_prefix_removal: - if name.startswith(prefix): - return name[len(prefix) :] - - return name - @property def func_name_python(self): """The name of the function in python with prefix removal applied.""" diff --git a/numbast/src/numbast/utils.py b/numbast/src/numbast/utils.py index af051473..54aeb7bc 100644 --- a/numbast/src/numbast/utils.py +++ b/numbast/src/numbast/utils.py @@ -316,3 +316,23 @@ def make_struct_conversion_operator_shim( ) return shim + + +def _apply_prefix_removal(name: str, prefix_to_remove: list[str]) -> str: + """Apply prefix removal to a function name based on the configuration. + + Parameters + ---------- + name : str + The original function name + + Returns + ------- + str + The function name with prefixes removed + """ + for prefix in prefix_to_remove: + if name.startswith(prefix): + return name[len(prefix) :] + + return name From a4d078ffdab1e6a4bd2d72d79774a5959e4da684 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 20 Nov 2025 14:29:46 -0800 Subject: [PATCH 04/15] move singular regular method to standalone class, support prefix removal of struct --- numbast/src/numbast/static/struct.py | 314 ++++++++++-------- .../numbast/tools/static_binding_generator.py | 17 +- .../tools/tests/config/prefix_removal.yml.j2 | 8 +- .../numbast/tools/tests/prefix_removal.cuh | 6 + .../tools/tests/test_prefix_removal.py | 43 +-- 5 files changed, 213 insertions(+), 175 deletions(-) diff --git a/numbast/src/numbast/static/struct.py b/numbast/src/numbast/static/struct.py index 5e770223..e338bd1e 100644 --- a/numbast/src/numbast/static/struct.py +++ b/numbast/src/numbast/static/struct.py @@ -24,6 +24,7 @@ make_struct_ctor_shim, make_struct_conversion_operator_shim, make_struct_regular_method_shim, + _apply_prefix_removal, ) from numbast.errors import TypeNotFoundError @@ -142,12 +143,14 @@ def {lower_scope_name}(shim_stream, shim_obj): def __init__( self, struct_name: str, + python_struct_name: str, struct_type_class: str, struct_type_name: str, header_path: str, ctor_decl: StructMethod, ): self._struct_name = struct_name + self._python_struct_name = python_struct_name self._struct_type_class = struct_type_class self._struct_type_name = struct_type_name self._header_path = header_path @@ -247,7 +250,7 @@ def _render_lowering(self): """Render lowering codes for this struct constructor.""" self._lowering_rendered = self.struct_ctor_lowering_template.format( - struct_name=self._struct_name, + struct_name=self._python_struct_name, param_types=self._nb_param_types_str, struct_type_name=self._struct_type_name, struct_device_caller_name=self._device_caller_name, @@ -336,12 +339,14 @@ def __init__( self, ctor_decls: list[StructMethod], struct_name, + python_struct_name, struct_type_class, struct_type_name, header_path, ): self._ctor_decls = ctor_decls self._struct_name = struct_name + self._python_struct_name = python_struct_name self._struct_type_class = struct_type_class self._struct_type_name = struct_type_name self._header_path = header_path @@ -371,7 +376,7 @@ def _render_typing(self, signature_strs: list[str]): self._struct_ctor_typing_rendered = ( self.struct_ctor_template_typing_template.format( struct_ctor_template_name=self._struct_ctor_template_name, - struct_name=self._struct_name, + struct_name=self._python_struct_name, signatures=signatures_str, ) ) @@ -384,6 +389,7 @@ def _render(self): try: renderer = StaticStructCtorRenderer( struct_name=self._struct_name, + python_struct_name=self._python_struct_name, struct_type_class=self._struct_type_class, struct_type_name=self._struct_type_name, header_path=self._header_path, @@ -670,16 +676,14 @@ def c_rendered(self) -> str: return self._c_rendered -class StaticStructRegularMethodsRenderer(BaseRenderer): - """Renderer for all regular (non-operator) member functions of a struct.""" +class StaticStructRegularMethodRenderer(BaseRenderer): + """Renderer for a single regular method of a struct.""" - # ---- Single method renderer ------------------------------------------------- - class _MethodRenderer(BaseRenderer): - c_ext_shim_var_template = """ + c_ext_shim_var_template = """ shim_raw_str = \"\"\"{shim_rendered}\"\"\" """ - struct_method_device_decl_template = """ + struct_method_device_decl_template = """ {device_decl_name} = declare_device( '{unique_shim_name}', {return_type}( @@ -689,12 +693,12 @@ class _MethodRenderer(BaseRenderer): ) """ - struct_method_device_caller_template = """ + struct_method_device_caller_template = """ def {device_caller_name}({nargs}): return {device_decl_name}({nargs}) """ - struct_method_lowering_template = """ + struct_method_lowering_template = """ @lower("{struct_name}.{method_name}", {struct_type_name}, {param_types}) def _{lower_fn_suffix}(context, builder, sig, args): context.active_code_library.add_linking_file(shim_obj) @@ -716,159 +720,162 @@ def _{lower_fn_suffix}(context, builder, sig, args): ) """ - lowering_body_template = """ + lowering_body_template = """ {shim_var} {decl_device} {lowering} """ - lower_scope_template = """ + lower_scope_template = """ def {lower_scope_name}(shim_stream, shim_obj): {body} {lower_scope_name}(shim_stream, shim_obj) """ - def __init__( - self, - struct_name: str, - struct_type_name: str, - header_path: str, - method_decl: StructMethod, - ): - super().__init__(method_decl) - self._struct_name = struct_name - self._struct_type_name = struct_type_name - self._header_path = header_path - self._method_decl = method_decl - - # Cache Numba param and return types (as strings) - self._nb_param_types = [ - to_numba_type_str(arg.unqualified_non_ref_type_name) - for arg in self._method_decl.param_types - ] - self._nb_param_types_str = ( - ", ".join(map(str, self._nb_param_types)) or "" - ) - self._nb_return_type = to_numba_type_str( - self._method_decl.return_type.unqualified_non_ref_type_name - ) - self._nb_return_type_str = str(self._nb_return_type) + def __init__( + self, + struct_name: str, + python_struct_name: str, + struct_type_name: str, + header_path: str, + method_decl: StructMethod, + ): + super().__init__(method_decl) + self._struct_name = struct_name + self._python_struct_name = python_struct_name + self._struct_type_name = struct_type_name + self._header_path = header_path + self._method_decl = method_decl - # Pointers for interop - def wrap_pointer(typ): - return f"CPointer({typ})" + # Cache Numba param and return types (as strings) + self._nb_param_types = [ + to_numba_type_str(arg.unqualified_non_ref_type_name) + for arg in self._method_decl.param_types + ] + self._nb_param_types_str = ( + ", ".join(map(str, self._nb_param_types)) or "" + ) + self._nb_return_type = to_numba_type_str( + self._method_decl.return_type.unqualified_non_ref_type_name + ) + self._nb_return_type_str = str(self._nb_return_type) - _pointer_wrapped_param_types = [ - wrap_pointer(typ) for typ in self._nb_param_types - ] - self._pointer_wrapped_param_types_str = ", ".join( - _pointer_wrapped_param_types - ) + # Pointers for interop + def wrap_pointer(typ): + return f"CPointer({typ})" - # Unique shim name and helpers - self._unique_shim_name = deduplicate_overloads( - f"__{self._method_decl.mangled_name}_nbst" - ) - self._device_decl_name = ( - f"_method_decl_{self._method_decl.mangled_name}" - ) - self._device_caller_name = ( - f"_device_caller_{self._method_decl.mangled_name}" + _pointer_wrapped_param_types = [ + wrap_pointer(typ) for typ in self._nb_param_types + ] + self._pointer_wrapped_param_types_str = ", ".join( + _pointer_wrapped_param_types + ) + + # Unique shim name and helpers + self._unique_shim_name = deduplicate_overloads( + f"__{self._method_decl.mangled_name}_nbst" + ) + self._device_decl_name = ( + f"_method_decl_{self._method_decl.mangled_name}" + ) + self._device_caller_name = ( + f"_device_caller_{self._method_decl.mangled_name}" + ) + self._lower_fn_suffix = f"lower_{self._method_decl.mangled_name}" + self._lower_scope_name = f"_lower_{self._method_decl.mangled_name}" + + @property + def signature_str(self) -> str: + """signature string for typing with recvr specified""" + recvr = self._struct_type_name + if self._nb_param_types_str: + return ( + f"signature({self._nb_return_type_str}, " + f"{self._nb_param_types_str}, recvr={recvr})" ) - self._lower_fn_suffix = f"lower_{self._method_decl.mangled_name}" - self._lower_scope_name = f"_lower_{self._method_decl.mangled_name}" - - @property - def signature_str(self) -> str: - """signature string for typing with recvr specified""" - recvr = self._struct_type_name - if self._nb_param_types_str: - return ( - f"signature({self._nb_return_type_str}, " - f"{self._nb_param_types_str}, recvr={recvr})" - ) - else: - return f"signature({self._nb_return_type_str}, recvr={recvr})" + else: + return f"signature({self._nb_return_type_str}, recvr={recvr})" - def _render_decl_device(self): - self.Imports.add("from numba.cuda import declare_device") - self.Imports.add("from numba.core.typing import signature") - self.Imports.add("from numba.types import CPointer") + def _render_decl_device(self): + self.Imports.add("from numba.cuda import declare_device") + self.Imports.add("from numba.core.typing import signature") + self.Imports.add("from numba.types import CPointer") - decl_device_rendered = self.struct_method_device_decl_template.format( + decl_device_rendered = self.struct_method_device_decl_template.format( + device_decl_name=self._device_decl_name, + unique_shim_name=self._unique_shim_name, + return_type=self._nb_return_type_str, + struct_type_name=self._struct_type_name, + pointer_wrapped_param_types=self._pointer_wrapped_param_types_str, + ) + + nargs = [f"arg_{i}" for i in range(len(self._method_decl.params) + 1)] + nargs_str = ", ".join(nargs) + device_caller_rendered = ( + self.struct_method_device_caller_template.format( + device_caller_name=self._device_caller_name, + nargs=nargs_str, device_decl_name=self._device_decl_name, - unique_shim_name=self._unique_shim_name, - return_type=self._nb_return_type_str, - struct_type_name=self._struct_type_name, - pointer_wrapped_param_types=self._pointer_wrapped_param_types_str, ) + ) - nargs = [ - f"arg_{i}" for i in range(len(self._method_decl.params) + 1) - ] - nargs_str = ", ".join(nargs) - device_caller_rendered = ( - self.struct_method_device_caller_template.format( - device_caller_name=self._device_caller_name, - nargs=nargs_str, - device_decl_name=self._device_decl_name, - ) - ) + self._decl_device_rendered = ( + decl_device_rendered + "\n" + device_caller_rendered + ) - self._decl_device_rendered = ( - decl_device_rendered + "\n" + device_caller_rendered - ) + def _render_shim_function(self): + self._c_ext_shim_rendered = make_struct_regular_method_shim( + shim_name=self._unique_shim_name, + struct_name=self._struct_name, + method_name=self._method_decl.name, + return_type=self._method_decl.return_type.unqualified_non_ref_type_name, + params=self._method_decl.params, + ) + self._c_ext_shim_var_rendered = self.c_ext_shim_var_template.format( + shim_rendered=self._c_ext_shim_rendered + ) + self.ShimFunctions.append(self._c_ext_shim_rendered) - def _render_shim_function(self): - self._c_ext_shim_rendered = make_struct_regular_method_shim( - shim_name=self._unique_shim_name, - struct_name=self._struct_name, - method_name=self._method_decl.name, - return_type=self._method_decl.return_type.unqualified_non_ref_type_name, - params=self._method_decl.params, - ) - self._c_ext_shim_var_rendered = self.c_ext_shim_var_template.format( - shim_rendered=self._c_ext_shim_rendered - ) - self.ShimFunctions.append(self._c_ext_shim_rendered) + def _render_lowering(self): + self.Imports.add("from numba.cuda.cudaimpl import lower") - def _render_lowering(self): - self.Imports.add("from numba.cuda.cudaimpl import lower") + param_types = self._nb_param_types_str or "" + lowering_rendered = self.struct_method_lowering_template.format( + struct_name=self._python_struct_name, + method_name=self._method_decl.name, + struct_type_name=self._struct_type_name, + param_types=param_types, + device_caller_name=self._device_caller_name, + return_type=self._nb_return_type_str, + pointer_wrapped_param_types=self._pointer_wrapped_param_types_str, + lower_fn_suffix=self._lower_fn_suffix, + unique_shim_name=self._unique_shim_name, + ) + self._lowering_rendered = lowering_rendered - param_types = self._nb_param_types_str or "" - lowering_rendered = self.struct_method_lowering_template.format( - struct_name=self._struct_name, - method_name=self._method_decl.name, - struct_type_name=self._struct_type_name, - param_types=param_types, - device_caller_name=self._device_caller_name, - return_type=self._nb_return_type_str, - pointer_wrapped_param_types=self._pointer_wrapped_param_types_str, - lower_fn_suffix=self._lower_fn_suffix, - unique_shim_name=self._unique_shim_name, - ) - self._lowering_rendered = lowering_rendered + def _render(self): + self._render_decl_device() + self._render_shim_function() + self._render_lowering() - def _render(self): - self._render_decl_device() - self._render_shim_function() - self._render_lowering() + lower_body = self.lowering_body_template.format( + shim_var=self._c_ext_shim_var_rendered, + decl_device=self._decl_device_rendered, + lowering=self._lowering_rendered, + ) + lower_body = indent(lower_body, " " * 4) - lower_body = self.lowering_body_template.format( - shim_var=self._c_ext_shim_var_rendered, - decl_device=self._decl_device_rendered, - lowering=self._lowering_rendered, - ) - lower_body = indent(lower_body, " " * 4) + self._python_rendered = self.lower_scope_template.format( + lower_scope_name=self._lower_scope_name, + body=lower_body, + ) + self._c_rendered = self._c_ext_shim_rendered - self._python_rendered = self.lower_scope_template.format( - lower_scope_name=self._lower_scope_name, - body=lower_body, - ) - self._c_rendered = self._c_ext_shim_rendered - # ---- All methods renderer --------------------------------------------------- +class StaticStructRegularMethodsRenderer(BaseRenderer): + """Renderer for all regular (non-operator) member functions of a struct.""" + method_template_typing_template = """ @register class {method_template_name}(ConcreteTemplate): @@ -879,12 +886,14 @@ class {method_template_name}(ConcreteTemplate): def __init__( self, struct_name: str, + python_struct_name: str, struct_type_name: str, header_path: str, method_decls: list[StructMethod], ): super().__init__(method_decls) self._struct_name = struct_name + self._python_struct_name = python_struct_name self._struct_type_name = struct_type_name self._header_path = header_path self._method_decls = method_decls @@ -904,8 +913,9 @@ def _render(self): # Render per-overload lowering and collect signatures for m in self._method_decls: try: - mr = self._MethodRenderer( + mr = StaticStructRegularMethodRenderer( struct_name=self._struct_name, + python_struct_name=self._python_struct_name, struct_type_name=self._struct_type_name, header_path=self._header_path, method_decl=m, @@ -1044,9 +1054,13 @@ def __init__( parent_type: type | None, data_model: type | None, header_path: os.PathLike | str, + struct_prefix_removal: list[str] = [], aliases: list[str] = [], ): super().__init__(decl) + self._python_struct_name = _apply_prefix_removal( + decl.name, struct_prefix_removal + ) self._struct_name = decl.name self._aliases = aliases @@ -1072,19 +1086,21 @@ def __init__( # We use a prefix here to identify internal objects so that C object names # does not interfere with python's name mangling mechanism. - self._struct_type_class_name = f"_type_class_{self._struct_name}" - self._struct_type_name = f"_type_{self._struct_name}" - self._struct_model_name = f"_model_{self._struct_name}" - self._struct_attr_typing_name = f"_attr_typing_{self._struct_name}" + self._struct_type_class_name = f"_type_class_{self._python_struct_name}" + self._struct_type_name = f"_type_{self._python_struct_name}" + self._struct_model_name = f"_model_{self._python_struct_name}" + self._struct_attr_typing_name = ( + f"_attr_typing_{self._python_struct_name}" + ) self._header_path = header_path - CTYPE_TO_NBTYPE_STR[decl.name] = self._struct_type_name + CTYPE_TO_NBTYPE_STR[self._struct_name] = self._struct_type_name # Track the public symbols that should be exposed via a # struct creation self._nbtype_symbols.append(self._struct_type_name) - self._record_symbols.append(self._struct_name) + self._record_symbols.append(self._python_struct_name) def _render_typing(self): """Render typing of the struct.""" @@ -1102,7 +1118,7 @@ def _render_typing(self): struct_type_class_name=self._struct_type_class_name, struct_type_name=self._struct_type_name, parent_type=self._parent_type_str, - struct_name=self._struct_name, + struct_name=self._python_struct_name, struct_alignof=self._decl.alignof_, struct_sizeof=self._decl.sizeof_, implicit_conversion_types=implicit_conversion_types, @@ -1117,7 +1133,7 @@ def _render_python_api(self): self._python_api_rendered = self.python_api_template.format( struct_type_name=self._struct_type_name, - struct_name=self._struct_name, + struct_name=self._python_struct_name, ) def _render_data_model(self): @@ -1218,6 +1234,7 @@ def _render_regular_methods(self): """Render regular member functions of the struct.""" static_methods_renderer = StaticStructRegularMethodsRenderer( struct_name=self._struct_name, + python_struct_name=self._python_struct_name, struct_type_name=self._struct_type_name, header_path=self._header_path, method_decls=self._decl.regular_member_functions(), @@ -1235,6 +1252,7 @@ def _render_struct_ctors(self): """Render constructors of the struct.""" static_ctors_renderer = StaticStructCtorsRenderer( struct_name=self._struct_name, + python_struct_name=self._python_struct_name, struct_type_class=self._struct_type_class_name, struct_type_name=self._struct_type_name, header_path=self._header_path, @@ -1398,7 +1416,13 @@ def _render( f"CUDA struct {name} does not provide a header path." ) - SSR = StaticStructRenderer(decl, nb_ty, nb_datamodel, header_path) + SSR = StaticStructRenderer( + decl, + nb_ty, + nb_datamodel, + header_path, + self._struct_prefix_removal, + ) self._python_rendered.append(SSR.render_python()) self._c_rendered.append(SSR.render_c()) diff --git a/numbast/src/numbast/tools/static_binding_generator.py b/numbast/src/numbast/tools/static_binding_generator.py index 795e0ce6..fadd7a7b 100644 --- a/numbast/src/numbast/tools/static_binding_generator.py +++ b/numbast/src/numbast/tools/static_binding_generator.py @@ -373,7 +373,14 @@ def _typedef_to_aliases(typedef_decls: list[Typedef]) -> dict[str, list[str]]: return aliases -def _generate_structs(struct_decls, header_path, types, data_models, excludes): +def _generate_structs( + struct_decls, + header_path, + types, + data_models, + struct_prefix_removal, + excludes, +): """Convert CLI inputs into structure that fits `StaticStructsRenderer` and create struct bindings.""" specs = {} for struct_decl in struct_decls: @@ -382,7 +389,12 @@ def _generate_structs(struct_decls, header_path, types, data_models, excludes): this_data_model = data_models.get(struct_name, None) specs[struct_name] = (this_type, this_data_model, header_path) - SSR = StaticStructsRenderer(struct_decls, specs, excludes=excludes) + SSR = StaticStructsRenderer( + struct_decls, + specs, + struct_prefix_removal=struct_prefix_removal, + excludes=excludes, + ) return SSR.render_as_str(with_imports=False, with_shim_stream=False) @@ -517,6 +529,7 @@ def _static_binding_generator( entry_point, config.types, config.datamodels, + config.api_prefix_removal.get("Struct", []), config.exclude_structs, ) diff --git a/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 b/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 index 711db61c..13d1b180 100644 --- a/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 +++ b/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 @@ -9,8 +9,12 @@ GPU Arch: File List: - {{ data }} Exclude: {} -Types: {} -Data Models: {} +Types: + __internal__Foo: Type +Data Models: + __internal__Foo: StructModel API Prefix Removal: Function: - prefix_ + Struct: + - __internal__ diff --git a/numbast/src/numbast/tools/tests/prefix_removal.cuh b/numbast/src/numbast/tools/tests/prefix_removal.cuh index 343411cb..575d0176 100644 --- a/numbast/src/numbast/tools/tests/prefix_removal.cuh +++ b/numbast/src/numbast/tools/tests/prefix_removal.cuh @@ -4,3 +4,9 @@ // clang-format on int __device__ prefix_foo(int a, int b) { return a + b; } + +struct __internal__Foo { + int x; + __device__ __internal__Foo() : x(0) {} + __device__ int get_x() { return x; } +}; diff --git a/numbast/src/numbast/tools/tests/test_prefix_removal.py b/numbast/src/numbast/tools/tests/test_prefix_removal.py index 2789d994..a8e4c77b 100644 --- a/numbast/src/numbast/tools/tests/test_prefix_removal.py +++ b/numbast/src/numbast/tools/tests/test_prefix_removal.py @@ -1,9 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import os -import subprocess -import sys +from numba import cuda def test_prefix_removal(run_in_isolated_folder, arch_str): @@ -17,41 +15,34 @@ def test_prefix_removal(run_in_isolated_folder, arch_str): ) run_result = res["result"] - output_folder = res["output_folder"] symbols = res["symbols"] alls = symbols["__all__"] + binding_path = res["binding_path"] + print(f"{binding_path=}") + assert run_result.exit_code == 0 # Verify that the function is exposed as "foo" (without the "prefix_" prefix) assert "foo" in alls, f"Expected 'foo' in __all__, got: {alls}" + assert "Foo" in alls, f"Expected 'Foo' in __all__, got: {alls}" # Verify that the original name "prefix_foo" is NOT exposed assert "prefix_foo" not in alls, ( f"Expected 'prefix_foo' NOT in __all__, got: {alls}" ) + assert "__internal__Foo" not in alls, ( + f"Expected '__internal__Foo' NOT in __all__, got: {alls}" + ) - # Test that the function can be imported and used as "foo" - test_kernel_src = """ -from numba import cuda -from prefix_removal import foo - -@cuda.jit -def kernel(): - result = foo(1, 2) # Verify that prefix_foo is accessible as foo - -kernel[1, 1]() -""" - - test_kernel = os.path.join(output_folder, "test.py") - with open(test_kernel, "w") as f: - f.write(test_kernel_src) + foo = symbols["foo"] + Foo = symbols["Foo"] - res = subprocess.run( - [sys.executable, test_kernel], - cwd=output_folder, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) + @cuda.jit + def kernel(): + result = foo(1, 2) # noqa: F841 + foo_obj = Foo() + x = foo_obj.get_x() # noqa: F841 + x2 = foo_obj.x # noqa: F841 - assert res.returncode == 0, res.stdout.decode("utf-8") + kernel[1, 1]() From 85271f72d2ab271caba9b5a3ab4dcbacb74bf612 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 20 Nov 2025 16:06:59 -0800 Subject: [PATCH 05/15] add __nv_bfloat16 test to existing support --- numbast/src/numbast/static/tests/data/bf16.cuh | 4 ++++ numbast/src/numbast/static/tests/test_bf16_support.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/numbast/src/numbast/static/tests/data/bf16.cuh b/numbast/src/numbast/static/tests/data/bf16.cuh index 6ad17c5d..75b06f60 100644 --- a/numbast/src/numbast/static/tests/data/bf16.cuh +++ b/numbast/src/numbast/static/tests/data/bf16.cuh @@ -10,4 +10,8 @@ nv_bfloat16 inline __device__ add(nv_bfloat16 a, nv_bfloat16 b) { return a + b; } +__nv_bfloat16 inline __device__ add2(__nv_bfloat16 a, __nv_bfloat16 b) { + return a + b; +} + #endif diff --git a/numbast/src/numbast/static/tests/test_bf16_support.py b/numbast/src/numbast/static/tests/test_bf16_support.py index 2dc281a9..c58247d0 100644 --- a/numbast/src/numbast/static/tests/test_bf16_support.py +++ b/numbast/src/numbast/static/tests/test_bf16_support.py @@ -10,16 +10,19 @@ def test_bindings_from_bf16(make_binding): binding1 = res1["bindings"] add = binding1["add"] + add2 = binding1["add2"] @cuda.jit def kernel(arr): x = add(bfloat16(3.14), bfloat16(3.14)) arr[0] = float32(x) + arr[1] = float32(add2(bfloat16(3.14), bfloat16(3.14))) - arr = cuda.device_array((1,), dtype="float32") + arr = cuda.device_array((2,), dtype="float32") kernel[1, 1](arr) assert pytest.approx(arr[0], 1e-2) == 6.28 + assert pytest.approx(arr[1], 1e-2) == 6.28 # Check that bfloat16 is imported assert "from numba.cuda.types import bfloat16" in res1["src"] From 158288030fd4c851071f48cf6d7a78f8a7bb121c Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 20 Nov 2025 16:07:47 -0800 Subject: [PATCH 06/15] support removal of enum prefix, and expose enums to __all__ --- numbast/src/numbast/static/enum.py | 20 +++++++++++++------ numbast/src/numbast/static/renderer.py | 20 +++++++++++++++++-- .../numbast/tools/static_binding_generator.py | 10 +++++++--- .../tools/tests/config/prefix_removal.yml.j2 | 2 ++ .../numbast/tools/tests/prefix_removal.cuh | 5 +++++ .../tools/tests/test_prefix_removal.py | 4 ++++ 6 files changed, 50 insertions(+), 11 deletions(-) diff --git a/numbast/src/numbast/static/enum.py b/numbast/src/numbast/static/enum.py index efe7ebfa..048ebebb 100644 --- a/numbast/src/numbast/static/enum.py +++ b/numbast/src/numbast/static/enum.py @@ -9,6 +9,7 @@ from numbast.static.renderer import BaseRenderer, get_rendered_imports from numbast.static.types import register_enum_type_str +from numbast.utils import _apply_prefix_removal file_logger = getLogger(f"{__name__}") logger_path = os.path.join(tempfile.gettempdir(), "test.py") @@ -28,13 +29,19 @@ class {enum_name}(IntEnum): """ enumerator_template = " {enumerator} = {value}" - def __init__(self, decl: Enum): + def __init__(self, decl: Enum, enum_prefix_removal: list[str] = []): self._decl = decl + self._enum_name = _apply_prefix_removal( + self._decl.name, enum_prefix_removal + ) + + self._enum_symbols.append(self._enum_name) + def _render(self): self.Imports.add("from enum import IntEnum") - register_enum_type_str(self._decl.name) + register_enum_type_str(self._enum_name) enumerators = [] for enumerator, value in zip( @@ -47,7 +54,7 @@ def _render(self): ) self._python_rendered = self.enum_template.format( - enum_name=self._decl.name, enumerators="\n".join(enumerators) + enum_name=self._enum_name, enumerators="\n".join(enumerators) ) @@ -57,18 +64,19 @@ class StaticEnumsRenderer(BaseRenderer): Since enums creates a new C++ type. It should be invoked before making struct / function bindings. """ - def __init__(self, decls: list[Enum]): + def __init__(self, decls: list[Enum], enum_prefix_removal: list[str] = []): super().__init__(decls) self._decls = decls + self._enum_prefix_removal = enum_prefix_removal - self._python_rendered = [] + self._python_rendered: list[tuple[set[str], str]] = [] def _render(self, with_imports): """Render python bindings for enums.""" self._python_str = "" for decl in self._decls: - SER = StaticEnumRenderer(decl) + SER = StaticEnumRenderer(decl, self._enum_prefix_removal) SER._render() self._python_rendered.append(SER._python_rendered) diff --git a/numbast/src/numbast/static/renderer.py b/numbast/src/numbast/static/renderer.py index 1fbcbc4b..15450a2f 100644 --- a/numbast/src/numbast/static/renderer.py +++ b/numbast/src/numbast/static/renderer.py @@ -83,6 +83,9 @@ def reset(self): _function_symbols: list[str] = [] """List of new function handles to expose.""" + _enum_symbols: list[str] = [] + """List of new enum handles to expose.""" + def __init__(self, decl): self.Imports.add("import numba") self.Imports.add("import io") @@ -245,6 +248,18 @@ def _get_function_symbols() -> str: return code +def _get_enum_symbols() -> str: + template = """ +_ENUM_SYMBOLS = [{enum_symbols}] +""" + + symbols = BaseRenderer._enum_symbols + quote_wrapped = [f'"{s}"' for s in symbols] + concat = ",".join(quote_wrapped) + code = template.format(enum_symbols=concat) + return code + + def get_all_exposed_symbols() -> str: """Return the definition of all exposed symbols via `__all__`. @@ -257,13 +272,14 @@ def get_all_exposed_symbols() -> str: nbtype_symbols = _get_nbtype_symbols() record_symbols = _get_record_symbols() function_symbols = _get_function_symbols() + enum_symbols = _get_enum_symbols() all_symbols = f""" {nbtype_symbols} {record_symbols} {function_symbols} - -__all__ = _NBTYPE_SYMBOLS + _RECORD_SYMBOLS + _FUNCTION_SYMBOLS +{enum_symbols} +__all__ = _NBTYPE_SYMBOLS + _RECORD_SYMBOLS + _FUNCTION_SYMBOLS + _ENUM_SYMBOLS """ return all_symbols diff --git a/numbast/src/numbast/tools/static_binding_generator.py b/numbast/src/numbast/tools/static_binding_generator.py index fadd7a7b..a65e6c6b 100644 --- a/numbast/src/numbast/tools/static_binding_generator.py +++ b/numbast/src/numbast/tools/static_binding_generator.py @@ -421,9 +421,11 @@ def _generate_functions( return SFR.render_as_str(with_imports=False, with_shim_stream=False) -def _generate_enums(enum_decls: list[Enum]): +def _generate_enums( + enum_decls: list[Enum], enum_prefix_removal: list[str] = [] +): """Create enum bindings.""" - SER = StaticEnumsRenderer(enum_decls) + SER = StaticEnumsRenderer(enum_decls, enum_prefix_removal) return SER.render_as_str(with_imports=False, with_shim_stream=False) @@ -523,7 +525,9 @@ def _static_binding_generator( aliases = _typedef_to_aliases(typedefs) rendered_aliases = render_aliases(aliases) - enum_bindings = _generate_enums(enums) + enum_bindings = _generate_enums( + enums, config.api_prefix_removal.get("Enum", []) + ) struct_bindings = _generate_structs( structs, entry_point, diff --git a/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 b/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 index 13d1b180..c98f348e 100644 --- a/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 +++ b/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 @@ -18,3 +18,5 @@ API Prefix Removal: - prefix_ Struct: - __internal__ + Enum: + - __internal__ diff --git a/numbast/src/numbast/tools/tests/prefix_removal.cuh b/numbast/src/numbast/tools/tests/prefix_removal.cuh index 575d0176..9cf17a12 100644 --- a/numbast/src/numbast/tools/tests/prefix_removal.cuh +++ b/numbast/src/numbast/tools/tests/prefix_removal.cuh @@ -10,3 +10,8 @@ struct __internal__Foo { __device__ __internal__Foo() : x(0) {} __device__ int get_x() { return x; } }; + +enum __internal__Bar { + BAR_A, + BAR_B, +}; diff --git a/numbast/src/numbast/tools/tests/test_prefix_removal.py b/numbast/src/numbast/tools/tests/test_prefix_removal.py index a8e4c77b..b0f7bf13 100644 --- a/numbast/src/numbast/tools/tests/test_prefix_removal.py +++ b/numbast/src/numbast/tools/tests/test_prefix_removal.py @@ -37,6 +37,7 @@ def test_prefix_removal(run_in_isolated_folder, arch_str): foo = symbols["foo"] Foo = symbols["Foo"] + Bar = symbols["Bar"] @cuda.jit def kernel(): @@ -45,4 +46,7 @@ def kernel(): x = foo_obj.get_x() # noqa: F841 x2 = foo_obj.x # noqa: F841 + bar = Bar.BAR_A # noqa: F841 + bar2 = Bar.BAR_B # noqa: F841 + kernel[1, 1]() From 70001df153fa05474b1b23b62dad4df5818e737c Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 3 Dec 2025 17:14:34 -0800 Subject: [PATCH 07/15] add bfloat16_raw type to type system --- numbast/src/numbast/static/renderer.py | 6 ++++++ numbast/src/numbast/static/types.py | 4 ++++ numbast/src/numbast/types.py | 6 ++++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/numbast/src/numbast/static/renderer.py b/numbast/src/numbast/static/renderer.py index 5e337088..8544e1da 100644 --- a/numbast/src/numbast/static/renderer.py +++ b/numbast/src/numbast/static/renderer.py @@ -109,6 +109,12 @@ def _try_import_numba_type(cls, typ: str): cls.Imports.add("from numba.cuda.types import bfloat16") cls._imported_numba_types.add(typ) + if typ == "__nv_bfloat16_raw": + cls.Imports.add( + "from numba.cuda._internal.cuda_bf16 import _type_unnamed1405307 as bfloat16_raw_type" + ) + cls._imported_numba_types.add(typ) + elif typ in vector_types: # CUDA target specific types cls.Imports.add("from numba.cuda.vector_types import vector_types") diff --git a/numbast/src/numbast/static/types.py b/numbast/src/numbast/static/types.py index 60d5459a..feb6fa8b 100644 --- a/numbast/src/numbast/static/types.py +++ b/numbast/src/numbast/static/types.py @@ -59,6 +59,10 @@ def to_numba_type_str(ty: str): BaseRenderer._try_import_numba_type("__nv_bfloat16") return "bfloat16" + if ty == "__nv_bfloat16_raw": + BaseRenderer._try_import_numba_type("__nv_bfloat16_raw") + return "bfloat16_raw_type" + if ty.endswith("*"): base_ty = ty.rstrip("*").rstrip(" ") ptr_ty_str = f"CPointer({to_numba_type_str(base_ty)})" diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py index 8c5ce316..43112229 100644 --- a/numbast/src/numbast/types.py +++ b/numbast/src/numbast/types.py @@ -9,6 +9,8 @@ from numba.cuda.vector_types import vector_types from numba.misc.special import typeof +from numba.cuda._internal.cuda_bf16 import _type_unnamed1405307 + class FunctorType(nbtypes.Type): def __init__(self, name): @@ -83,8 +85,8 @@ def register_enum_type(cxx_name: str, e: IntEnum): def to_numba_type(ty: str): - if ty == "__nv_bfloat16": - return bfloat16 + if ty == "__nv_bfloat16_raw": + return _type_unnamed1405307 if "FunctorType" in ty: return FunctorType(ty[:-11]) From 3979ad21f776c9ef1cca2a2523556017a0e7f4f7 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 3 Dec 2025 17:28:16 -0800 Subject: [PATCH 08/15] use target registry lower_cast registerer in separate registry mode --- numbast/src/numbast/static/renderer.py | 3 ++- numbast/src/numbast/static/struct.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/numbast/src/numbast/static/renderer.py b/numbast/src/numbast/static/renderer.py index 8544e1da..cee8bb9c 100644 --- a/numbast/src/numbast/static/renderer.py +++ b/numbast/src/numbast/static/renderer.py @@ -22,6 +22,7 @@ class BaseRenderer: lower = target_registry.lower lower_attr = target_registry.lower_getattr lower_constant = target_registry.lower_constant +lower_cast = target_registry.lower_cast """ KeyedStringIO = """ @@ -347,7 +348,7 @@ def registry_setup(use_separate_registry: bool) -> str: "from numba.cuda.typing.templates import Registry as TypingRegistry" ) BaseRenderer.Imports.add( - "from numba.core.imputils import Registry as TargetRegistry, lower_cast" + "from numba.core.imputils import Registry as TargetRegistry" ) return BaseRenderer.SeparateRegistrySetup else: diff --git a/numbast/src/numbast/static/struct.py b/numbast/src/numbast/static/struct.py index 9c53d20e..6f791991 100644 --- a/numbast/src/numbast/static/struct.py +++ b/numbast/src/numbast/static/struct.py @@ -608,8 +608,6 @@ def _render_shim_function(self): def _render_lowering(self): """Render lowering codes for this struct constructor.""" - self.Imports.add("from numba.core.extending import lower_cast") - self._lowering_rendered = ( self.struct_conversion_op_lowering_template.format( struct_name=self._struct_name, From df501fe53d1fffc63a457a5f874cdc2a8bd85d38 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 4 Dec 2025 13:14:26 -0800 Subject: [PATCH 09/15] pin numba-cuda versions to >=0.21.0 --- ci/test_conda_python.sh | 2 +- conda/environment-cu12.yaml | 2 +- conda/environment-cu13.yaml | 2 +- conda/environment_template.yaml | 2 +- conda/recipes/numbast/meta.yaml | 2 +- conda/recipes/numbast_extensions/meta.yaml | 2 +- numbast/pyproject.toml | 6 +++--- numbast_extensions/pyproject.toml | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ci/test_conda_python.sh b/ci/test_conda_python.sh index 73c88e17..75a9c7bf 100755 --- a/ci/test_conda_python.sh +++ b/ci/test_conda_python.sh @@ -23,7 +23,7 @@ rapids-mamba-retry create -n test \ cuda-version=${RAPIDS_CUDA_VERSION%.*} \ cuda-nvrtc \ numba >=0.59 \ - "numba-cuda>=0.20.1,<0.21.0" \ + "numba-cuda>=0.21.0,<0.23.0" \ cuda-cudart-dev \ python=${RAPIDS_PY_VERSION} \ cffi \ diff --git a/conda/environment-cu12.yaml b/conda/environment-cu12.yaml index 970b4c12..f8debd67 100644 --- a/conda/environment-cu12.yaml +++ b/conda/environment-cu12.yaml @@ -9,7 +9,7 @@ dependencies: - python>=3.10 - cmake>=3.29 # Allow overriding CMAKE_PREFIX_PATH and CMAKE_INSTALL_PREFIX - clangdev >=18,<22.0 - - numba-cuda[cu12] >=0.20.1,<0.21.0 + - numba-cuda[cu12] >=0.21.0,<0.23.0 - pybind11 - pytest - pytest-benchmark diff --git a/conda/environment-cu13.yaml b/conda/environment-cu13.yaml index d7be7861..a02c8935 100644 --- a/conda/environment-cu13.yaml +++ b/conda/environment-cu13.yaml @@ -9,7 +9,7 @@ dependencies: - python>=3.10 - cmake>=3.29 # Allow overriding CMAKE_PREFIX_PATH and CMAKE_INSTALL_PREFIX - clangdev >=18,<22.0 - - numba-cuda[cu13] >=0.20.1,<0.21.0 + - numba-cuda[cu13] >=0.21.0,<0.23.0 - pybind11 - pytest - pytest-benchmark diff --git a/conda/environment_template.yaml b/conda/environment_template.yaml index cbc83926..f9374889 100644 --- a/conda/environment_template.yaml +++ b/conda/environment_template.yaml @@ -9,7 +9,7 @@ dependencies: - python={{ python_version }} - cmake>=3.29 # Allow overriding CMAKE_PREFIX_PATH and CMAKE_INSTALL_PREFIX - clangdev >=18,<22.0 - - numba-cuda >=0.20.1,<0.21.0 + - numba-cuda >=0.21.0,<0.23.0 - pybind11 - pytest - pytest-benchmark diff --git a/conda/recipes/numbast/meta.yaml b/conda/recipes/numbast/meta.yaml index 688cf62d..d9794ff5 100644 --- a/conda/recipes/numbast/meta.yaml +++ b/conda/recipes/numbast/meta.yaml @@ -52,7 +52,7 @@ requirements: - cuda-version >=12.5 - numba >=0.59 - python - - numba-cuda >=0.20.1,<0.21.0 + - numba-cuda >=0.21.0,<0.23.0 - pyyaml - click - jinja2 diff --git a/conda/recipes/numbast_extensions/meta.yaml b/conda/recipes/numbast_extensions/meta.yaml index a3fcd0ba..b15e924b 100644 --- a/conda/recipes/numbast_extensions/meta.yaml +++ b/conda/recipes/numbast_extensions/meta.yaml @@ -43,7 +43,7 @@ requirements: - cuda-cudart-dev - python - numba >=0.59 - - numba-cuda >=0.20.1,<0.21.0 + - numba-cuda >=0.21.0,<0.23.0 - numbast >=0.2.0 test: diff --git a/numbast/pyproject.toml b/numbast/pyproject.toml index d8394d27..af490037 100644 --- a/numbast/pyproject.toml +++ b/numbast/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ ] dependencies = [ "numba>=0.59.0", - "numba-cuda>=0.20.1,<0.21.0", + "numba-cuda>=0.21.0,<0.23.0", "ast_canopy>=0.5.0", "pyyaml", "click", @@ -41,12 +41,12 @@ repository = "https://github.com/NVIDIA/numbast" [project.optional-dependencies] dev = ["ruff"] docs = ["sphinx>=7.0", "sphinx-copybutton>=0.5.2"] -test-cu12 = ["pytest", "cffi", "numba-cuda[cu12]>=0.20.1,<0.21.0"] +test-cu12 = ["pytest", "cffi", "numba-cuda[cu12]>=0.21.0,<0.23.0"] test-cu13 = [ "pytest", "cffi", "cuda-toolkit[cudart,crt,curand,cccl,nvcc]==13.*", - "numba-cuda[cu13]>=0.20.1,<0.21.0", + "numba-cuda[cu13]>=0.21.0,<0.23.0", ] diff --git a/numbast_extensions/pyproject.toml b/numbast_extensions/pyproject.toml index bb8fc9a7..6065c596 100644 --- a/numbast_extensions/pyproject.toml +++ b/numbast_extensions/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ] -dependencies = ["numba-cuda>=0.20.1,<0.21.0", "numbast>=0.2.0"] +dependencies = ["numba-cuda>=0.21.0,<0.23.0", "numbast>=0.2.0"] [project.urls] homepage = "https://github.com/NVIDIA/numbast" From 5b10afb70816af4532dbff30d44d3ecbbff5284d Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 4 Dec 2025 13:35:55 -0800 Subject: [PATCH 10/15] document version requirements for bindings v.s. numba-cuda --- docs/source/faq.rst | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/docs/source/faq.rst b/docs/source/faq.rst index 96a0fc8a..8c694ed7 100644 --- a/docs/source/faq.rst +++ b/docs/source/faq.rst @@ -122,3 +122,48 @@ Notes libc++, ensure its headers are installed (e.g., ``/usr/include/c++/v1`` or Conda ``libcxx-devel``) and pass ``-stdlib=libc++``; otherwise libstdc++ is typically selected by default on Linux. - AST Canopy is linked against Clang 20. For host resources, please use corresponding Clang version. + + +Generated bindings and Numba-CUDA version requirements +------------------------------------------------------- + +**What version of numba-cuda do generated bindings require?** + +Bindings generated by Numbast have specific version requirements for ``numba-cuda`` at runtime. The version of +Numbast used to generate the bindings determines the compatible ``numba-cuda`` versions. + +.. list-table:: Numbast to numba-cuda compatibility + :header-rows: 1 + :widths: 30 70 + + * - Numbast Version + - Required numba-cuda Version + * - 0.6.0 (current dev) + - ``>=0.21.0,<0.23.0`` + * - 0.5.x + - ``>=0.20.1,<0.21.0`` + +**Why do generated bindings have version requirements?** + +Numbast generates Python code that uses Numba-CUDA's internal APIs. These APIs can change between releases, so +bindings generated with a specific version of Numbast are tested against a specific range of ``numba-cuda`` versions. + +**How do I ensure compatibility?** + +*For dynamic binding generation:* + +- The correct ``numba-cuda`` version constraints are automatically enforced at the package dependency level + and managed by your package manager (pip or conda). When you install Numbast, compatible versions of + ``numba-cuda`` are installed automatically via the dependencies specified in ``pyproject.toml`` and Conda + environment files. + +*For static binding generation:* + +- When distributing generated bindings, document the required ``numba-cuda`` version range in your package + dependencies so users can install a compatible version. +- Generated static bindings (see :doc:`Static binding generation `) can be regenerated with newer + versions of Numbast if you need to support newer ``numba-cuda`` releases. + +.. note:: + These version restrictions may be relaxed or removed once ``numba-cuda`` releases a stable 1.0 version with + stabilized public APIs. Until then, bindings are tested against specific version ranges to ensure compatibility. From 0fe0bf048c266e37c3114e10946e28c9c84b4b5e Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 4 Dec 2025 14:53:02 -0800 Subject: [PATCH 11/15] check on numba.cuda.types --- numbast/src/numbast/static/renderer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numbast/src/numbast/static/renderer.py b/numbast/src/numbast/static/renderer.py index 1cec1e84..c2842069 100644 --- a/numbast/src/numbast/static/renderer.py +++ b/numbast/src/numbast/static/renderer.py @@ -122,7 +122,7 @@ def _try_import_numba_type(cls, typ: str): cls.Imported_VectorTypes.append(typ) cls._imported_numba_types.add(typ) - elif typ in numba.types.__dict__: + elif typ in numba.cuda.types.__dict__: cls.Imports.add(f"from numba.cuda.types import {typ}") cls._imported_numba_types.add(typ) From fa5f7afa2cde5175ee5b5e8cf46e07f01ff0d0e8 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 4 Dec 2025 15:04:02 -0800 Subject: [PATCH 12/15] avoid unused parameter --- numbast/tests/data/sample_function.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numbast/tests/data/sample_function.cuh b/numbast/tests/data/sample_function.cuh index 9083cc4c..cc0ba3fc 100644 --- a/numbast/tests/data/sample_function.cuh +++ b/numbast/tests/data/sample_function.cuh @@ -5,4 +5,4 @@ #pragma once -void __device__ func_with_void_ptr_arg(void *ptr) {} +__device__ void func_with_void_ptr_arg(void *ptr) { (void)ptr; } From 0375b487344f1b293b5df15aa0d99211bcdd2496 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 4 Dec 2025 15:25:21 -0800 Subject: [PATCH 13/15] use numba_typeref_ctor from vended module --- numbast/src/numbast/class_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numbast/src/numbast/class_template.py b/numbast/src/numbast/class_template.py index 60cf81a0..3a37c37e 100644 --- a/numbast/src/numbast/class_template.py +++ b/numbast/src/numbast/class_template.py @@ -24,7 +24,7 @@ from numba.cuda import declare_device from numba.cuda.cudadecl import register_global, register, register_attr from numba.cuda.cudaimpl import lower -from numba.core.imputils import numba_typeref_ctor +from numba.cuda.core.imputils import numba_typeref_ctor from numba.core.typing.npydecl import parse_dtype from numba.core.errors import RequireLiteralValue, TypingError From ac0afbff35cae6b9e2c69f4931a1d9f1176eeb22 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Tue, 9 Dec 2025 11:35:53 -0800 Subject: [PATCH 14/15] adding cudaroundmode enum to type system --- numbast/src/numbast/static/enum.py | 5 +++-- numbast/src/numbast/static/types.py | 20 ++++++++++++++++++-- numbast/src/numbast/types.py | 16 +++++++++++++--- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/numbast/src/numbast/static/enum.py b/numbast/src/numbast/static/enum.py index ce9bd63c..14f200fd 100644 --- a/numbast/src/numbast/static/enum.py +++ b/numbast/src/numbast/static/enum.py @@ -62,8 +62,9 @@ def _render(self): - Writes the resulting Python class source into `self._python_rendered`. """ self.Imports.add("from enum import IntEnum") - self.Imports.add("from numba.types import IntEnumMember") - self.Imports.add("from numba.types import int64") + + BaseRenderer._try_import_numba_type("IntEnumMember") + BaseRenderer._try_import_numba_type("int64") register_enum_type_str(self._decl.name, self._enum_name) diff --git a/numbast/src/numbast/static/types.py b/numbast/src/numbast/static/types.py index feb6fa8b..be9d181e 100644 --- a/numbast/src/numbast/static/types.py +++ b/numbast/src/numbast/static/types.py @@ -16,17 +16,26 @@ CTYPE_TO_NBTYPE_STR = copy.deepcopy(_DEFAULT_CTYPE_TO_NBTYPE_STR_MAP) -def register_enum_type_str(ctype_enum_name: str, enum_name: str): +def register_enum_type_str( + ctype_enum_name: str, enum_name: str, underlying_integer_type: str = "int32" +): """ Register a mapping from a C++ enum type name to its corresponding Numba type string. Parameters: ctype_enum_name (str): The C++ enum type name to register (as it appears in C/C++ headers). enum_name (str): The enum identifier to use inside the generated Numba type string (becomes the first argument to `IntEnumMember`). + underlying_integer_type (str): The underlying integer type to use for the enum. """ global CTYPE_TO_NBTYPE_STR - CTYPE_TO_NBTYPE_STR[ctype_enum_name] = f"IntEnumMember({enum_name}, int64)" + CTYPE_TO_NBTYPE_STR[ctype_enum_name] = ( + f"IntEnumMember({enum_name}, {underlying_integer_type})" + ) + + +# Add additional enum type mappings here +register_enum_type_str("cudaRoundMode", "cudaRoundMode", "int32") def reset_types(): @@ -55,6 +64,13 @@ def to_numba_type_str(ty: str): The corresponding string representing a Numba type """ + if ty == "cudaRoundMode": + BaseRenderer.Imports.add( + "from cuda.bindings.runtime import cudaRoundMode" + ) + BaseRenderer._try_import_numba_type("IntEnumMember") + return CTYPE_TO_NBTYPE_STR[ty] + if ty == "__nv_bfloat16": BaseRenderer._try_import_numba_type("__nv_bfloat16") return "bfloat16" diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py index 15fcb896..2a1aadc7 100644 --- a/numbast/src/numbast/types.py +++ b/numbast/src/numbast/types.py @@ -6,11 +6,13 @@ from numba import types as nbtypes from numba.cuda.types import bfloat16 -from numba.cuda.typing.typeof import typeof from numba.cuda.vector_types import vector_types + from numba.cuda._internal.cuda_bf16 import _type_unnamed1405307 +from cuda.bindings import runtime + class FunctorType(nbtypes.Type): def __init__(self, name): @@ -78,10 +80,14 @@ def __init__(self, name): } -def register_enum_type(cxx_name: str, e: IntEnum): +def register_enum_type( + cxx_name: str, + e: IntEnum, + underlying_integer_type: nbtypes.Type = nbtypes.int32, +): global CTYPE_MAPS - CTYPE_MAPS[cxx_name] = typeof(e) + CTYPE_MAPS[cxx_name] = nbtypes.IntEnumMember(e, underlying_integer_type) def to_numba_type(ty: str): @@ -125,3 +131,7 @@ def is_c_integral_type(typ_str: str) -> bool: def is_c_floating_type(typ_str: str) -> bool: return typ_str in FLOATING_TYPE_MAPS + + +# Register CUDA Python Types +register_enum_type("cudaRoundMode", runtime.cudaRoundMode, nbtypes.int32) From 3ea27f9e58a644a5dcde54fe26fad3f9f8af8b03 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Fri, 26 Dec 2025 12:19:46 -0800 Subject: [PATCH 15/15] pass through clang binary path to tooling --- ast_canopy/ast_canopy/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ast_canopy/ast_canopy/api.py b/ast_canopy/ast_canopy/api.py index f66a49d7..0af065a5 100644 --- a/ast_canopy/ast_canopy/api.py +++ b/ast_canopy/ast_canopy/api.py @@ -433,7 +433,7 @@ def parse_declarations_from_source( # 3. CUDA Toolkit include directories # 4. Additional include directories command_line_options = [ - "clang++", + clang_binary if not None else "clang++", *clang_verbose_flag, "--cuda-device-only", "-xcuda",