diff --git a/tilelang/language/eager/ast.py b/tilelang/language/eager/ast.py index 378306128..c64ddc421 100644 --- a/tilelang/language/eager/ast.py +++ b/tilelang/language/eager/ast.py @@ -6,6 +6,7 @@ from contextlib import AbstractContextManager from collections.abc import Iterable + # Python 3.9 compatibility for ParamSpec try: from typing import ParamSpec @@ -249,6 +250,14 @@ def override(self, name: str): return globals()[name] +def _try_eval(node: ast.expr, nonlocals: dict[str, Any], globals: dict[str, Any]) -> Any: + try: + code = "lambda " + ",".join(nonlocals.keys()) + ": " + ast.unparse(node) + return eval(code, globals)(**nonlocals) + except Exception: + return _empty + + class DSLMutator(ast.NodeTransformer): def __init__(self, nonlocals: dict[str, Any], globals: dict[str, Any], filename: str): self.tmp_counter = 0 @@ -487,11 +496,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef): ) def _try_eval(self, node: ast.expr) -> Any: - try: - code = "lambda " + ",".join(self.nonlocals.keys()) + ": " + ast.unparse(node) - return eval(code, self.globals)(**self.nonlocals) - except Exception: - return _empty + return _try_eval(node, self.nonlocals, self.globals) def _parse_arg_annot(self, stmt: ast.stmt, arg_names: set[str]): if not isinstance(stmt, ast.AnnAssign): @@ -607,6 +612,21 @@ class IRGenerator(Generic[_P, _T]): extra_type_hints: dict[str, Any] = field(default_factory=dict) +def has_internal_prim_func(func: Callable[_P, _T]) -> bool: + tree = utils.get_ast(func) + nonlocals = utils.get_func_nonlocals(func) + for item in ast.walk(tree): + if isinstance(item, ast.FunctionDef): + decors = item.decorator_list + for decor in decors: + if isinstance(decor, ast.Attribute) and decor.attr == "prim_func": + from tilelang.language.eager import prim_func + + if _try_eval(decor, nonlocals, func.__globals__) is prim_func: + return True + return False + + def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: """ Transform a Python function into an IR (Intermediate Representation) generator. diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index 6b00a3426..33cc73903 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -10,7 +10,7 @@ from tvm.ir.expr import Range from tvm.tir.stmt import BufferRegion from tvm.tir.stmt_functor import substitute -from .ast import BaseBuilder, IRGenerator, eval_op, mutate +from .ast import BaseBuilder, IRGenerator, eval_op, has_internal_prim_func, mutate from .utils import construct_strides from tilelang.utils import side_effect import tvm @@ -1016,6 +1016,12 @@ def foo(A, B): with T.Kernel(...): ... # no return """ + if has_internal_prim_func(self.orig_func): + return True + try: + inspect.signature(self.orig_func).bind(*args, **kwargs) + except TypeError: + return False try: prim_func = self.orig_func(*args, **kwargs) # lazy jit must return PrimFunc @@ -1024,7 +1030,7 @@ def foo(A, B): self.p1_cache[p1_key] = TirTemplate.from_lazy_style(prim_func) return True return False - except (JITNoBuilderError, EagerJITBuildError, TypeError): + except (JITNoBuilderError, EagerJITBuildError): # In eager mode, we construct AST directly without prim_func, # so there's no Builder available when the function is called. # When eager-only features like T.const() or T.Kernel() are used,