diff --git a/numbast/src/numbast/static/enum.py b/numbast/src/numbast/static/enum.py index 9c2da2dc..0ad55d06 100644 --- a/numbast/src/numbast/static/enum.py +++ b/numbast/src/numbast/static/enum.py @@ -9,6 +9,7 @@ from numbast.static.renderer import BaseRenderer, get_rendered_imports from numbast.static.types import register_enum_type_str +from numbast.utils import _apply_prefix_removal file_logger = getLogger(f"{__name__}") logger_path = os.path.join(tempfile.gettempdir(), "test.py") @@ -28,28 +29,38 @@ class {enum_name}(IntEnum): """ enumerator_template = " {enumerator} = {value}" - def __init__(self, decl: Enum): + def __init__( + self, decl: Enum, enum_prefix_removal: list[str] | None = None + ): self._decl = decl + self._enum_prefix_removal = enum_prefix_removal or [] + + self._enum_name = _apply_prefix_removal( + self._decl.name, self._enum_prefix_removal + ) + + self._enum_symbols.append(self._enum_name) def _render(self): self.Imports.add("from enum import IntEnum") self.Imports.add("from numba.types import IntEnumMember") self.Imports.add("from numba.types import int64") - register_enum_type_str(self._decl.name, self._decl.name) + register_enum_type_str(self._decl.name, self._enum_name) enumerators = [] for enumerator, value in zip( self._decl.enumerators, self._decl.enumerator_values ): + py_name = _apply_prefix_removal( + enumerator, self._enum_prefix_removal + ) enumerators.append( - self.enumerator_template.format( - enumerator=enumerator, value=value - ) + self.enumerator_template.format(enumerator=py_name, value=value) ) self._python_rendered = self.enum_template.format( - enum_name=self._decl.name, enumerators="\n".join(enumerators) + enum_name=self._enum_name, enumerators="\n".join(enumerators) ) @@ -59,18 +70,21 @@ class StaticEnumsRenderer(BaseRenderer): Since enums creates a new C++ type. It should be invoked before making struct / function bindings. """ - def __init__(self, decls: list[Enum]): + def __init__( + self, decls: list[Enum], enum_prefix_removal: list[str] | None = None + ): super().__init__(decls) self._decls = decls + self._enum_prefix_removal = enum_prefix_removal or [] - self._python_rendered = [] + self._python_rendered: list[str] = [] def _render(self, with_imports): """Render python bindings for enums.""" self._python_str = "" for decl in self._decls: - SER = StaticEnumRenderer(decl) + SER = StaticEnumRenderer(decl, self._enum_prefix_removal) SER._render() self._python_rendered.append(SER._python_rendered) diff --git a/numbast/src/numbast/static/function.py b/numbast/src/numbast/static/function.py index 6c8ef057..217eabb5 100644 --- a/numbast/src/numbast/static/function.py +++ b/numbast/src/numbast/static/function.py @@ -16,7 +16,7 @@ get_shim, ) from numbast.static.types import to_numba_type_str -from numbast.utils import make_function_shim +from numbast.utils import make_function_shim, _apply_prefix_removal from numbast.errors import TypeNotFoundError, MangledFunctionNameConflictError from ast_canopy.decl import Function @@ -413,8 +413,9 @@ def __init__( function_prefix_removal: list[str] = [], ): super().__init__(decl, header_path, use_cooperative) - self._function_prefix_removal = function_prefix_removal - self._python_func_name = self._apply_prefix_removal(self._decl.name) + self._python_func_name = _apply_prefix_removal( + decl.name, function_prefix_removal + ) # Override the base class symbol tracking to use the Python function name # Remove the original name that was added by the base class @@ -423,25 +424,6 @@ def __init__( # Add the Python function name (with prefix removal applied) self._function_symbols.append(self._python_func_name) - def _apply_prefix_removal(self, name: str) -> str: - """Apply prefix removal to a function name based on the configuration. - - Parameters - ---------- - name : str - The original function name - - Returns - ------- - str - The function name with prefixes removed - """ - for prefix in self._function_prefix_removal: - if name.startswith(prefix): - return name[len(prefix) :] - - return name - @property def func_name_python(self): """The name of the function in python with prefix removal applied.""" diff --git a/numbast/src/numbast/static/renderer.py b/numbast/src/numbast/static/renderer.py index 1fbcbc4b..58b64eb5 100644 --- a/numbast/src/numbast/static/renderer.py +++ b/numbast/src/numbast/static/renderer.py @@ -83,6 +83,9 @@ def reset(self): _function_symbols: list[str] = [] """List of new function handles to expose.""" + _enum_symbols: list[str] = [] + """List of new enum handles to expose.""" + def __init__(self, decl): self.Imports.add("import numba") self.Imports.add("import io") @@ -128,6 +131,7 @@ def clear_base_renderer_cache(): BaseRenderer._nbtype_symbols.clear() BaseRenderer._record_symbols.clear() BaseRenderer._function_symbols.clear() + BaseRenderer._enum_symbols.clear() def get_reproducible_info( @@ -245,6 +249,18 @@ def _get_function_symbols() -> str: return code +def _get_enum_symbols() -> str: + template = """ +_ENUM_SYMBOLS = [{enum_symbols}] +""" + + symbols = BaseRenderer._enum_symbols + quote_wrapped = [f'"{s}"' for s in symbols] + concat = ",".join(quote_wrapped) + code = template.format(enum_symbols=concat) + return code + + def get_all_exposed_symbols() -> str: """Return the definition of all exposed symbols via `__all__`. @@ -257,13 +273,14 @@ def get_all_exposed_symbols() -> str: nbtype_symbols = _get_nbtype_symbols() record_symbols = _get_record_symbols() function_symbols = _get_function_symbols() + enum_symbols = _get_enum_symbols() all_symbols = f""" {nbtype_symbols} {record_symbols} {function_symbols} - -__all__ = _NBTYPE_SYMBOLS + _RECORD_SYMBOLS + _FUNCTION_SYMBOLS +{enum_symbols} +__all__ = _NBTYPE_SYMBOLS + _RECORD_SYMBOLS + _FUNCTION_SYMBOLS + _ENUM_SYMBOLS """ return all_symbols diff --git a/numbast/src/numbast/static/struct.py b/numbast/src/numbast/static/struct.py index 550f5483..b2990196 100644 --- a/numbast/src/numbast/static/struct.py +++ b/numbast/src/numbast/static/struct.py @@ -24,6 +24,7 @@ make_struct_ctor_shim, make_struct_conversion_operator_shim, make_struct_regular_method_shim, + _apply_prefix_removal, ) from numbast.errors import TypeNotFoundError @@ -142,12 +143,14 @@ def {lower_scope_name}(shim_stream, shim_obj): def __init__( self, struct_name: str, + python_struct_name: str, struct_type_class: str, struct_type_name: str, header_path: str, ctor_decl: StructMethod, ): self._struct_name = struct_name + self._python_struct_name = python_struct_name self._struct_type_class = struct_type_class self._struct_type_name = struct_type_name self._header_path = header_path @@ -247,7 +250,7 @@ def _render_lowering(self): """Render lowering codes for this struct constructor.""" self._lowering_rendered = self.struct_ctor_lowering_template.format( - struct_name=self._struct_name, + struct_name=self._python_struct_name, param_types=self._nb_param_types_str, struct_type_name=self._struct_type_name, struct_device_caller_name=self._device_caller_name, @@ -336,12 +339,14 @@ def __init__( self, ctor_decls: list[StructMethod], struct_name, + python_struct_name, struct_type_class, struct_type_name, header_path, ): self._ctor_decls = ctor_decls self._struct_name = struct_name + self._python_struct_name = python_struct_name self._struct_type_class = struct_type_class self._struct_type_name = struct_type_name self._header_path = header_path @@ -371,7 +376,7 @@ def _render_typing(self, signature_strs: list[str]): self._struct_ctor_typing_rendered = ( self.struct_ctor_template_typing_template.format( struct_ctor_template_name=self._struct_ctor_template_name, - struct_name=self._struct_name, + struct_name=self._python_struct_name, signatures=signatures_str, ) ) @@ -384,6 +389,7 @@ def _render(self): try: renderer = StaticStructCtorRenderer( struct_name=self._struct_name, + python_struct_name=self._python_struct_name, struct_type_class=self._struct_type_class, struct_type_name=self._struct_type_name, header_path=self._header_path, @@ -663,16 +669,14 @@ def c_rendered(self) -> str: return self._c_rendered -class StaticStructRegularMethodsRenderer(BaseRenderer): - """Renderer for all regular (non-operator) member functions of a struct.""" +class StaticStructRegularMethodRenderer(BaseRenderer): + """Renderer for a single regular method of a struct.""" - # ---- Single method renderer ------------------------------------------------- - class _MethodRenderer(BaseRenderer): - c_ext_shim_var_template = """ + c_ext_shim_var_template = """ shim_raw_str = \"\"\"{shim_rendered}\"\"\" """ - struct_method_device_decl_template = """ + struct_method_device_decl_template = """ {device_decl_name} = declare_device( '{unique_shim_name}', {return_type}( @@ -682,12 +686,12 @@ class _MethodRenderer(BaseRenderer): ) """ - struct_method_device_caller_template = """ + struct_method_device_caller_template = """ def {device_caller_name}({nargs}): return {device_decl_name}({nargs}) """ - struct_method_lowering_template = """ + struct_method_lowering_template = """ @lower("{struct_name}.{method_name}", {struct_type_name}, {param_types}) def _{lower_fn_suffix}(context, builder, sig, args): context.active_code_library.add_linking_file(shim_obj) @@ -709,159 +713,162 @@ def _{lower_fn_suffix}(context, builder, sig, args): ) """ - lowering_body_template = """ + lowering_body_template = """ {shim_var} {decl_device} {lowering} """ - lower_scope_template = """ + lower_scope_template = """ def {lower_scope_name}(shim_stream, shim_obj): {body} {lower_scope_name}(shim_stream, shim_obj) """ - def __init__( - self, - struct_name: str, - struct_type_name: str, - header_path: str, - method_decl: StructMethod, - ): - super().__init__(method_decl) - self._struct_name = struct_name - self._struct_type_name = struct_type_name - self._header_path = header_path - self._method_decl = method_decl - - # Cache Numba param and return types (as strings) - self._nb_param_types = [ - to_numba_type_str(arg.unqualified_non_ref_type_name) - for arg in self._method_decl.param_types - ] - self._nb_param_types_str = ( - ", ".join(map(str, self._nb_param_types)) or "" - ) - self._nb_return_type = to_numba_type_str( - self._method_decl.return_type.unqualified_non_ref_type_name - ) - self._nb_return_type_str = str(self._nb_return_type) + def __init__( + self, + struct_name: str, + python_struct_name: str, + struct_type_name: str, + header_path: str, + method_decl: StructMethod, + ): + super().__init__(method_decl) + self._struct_name = struct_name + self._python_struct_name = python_struct_name + self._struct_type_name = struct_type_name + self._header_path = header_path + self._method_decl = method_decl - # Pointers for interop - def wrap_pointer(typ): - return f"CPointer({typ})" + # Cache Numba param and return types (as strings) + self._nb_param_types = [ + to_numba_type_str(arg.unqualified_non_ref_type_name) + for arg in self._method_decl.param_types + ] + self._nb_param_types_str = ( + ", ".join(map(str, self._nb_param_types)) or "" + ) + self._nb_return_type = to_numba_type_str( + self._method_decl.return_type.unqualified_non_ref_type_name + ) + self._nb_return_type_str = str(self._nb_return_type) - _pointer_wrapped_param_types = [ - wrap_pointer(typ) for typ in self._nb_param_types - ] - self._pointer_wrapped_param_types_str = ", ".join( - _pointer_wrapped_param_types - ) + # Pointers for interop + def wrap_pointer(typ): + return f"CPointer({typ})" - # Unique shim name and helpers - self._unique_shim_name = deduplicate_overloads( - f"__{self._method_decl.mangled_name}_nbst" - ) - self._device_decl_name = ( - f"_method_decl_{self._method_decl.mangled_name}" - ) - self._device_caller_name = ( - f"_device_caller_{self._method_decl.mangled_name}" + _pointer_wrapped_param_types = [ + wrap_pointer(typ) for typ in self._nb_param_types + ] + self._pointer_wrapped_param_types_str = ", ".join( + _pointer_wrapped_param_types + ) + + # Unique shim name and helpers + self._unique_shim_name = deduplicate_overloads( + f"__{self._method_decl.mangled_name}_nbst" + ) + self._device_decl_name = ( + f"_method_decl_{self._method_decl.mangled_name}" + ) + self._device_caller_name = ( + f"_device_caller_{self._method_decl.mangled_name}" + ) + self._lower_fn_suffix = f"lower_{self._method_decl.mangled_name}" + self._lower_scope_name = f"_lower_{self._method_decl.mangled_name}" + + @property + def signature_str(self) -> str: + """signature string for typing with recvr specified""" + recvr = self._struct_type_name + if self._nb_param_types_str: + return ( + f"signature({self._nb_return_type_str}, " + f"{self._nb_param_types_str}, recvr={recvr})" ) - self._lower_fn_suffix = f"lower_{self._method_decl.mangled_name}" - self._lower_scope_name = f"_lower_{self._method_decl.mangled_name}" - - @property - def signature_str(self) -> str: - """signature string for typing with recvr specified""" - recvr = self._struct_type_name - if self._nb_param_types_str: - return ( - f"signature({self._nb_return_type_str}, " - f"{self._nb_param_types_str}, recvr={recvr})" - ) - else: - return f"signature({self._nb_return_type_str}, recvr={recvr})" + else: + return f"signature({self._nb_return_type_str}, recvr={recvr})" - def _render_decl_device(self): - self.Imports.add("from numba.cuda import declare_device") - self.Imports.add("from numba.core.typing import signature") - self.Imports.add("from numba.types import CPointer") + def _render_decl_device(self): + self.Imports.add("from numba.cuda import declare_device") + self.Imports.add("from numba.core.typing import signature") + self.Imports.add("from numba.types import CPointer") - decl_device_rendered = self.struct_method_device_decl_template.format( + decl_device_rendered = self.struct_method_device_decl_template.format( + device_decl_name=self._device_decl_name, + unique_shim_name=self._unique_shim_name, + return_type=self._nb_return_type_str, + struct_type_name=self._struct_type_name, + pointer_wrapped_param_types=self._pointer_wrapped_param_types_str, + ) + + nargs = [f"arg_{i}" for i in range(len(self._method_decl.params) + 1)] + nargs_str = ", ".join(nargs) + device_caller_rendered = ( + self.struct_method_device_caller_template.format( + device_caller_name=self._device_caller_name, + nargs=nargs_str, device_decl_name=self._device_decl_name, - unique_shim_name=self._unique_shim_name, - return_type=self._nb_return_type_str, - struct_type_name=self._struct_type_name, - pointer_wrapped_param_types=self._pointer_wrapped_param_types_str, ) + ) - nargs = [ - f"arg_{i}" for i in range(len(self._method_decl.params) + 1) - ] - nargs_str = ", ".join(nargs) - device_caller_rendered = ( - self.struct_method_device_caller_template.format( - device_caller_name=self._device_caller_name, - nargs=nargs_str, - device_decl_name=self._device_decl_name, - ) - ) + self._decl_device_rendered = ( + decl_device_rendered + "\n" + device_caller_rendered + ) - self._decl_device_rendered = ( - decl_device_rendered + "\n" + device_caller_rendered - ) + def _render_shim_function(self): + self._c_ext_shim_rendered = make_struct_regular_method_shim( + shim_name=self._unique_shim_name, + struct_name=self._struct_name, + method_name=self._method_decl.name, + return_type=self._method_decl.return_type.unqualified_non_ref_type_name, + params=self._method_decl.params, + ) + self._c_ext_shim_var_rendered = self.c_ext_shim_var_template.format( + shim_rendered=self._c_ext_shim_rendered + ) + self.ShimFunctions.append(self._c_ext_shim_rendered) - def _render_shim_function(self): - self._c_ext_shim_rendered = make_struct_regular_method_shim( - shim_name=self._unique_shim_name, - struct_name=self._struct_name, - method_name=self._method_decl.name, - return_type=self._method_decl.return_type.unqualified_non_ref_type_name, - params=self._method_decl.params, - ) - self._c_ext_shim_var_rendered = self.c_ext_shim_var_template.format( - shim_rendered=self._c_ext_shim_rendered - ) - self.ShimFunctions.append(self._c_ext_shim_rendered) + def _render_lowering(self): + self.Imports.add("from numba.cuda.cudaimpl import lower") - def _render_lowering(self): - self.Imports.add("from numba.cuda.cudaimpl import lower") + param_types = self._nb_param_types_str or "" + lowering_rendered = self.struct_method_lowering_template.format( + struct_name=self._python_struct_name, + method_name=self._method_decl.name, + struct_type_name=self._struct_type_name, + param_types=param_types, + device_caller_name=self._device_caller_name, + return_type=self._nb_return_type_str, + pointer_wrapped_param_types=self._pointer_wrapped_param_types_str, + lower_fn_suffix=self._lower_fn_suffix, + unique_shim_name=self._unique_shim_name, + ) + self._lowering_rendered = lowering_rendered - param_types = self._nb_param_types_str or "" - lowering_rendered = self.struct_method_lowering_template.format( - struct_name=self._struct_name, - method_name=self._method_decl.name, - struct_type_name=self._struct_type_name, - param_types=param_types, - device_caller_name=self._device_caller_name, - return_type=self._nb_return_type_str, - pointer_wrapped_param_types=self._pointer_wrapped_param_types_str, - lower_fn_suffix=self._lower_fn_suffix, - unique_shim_name=self._unique_shim_name, - ) - self._lowering_rendered = lowering_rendered + def _render(self): + self._render_decl_device() + self._render_shim_function() + self._render_lowering() - def _render(self): - self._render_decl_device() - self._render_shim_function() - self._render_lowering() + lower_body = self.lowering_body_template.format( + shim_var=self._c_ext_shim_var_rendered, + decl_device=self._decl_device_rendered, + lowering=self._lowering_rendered, + ) + lower_body = indent(lower_body, " " * 4) - lower_body = self.lowering_body_template.format( - shim_var=self._c_ext_shim_var_rendered, - decl_device=self._decl_device_rendered, - lowering=self._lowering_rendered, - ) - lower_body = indent(lower_body, " " * 4) + self._python_rendered = self.lower_scope_template.format( + lower_scope_name=self._lower_scope_name, + body=lower_body, + ) + self._c_rendered = self._c_ext_shim_rendered - self._python_rendered = self.lower_scope_template.format( - lower_scope_name=self._lower_scope_name, - body=lower_body, - ) - self._c_rendered = self._c_ext_shim_rendered - # ---- All methods renderer --------------------------------------------------- +class StaticStructRegularMethodsRenderer(BaseRenderer): + """Renderer for all regular (non-operator) member functions of a struct.""" + method_template_typing_template = """ @register class {method_template_name}(ConcreteTemplate): @@ -872,12 +879,14 @@ class {method_template_name}(ConcreteTemplate): def __init__( self, struct_name: str, + python_struct_name: str, struct_type_name: str, header_path: str, method_decls: list[StructMethod], ): super().__init__(method_decls) self._struct_name = struct_name + self._python_struct_name = python_struct_name self._struct_type_name = struct_type_name self._header_path = header_path self._method_decls = method_decls @@ -897,8 +906,9 @@ def _render(self): # Render per-overload lowering and collect signatures for m in self._method_decls: try: - mr = self._MethodRenderer( + mr = StaticStructRegularMethodRenderer( struct_name=self._struct_name, + python_struct_name=self._python_struct_name, struct_type_name=self._struct_type_name, header_path=self._header_path, method_decl=m, @@ -1037,9 +1047,15 @@ def __init__( parent_type: type | None, data_model: type | None, header_path: os.PathLike | str, + struct_prefix_removal: list[str] | None = None, aliases: list[str] = [], ): super().__init__(decl) + self._struct_prefix_removal = struct_prefix_removal or [] + + self._python_struct_name = _apply_prefix_removal( + decl.name, self._struct_prefix_removal + ) self._struct_name = decl.name self._aliases = aliases @@ -1065,19 +1081,21 @@ def __init__( # We use a prefix here to identify internal objects so that C object names # does not interfere with python's name mangling mechanism. - self._struct_type_class_name = f"_type_class_{self._struct_name}" - self._struct_type_name = f"_type_{self._struct_name}" - self._struct_model_name = f"_model_{self._struct_name}" - self._struct_attr_typing_name = f"_attr_typing_{self._struct_name}" + self._struct_type_class_name = f"_type_class_{self._python_struct_name}" + self._struct_type_name = f"_type_{self._python_struct_name}" + self._struct_model_name = f"_model_{self._python_struct_name}" + self._struct_attr_typing_name = ( + f"_attr_typing_{self._python_struct_name}" + ) self._header_path = header_path - CTYPE_TO_NBTYPE_STR[decl.name] = self._struct_type_name + CTYPE_TO_NBTYPE_STR[self._struct_name] = self._struct_type_name # Track the public symbols that should be exposed via a # struct creation self._nbtype_symbols.append(self._struct_type_name) - self._record_symbols.append(self._struct_name) + self._record_symbols.append(self._python_struct_name) def _render_typing(self): """Render typing of the struct.""" @@ -1095,7 +1113,7 @@ def _render_typing(self): struct_type_class_name=self._struct_type_class_name, struct_type_name=self._struct_type_name, parent_type=self._parent_type_str, - struct_name=self._struct_name, + struct_name=self._python_struct_name, struct_alignof=self._decl.alignof_, struct_sizeof=self._decl.sizeof_, implicit_conversion_types=implicit_conversion_types, @@ -1110,7 +1128,7 @@ def _render_python_api(self): self._python_api_rendered = self.python_api_template.format( struct_type_name=self._struct_type_name, - struct_name=self._struct_name, + struct_name=self._python_struct_name, ) def _render_data_model(self): @@ -1211,6 +1229,7 @@ def _render_regular_methods(self): """Render regular member functions of the struct.""" static_methods_renderer = StaticStructRegularMethodsRenderer( struct_name=self._struct_name, + python_struct_name=self._python_struct_name, struct_type_name=self._struct_type_name, header_path=self._header_path, method_decls=self._decl.regular_member_functions(), @@ -1228,6 +1247,7 @@ def _render_struct_ctors(self): """Render constructors of the struct.""" static_ctors_renderer = StaticStructCtorsRenderer( struct_name=self._struct_name, + python_struct_name=self._python_struct_name, struct_type_class=self._struct_type_class_name, struct_type_name=self._struct_type_name, header_path=self._header_path, @@ -1359,11 +1379,13 @@ def __init__( decls: list[Struct], specs: dict[str, tuple[type | None, type | None, os.PathLike]], default_header: os.PathLike | str | None = None, + struct_prefix_removal: list[str] | None = None, excludes: list[str] = [], ): self._decls = decls self._specs = specs self._default_header = default_header + self._struct_prefix_removal = struct_prefix_removal or [] self._python_rendered = [] self._c_rendered = [] @@ -1389,7 +1411,13 @@ def _render( f"CUDA struct {name} does not provide a header path." ) - SSR = StaticStructRenderer(decl, nb_ty, nb_datamodel, header_path) + SSR = StaticStructRenderer( + decl, + nb_ty, + nb_datamodel, + header_path, + self._struct_prefix_removal, + ) self._python_rendered.append(SSR.render_python()) self._c_rendered.append(SSR.render_c()) diff --git a/numbast/src/numbast/static/types.py b/numbast/src/numbast/static/types.py index 6dda8ecf..61af19d8 100644 --- a/numbast/src/numbast/static/types.py +++ b/numbast/src/numbast/static/types.py @@ -17,6 +17,7 @@ def register_enum_type_str(ctype_enum_name: str, enum_name: str): + """Register the C++ enum type name mapping to its Numba type.""" global CTYPE_TO_NBTYPE_STR CTYPE_TO_NBTYPE_STR[ctype_enum_name] = f"IntEnumMember({enum_name}, int64)" diff --git a/numbast/src/numbast/tools/static_binding_generator.py b/numbast/src/numbast/tools/static_binding_generator.py index 795e0ce6..2ed0e615 100644 --- a/numbast/src/numbast/tools/static_binding_generator.py +++ b/numbast/src/numbast/tools/static_binding_generator.py @@ -88,6 +88,8 @@ class Config: api_prefix_removal : dict[str, list[str]] Dictionary mapping declaration types to lists of prefixes to remove from names. For example, {"Function": ["prefix_"]} would remove "prefix_" from function names. + Acceptable keywords: ["Struct", "Function", "Enum"]. Value types are lists of prefix + strings. Specifically, prefixes in enums are also applicable to enum values. module_callbacks : dict[str, str] Dictionary containing setup and teardown callbacks for the module. Expected keys: "setup", "teardown". Each value is a string callback function. @@ -373,7 +375,14 @@ def _typedef_to_aliases(typedef_decls: list[Typedef]) -> dict[str, list[str]]: return aliases -def _generate_structs(struct_decls, header_path, types, data_models, excludes): +def _generate_structs( + struct_decls, + header_path, + types, + data_models, + struct_prefix_removal, + excludes, +): """Convert CLI inputs into structure that fits `StaticStructsRenderer` and create struct bindings.""" specs = {} for struct_decl in struct_decls: @@ -382,7 +391,12 @@ def _generate_structs(struct_decls, header_path, types, data_models, excludes): this_data_model = data_models.get(struct_name, None) specs[struct_name] = (this_type, this_data_model, header_path) - SSR = StaticStructsRenderer(struct_decls, specs, excludes=excludes) + SSR = StaticStructsRenderer( + struct_decls, + specs, + struct_prefix_removal=struct_prefix_removal, + excludes=excludes, + ) return SSR.render_as_str(with_imports=False, with_shim_stream=False) @@ -409,9 +423,11 @@ def _generate_functions( return SFR.render_as_str(with_imports=False, with_shim_stream=False) -def _generate_enums(enum_decls: list[Enum]): +def _generate_enums( + enum_decls: list[Enum], enum_prefix_removal: list[str] = [] +): """Create enum bindings.""" - SER = StaticEnumsRenderer(enum_decls) + SER = StaticEnumsRenderer(enum_decls, enum_prefix_removal) return SER.render_as_str(with_imports=False, with_shim_stream=False) @@ -511,12 +527,15 @@ def _static_binding_generator( aliases = _typedef_to_aliases(typedefs) rendered_aliases = render_aliases(aliases) - enum_bindings = _generate_enums(enums) + enum_bindings = _generate_enums( + enums, config.api_prefix_removal.get("Enum", []) + ) struct_bindings = _generate_structs( structs, entry_point, config.types, config.datamodels, + config.api_prefix_removal.get("Struct", []), config.exclude_structs, ) diff --git a/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 b/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 index 711db61c..47f63a5b 100644 --- a/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 +++ b/numbast/src/numbast/tools/tests/config/prefix_removal.yml.j2 @@ -9,8 +9,15 @@ GPU Arch: File List: - {{ data }} Exclude: {} -Types: {} -Data Models: {} +Types: + __internal__Foo: Type +Data Models: + __internal__Foo: StructModel API Prefix Removal: Function: - prefix_ + Struct: + - __internal__ + Enum: + - __internal__ + - __nv__ diff --git a/numbast/src/numbast/tools/tests/prefix_removal.cuh b/numbast/src/numbast/tools/tests/prefix_removal.cuh index 343411cb..229afb79 100644 --- a/numbast/src/numbast/tools/tests/prefix_removal.cuh +++ b/numbast/src/numbast/tools/tests/prefix_removal.cuh @@ -4,3 +4,14 @@ // clang-format on int __device__ prefix_foo(int a, int b) { return a + b; } + +struct __internal__Foo { + int x; + __device__ __internal__Foo() : x(0) {} + __device__ int get_x() { return x; } +}; + +enum __internal__Bar { + __nv__BAR_A, + __nv__BAR_B, +}; diff --git a/numbast/src/numbast/tools/tests/test_prefix_removal.py b/numbast/src/numbast/tools/tests/test_prefix_removal.py index 2789d994..4e913ecb 100644 --- a/numbast/src/numbast/tools/tests/test_prefix_removal.py +++ b/numbast/src/numbast/tools/tests/test_prefix_removal.py @@ -1,9 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import os -import subprocess -import sys +from numba import cuda def test_prefix_removal(run_in_isolated_folder, arch_str): @@ -17,7 +15,6 @@ def test_prefix_removal(run_in_isolated_folder, arch_str): ) run_result = res["result"] - output_folder = res["output_folder"] symbols = res["symbols"] alls = symbols["__all__"] @@ -25,33 +22,32 @@ def test_prefix_removal(run_in_isolated_folder, arch_str): # Verify that the function is exposed as "foo" (without the "prefix_" prefix) assert "foo" in alls, f"Expected 'foo' in __all__, got: {alls}" + assert "Foo" in alls, f"Expected 'Foo' in __all__, got: {alls}" # Verify that the original name "prefix_foo" is NOT exposed assert "prefix_foo" not in alls, ( f"Expected 'prefix_foo' NOT in __all__, got: {alls}" ) + assert "__internal__Foo" not in alls, ( + f"Expected '__internal__Foo' NOT in __all__, got: {alls}" + ) + assert "Bar" in alls, f"Expected 'Bar' in __all__, got: {alls}" + assert "__internal__Bar" not in alls, ( + f"Expected '__internal__Bar' NOT in __all__, got: {alls}" + ) - # Test that the function can be imported and used as "foo" - test_kernel_src = """ -from numba import cuda -from prefix_removal import foo - -@cuda.jit -def kernel(): - result = foo(1, 2) # Verify that prefix_foo is accessible as foo - -kernel[1, 1]() -""" + foo = symbols["foo"] + Foo = symbols["Foo"] + Bar = symbols["Bar"] - test_kernel = os.path.join(output_folder, "test.py") - with open(test_kernel, "w") as f: - f.write(test_kernel_src) + @cuda.jit + def kernel(): + result = foo(1, 2) # noqa: F841 + foo_obj = Foo() + x = foo_obj.get_x() # noqa: F841 + x2 = foo_obj.x # noqa: F841 - res = subprocess.run( - [sys.executable, test_kernel], - cwd=output_folder, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) + bar = Bar.BAR_A # noqa: F841 + bar2 = Bar.BAR_B # noqa: F841 - assert res.returncode == 0, res.stdout.decode("utf-8") + kernel[1, 1]() diff --git a/numbast/src/numbast/tools/tests/test_symbol_exposure.py b/numbast/src/numbast/tools/tests/test_symbol_exposure.py index 1c1ea84f..1b2ceac5 100644 --- a/numbast/src/numbast/tools/tests/test_symbol_exposure.py +++ b/numbast/src/numbast/tools/tests/test_symbol_exposure.py @@ -5,6 +5,9 @@ import subprocess import sys +from numbast.static.renderer import clear_base_renderer_cache +from numbast.static.function import clear_function_apis_registry + from cuda.core.experimental import Device dev = Device(0) @@ -13,6 +16,9 @@ def test_symbol_exposure(run_in_isolated_folder, arch_str): """Test that only a limited set of symbols are exposed via __all__ imports.""" + clear_base_renderer_cache() + clear_function_apis_registry() + res = run_in_isolated_folder( "cfg.yml.j2", "data.cuh", diff --git a/numbast/src/numbast/utils.py b/numbast/src/numbast/utils.py index af051473..b9a8ed1f 100644 --- a/numbast/src/numbast/utils.py +++ b/numbast/src/numbast/utils.py @@ -316,3 +316,23 @@ def make_struct_conversion_operator_shim( ) return shim + + +def _apply_prefix_removal(name: str, prefix_to_remove: list[str]) -> str: + """Apply prefix removal to a name based on the configuration. + + Parameters + ---------- + name : str + The original struct, function or enum type name, or named enum values. + + Returns + ------- + str + The name with prefixes removed + """ + for prefix in prefix_to_remove: + if name.startswith(prefix): + return name[len(prefix) :] + + return name