From ea61230377501c6fbb14e14026574eac9d1167ad Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Mon, 19 Jan 2026 15:38:13 +0800 Subject: [PATCH 1/6] [Refactor] Support lazy evaluate in eager jit --- .../test_tilelang_language_eager_jit.py | 7 -- tilelang/jit/__init__.py | 3 +- tilelang/language/eager/ast.py | 10 ++ tilelang/language/eager/builder.py | 91 ++++++++++++------- 4 files changed, 71 insertions(+), 40 deletions(-) diff --git a/testing/python/language/test_tilelang_language_eager_jit.py b/testing/python/language/test_tilelang_language_eager_jit.py index 0d58f310d..28f8fe8aa 100644 --- a/testing/python/language/test_tilelang_language_eager_jit.py +++ b/testing/python/language/test_tilelang_language_eager_jit.py @@ -166,11 +166,6 @@ def test_jit2_return(): def copy_impl(A): M, N = A.shape B = T.empty(M, N, dtype=A.dtype) - M, N = A.shape - M_, N_ = B.shape - assert M == M_, f"M mismatch {M} {M_}" - assert N == N_, f"N mismatch {N} {N_}" - # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) return B @@ -212,8 +207,6 @@ def copy6(A): A: T.StridedTensor[[N, M], [N_, M_], T.float32] return copy_impl(A) - tilelang.par_compile([copy.get_tir(T.Tensor((128, 128))) for copy in [copy1, copy2, copy3, copy4]]) - for copy in [copy1, copy2, copy3, copy4]: A = torch.randn(128, 128, device="cuda") B = copy(A) diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 822fb3298..a0d0e6907 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -313,7 +313,8 @@ def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy" # auto: infer by checking if function returns PrimFunc directly if not isinstance(self.func, JITFunc): return "lazy" - return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager" + is_lazy_style = self.func._is_lazy_style(*args, **kwargs) + return "lazy" if is_lazy_style else "eager" def initialize_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]: if self.mode == "auto": diff --git a/tilelang/language/eager/ast.py b/tilelang/language/eager/ast.py index 378306128..d54e38f63 100644 --- a/tilelang/language/eager/ast.py +++ b/tilelang/language/eager/ast.py @@ -570,9 +570,19 @@ def visit_Return(self, node: ast.Return): return quote("return __tb.ret(value)", value=node.value, span=node) def visit_With(self, node: ast.With): + is_kernel_ctx = False + for expr in node.items: + cexpr = expr.context_expr + if isinstance(cexpr, ast.Call) and isinstance(cexpr.func, ast.Attribute) and cexpr.func.attr == "Kernel": + eval_res = self._try_eval(cexpr.func) + from tilelang.language import Kernel + if eval_res is Kernel: + is_kernel_ctx = True node = self.generic_visit(node) for expr in node.items: expr.context_expr = quote_expr("__tb.ctx_with(e)", e=expr.context_expr, span=expr) + if is_kernel_ctx: + return [quote1('if __tb.skip_kernel_ctx(): return'), node] return node def visit_Assert(self, node: ast.Assert): diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index 6b00a3426..54ae9ec07 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -164,6 +164,9 @@ def is_var(v: Any) -> bool: return isinstance(v, Buffer) and v.scope() == "local.var" +EagerJITStage = Literal['phase1', 'phase2', 'none'] + + class Builder(BaseBuilder): def __init__(self): self.frames: list[AnyFrame] = [] @@ -173,7 +176,8 @@ def __init__(self): self.out_idx = [] self.out_tensor_cnt = 0 self.constexpr_var = set() - self.eager_jit = False + self.eager_jit: EagerJITStage = 'none' + self.eager_jit_subs: dict[str, PrimExpr] = {} self.current_file = "" self.current_line = 0 self.current_macro_name = "" @@ -192,7 +196,7 @@ def prim_func(self, name): with self.ir_builder, self.with_frame(tir.prim_func()): tir.func_name(name) yield - if len(self.out_idx) != self.out_tensor_cnt: + if self.eager_jit != 'phase1' and len(self.out_idx) != self.out_tensor_cnt: raise RuntimeError("Not all tensor allocated from `T.empty` are returned") finally: del thread_local_storage.builder @@ -597,6 +601,8 @@ def ret(self, value=None): ) return value else: + if self.eager_jit == 'phase1': + return NotImplemented if not isinstance(value, tuple): value = (value,) for v in value: @@ -683,6 +689,7 @@ def override(self, name: str): def constexpr(self, name: str, dtype: str = "int32") -> Var: var = tir.Var(name, dtype) self.constexpr_var.add(var) + var.orig_name = name return var def set_fileline(self, filename: str, lineno: int, name: str): @@ -694,6 +701,8 @@ def get_fileline_stack(self, stacklevel=1): stack = self.macro_fileline_stack + [(self.current_file, self.current_line, self.current_macro_name)] return stack[: len(stack) - stacklevel + 1] + def skip_kernel_ctx(self): + return self.eager_jit == 'phase1' _P = ParamSpec("_P") _T = TypeVar("_T") @@ -866,17 +875,29 @@ def kernel(A, B): builder = Builder.current() # assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)" # assert builder.eager_jit, "T.const() can only be used inside @tilelang.jit (eager mode)" - if builder is None or not builder.eager_jit: + if builder is None or not builder.eager_jit != 'none': raise JITNoBuilderError("T.const() can only be used inside @tilelang.jit (eager mode)") - if "," in name: - names = re.split(r"\s*,\s*", name) - return tuple(builder.constexpr(n, dtype) for n in names) - if " " in name: - names = re.split(r"\s+", name) - return tuple(builder.constexpr(n, dtype) for n in names) - else: - return builder.constexpr(name, dtype) + if builder.eager_jit == 'phase1': + # in stage 1, we create constexpr variables + if "," in name: + names = re.split(r"\s*,\s*", name) + return tuple(builder.constexpr(n, dtype) for n in names) + if " " in name: + names = re.split(r"\s+", name) + return tuple(builder.constexpr(n, dtype) for n in names) + else: + return builder.constexpr(name, dtype) + elif builder.eager_jit == 'phase2': + # in stage 2, we substitute constexpr variables with actual values + if "," in name: + names = re.split(r"\s*,\s*", name) + return tuple(builder.eager_jit_subs[n] for n in names) + if " " in name: + names = re.split(r"\s+", name) + return tuple(builder.eager_jit_subs[n] for n in names) + else: + return builder.eager_jit_subs[name] @dataclass @@ -889,12 +910,15 @@ class TirTemplate(Generic[_P, _T]): actual tensor shapes at runtime. """ + name: str prim_func: PrimFunc[_P, _T] matcher: dict[Var, tuple[tvm.tir.Var, str, int, str]] | None = None + constexprs: set[Var] = None is_lazy_style: bool = False # True if from lazy-style (returns PrimFunc directly) + ir_gen: IRGenerator[_P, _T] | None = None @classmethod - def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate[_P, _T]: + def create(cls, name: str, prim_func: PrimFunc[_P, _T], constexpr: set[Var], ir_gen: IRGenerator[_P, _T] | None = None) -> TirTemplate[_P, _T]: matcher = {} for k, v in prim_func.buffer_map.items(): for i, s in enumerate(v.shape): @@ -915,12 +939,13 @@ def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate f"Buffer shapes: {shapes}\n" f"Buffer strides: {strides}" ) - return cls(prim_func=prim_func, matcher=matcher, is_lazy_style=False) + matcher = {k: matcher[k] for k in constexpr} + return cls(name=name, prim_func=prim_func, matcher=matcher, constexprs=constexpr, is_lazy_style=False, ir_gen=ir_gen) @classmethod - def from_lazy_style(cls, prim_func: PrimFunc[_P, _T]) -> TirTemplate[_P, _T]: + def from_lazy_style(cls, name: str, prim_func: PrimFunc[_P, _T]) -> TirTemplate[_P, _T]: """Create template from lazy-style function that returns PrimFunc directly.""" - return cls(prim_func=prim_func, is_lazy_style=True) + return cls(name=name, prim_func=prim_func, is_lazy_style=True) def _parse_phase2_key(self, **kwargs): if self.matcher is None: @@ -946,16 +971,20 @@ def _parse_phase2_key(self, **kwargs): ) return tuple(result) - def get_tir(self, **kwargs): + def get_tir(self, tensor_args, given_tensor_args, kwargs): if self.is_lazy_style: return self.prim_func - values = self._parse_phase2_key(**kwargs) - subs = {name: value for name, value in zip(self.matcher, values)} - result = substitute_primfunc(self.prim_func, subs) - result.orig_func = self.prim_func.orig_func - if hasattr(self.prim_func, "out_idx_override"): - result.out_idx_override = self.prim_func.out_idx_override - return result + values = self._parse_phase2_key(**given_tensor_args, **kwargs) + subs = {name.orig_name: value for name, value in zip(self.matcher, values)} + builder = Builder() + builder.eager_jit = 'phase2' + builder.eager_jit_subs = subs + with builder.prim_func(self.name): + self.ir_gen.gen(builder)(**tensor_args, **kwargs) + pf = builder.get() + if builder.out_idx: + pf.out_idx_override = builder.out_idx + return pf @dataclass @@ -1021,7 +1050,7 @@ def foo(A, B): # lazy jit must return PrimFunc if isinstance(prim_func, PrimFunc): p1_key, _, _ = self._parse_phase1_key(*args, **kwargs) - self.p1_cache[p1_key] = TirTemplate.from_lazy_style(prim_func) + self.p1_cache[p1_key] = TirTemplate.from_lazy_style(self.orig_func.__name__, prim_func) return True return False except (JITNoBuilderError, EagerJITBuildError, TypeError): @@ -1036,18 +1065,18 @@ def _build_tir_template(self, *args, **kwargs) -> TirTemplate[_P, _T]: """Build TIR template based on the execution mode.""" if self.mode == "lazy": # lazy: function returns PrimFunc directly - return TirTemplate.from_lazy_style(self.orig_func(*args, **kwargs)) + return TirTemplate.from_lazy_style(self.orig_func.__name__, self.orig_func(*args, **kwargs)) elif self.mode == "eager": # eager: trace function body through Builder to construct TIR builder = Builder() - builder.eager_jit = True + builder.eager_jit = 'phase1' with builder.prim_func(self.orig_func.__name__): self.ir_gen.gen(builder)(**self.tensor_args, **kwargs) pf = builder.get() pf.orig_func = self.orig_func if builder.out_idx: pf.out_idx_override = builder.out_idx - return TirTemplate.create(pf, builder.constexpr_var) + return TirTemplate.create(self.orig_func.__name__, pf, builder.constexpr_var, self.ir_gen) else: raise ValueError(f"Invalid jit mode: {self.mode}, expected 'lazy' or 'eager'") @@ -1061,17 +1090,15 @@ def parse_args(self, *args, **kwargs): # mode should be set by JITImpl before calling parse_args tir_temp = self._build_tir_template(**kwargs) self.p1_cache[p1_key] = tir_temp - p2_key = tir_temp._parse_phase2_key(**tensor_args) + p2_key = tir_temp._parse_phase2_key(**tensor_args, **kwargs) return (p1_key, p2_key), tensor_args def get_tir(self, *args, **kwargs): p1_key, tensor_args, kwargs = self._parse_phase1_key(*args, **kwargs) if p1_key not in self.p1_cache: # in legacy gemm, we use lazy tir template to build the tir - tir_temp = self._build_tir_template(**kwargs) - self.p1_cache[p1_key] = tir_temp - return tir_temp.get_tir(**tensor_args, **kwargs) - return self.p1_cache[p1_key].get_tir(**tensor_args, **kwargs) + self.p1_cache[p1_key] = self._build_tir_template(**kwargs) + return self.p1_cache[p1_key].get_tir(self.tensor_args, tensor_args, kwargs) def __call__(self, *args, **kwargs): return self.get_tir(*args, **kwargs) From 6f0c36e31c0c4636687be098879c5849878f71e7 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Mon, 19 Jan 2026 15:40:07 +0800 Subject: [PATCH 2/6] add test script --- .../python/issue/test_tilelang_issue_1690.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 testing/python/issue/test_tilelang_issue_1690.py diff --git a/testing/python/issue/test_tilelang_issue_1690.py b/testing/python/issue/test_tilelang_issue_1690.py new file mode 100644 index 000000000..fe9967c29 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1690.py @@ -0,0 +1,19 @@ +import tilelang +import tilelang.testing +import tilelang.language as T + +def test_issue_1690(): + + @tilelang.jit() + def test(A): + N = T.const('N') + A: T.Tensor[[N], T.float32] + with T.Kernel(): + tmp = T.alloc_fragment((N,), T.float32) + tmp_max = T.alloc_fragment(1, T.float32) + T.copy(A, tmp) + T.reduce_max(tmp, tmp_max, dim=0) + test.compile(N=16) + +if __name__ == '__main__': + tilelang.testing.main() \ No newline at end of file From 5a779dcb1c6b630604c88e206af5967b726b96b8 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Mon, 19 Jan 2026 15:40:29 +0800 Subject: [PATCH 3/6] fix lint error --- .../python/issue/test_tilelang_issue_1690.py | 10 +++++--- tilelang/language/eager/ast.py | 3 ++- tilelang/language/eager/builder.py | 25 +++++++++++-------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/testing/python/issue/test_tilelang_issue_1690.py b/testing/python/issue/test_tilelang_issue_1690.py index fe9967c29..832240538 100644 --- a/testing/python/issue/test_tilelang_issue_1690.py +++ b/testing/python/issue/test_tilelang_issue_1690.py @@ -2,18 +2,20 @@ import tilelang.testing import tilelang.language as T -def test_issue_1690(): +def test_issue_1690(): @tilelang.jit() def test(A): - N = T.const('N') + N = T.const("N") A: T.Tensor[[N], T.float32] with T.Kernel(): tmp = T.alloc_fragment((N,), T.float32) tmp_max = T.alloc_fragment(1, T.float32) T.copy(A, tmp) T.reduce_max(tmp, tmp_max, dim=0) + test.compile(N=16) -if __name__ == '__main__': - tilelang.testing.main() \ No newline at end of file + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/eager/ast.py b/tilelang/language/eager/ast.py index d54e38f63..d887ba83e 100644 --- a/tilelang/language/eager/ast.py +++ b/tilelang/language/eager/ast.py @@ -576,13 +576,14 @@ def visit_With(self, node: ast.With): if isinstance(cexpr, ast.Call) and isinstance(cexpr.func, ast.Attribute) and cexpr.func.attr == "Kernel": eval_res = self._try_eval(cexpr.func) from tilelang.language import Kernel + if eval_res is Kernel: is_kernel_ctx = True node = self.generic_visit(node) for expr in node.items: expr.context_expr = quote_expr("__tb.ctx_with(e)", e=expr.context_expr, span=expr) if is_kernel_ctx: - return [quote1('if __tb.skip_kernel_ctx(): return'), node] + return [quote1("if __tb.skip_kernel_ctx(): return"), node] return node def visit_Assert(self, node: ast.Assert): diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index 54ae9ec07..ec9f5d21b 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -164,7 +164,7 @@ def is_var(v: Any) -> bool: return isinstance(v, Buffer) and v.scope() == "local.var" -EagerJITStage = Literal['phase1', 'phase2', 'none'] +EagerJITStage = Literal["phase1", "phase2", "none"] class Builder(BaseBuilder): @@ -176,7 +176,7 @@ def __init__(self): self.out_idx = [] self.out_tensor_cnt = 0 self.constexpr_var = set() - self.eager_jit: EagerJITStage = 'none' + self.eager_jit: EagerJITStage = "none" self.eager_jit_subs: dict[str, PrimExpr] = {} self.current_file = "" self.current_line = 0 @@ -196,7 +196,7 @@ def prim_func(self, name): with self.ir_builder, self.with_frame(tir.prim_func()): tir.func_name(name) yield - if self.eager_jit != 'phase1' and len(self.out_idx) != self.out_tensor_cnt: + if self.eager_jit != "phase1" and len(self.out_idx) != self.out_tensor_cnt: raise RuntimeError("Not all tensor allocated from `T.empty` are returned") finally: del thread_local_storage.builder @@ -601,7 +601,7 @@ def ret(self, value=None): ) return value else: - if self.eager_jit == 'phase1': + if self.eager_jit == "phase1": return NotImplemented if not isinstance(value, tuple): value = (value,) @@ -702,7 +702,8 @@ def get_fileline_stack(self, stacklevel=1): return stack[: len(stack) - stacklevel + 1] def skip_kernel_ctx(self): - return self.eager_jit == 'phase1' + return self.eager_jit == "phase1" + _P = ParamSpec("_P") _T = TypeVar("_T") @@ -875,10 +876,10 @@ def kernel(A, B): builder = Builder.current() # assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)" # assert builder.eager_jit, "T.const() can only be used inside @tilelang.jit (eager mode)" - if builder is None or not builder.eager_jit != 'none': + if builder is None or not builder.eager_jit != "none": raise JITNoBuilderError("T.const() can only be used inside @tilelang.jit (eager mode)") - if builder.eager_jit == 'phase1': + if builder.eager_jit == "phase1": # in stage 1, we create constexpr variables if "," in name: names = re.split(r"\s*,\s*", name) @@ -888,7 +889,7 @@ def kernel(A, B): return tuple(builder.constexpr(n, dtype) for n in names) else: return builder.constexpr(name, dtype) - elif builder.eager_jit == 'phase2': + elif builder.eager_jit == "phase2": # in stage 2, we substitute constexpr variables with actual values if "," in name: names = re.split(r"\s*,\s*", name) @@ -918,7 +919,9 @@ class TirTemplate(Generic[_P, _T]): ir_gen: IRGenerator[_P, _T] | None = None @classmethod - def create(cls, name: str, prim_func: PrimFunc[_P, _T], constexpr: set[Var], ir_gen: IRGenerator[_P, _T] | None = None) -> TirTemplate[_P, _T]: + def create( + cls, name: str, prim_func: PrimFunc[_P, _T], constexpr: set[Var], ir_gen: IRGenerator[_P, _T] | None = None + ) -> TirTemplate[_P, _T]: matcher = {} for k, v in prim_func.buffer_map.items(): for i, s in enumerate(v.shape): @@ -977,7 +980,7 @@ def get_tir(self, tensor_args, given_tensor_args, kwargs): values = self._parse_phase2_key(**given_tensor_args, **kwargs) subs = {name.orig_name: value for name, value in zip(self.matcher, values)} builder = Builder() - builder.eager_jit = 'phase2' + builder.eager_jit = "phase2" builder.eager_jit_subs = subs with builder.prim_func(self.name): self.ir_gen.gen(builder)(**tensor_args, **kwargs) @@ -1069,7 +1072,7 @@ def _build_tir_template(self, *args, **kwargs) -> TirTemplate[_P, _T]: elif self.mode == "eager": # eager: trace function body through Builder to construct TIR builder = Builder() - builder.eager_jit = 'phase1' + builder.eager_jit = "phase1" with builder.prim_func(self.orig_func.__name__): self.ir_gen.gen(builder)(**self.tensor_args, **kwargs) pf = builder.get() From d152cf157c99047a5be22b92f3c1c6f3bbe0bbe2 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Mon, 19 Jan 2026 15:49:12 +0800 Subject: [PATCH 4/6] canonicalize signature and error of `T.empty` --- tilelang/language/allocate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index c14a733b7..36a2d8353 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -31,6 +31,7 @@ from . import dtypes as _dtypes from .dtypes import dtype as tl_dtype from .eager.builder import OutTensor +from .proxy import Tensor _Shapes = TypeVarTuple("_Shapes") _DType = TypeVar("_DType") @@ -264,10 +265,10 @@ def alloc_tcgen05_instr_desc(dtype: str = _dtypes.uint32): @overload -def empty(shape, dtype: str = _dtypes.float32): ... +def empty(shape, dtype: str = _dtypes.float32) -> Tensor: ... -def empty(*shape, dtype: str = _dtypes.float32): +def empty(*shape, dtype: str = _dtypes.float32) -> Tensor: if len(shape) == 1 and isinstance(shape[0], (tuple, list)): return OutTensor(shape[0], dtype) elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): @@ -275,4 +276,4 @@ def empty(*shape, dtype: str = _dtypes.float32): elif all([isinstance(x, (int, PrimExpr)) for x in shape]): return OutTensor(shape, dtype) else: - raise RuntimeError(f"Invalid shape {shape}") + raise TypeError(f"Invalid shape {shape}") From a94c23a92fe335267e8fd8e5ae05d19369e720df Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Mon, 19 Jan 2026 16:04:50 +0800 Subject: [PATCH 5/6] Fix comment from genai --- tilelang/language/eager/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index ec9f5d21b..7255a1913 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -876,7 +876,7 @@ def kernel(A, B): builder = Builder.current() # assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)" # assert builder.eager_jit, "T.const() can only be used inside @tilelang.jit (eager mode)" - if builder is None or not builder.eager_jit != "none": + if builder is None or builder.eager_jit == "none": raise JITNoBuilderError("T.const() can only be used inside @tilelang.jit (eager mode)") if builder.eager_jit == "phase1": From fe807f78c452b8e4cb3eede57aed31dfdbd7e5ef Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Fri, 23 Jan 2026 14:13:38 +0800 Subject: [PATCH 6/6] [EagerJIT] Add comment for eager jit stage --- tilelang/language/eager/builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index 7255a1913..c2d86b5a6 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -164,6 +164,9 @@ def is_var(v: Any) -> bool: return isinstance(v, Buffer) and v.scope() == "local.var" +# phase1: eager jit obtain function signature +# phase2: eager jit elaborate function +# none: not inside eager jit, i.e. it is lazyjit EagerJITStage = Literal["phase1", "phase2", "none"]