Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions numbast/src/numbast/static/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from numbast.static.renderer import BaseRenderer, get_rendered_imports
from numbast.static.types import register_enum_type_str
from numbast.utils import _apply_prefix_removal

file_logger = getLogger(f"{__name__}")
logger_path = os.path.join(tempfile.gettempdir(), "test.py")
Expand All @@ -28,28 +29,38 @@ class {enum_name}(IntEnum):
"""
enumerator_template = " {enumerator} = {value}"

def __init__(self, decl: Enum):
def __init__(
self, decl: Enum, enum_prefix_removal: list[str] | None = None
):
self._decl = decl
self._enum_prefix_removal = enum_prefix_removal or []

self._enum_name = _apply_prefix_removal(
self._decl.name, self._enum_prefix_removal
)

self._enum_symbols.append(self._enum_name)

def _render(self):
self.Imports.add("from enum import IntEnum")
self.Imports.add("from numba.types import IntEnumMember")
self.Imports.add("from numba.types import int64")

register_enum_type_str(self._decl.name, self._decl.name)
register_enum_type_str(self._decl.name, self._enum_name)

enumerators = []
for enumerator, value in zip(
self._decl.enumerators, self._decl.enumerator_values
):
py_name = _apply_prefix_removal(
enumerator, self._enum_prefix_removal
)
enumerators.append(
self.enumerator_template.format(
enumerator=enumerator, value=value
)
self.enumerator_template.format(enumerator=py_name, value=value)
)

self._python_rendered = self.enum_template.format(
enum_name=self._decl.name, enumerators="\n".join(enumerators)
enum_name=self._enum_name, enumerators="\n".join(enumerators)
)


Expand All @@ -59,18 +70,21 @@ class StaticEnumsRenderer(BaseRenderer):
Since enums creates a new C++ type. It should be invoked before making struct / function bindings.
"""

def __init__(self, decls: list[Enum]):
def __init__(
self, decls: list[Enum], enum_prefix_removal: list[str] | None = None
):
super().__init__(decls)
self._decls = decls
self._enum_prefix_removal = enum_prefix_removal or []

self._python_rendered = []
self._python_rendered: list[str] = []

def _render(self, with_imports):
"""Render python bindings for enums."""
self._python_str = ""

for decl in self._decls:
SER = StaticEnumRenderer(decl)
SER = StaticEnumRenderer(decl, self._enum_prefix_removal)
SER._render()
self._python_rendered.append(SER._python_rendered)

Expand Down
26 changes: 4 additions & 22 deletions numbast/src/numbast/static/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_shim,
)
from numbast.static.types import to_numba_type_str
from numbast.utils import make_function_shim
from numbast.utils import make_function_shim, _apply_prefix_removal
from numbast.errors import TypeNotFoundError, MangledFunctionNameConflictError

from ast_canopy.decl import Function
Expand Down Expand Up @@ -413,8 +413,9 @@ def __init__(
function_prefix_removal: list[str] = [],
):
super().__init__(decl, header_path, use_cooperative)
self._function_prefix_removal = function_prefix_removal
self._python_func_name = self._apply_prefix_removal(self._decl.name)
self._python_func_name = _apply_prefix_removal(
decl.name, function_prefix_removal
)

# Override the base class symbol tracking to use the Python function name
# Remove the original name that was added by the base class
Expand All @@ -423,25 +424,6 @@ def __init__(
# Add the Python function name (with prefix removal applied)
self._function_symbols.append(self._python_func_name)

def _apply_prefix_removal(self, name: str) -> str:
"""Apply prefix removal to a function name based on the configuration.

Parameters
----------
name : str
The original function name

Returns
-------
str
The function name with prefixes removed
"""
for prefix in self._function_prefix_removal:
if name.startswith(prefix):
return name[len(prefix) :]

return name

@property
def func_name_python(self):
"""The name of the function in python with prefix removal applied."""
Expand Down
21 changes: 19 additions & 2 deletions numbast/src/numbast/static/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def reset(self):
_function_symbols: list[str] = []
"""List of new function handles to expose."""

_enum_symbols: list[str] = []
"""List of new enum handles to expose."""

def __init__(self, decl):
self.Imports.add("import numba")
self.Imports.add("import io")
Expand Down Expand Up @@ -128,6 +131,7 @@ def clear_base_renderer_cache():
BaseRenderer._nbtype_symbols.clear()
BaseRenderer._record_symbols.clear()
BaseRenderer._function_symbols.clear()
BaseRenderer._enum_symbols.clear()


def get_reproducible_info(
Expand Down Expand Up @@ -245,6 +249,18 @@ def _get_function_symbols() -> str:
return code


def _get_enum_symbols() -> str:
template = """
_ENUM_SYMBOLS = [{enum_symbols}]
"""

symbols = BaseRenderer._enum_symbols
quote_wrapped = [f'"{s}"' for s in symbols]
concat = ",".join(quote_wrapped)
code = template.format(enum_symbols=concat)
return code


def get_all_exposed_symbols() -> str:
"""Return the definition of all exposed symbols via `__all__`.

Expand All @@ -257,13 +273,14 @@ def get_all_exposed_symbols() -> str:
nbtype_symbols = _get_nbtype_symbols()
record_symbols = _get_record_symbols()
function_symbols = _get_function_symbols()
enum_symbols = _get_enum_symbols()

all_symbols = f"""
{nbtype_symbols}
{record_symbols}
{function_symbols}

__all__ = _NBTYPE_SYMBOLS + _RECORD_SYMBOLS + _FUNCTION_SYMBOLS
{enum_symbols}
__all__ = _NBTYPE_SYMBOLS + _RECORD_SYMBOLS + _FUNCTION_SYMBOLS + _ENUM_SYMBOLS
"""

return all_symbols
Expand Down
Loading
Loading