diff --git a/3rdparty/tvm b/3rdparty/tvm index e47e76a2a..001022bdb 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit e47e76a2a0d565e02b6474c06f9f47e1374821f3 +Subproject commit 001022bdb2dbb337d242eed9d208f8555b8edc98 diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index f24aa38b7..849cedc87 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -15,7 +15,7 @@ def get_configs(): return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] -@autotune(configs=get_configs(), warmup=500, rep=100) +@autotune(configs=get_configs()) @tilelang.jit( out_idx=[3], pass_configs={ diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index d4bc9480f..15e552587 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -6,7 +6,6 @@ from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter, ) -from tilelang.transform import simplify_prim_func def make_swizzle_layout(shared_buf): @@ -25,7 +24,6 @@ def transform_func(i, j): @tilelang.jit(out_idx=[2]) -@simplify_prim_func def tl_matmul( M, N, diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 762885ec3..04f735950 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -7,7 +7,6 @@ from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter, ) -from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type tilelang.testing.set_random_seed(0) @@ -29,7 +28,6 @@ def transform_func(i, j): @tilelang.jit(out_idx=[2]) -@simplify_prim_func def tl_matmul( M, N, diff --git a/examples/lazy_jit/lazyjit.en.ipynb b/examples/lazy_jit/lazyjit.en.ipynb index 5b5df8e6a..9cb343a1e 100644 --- a/examples/lazy_jit/lazyjit.en.ipynb +++ b/examples/lazy_jit/lazyjit.en.ipynb @@ -53,7 +53,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm(\n", " A,\n", " B,\n", @@ -209,7 +209,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_dyn_K(A, B):\n", " M, N, K = T.dynamic(\"M, N, K\")\n", " A: T.Tensor[[M, K], T.float16]\n", @@ -248,7 +248,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def as_contingious(A):\n", " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", @@ -307,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_ptr(\n", " A,\n", " B,\n", @@ -359,7 +359,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_ptr_dyn(A, B, M, N, K):\n", " M: T.int32\n", " N: T.int32\n", @@ -421,7 +421,7 @@ } ], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def example_wrong_kernel(A):\n", " M = T.const(\"M\")\n", " A: T.Tensor[[M * 2, M * 3], T.float32]\n", @@ -470,7 +470,7 @@ } ], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def dyn_annot(\n", " A: T.ptr, # 1. T.ptr type annotation\n", " is_2d=False,\n", @@ -515,7 +515,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def add_one(X, data: T.float32 = 1):\n", " M, N = T.const(\"M, N\")\n", " X: T.Tensor[[M, N], T.float32]\n", @@ -577,7 +577,7 @@ "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", "\n", "\n", - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def dummy_kernel(A, B):\n", " M, N = T.const(\"M, N\")\n", " A: T.Tensor[[M, N], T.float16]\n", @@ -797,7 +797,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def element_wise(A, fn):\n", " N = T.dynamic(\"N\")\n", " A: T.Tensor[[N], T.float32]\n", @@ -857,7 +857,7 @@ " n31(x * 3 + 1, var)\n", "\n", "\n", - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", " n31(n, A[0])" diff --git a/examples/lazy_jit/lazyjit.zh.ipynb b/examples/lazy_jit/lazyjit.zh.ipynb index 387aff461..d7afafe69 100644 --- a/examples/lazy_jit/lazyjit.zh.ipynb +++ b/examples/lazy_jit/lazyjit.zh.ipynb @@ -53,7 +53,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm(\n", " A,\n", " B,\n", @@ -209,7 +209,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_dyn_K(A, B):\n", " M, N, K = T.dynamic(\"M, N, K\")\n", " A: T.Tensor[[M, K], T.float16]\n", @@ -248,7 +248,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def as_contingious(A):\n", " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", @@ -307,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_ptr(\n", " A,\n", " B,\n", @@ -359,7 +359,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_ptr_dyn(A, B, M, N, K):\n", " M: T.int32\n", " N: T.int32\n", @@ -421,7 +421,7 @@ } ], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def example_wrong_kernel(A):\n", " M = T.const(\"M\")\n", " A: T.Tensor[[M * 2, M * 3], T.float32]\n", @@ -470,7 +470,7 @@ } ], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def dyn_annot(\n", " A: T.ptr, # 1. T.ptr type annotation\n", " is_2d=False,\n", @@ -515,7 +515,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def add_one(X, data: T.float32 = 1):\n", " M, N = T.const(\"M, N\")\n", " X: T.Tensor[[M, N], T.float32]\n", @@ -577,7 +577,7 @@ "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", "\n", "\n", - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def dummy_kernel(A, B):\n", " M, N = T.const(\"M, N\")\n", " A: T.Tensor[[M, N], T.float16]\n", @@ -797,7 +797,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def element_wise(A, fn):\n", " N = T.dynamic(\"N\")\n", " A: T.Tensor[[N], T.float32]\n", @@ -857,7 +857,7 @@ " n31(x * 3 + 1, var)\n", "\n", "\n", - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", " n31(n, A[0])" diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 9d5213e1a..bc256adaf 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -204,9 +204,9 @@ def test_str_repr(): def test_var_assign(): - @tilelang.jit(out_idx=-1) - @T.prim_func - def test_var_assign(A: T.Tensor((2,), T.int32)): + @tilelang.jit + def test_var_assign(): + A = T.empty((2,), T.int32) with T.Kernel(1) as _: a: T.int32 = 1 b: T.int32 = a @@ -214,8 +214,9 @@ def test_var_assign(A: T.Tensor((2,), T.int32)): d: T.int32 = a A[0] = b A[1] = d + return A - res = test_var_assign()() + res = test_var_assign() assert res[0] == 1 assert res[1] == 2 @@ -255,9 +256,9 @@ def test_macro_return(): def test_serial_for_with_step(): - @tilelang.jit(out_idx=-1) - @T.prim_func - def test_stepped_serial(A: T.Tensor((10,), T.int32)): + @tilelang.jit + def stepped_serial(): + A = T.empty((10,), T.int32) with T.Kernel(1) as _: for i in range(0, 10, 2): T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range") @@ -265,22 +266,22 @@ def test_stepped_serial(A: T.Tensor((10,), T.int32)): for i in range(1, 10, 2): T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range") A[i] = 2.0 + return A - ker = test_stepped_serial() - res = ker() + res = stepped_serial() ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device="cuda") assert torch.all(res == ref), f"Expected {ref}, but got {res}" - @tilelang.jit(out_idx=-1) - @T.prim_func - def test_serial_step_neg(A: T.Tensor((10,), T.int32)): + @tilelang.jit + def stepped_serial_neg(): + A = T.empty((10,), T.int32) with T.Kernel(1) as _: for i in range(10, 0, -1): T.device_assert(0 < i <= 10, "i out of range") A[10 - i] = i + return A - ker = test_serial_step_neg() - res = ker() + res = stepped_serial_neg() ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device="cuda") assert torch.all(res == ref), f"Expected {ref}, but got {res}" @@ -292,8 +293,8 @@ def test_serial_step_neg(A: T.Tensor((10,), T.int32)): def test_swap_logic(): @tilelang.jit - @T.prim_func - def swap_var(A: T.Tensor[(2,), T.float32]): + def swap_var(A): + A: T.Tensor[(2,), T.float32] with T.Kernel(1, threads=1) as _: a = T.alloc_var(T.float32, A[0]) b = T.alloc_var(T.float32, A[1]) @@ -301,20 +302,19 @@ def swap_var(A: T.Tensor[(2,), T.float32]): A[0], A[1] = a, b @tilelang.jit - @T.prim_func - def swap_idx(A: T.Tensor[(2,), T.float32]): + def swap_idx(A): + A: T.Tensor[(2,), T.float32] with T.Kernel(1, threads=1) as _: A[0], A[1] = A[1], A[0] - k_swap_var = swap_var() data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda() - k_swap_var(data) + swap_var(data) ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda() + torch.testing.assert_close(data, ref) - k_swap_idx = swap_idx() data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda() - k_swap_idx(data) + swap_idx(data) ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda() torch.testing.assert_close(data, ref) @@ -322,9 +322,9 @@ def swap_idx(A: T.Tensor[(2,), T.float32]): # TODO(Gong): ROCm is not supported alloc_var with initializer @tilelang.testing.requires_cuda def test_while_loop(): - @tilelang.jit(out_idx=-1) - @T.prim_func - def test_while_loop(A: T.Tensor((1,), T.int32)): + @tilelang.jit + def while_loop(): + A = T.empty((1,), T.int32) with T.Kernel(1) as _: i = T.alloc_var(T.int32, 0) sum = T.alloc_var(T.int32) @@ -332,10 +332,10 @@ def test_while_loop(A: T.Tensor((1,), T.int32)): sum += i i += 1 A[0] = sum + return A - ker = test_while_loop() - A = ker() - assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}" + res = while_loop() + assert res[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {res[0].item()}" def test_var_macro(): @@ -456,27 +456,24 @@ def cond(): def test_constexpr_if(): @tilelang.jit - def probe(tmp: bool): - @T.prim_func - def foo(A: T.Tensor[[2], T.int32]): - with T.Kernel(1): - if tmp: - v = A[0] - else: - v = A[1] - if tmp: - A[1] = v + 1 - else: - A[0] = v + 1 - - return foo + def probe(A, tmp: bool): + A: T.Tensor[(2,), T.int32] + with T.Kernel(1): + if tmp: + v = A[0] + else: + v = A[1] + if tmp: + A[1] = v + 1 + else: + A[0] = v + 1 A = torch.tensor([10, 20], dtype=torch.int32).cuda() expect_1 = torch.tensor([10, 11], dtype=torch.int32).cuda() expect_2 = torch.tensor([12, 11], dtype=torch.int32).cuda() - probe(True)(A) + probe(A, True) assert torch.equal(A, expect_1) - probe(False)(A) + probe(A, False) assert torch.equal(A, expect_2) diff --git a/testing/python/language/test_tilelang_language_lazy_jit.py b/testing/python/language/test_tilelang_language_lazy_jit.py index e3eabdce6..8e6ff6bb4 100644 --- a/testing/python/language/test_tilelang_language_lazy_jit.py +++ b/testing/python/language/test_tilelang_language_lazy_jit.py @@ -6,7 +6,7 @@ def test_jit2_gemm(): - @tilelang.lazy_jit(verbose=True) + @tilelang.jit(verbose=True) def gemm( A, B, @@ -45,7 +45,7 @@ def gemm( def test_jit2_gemm_ptr(): - @tilelang.lazy_jit + @tilelang.jit def gemm_ptr( A: T.ptr, B: T.ptr, @@ -102,28 +102,28 @@ def copy_impl(A, B): with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) - @tilelang.lazy_jit + @tilelang.jit def copy1(A, B): N, M = T.const("N, M") A: T.Tensor[[N, M], T.float32] B: T.Tensor[[N, M], T.float32] copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy2( A: T.Tensor[[128, 128], T.float32], B: T.Tensor[[128, 128], T.float32], ): copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy3(A, B): N = T.const("N") A: T.Tensor[[N, 128], T.float32] B: T.Tensor[[N, 128], T.float32] copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy4(A, B): N = T.dynamic("N") M = T.const("M") @@ -131,14 +131,14 @@ def copy4(A, B): B: T.Tensor[[N, M], T.float32] copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy5(A, B): N, M, N_, M_ = T.const("N, M, N_, M_") A: T.StridedTensor[[N, M], [N_, M_], T.float32] B: T.StridedTensor[[N, M], [N_, M_], T.float32] copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy6(A, B): N = T.dynamic("N") M, N_, M_ = T.const("M, N_, M_") @@ -175,37 +175,37 @@ def copy_impl(A): T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) return B - @tilelang.lazy_jit + @tilelang.jit def copy1(A): M, N = T.const("M, N") A: T.Tensor[[M, N], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy2(A): A: T.Tensor[[128, 128], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy3(A): N = T.const("N") A: T.Tensor[[N, 128], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy4(A): N = T.dynamic("N") M = T.const("M") A: T.Tensor[[N, M], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy5(A): N, M, N_, M_ = T.const("N, M, N_, M_") A: T.StridedTensor[[N, M], [N_, M_], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy6(A): N = T.dynamic("N") M, N_, M_ = T.const("M, N_, M_") diff --git a/testing/python/language/test_tilelang_language_subtype.py b/testing/python/language/test_tilelang_language_subtype.py index e7f898cc4..ef0641a9d 100644 --- a/testing/python/language/test_tilelang_language_subtype.py +++ b/testing/python/language/test_tilelang_language_subtype.py @@ -6,7 +6,7 @@ import tilelang.language as T -@tilelang.lazy_jit +@tilelang.jit def basic_shape_kernel(x): m = T.dynamic("m") x: T.Tensor[(m, 16), T.float4_e2m1fn] @@ -15,7 +15,7 @@ def basic_shape_kernel(x): pass -@tilelang.lazy_jit +@tilelang.jit def strided_kernel(x): m = T.dynamic("m") s = T.dynamic("s") @@ -95,7 +95,7 @@ def test_subtype_different_strides(): strided_kernel(t_strided) -@tilelang.lazy_jit +@tilelang.jit def symbolic_last_dim_kernel(x): """Kernel with symbolic variable in the last dimension.""" n = T.dynamic("n") @@ -105,7 +105,7 @@ def symbolic_last_dim_kernel(x): pass -@tilelang.lazy_jit +@tilelang.jit def symbolic_last_dim_strided_kernel(x): """Kernel with symbolic variable in both shape and stride of last dimension.""" n = T.dynamic("n") @@ -116,7 +116,7 @@ def symbolic_last_dim_strided_kernel(x): pass -@tilelang.lazy_jit +@tilelang.jit def shared_symbolic_kernel(x, y): """Kernel with shared symbolic variable across multiple buffers.""" m = T.dynamic("m") @@ -127,7 +127,7 @@ def shared_symbolic_kernel(x, y): pass -@tilelang.lazy_jit +@tilelang.jit def shared_symbolic_strided_kernel(x, y): """Kernel with shared symbolic variable in strides.""" m = T.dynamic("m") @@ -139,7 +139,7 @@ def shared_symbolic_strided_kernel(x, y): pass -@tilelang.lazy_jit +@tilelang.jit def complex_expr_kernel(x, y): """Kernel with complex expressions involving symbolic variables.""" m = T.dynamic("m") diff --git a/testing/python/layout/test_tilelang_annotate_loop_layout.py b/testing/python/layout/test_tilelang_annotate_loop_layout.py index 9698a66ee..fbe9bb8a8 100644 --- a/testing/python/layout/test_tilelang_annotate_loop_layout.py +++ b/testing/python/layout/test_tilelang_annotate_loop_layout.py @@ -4,7 +4,7 @@ # TODO(lei): replicate loop layout and more complicated layout cases -@tilelang.lazy_jit +@tilelang.jit def loop_layout_kernel(A, B, loop_layout): M, N = T.const("M, N") A: T.Tensor[(M, N), T.float32] @@ -51,7 +51,7 @@ def loop_layout_fn(i, j): assert "*(float4*)(B + ((((int)threadIdx.x) * 32) + (i * 4))) = *(float4*)(A + ((((int)threadIdx.x) * 32) + (i * 4)));" in code -@tilelang.lazy_jit +@tilelang.jit def copy_with_layout_kernel(A, B, loop_layout): M, N = T.const("M, N") A: T.Tensor[(M, N), T.float32] @@ -79,7 +79,7 @@ def loop_layout_fn(i, j, rep): assert "*(float4*)(B + ((i * 512) + (((int)threadIdx.x) * 4))) = *(float4*)(A + ((i * 512) + (((int)threadIdx.x) * 4)));" in code -@tilelang.lazy_jit +@tilelang.jit def replicate_loop_layout_kernel(A, B, loop_layout): M, N = T.const("M, N") A: T.Tensor[(M, N), T.float32] diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 420516173..b629ef7b4 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -142,7 +142,7 @@ def _load_tile_lang_lib(): if env.SKIP_LOADING_TILELANG_SO == "0": _LIB, _LIB_PATH = _load_tile_lang_lib() - from .jit import jit, lazy_jit, JITKernel, compile, par_compile # noqa: F401 + from .jit import jit, JITKernel, compile, par_compile # noqa: F401 from .profiler import Profiler # noqa: F401 from .cache import clear_cache # noqa: F401 from .utils import ( diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 44c2ff41f..5c5bd7c5e 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -192,72 +192,76 @@ def par_compile( @dataclass class JITImpl(Generic[_P, _KP, _T, _Ret]): """ - Detailed Just-In-Time wrapper for TileLang programs. + Just-In-Time compilation wrapper for TileLang programs. - This dataclass encapsulates the configuration and runtime helpers used by the - top-level `jit` and `jit2` decorators. It represents a configured JIT - "factory" that can (a) elaborate TileLang/PrimFunc creators into concrete - TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via - the TVM bridge, (c) cache compiled kernels keyed by call-site arguments - (and optional tuning parameters), and (d) provide parallel compilation - helpers for batch autotuning workflows. + This class provides a unified interface for compiling and executing TileLang + kernels. It supports two execution modes that are automatically inferred: + + Execution Modes + --------------- + - **lazy**: The decorated function returns a PrimFunc explicitly. Calling the + JIT wrapper returns a compiled kernel object, which can be invoked separately. + This mode is useful when you want to inspect or reuse the kernel object. + + Example (lazy mode):: + + @tilelang.jit(out_idx=[-1]) + def matmul(M, N, K, block_M, block_N, block_K): + @T.prim_func + def kernel(A: T.Tensor((M, K), dtype), ...): + ... + return kernel # explicitly return PrimFunc + + kernel = matmul(1024, 1024, 1024, 128, 128, 32) # returns kernel + result = kernel(a, b) # execute separately + + - **eager**: The decorated function uses the DSL builder pattern with tensor + type annotations. Calling the JIT wrapper compiles and immediately executes + the kernel, returning the result directly. + + Example (eager mode):: + + @tilelang.jit + def gemm(A, B, C, block_M: int = 64): + M, N, K = T.const("M N K") + A: T.Tensor[[M, K], dtype] # tensor shape via annotation + B: T.Tensor[[K, N], dtype] + C: T.Tensor[[M, N], dtype] + with T.Kernel(...): + ... + # no return, builder constructs TIR implicitly + + gemm(A, B, C) # compiles and executes immediately + + The mode is automatically inferred based on whether the function returns a + PrimFunc (lazy) or uses the builder pattern (eager). Attributes ---------- out_idx : list[int] | int | None - Which output tensor(s) of the compiled kernel should be returned to the - caller. Accepts a single index, a list of indices, or None to return all. - execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] - Backend used for exchanging arguments and executing the generated kernel. - target : str | tvm.target.Target - TVM compilation target (e.g. "cuda", "llvm", or "auto"). - target_host : str | tvm.target.Target | None - Host target used for cross-compilation, or None to infer/default. - verbose : bool - Enable verbose messages during compilation/build. + Index(es) of output tensor(s) to return. Accepts single index, list, or None. + execution_backend : str | None + Backend for kernel execution ("auto", "dlpack", "tvm_ffi", etc.). + target : str | Target | None + TVM compilation target (e.g., "cuda", "llvm", "auto"). + target_host : str | Target | None + Host target for cross-compilation. + verbose : bool | None + Enable verbose compilation output. pass_configs : dict[str, Any] | None - Extra TVM pass configuration options forwarded to the compiler's - PassContext. + TVM pass configuration options. debug_root_path : str | None - If provided, compiled kernel source and the elaborated Python program - are written to this directory to ease debugging and inspection. + Directory to save compiled kernel source for debugging. compile_flags : list[str] | str | None - Additional flags passed to the compiler. A single string will be converted - to a single-element list. + Additional compiler flags. func_source : str - Original Python source string from which the PrimFunc or creator was - derived. Used for diagnostics and debug dumps. + Original Python source code of the decorated function. signature : inspect.Signature - Function signature of the original Python function (useful for tooling). - v2 : bool - Indicates whether the object wraps a "v2" PrimFunc creator (True) or a - plain callable / PrimFunc (False). v2-mode enables argument conversion - hooks and a distinct cache keying strategy. - func : Callable | PrimFunc | PrimFuncCreater - The underlying object: either a user function that returns a PrimFunc - (creator), a PrimFuncCreater, or an already-constructed PrimFunc. - For presentation/readability the function is stored last in the dataclass. - - Behavioral summary - ------------------ - - get_tir(*args, **kwargs) - Converts provided call-site arguments into a concrete PrimFunc. If the - wrapped object is a PrimFuncCreater or a user callable, it is invoked - with the given arguments. If the wrapped object is already a PrimFunc, - it is returned as-is. - - - compile(...) - A convenience wrapper that elaborates and immediately compiles a single - PrimFunc into a JITKernel using the module-level `compile` function. - When `debug_root_path` is set, the compiled C kernel and the source - Python program are saved for inspection. - - - par_compile(configs, ...) - Accepts an iterable of configs (either dicts mapping keyword args or - tuples mapping to positional args). Each config is elaborated to a - PrimFunc and the resulting set is compiled in parallel via the - module-level `par_compile` helper. Returns a list of JITKernel objects - in the same order as the provided configs. + Function signature of the original function. + mode : Literal["auto", "lazy", "eager"] + Execution mode. "auto" infers from function behavior. + func : LazyJITFunc + The wrapped function object. """ out_idx: list[int] | int | None @@ -270,9 +274,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): compile_flags: list[str] | str | None func_source: str signature: inspect.Signature - lazy_jit: bool + mode: Literal["auto", "lazy", "eager"] # place func at the last element for better __repr__ - func: Callable[_P, _T] | PrimFunc[_KP, _T] | LazyJITFunc[_KP, _T] + func: LazyJITFunc[_KP, _T] def __post_init__(self): if self.debug_root_path is not None and not path.isabs(self.debug_root_path): @@ -288,17 +292,34 @@ def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]: """ Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object. """ - if isinstance(self.func, LazyJITFunc): - tir = self.func.get_tir(*args, **kwargs) - elif isinstance(self.func, PrimFunc): + # Ensure mode is set before calling func + if self.mode == "auto" and isinstance(self.func, LazyJITFunc): + self.mode = self._infer_jit_mode() + self.func.set_mode(self.mode) + + if isinstance(self.func, PrimFunc): tir = self.func - elif callable(self.func): + elif isinstance(self.func, (LazyJITFunc, Callable)): tir = self.func(*args, **kwargs) else: raise ValueError(f"Invalid function type: {type(self.func)}") assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}" return tir + def _infer_jit_mode(self) -> Literal["lazy", "eager"]: + """ + Infer the JIT execution mode based on static AST analysis. + + Returns "lazy" if the function explicitly returns a PrimFunc, + or "eager" if it uses the DSL builder pattern. + """ + if self.mode in ("lazy", "eager"): + return self.mode + # auto: infer using static AST analysis + if not isinstance(self.func, LazyJITFunc): + return "lazy" + return "lazy" if self.func._is_lazy_style() else "eager" + def par_compile( self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, ignore_error: bool = False ) -> list[JITKernel[_KP, _T]]: @@ -345,9 +366,18 @@ def par_compile( ) def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: - func = self.get_tir(*args, **kwargs) + # infer jit mode on first compile + if self.mode == "auto": + self.mode = self._infer_jit_mode() + + # out_idx is only supported in lazy mode + if self.mode == "eager" and self.out_idx is not None: + raise ValueError("out_idx is only supported in lazy mode. In eager mode, use T.empty() to declare output tensors instead.") + + self.func.set_mode(self.mode) + prim_func = self.get_tir(*args, **kwargs) kernel_result = compile( - func, + prim_func, out_idx=self.out_idx, execution_backend=self.execution_backend, target=self.target, @@ -368,7 +398,7 @@ def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: with open(path.join(self.debug_root_path, kernel_file), "w") as f: print(kernel_result.get_kernel_source(), file=f) with open(path.join(self.debug_root_path, program_file), "w") as f: - print(func.script(), file=f) + print(prim_func.script(), file=f) return kernel_result @@ -397,134 +427,39 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: } return compile_args - if self.lazy_jit: - kwargs.update(kwargs.pop("__tune_params", {})) - key, kernel_args = self.func.parse_args(*args, **kwargs) - kernel = self._kernel_cache.get(key, None) - if kernel is None: - kernel = self.compile(*args, **kwargs) - self._kernel_cache[key] = kernel - return kernel(*kernel_args.values()) + kwargs.update(kwargs.pop("__tune_params", {})) + # infer mode early, before parse_args needs it + if self.mode == "auto": + self.mode = self._infer_jit_mode() + self.func.set_mode(self.mode) + + key, kernel_args = self.func.parse_args(*args, **kwargs) + kernel = self._kernel_cache.get(key, None) + if kernel is None: + kernel = self.compile(*args, **kwargs) + self._kernel_cache[key] = kernel + + # eager mode: execute kernel immediately and return result + # lazy mode: return kernel object for manual invocation + if self.mode == "eager": + return kernel(*kernel_args.values()) else: - key = self.parse_cache_key(*args, **kwargs) - tune_params = kwargs.pop("__tune_params", {}) - kernel = self._kernel_cache.get(key, None) - if kernel is None: - kernel = self.compile(*args, **kwargs, **tune_params) - self._kernel_cache[key] = kernel return kernel ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] -@overload -def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: ... - - -@overload -def jit( - *, # Indicates subsequent arguments are keyword-only - out_idx: Any = None, - target: str | Target | None = None, - target_host: str | Target | None = None, - execution_backend: ExecutionBackend | None = None, - verbose: bool | None = None, - pass_configs: dict[str, Any] | None = None, - debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None, -) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ... - - -def jit( # This is the new public interface - func: Callable[_P, _T] | PrimFunc | None = None, - *, # Indicates subsequent arguments are keyword-only - out_idx: Any = None, - target: str | Target | None = None, - target_host: str | Target | None = None, - execution_backend: ExecutionBackend | None = None, - verbose: bool | None = None, - pass_configs: dict[str, Any] | None = None, - debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None, -): - """ - Just-In-Time (JIT) compiler decorator for TileLang functions. - - This decorator can be used without arguments (e.g., `@tilelang.jit`): - Applies JIT compilation with default settings. - - Parameters - ---------- - func_or_out_idx : Any, optional - If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. - If using `@tilelang.jit` directly on a function, this argument is implicitly - the function to be decorated (and `out_idx` will be `None`). - target : Union[str, Target], optional - Compilation target for TVM (e.g., "cuda", "llvm"). If None, reads from - TILELANG_TARGET environment variable (defaults to "auto"). - target_host : Union[str, Target], optional - Target host for cross-compilation. Defaults to None. - execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional - Backend for kernel execution and argument passing. If None, reads from - TILELANG_EXECUTION_BACKEND environment variable (defaults to "auto"). - verbose : bool, optional - Enables verbose logging during compilation. If None, reads from - TILELANG_VERBOSE environment variable (defaults to False). - pass_configs : Optional[Dict[str, Any]], optional - Configurations for TVM's pass context. Defaults to None. - debug_root_path : Optional[str], optional - Directory to save compiled kernel source for debugging. Defaults to None. - - Environment Variables - --------------------- - TILELANG_TARGET : str - Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto". - TILELANG_EXECUTION_BACKEND : str - Default execution backend. Defaults to "auto". - TILELANG_VERBOSE : str - Set to "1", "true", "yes", or "on" to enable verbose compilation by default. - - Returns - ------- - Callable - Either a JIT-compiled wrapper around the input function, or a configured decorator - instance that can then be applied to a function. - """ - - def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: - if isinstance(func, PrimFunc): - orig_func = func.orig_func - else: - orig_func = func - return JITImpl( - func=func, - out_idx=out_idx, - execution_backend=execution_backend, - target=target, - target_host=target_host, - verbose=verbose, - pass_configs=pass_configs, - debug_root_path=debug_root_path, - compile_flags=compile_flags, - func_source=inspect.getsource(orig_func), - signature=inspect.signature(orig_func), - lazy_jit=False, - ) - - if func is not None: - return decorator(func) - else: - return decorator +JITMode = Literal["auto", "lazy", "eager"] @overload -def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ... +def jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ... @overload -def lazy_jit( +def jit( *, out_idx: Any = None, target: str | Target | None = None, @@ -534,12 +469,14 @@ def lazy_jit( pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, + mode: JITMode = "auto", ) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ... -def lazy_jit( +def jit( func: Callable[_P, _T] | PrimFunc | None = None, *, # Indicates subsequent arguments are keyword-only + out_idx: list[int] | int | None = None, target: str | Target | None = None, target_host: str | Target | None = None, execution_backend: ExecutionBackend | None = None, @@ -547,16 +484,39 @@ def lazy_jit( pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, -): + mode: JITMode = "auto", +) -> Callable[[Callable[_P, _T]], JITImpl[_KP, _KP, _T, _T]]: """ - Lazy JIT compiler decorator - returns the kernel object on first call, then executes it. + JIT compiler decorator for TileLang functions. + + Supports two execution modes: + - **lazy**: Function returns PrimFunc explicitly. Returns compiled kernel object. + - **eager**: Function uses DSL builder pattern. Executes kernel immediately. - Supports environment variable defaults for target, execution_backend, and verbose. - See `jit` documentation for parameter details and environment variables. + Parameters + ---------- + out_idx : list[int] | int | None + Output tensor index(es). Only supported in lazy mode. + target : str | Target | None + TVM compilation target (e.g., "cuda", "llvm", "auto"). + target_host : str | Target | None + Host target for cross-compilation. + execution_backend : ExecutionBackend | None + Backend for kernel execution. + verbose : bool | None + Enable verbose compilation output. + pass_configs : dict[str, Any] | None + TVM pass configuration options. + debug_root_path : str | None + Directory to save compiled kernel source for debugging. + compile_flags : list[str] | str | None + Additional compiler flags. + mode : "auto" | "lazy" | "eager" + Execution mode. Default "auto" infers from function structure. """ compile_args = dict( - out_idx=None, + out_idx=out_idx, execution_backend=execution_backend, target=target, target_host=target_host, @@ -568,8 +528,16 @@ def lazy_jit( def decorator(func: Callable[_P, _T]): pf: LazyJITFunc[_P, _T] = prim_func(func, lazy_jit=True) + pf.set_mode(mode) + func_source = inspect.getsource(pf.orig_func) + signature = inspect.signature(pf.orig_func) + return JITImpl( - func=pf, **compile_args, func_source=inspect.getsource(pf.orig_func), signature=inspect.signature(pf.orig_func), lazy_jit=True + func=pf, + **compile_args, + func_source=func_source, + signature=signature, + mode=mode, ) return decorator(func) if func is not None else decorator diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index 7734c724e..8c166b02a 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -1,5 +1,6 @@ from __future__ import annotations import ast +import logging from dataclasses import dataclass, field from typing import Callable, Generic, Any, Literal, TypeVar from contextlib import AbstractContextManager @@ -16,6 +17,8 @@ from . import utils from . import dtypes +logger = logging.getLogger(__name__) + _span_attrs = ["lineno", "col_offset", "end_lineno", "end_col_offset"] @@ -581,6 +584,90 @@ class IRGenerator(Generic[_P, _T]): gen: Callable[[BaseBuilder], Callable[_P, _T]] source: str extra_type_hints: dict[str, Any] = field(default_factory=dict) + is_lazy_style: bool = False # True if the function returns a PrimFunc (lazy style) + + +def _has_return_value(tree: ast.FunctionDef) -> bool: + """Check if function has any return statement with a non-None value.""" + for node in ast.walk(tree): + if isinstance(node, ast.Return) and node.value is not None: + # Skip explicit "return None" + if isinstance(node.value, ast.Constant) and node.value.value is None: + continue + return True + return False + + +def _detect_lazy_style(tree: ast.FunctionDef, func_name: str = "") -> bool: + """ + Detect if a function uses lazy style by analyzing its AST. + + Lazy style: function returns a PrimFunc (must have return value). + Eager style: function uses DSL builder pattern (no return needed). + + Detection logic: + 1. No return value -> eager (lazy must return PrimFunc) + 2. Has inner @T.prim_func definition -> lazy + 3. Has eager-mode constructs (T.Kernel, T.const, etc.) -> eager + 4. Has return value but no eager constructs -> lazy (external prim_func) + + Returns True if lazy style is detected. + """ + # Rule 1: No return value means eager mode (lazy must return PrimFunc) + if not _has_return_value(tree): + return False + + # Rule 2: Check for inner @T.prim_func or @prim_func decorated functions + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node is not tree: + for decorator in node.decorator_list: + if ( + isinstance(decorator, ast.Attribute) + and decorator.attr == "prim_func" + or isinstance(decorator, ast.Name) + and decorator.id == "prim_func" + ): + return True # lazy style + + # Rule 3: Check for eager-mode DSL constructs + eager_mode_attrs = {"Kernel", "const", "dynamic", "Tensor", "StridedTensor", "empty"} + eager_mode_names = {"Kernel", "const", "dynamic", "empty"} + + for node in ast.walk(tree): + # Skip inner function definitions + if isinstance(node, ast.FunctionDef) and node is not tree: + continue + + # Check for T.Kernel, T.const, T.Tensor, etc. + if ( + isinstance(node, ast.Attribute) + and node.attr in eager_mode_attrs + and isinstance(node.value, ast.Name) + and node.value.id in ("T", "tilelang", "tl") + ): + return False # eager style + + # Check for T.Tensor[[...], ...] or Tensor[[...], ...] subscript pattern + if isinstance(node, ast.Subscript): + if isinstance(node.value, ast.Attribute) and node.value.attr in ("Tensor", "StridedTensor"): + return False # eager style + if isinstance(node.value, ast.Name) and node.value.id in ("Tensor", "StridedTensor"): + return False # eager style + + # Check for direct function calls: const("M"), Kernel(...), etc. + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in eager_mode_names: + return False # eager style + + # Rule 4: Has return value but no eager constructs -> lazy (external prim_func) + # Warn user since this is inferred without explicit markers + logger.warning( + "Cannot auto-detect JIT mode for '%s': function has return value but no " + "recognizable pattern (inner @T.prim_func or eager constructs like T.Kernel). " + "Assuming lazy mode. Consider explicitly setting mode='lazy' or mode='eager' " + "in @tilelang.jit() decorator.", + func_name, + ) + return True def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: @@ -613,6 +700,9 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: filename = inspect.getsourcefile(func) or inspect.getfile(func) nonlocals = utils.get_func_nonlocals(func) + # Detect lazy style before AST transformation + is_lazy = _detect_lazy_style(tree, func.__name__) + # DSLMutator generates a function named `make_closure` # it accepts all names inside nonlocal, and returns the mutated function # this is because we must separate the closure namespace form the global namespace @@ -637,4 +727,4 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: func.__globals__, # use the original globalns ) fn = make_closure(**nonlocals) - return IRGenerator(gen=fn, source=ast.unparse(tree), extra_type_hints=mut.extra_type_hints) + return IRGenerator(gen=fn, source=ast.unparse(tree), extra_type_hints=mut.extra_type_hints, is_lazy_style=is_lazy) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 676c953ee..ffb679964 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -17,7 +17,7 @@ from tvm.script.ir_builder import tir, IRBuilder from tvm.tir.expr import BufferLoad, CallEffectKind, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var -from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union +from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union, Literal, get_origin from collections.abc import Hashable from collections.abc import Sequence @@ -823,9 +823,23 @@ def get_type_hints(func): def const(name: str, dtype: str = "int32") -> tuple[Var, ...]: + """ + Declare constexpr variables for dynamic tensor dimensions (eager mode only). + + In eager mode, use T.const() to declare shape dimensions that will be + inferred from actual tensor arguments at runtime. + + Example:: + + @tilelang.jit + def kernel(A, B): + M, N = T.const("M, N") + A: T.Tensor[[M, N], T.float32] + ... + """ builder = Builder.current() - assert builder is not None, "const can only be used inside `tilelang.lazy_jit` function" - assert builder.lazy_jit, "const can only be used inside `tilelang.lazy_jit` function" + assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)" + assert builder.lazy_jit, "T.const() can only be used inside @tilelang.jit (eager mode)" if "," in name: names = re.split(r"\s*,\s*", name) return tuple(builder.constexpr(n, dtype) for n in names) @@ -838,8 +852,17 @@ def const(name: str, dtype: str = "int32") -> tuple[Var, ...]: @dataclass class TirTemplate(Generic[_P, _T]): + """ + Template for generating TIR PrimFunc with dynamic shape substitution. + + For lazy-style functions, the PrimFunc is used directly without substitution. + For eager-style functions, constexpr variables are substituted based on + actual tensor shapes at runtime. + """ + prim_func: PrimFunc[_P, _T] - matcher: dict[Var, tuple[tvm.tir.Var, str, int]] + matcher: dict[Var, tuple[tvm.tir.Var, str, int]] | None = None + is_lazy_style: bool = False # True if from lazy-style (returns PrimFunc directly) @classmethod def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate[_P, _T]: @@ -863,7 +886,12 @@ def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate f"Buffer shapes: {shapes}\n" f"Buffer strides: {strides}" ) - return cls(prim_func=prim_func, matcher=matcher) + return cls(prim_func=prim_func, matcher=matcher, is_lazy_style=False) + + @classmethod + def from_lazy_style(cls, prim_func: PrimFunc[_P, _T]) -> TirTemplate[_P, _T]: + """Create template from lazy-style function that returns PrimFunc directly.""" + return cls(prim_func=prim_func, is_lazy_style=True) def _parse_phase2_key(self, **kwargs): result = [] @@ -879,6 +907,8 @@ def _parse_phase2_key(self, **kwargs): return tuple(result) def get_tir(self, **kwargs): + if self.is_lazy_style: + return self.prim_func values = self._parse_phase2_key(**kwargs) subs = {name: value for name, value in zip(self.matcher, values)} result = substitute_primfunc(self.prim_func, subs) @@ -890,11 +920,28 @@ def get_tir(self, **kwargs): @dataclass class LazyJITFunc(Generic[_P, _T]): + """ + Internal wrapper for JIT-compiled functions. + + This class handles both lazy and eager execution styles: + + - **lazy style**: Function explicitly returns a PrimFunc. The original function + is called directly to obtain the TIR. + + - **eager style**: Function uses the DSL builder pattern with tensor type + annotations. The TIR is constructed by tracing the function body through + the Builder. + + The style is determined by `_is_lazy_style()` which checks if calling the + original function returns a PrimFunc directly. + """ + orig_func: Callable[_P, _T] arg_names: list[str] tensor_args: dict[str, Buffer | Var] tensor_args_defaults: dict[str, Any] ir_gen: IRGenerator[_P, _T] + mode: Literal["auto", "lazy", "eager"] = "auto" def __post_init__(self): # we don't want it to show up in the constructor @@ -911,10 +958,36 @@ def _parse_phase1_key(self, *args, **kwargs): p1_key = tuple(sorted(kwargs.items())) return p1_key, tensor_args, kwargs - def parse_args(self, *args, **kwargs): - p1_key, tensor_args, kwargs = self._parse_phase1_key(*args, **kwargs) - tir_temp = self.p1_cache.get(p1_key, None) - if tir_temp is None: + def _is_lazy_style(self) -> bool: + """ + Check if the function uses lazy style (explicitly returns PrimFunc). + + This uses static AST analysis performed during mutate() to detect + if the function defines an inner @T.prim_func and returns it. + + Lazy style functions define an inner @T.prim_func and return it: + @jit + def foo(M, N): + @T.prim_func + def kernel(...): ... + return kernel # <- returns PrimFunc + + Eager style functions use the builder pattern with type annotations: + @jit + def foo(A, B): + A: T.Tensor[...] + with T.Kernel(...): ... + # no return + """ + return self.ir_gen.is_lazy_style + + def _build_tir_template(self, *args, **kwargs) -> TirTemplate[_P, _T]: + """Build TIR template based on the execution mode.""" + if self.mode == "lazy": + # lazy: function returns PrimFunc directly + return TirTemplate.from_lazy_style(self.orig_func(*args, **kwargs)) + elif self.mode == "eager": + # eager: trace function body through Builder to construct TIR builder = Builder() builder.lazy_jit = True with builder.prim_func(self.orig_func.__name__): @@ -923,18 +996,52 @@ def parse_args(self, *args, **kwargs): pf.orig_func = self.orig_func if builder.out_idx: pf.out_idx_override = builder.out_idx - tir_temp = TirTemplate.create(pf, builder.constexpr_var) + return TirTemplate.create(pf, builder.constexpr_var) + else: + raise ValueError(f"Invalid jit mode: {self.mode}, expected 'lazy' or 'eager'") + + def parse_args(self, *args, **kwargs): + """Parse arguments and return cache key and tensor args.""" + p1_key, tensor_args, kwargs = self._parse_phase1_key(*args, **kwargs) + if not tensor_args: + return (p1_key, None), kwargs + tir_temp = self.p1_cache.get(p1_key, None) + if tir_temp is None: + # mode should be set by JITImpl before calling parse_args + tir_temp = self._build_tir_template(*args, **kwargs) self.p1_cache[p1_key] = tir_temp p2_key = tir_temp._parse_phase2_key(**tensor_args) return (p1_key, p2_key), tensor_args def get_tir(self, *args, **kwargs): (p1_key, _), tensor_args = self.parse_args(*args, **kwargs) + if p1_key not in self.p1_cache: + # in legacy gemm, we use lazy tir template to build the tir + tir_temp = self._build_tir_template(*args, **kwargs) + self.p1_cache[p1_key] = tir_temp + return tir_temp.get_tir(**tensor_args) return self.p1_cache[p1_key].get_tir(**tensor_args) def __call__(self, *args, **kwargs): return self.get_tir(*args, **kwargs) + def set_mode(self, mode: Literal["lazy", "eager"]) -> LazyJITFunc[_P, _T]: + """Set the JIT execution mode (internal use only).""" + self.mode = mode + return self + + # Proxy function attributes for compatibility with autotuner and inspect. + # These attributes are needed by autotuner to extract closure variables + # and generate cache keys. + _PROXIED_ATTRS = frozenset({"__closure__", "__code__", "__name__", "__globals__", "__wrapped__"}) + + def __getattr__(self, name): + if name in LazyJITFunc._PROXIED_ATTRS: + if name == "__wrapped__": + return self.orig_func + return getattr(self.orig_func, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def substitute_primfunc(prim_func, vmap): analyzer = tvm.arith.Analyzer() @@ -958,7 +1065,7 @@ def substitute_buffer(buf): ) -def prim_func(func: Callable[_P, _T] = None, *, lazy_jit=False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: +def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]: sig = inspect.signature(func) ir_gen = mutate(func) @@ -972,8 +1079,12 @@ def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, elif param.name in func_annot: annot[param.name] = func_annot[param.name] for k in annot: - if not isinstance(annot[k], type) and callable(annot[k]): + # Call callable annotations (e.g., factory functions) to get the actual type. + # Skip typing generics like Optional[int], Union[...], List[...] which are + # callable but cannot be instantiated. + if not isinstance(annot[k], type) and callable(annot[k]) and get_origin(annot[k]) is None: annot[k] = annot[k]() + if lazy_jit: arg_names = list(sig.parameters.keys()) tensor_args = {k: v for k, v in annot.items() if isinstance(v, (Buffer, Var))}