diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index 01695742b..91d85a1a4 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -329,21 +329,15 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1 max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - block_H = 64 Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") - # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # Ensure at least one element equals cache_seqlen random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence - - print("cache_seqlens: ", cache_seqlens) - max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() - print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") @@ -357,13 +351,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1 # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) - # print("block_indices: ", block_indices) - actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] - print("actual_num_blocks: ", actual_num_blocks) - # print(block_indices.shape, actual_num_blocks.shape) - max_num_blocks = torch.max(max_valid_num_blocks).item() - print("max_num_blocks: ", max_num_blocks) ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) @@ -402,6 +390,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1 avg_time = elapsed_time / 1000 avg_flops = total_flops / avg_time print(f"Average time: {avg_time:.6f} seconds") + print(f"Average FLOPS: {avg_flops:.2f} GFLOPS") # Measure performance of reference implementation import flash_attn # noqa: F401 @@ -415,7 +404,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1 avg_time_ref = elapsed_time_ref / 1000 avg_flops_ref = total_flops / avg_time_ref print(f"Average time of ref: {avg_time_ref:.6f} seconds") - + print(f"Average FLOPS of ref: {avg_flops_ref:.2f} GFLOPS") print(f"Speedup: {avg_time_ref / avg_time:.2f}x") diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index d4bc9480f..15e552587 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -6,7 +6,6 @@ from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter, ) -from tilelang.transform import simplify_prim_func def make_swizzle_layout(shared_buf): @@ -25,7 +24,6 @@ def transform_func(i, j): @tilelang.jit(out_idx=[2]) -@simplify_prim_func def tl_matmul( M, N, diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 762885ec3..04f735950 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -7,7 +7,6 @@ from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter, ) -from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type tilelang.testing.set_random_seed(0) @@ -29,7 +28,6 @@ def transform_func(i, j): @tilelang.jit(out_idx=[2]) -@simplify_prim_func def tl_matmul( M, N, diff --git a/examples/lazy_jit/lazyjit.en.ipynb b/examples/lazy_jit/lazyjit.en.ipynb index 5b5df8e6a..9cb343a1e 100644 --- a/examples/lazy_jit/lazyjit.en.ipynb +++ b/examples/lazy_jit/lazyjit.en.ipynb @@ -53,7 +53,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm(\n", " A,\n", " B,\n", @@ -209,7 +209,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_dyn_K(A, B):\n", " M, N, K = T.dynamic(\"M, N, K\")\n", " A: T.Tensor[[M, K], T.float16]\n", @@ -248,7 +248,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def as_contingious(A):\n", " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", @@ -307,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_ptr(\n", " A,\n", " B,\n", @@ -359,7 +359,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_ptr_dyn(A, B, M, N, K):\n", " M: T.int32\n", " N: T.int32\n", @@ -421,7 +421,7 @@ } ], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def example_wrong_kernel(A):\n", " M = T.const(\"M\")\n", " A: T.Tensor[[M * 2, M * 3], T.float32]\n", @@ -470,7 +470,7 @@ } ], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def dyn_annot(\n", " A: T.ptr, # 1. T.ptr type annotation\n", " is_2d=False,\n", @@ -515,7 +515,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def add_one(X, data: T.float32 = 1):\n", " M, N = T.const(\"M, N\")\n", " X: T.Tensor[[M, N], T.float32]\n", @@ -577,7 +577,7 @@ "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", "\n", "\n", - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def dummy_kernel(A, B):\n", " M, N = T.const(\"M, N\")\n", " A: T.Tensor[[M, N], T.float16]\n", @@ -797,7 +797,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def element_wise(A, fn):\n", " N = T.dynamic(\"N\")\n", " A: T.Tensor[[N], T.float32]\n", @@ -857,7 +857,7 @@ " n31(x * 3 + 1, var)\n", "\n", "\n", - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", " n31(n, A[0])" diff --git a/examples/lazy_jit/lazyjit.zh.ipynb b/examples/lazy_jit/lazyjit.zh.ipynb index 387aff461..d7afafe69 100644 --- a/examples/lazy_jit/lazyjit.zh.ipynb +++ b/examples/lazy_jit/lazyjit.zh.ipynb @@ -53,7 +53,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm(\n", " A,\n", " B,\n", @@ -209,7 +209,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_dyn_K(A, B):\n", " M, N, K = T.dynamic(\"M, N, K\")\n", " A: T.Tensor[[M, K], T.float16]\n", @@ -248,7 +248,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def as_contingious(A):\n", " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", @@ -307,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_ptr(\n", " A,\n", " B,\n", @@ -359,7 +359,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def gemm_ptr_dyn(A, B, M, N, K):\n", " M: T.int32\n", " N: T.int32\n", @@ -421,7 +421,7 @@ } ], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def example_wrong_kernel(A):\n", " M = T.const(\"M\")\n", " A: T.Tensor[[M * 2, M * 3], T.float32]\n", @@ -470,7 +470,7 @@ } ], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def dyn_annot(\n", " A: T.ptr, # 1. T.ptr type annotation\n", " is_2d=False,\n", @@ -515,7 +515,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def add_one(X, data: T.float32 = 1):\n", " M, N = T.const(\"M, N\")\n", " X: T.Tensor[[M, N], T.float32]\n", @@ -577,7 +577,7 @@ "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", "\n", "\n", - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def dummy_kernel(A, B):\n", " M, N = T.const(\"M, N\")\n", " A: T.Tensor[[M, N], T.float16]\n", @@ -797,7 +797,7 @@ "metadata": {}, "outputs": [], "source": [ - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def element_wise(A, fn):\n", " N = T.dynamic(\"N\")\n", " A: T.Tensor[[N], T.float32]\n", @@ -857,7 +857,7 @@ " n31(x * 3 + 1, var)\n", "\n", "\n", - "@tilelang.lazy_jit\n", + "@tilelang.jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", " n31(n, A[0])" diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 452e1b2ae..15d7f71e2 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -222,7 +222,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { continue; // Check if buffer exists in use_list_ - if (!use_list_.count(buffer)) { + if (!use_list_.count(buffer) && IsFragmentBuffer(buffer)) { LOG(WARNING) << "Layout inference failed for buffer " << buffer << ". " << "The buffer cannot be inferred with current layout " diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index dcb9570fb..01fa84586 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -54,7 +54,6 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout, } Array layout_shape = layout->OutputShape(); Array output_shape = layout_shape; - if (ptr_type->storage_scope == "shared" || ptr_type->storage_scope == "shared.dyn") { int replicate_extent = 1; @@ -67,6 +66,8 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout, } for (size_t i = 0; i < layout_shape.size(); i++) { auto shape = layout_shape[i].as(); + ICHECK(shape) << "Layout output shape must be constant integer, but got: " + << layout_shape[i]; layout_extent *= shape->value; } replicate_extent = buffer_extent / layout_extent; diff --git a/testing/python/arith/test_arith_hard.py b/testing/python/arith/test_arith_hard.py index 6fc859ba6..45bd86d0d 100644 --- a/testing/python/arith/test_arith_hard.py +++ b/testing/python/arith/test_arith_hard.py @@ -3,6 +3,7 @@ from tvm.arith import Analyzer from tvm.ir.expr import Range from tvm.tir.expr import Not, Or +from tvm.tir import all as tir_all def implies(x, y): @@ -21,30 +22,25 @@ def check_expr(expr): if not result: smtlib2 = analyzer.get_smtlib2(expr) raise AssertionError(f"Failed to prove: {expr}\nSMT-LIB2:\n{smtlib2}") - # assert result, f"Failed to prove: {expr}" - @T.macro def complex_expr_1(): - return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b) + return implies(tir_all(a > 0, b > 0, c > 0), ((b - a) // c) * c + a <= b) check_expr(complex_expr_1()) - @T.macro def complex_expr_2(): - return implies(a < b and b < c and a * d < b * d, b * d < c * d) + return implies(tir_all(a < b, b < c, a * d < b * d), b * d < c * d) check_expr(complex_expr_2()) - @T.macro def complex_expr_3(): - return implies(a >= 0 and a < 128, a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64) + return implies(tir_all(a >= 0, a < 128), a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64) check_expr(complex_expr_3()) - @T.macro def complex_expr_4(): return implies( - a >= 0 and a < 128, + tir_all(a >= 0, a < 128), (a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 + 16 - (a // 64 + a % 8 // 4) // 2 * 64) // 512 == (a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 - (a // 64 + a % 8 // 4) // 2 * 64) // 512, ) @@ -59,9 +55,8 @@ def test_smtlib2(): b = T.Var("b", T.int32) c = T.Var("c", T.int32) - @T.macro def complex_expr_1(): - return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b) + return implies(tir_all(a > 0, b > 0, c > 0), ((b - a) // c) * c + a <= b) e = complex_expr_1() analyzer = Analyzer() diff --git a/testing/python/issue/test_tilelang_issue_1549.py b/testing/python/issue/test_tilelang_issue_1549.py index d23659e37..69d486ea5 100644 --- a/testing/python/issue/test_tilelang_issue_1549.py +++ b/testing/python/issue/test_tilelang_issue_1549.py @@ -4,6 +4,7 @@ import torch +@tilelang.testing.requires_cuda def test_issue_1549_strange_var_vectorization(): @tl.jit def get_wrong_kernel(M: int = 4096): diff --git a/testing/python/issue/test_tilelang_issue_1601.py b/testing/python/issue/test_tilelang_issue_1601.py index 587622185..762a13f4a 100644 --- a/testing/python/issue/test_tilelang_issue_1601.py +++ b/testing/python/issue/test_tilelang_issue_1601.py @@ -3,6 +3,7 @@ import tilelang.language as T +@tilelang.testing.requires_cuda def test_issue_1601(): @tilelang.jit def qwq(): diff --git a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py index 1870be745..535e98329 100644 --- a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py @@ -6,17 +6,15 @@ from tilelang.intrinsics import ( make_mma_swizzle_layout as make_swizzle_layout, ) - +from tilelang.transform import simplify_prim_func from tilelang.intrinsics.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, INT4TensorCoreIntrinEmitterWithLadderTransform, ) -from tilelang.transform import simplify_prim_func tilelang.testing.set_random_seed(42) -# @simplify_prim_func def tl_matmul( M, N, diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 9d5213e1a..010ddbe8e 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -4,7 +4,8 @@ import tilelang.testing import tvm from tvm.script.ir_builder.base import IRBuilderFrame -from tvm.tir.expr import IntImm, Var +from tvm.tir.expr import IntImm, Var, Not, Or +from tvm.tir import all as tir_all def test_argument(): @@ -39,7 +40,7 @@ def test_argument( def test_expr(): - from tilelang.language.v2.dtypes import _all_dtypes + from tilelang.language.eager.dtypes import _all_dtypes errors = [] for name in _all_dtypes: @@ -204,9 +205,9 @@ def test_str_repr(): def test_var_assign(): - @tilelang.jit(out_idx=-1) - @T.prim_func - def test_var_assign(A: T.Tensor((2,), T.int32)): + @tilelang.jit + def test_var_assign(): + A = T.empty((2,), T.int32) with T.Kernel(1) as _: a: T.int32 = 1 b: T.int32 = a @@ -214,8 +215,9 @@ def test_var_assign(A: T.Tensor((2,), T.int32)): d: T.int32 = a A[0] = b A[1] = d + return A - res = test_var_assign()() + res = test_var_assign() assert res[0] == 1 assert res[1] == 2 @@ -255,9 +257,9 @@ def test_macro_return(): def test_serial_for_with_step(): - @tilelang.jit(out_idx=-1) - @T.prim_func - def test_stepped_serial(A: T.Tensor((10,), T.int32)): + @tilelang.jit + def stepped_serial(): + A = T.empty((10,), T.int32) with T.Kernel(1) as _: for i in range(0, 10, 2): T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range") @@ -265,22 +267,22 @@ def test_stepped_serial(A: T.Tensor((10,), T.int32)): for i in range(1, 10, 2): T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range") A[i] = 2.0 + return A - ker = test_stepped_serial() - res = ker() + res = stepped_serial() ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device="cuda") assert torch.all(res == ref), f"Expected {ref}, but got {res}" - @tilelang.jit(out_idx=-1) - @T.prim_func - def test_serial_step_neg(A: T.Tensor((10,), T.int32)): + @tilelang.jit + def stepped_serial_neg(): + A = T.empty((10,), T.int32) with T.Kernel(1) as _: for i in range(10, 0, -1): T.device_assert(0 < i <= 10, "i out of range") A[10 - i] = i + return A - ker = test_serial_step_neg() - res = ker() + res = stepped_serial_neg() ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device="cuda") assert torch.all(res == ref), f"Expected {ref}, but got {res}" @@ -292,8 +294,8 @@ def test_serial_step_neg(A: T.Tensor((10,), T.int32)): def test_swap_logic(): @tilelang.jit - @T.prim_func - def swap_var(A: T.Tensor[(2,), T.float32]): + def swap_var(A): + A: T.Tensor[(2,), T.float32] with T.Kernel(1, threads=1) as _: a = T.alloc_var(T.float32, A[0]) b = T.alloc_var(T.float32, A[1]) @@ -301,20 +303,19 @@ def swap_var(A: T.Tensor[(2,), T.float32]): A[0], A[1] = a, b @tilelang.jit - @T.prim_func - def swap_idx(A: T.Tensor[(2,), T.float32]): + def swap_idx(A): + A: T.Tensor[(2,), T.float32] with T.Kernel(1, threads=1) as _: A[0], A[1] = A[1], A[0] - k_swap_var = swap_var() data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda() - k_swap_var(data) + swap_var(data) ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda() + torch.testing.assert_close(data, ref) - k_swap_idx = swap_idx() data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda() - k_swap_idx(data) + swap_idx(data) ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda() torch.testing.assert_close(data, ref) @@ -322,9 +323,9 @@ def swap_idx(A: T.Tensor[(2,), T.float32]): # TODO(Gong): ROCm is not supported alloc_var with initializer @tilelang.testing.requires_cuda def test_while_loop(): - @tilelang.jit(out_idx=-1) - @T.prim_func - def test_while_loop(A: T.Tensor((1,), T.int32)): + @tilelang.jit + def while_loop(): + A = T.empty((1,), T.int32) with T.Kernel(1) as _: i = T.alloc_var(T.int32, 0) sum = T.alloc_var(T.int32) @@ -332,10 +333,10 @@ def test_while_loop(A: T.Tensor((1,), T.int32)): sum += i i += 1 A[0] = sum + return A - ker = test_while_loop() - A = ker() - assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}" + res = while_loop() + assert res[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {res[0].item()}" def test_var_macro(): @@ -447,36 +448,32 @@ def test_boolop(): c = Var("c", T.int32) d = Var("d", T.int32) - @T.macro def cond(): - return not (a < b and b < c and a * d < b * d) or b * d < c * d + return Or(Not(tir_all(a < b, b < c, a * d < b * d)), b * d < c * d) cond() def test_constexpr_if(): @tilelang.jit - def probe(tmp: bool): - @T.prim_func - def foo(A: T.Tensor[[2], T.int32]): - with T.Kernel(1): - if tmp: - v = A[0] - else: - v = A[1] - if tmp: - A[1] = v + 1 - else: - A[0] = v + 1 - - return foo + def probe(A, tmp: bool): + A: T.Tensor[(2,), T.int32] + with T.Kernel(1): + if tmp: + v = A[0] + else: + v = A[1] + if tmp: + A[1] = v + 1 + else: + A[0] = v + 1 A = torch.tensor([10, 20], dtype=torch.int32).cuda() expect_1 = torch.tensor([10, 11], dtype=torch.int32).cuda() expect_2 = torch.tensor([12, 11], dtype=torch.int32).cuda() - probe(True)(A) + probe(A, True) assert torch.equal(A, expect_1) - probe(False)(A) + probe(A, False) assert torch.equal(A, expect_2) diff --git a/testing/python/language/test_tilelang_language_lazy_jit.py b/testing/python/language/test_tilelang_language_lazy_jit.py index e3eabdce6..d7eba6c1b 100644 --- a/testing/python/language/test_tilelang_language_lazy_jit.py +++ b/testing/python/language/test_tilelang_language_lazy_jit.py @@ -6,7 +6,7 @@ def test_jit2_gemm(): - @tilelang.lazy_jit(verbose=True) + @tilelang.jit(verbose=True) def gemm( A, B, @@ -45,7 +45,7 @@ def gemm( def test_jit2_gemm_ptr(): - @tilelang.lazy_jit + @tilelang.jit def gemm_ptr( A: T.ptr, B: T.ptr, @@ -102,28 +102,28 @@ def copy_impl(A, B): with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) - @tilelang.lazy_jit + @tilelang.jit def copy1(A, B): N, M = T.const("N, M") A: T.Tensor[[N, M], T.float32] B: T.Tensor[[N, M], T.float32] copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy2( A: T.Tensor[[128, 128], T.float32], B: T.Tensor[[128, 128], T.float32], ): copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy3(A, B): N = T.const("N") A: T.Tensor[[N, 128], T.float32] B: T.Tensor[[N, 128], T.float32] copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy4(A, B): N = T.dynamic("N") M = T.const("M") @@ -131,14 +131,14 @@ def copy4(A, B): B: T.Tensor[[N, M], T.float32] copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy5(A, B): N, M, N_, M_ = T.const("N, M, N_, M_") A: T.StridedTensor[[N, M], [N_, M_], T.float32] B: T.StridedTensor[[N, M], [N_, M_], T.float32] copy_impl(A, B) - @tilelang.lazy_jit + @tilelang.jit def copy6(A, B): N = T.dynamic("N") M, N_, M_ = T.const("M, N_, M_") @@ -146,7 +146,7 @@ def copy6(A, B): B: T.StridedTensor[[N, M], [N_, M_], T.float32] copy_impl(A, B) - tilelang.par_compile([copy.get_tir(T.Tensor((128, 128))) for copy in [copy1, copy2, copy3, copy4]]) + tilelang.par_compile([copy.get_tir(T.Tensor((128, 128)), T.Tensor((128, 128))) for copy in [copy1, copy2, copy3, copy4]]) for copy in [copy1, copy2, copy3, copy4]: A = torch.randn(128, 128, device="cuda") @@ -175,37 +175,37 @@ def copy_impl(A): T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) return B - @tilelang.lazy_jit + @tilelang.jit def copy1(A): M, N = T.const("M, N") A: T.Tensor[[M, N], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy2(A): A: T.Tensor[[128, 128], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy3(A): N = T.const("N") A: T.Tensor[[N, 128], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy4(A): N = T.dynamic("N") M = T.const("M") A: T.Tensor[[N, M], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy5(A): N, M, N_, M_ = T.const("N, M, N_, M_") A: T.StridedTensor[[N, M], [N_, M_], T.float32] return copy_impl(A) - @tilelang.lazy_jit + @tilelang.jit def copy6(A): N = T.dynamic("N") M, N_, M_ = T.const("M, N_, M_") @@ -226,4 +226,5 @@ def copy6(A): if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_jit2_return() diff --git a/testing/python/language/test_tilelang_language_ptr.py b/testing/python/language/test_tilelang_language_ptr.py index da137e019..dd167efe5 100644 --- a/testing/python/language/test_tilelang_language_ptr.py +++ b/testing/python/language/test_tilelang_language_ptr.py @@ -41,21 +41,21 @@ def main( def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): program = matmul_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) - jit_kernel = tl.compile(program, execution_backend="cython") + cython_jit_kernel = tl.compile(program, execution_backend="cython") + ffi_jit_kernel = tl.compile(program, execution_backend="tvm_ffi") def ref_program(a, b): return (a @ b.T).to(torch.float32) a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) + ffi_c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) + cython_c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) - c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) - - jit_kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), M, N, K) - - ref_c = (a @ b.T).to(map_torch_type(accum_dtype)) - - torch.testing.assert_close(c, ref_c, atol=1e-2, rtol=1e-2) + ffi_jit_kernel(a, b, ffi_c, M, N, K) + cython_jit_kernel(a.data_ptr(), b.data_ptr(), cython_c.data_ptr(), M, N, K) + torch.testing.assert_close(ffi_c, ref_program(a, b), atol=1e-2, rtol=1e-2) + torch.testing.assert_close(cython_c, ffi_c, atol=1e-2, rtol=1e-2) def test_matmul(): diff --git a/testing/python/language/test_tilelang_language_subtype.py b/testing/python/language/test_tilelang_language_subtype.py index e7f898cc4..ef0641a9d 100644 --- a/testing/python/language/test_tilelang_language_subtype.py +++ b/testing/python/language/test_tilelang_language_subtype.py @@ -6,7 +6,7 @@ import tilelang.language as T -@tilelang.lazy_jit +@tilelang.jit def basic_shape_kernel(x): m = T.dynamic("m") x: T.Tensor[(m, 16), T.float4_e2m1fn] @@ -15,7 +15,7 @@ def basic_shape_kernel(x): pass -@tilelang.lazy_jit +@tilelang.jit def strided_kernel(x): m = T.dynamic("m") s = T.dynamic("s") @@ -95,7 +95,7 @@ def test_subtype_different_strides(): strided_kernel(t_strided) -@tilelang.lazy_jit +@tilelang.jit def symbolic_last_dim_kernel(x): """Kernel with symbolic variable in the last dimension.""" n = T.dynamic("n") @@ -105,7 +105,7 @@ def symbolic_last_dim_kernel(x): pass -@tilelang.lazy_jit +@tilelang.jit def symbolic_last_dim_strided_kernel(x): """Kernel with symbolic variable in both shape and stride of last dimension.""" n = T.dynamic("n") @@ -116,7 +116,7 @@ def symbolic_last_dim_strided_kernel(x): pass -@tilelang.lazy_jit +@tilelang.jit def shared_symbolic_kernel(x, y): """Kernel with shared symbolic variable across multiple buffers.""" m = T.dynamic("m") @@ -127,7 +127,7 @@ def shared_symbolic_kernel(x, y): pass -@tilelang.lazy_jit +@tilelang.jit def shared_symbolic_strided_kernel(x, y): """Kernel with shared symbolic variable in strides.""" m = T.dynamic("m") @@ -139,7 +139,7 @@ def shared_symbolic_strided_kernel(x, y): pass -@tilelang.lazy_jit +@tilelang.jit def complex_expr_kernel(x, y): """Kernel with complex expressions involving symbolic variables.""" m = T.dynamic("m") diff --git a/testing/python/layout/test_tilelang_annotate_loop_layout.py b/testing/python/layout/test_tilelang_annotate_loop_layout.py index 9698a66ee..fbe9bb8a8 100644 --- a/testing/python/layout/test_tilelang_annotate_loop_layout.py +++ b/testing/python/layout/test_tilelang_annotate_loop_layout.py @@ -4,7 +4,7 @@ # TODO(lei): replicate loop layout and more complicated layout cases -@tilelang.lazy_jit +@tilelang.jit def loop_layout_kernel(A, B, loop_layout): M, N = T.const("M, N") A: T.Tensor[(M, N), T.float32] @@ -51,7 +51,7 @@ def loop_layout_fn(i, j): assert "*(float4*)(B + ((((int)threadIdx.x) * 32) + (i * 4))) = *(float4*)(A + ((((int)threadIdx.x) * 32) + (i * 4)));" in code -@tilelang.lazy_jit +@tilelang.jit def copy_with_layout_kernel(A, B, loop_layout): M, N = T.const("M, N") A: T.Tensor[(M, N), T.float32] @@ -79,7 +79,7 @@ def loop_layout_fn(i, j, rep): assert "*(float4*)(B + ((i * 512) + (((int)threadIdx.x) * 4))) = *(float4*)(A + ((i * 512) + (((int)threadIdx.x) * 4)));" in code -@tilelang.lazy_jit +@tilelang.jit def replicate_loop_layout_kernel(A, B, loop_layout): M, N = T.const("M, N") A: T.Tensor[(M, N), T.float32] diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 420516173..a41c40a5b 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -142,7 +142,7 @@ def _load_tile_lang_lib(): if env.SKIP_LOADING_TILELANG_SO == "0": _LIB, _LIB_PATH = _load_tile_lang_lib() - from .jit import jit, lazy_jit, JITKernel, compile, par_compile # noqa: F401 + from .jit import jit, JITKernel, compile, par_compile # noqa: F401 from .profiler import Profiler # noqa: F401 from .cache import clear_cache # noqa: F401 from .utils import ( @@ -160,7 +160,7 @@ def _load_tile_lang_lib(): engine, # noqa: F401 tools, # noqa: F401 ) - from .language.v2 import dtypes # noqa: F401 + from .language.eager import dtypes # noqa: F401 from .autotuner import autotune # noqa: F401 from .transform import PassConfigKey # noqa: F401 from .engine import lower, register_cuda_postproc, register_hip_postproc, register_c_postproc # noqa: F401 diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 44c2ff41f..822fb3298 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -24,7 +24,7 @@ except ImportError: # Python < 3.10 from typing_extensions import ParamSpec from tilelang import tvm as tvm -from tilelang.language.v2 import PrimFunc, prim_func, LazyJITFunc +from tilelang.language.eager import PrimFunc, prim_func, JITFunc from tvm.target import Target from tilelang.jit.kernel import JITKernel @@ -192,72 +192,75 @@ def par_compile( @dataclass class JITImpl(Generic[_P, _KP, _T, _Ret]): """ - Detailed Just-In-Time wrapper for TileLang programs. + Just-In-Time compilation wrapper for TileLang programs. - This dataclass encapsulates the configuration and runtime helpers used by the - top-level `jit` and `jit2` decorators. It represents a configured JIT - "factory" that can (a) elaborate TileLang/PrimFunc creators into concrete - TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via - the TVM bridge, (c) cache compiled kernels keyed by call-site arguments - (and optional tuning parameters), and (d) provide parallel compilation - helpers for batch autotuning workflows. + This class provides a unified interface for compiling and executing TileLang + kernels. It supports two execution modes that are automatically inferred: + + Execution Modes + --------------- + - **lazy**: The decorated function returns a PrimFunc explicitly. Calling the + JIT wrapper returns a compiled kernel object, which can be invoked separately. + This mode is useful when you want to inspect or reuse the kernel object. + + Example (lazy mode):: + + @tilelang.jit(out_idx=[-1]) + def matmul(M, N, K, block_M, block_N, block_K): + @T.prim_func + def kernel(A: T.Tensor((M, K), dtype), ...): + ... + return kernel # explicitly return PrimFunc + + kernel = matmul(1024, 1024, 1024, 128, 128, 32) # returns kernel + result = kernel(a, b) # execute separately + + - **eager**: The decorated function uses the DSL builder pattern with tensor + type annotations. Calling the JIT wrapper compiles and immediately executes + the kernel, returning the result directly. + + Example (eager mode):: + + @tilelang.jit + def gemm(A, B, C, block_M: int = 64): + M, N, K = T.const("M N K") + A: T.Tensor[[M, K], dtype] # tensor shape via annotation + B: T.Tensor[[K, N], dtype] + C: T.Tensor[[M, N], dtype] + with T.Kernel(...): + ... + + gemm(A, B, C) # compiles and executes immediately + + The mode is automatically inferred based on whether the function returns a + PrimFunc (lazy) or uses the builder pattern (eager). Attributes ---------- out_idx : list[int] | int | None - Which output tensor(s) of the compiled kernel should be returned to the - caller. Accepts a single index, a list of indices, or None to return all. - execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] - Backend used for exchanging arguments and executing the generated kernel. - target : str | tvm.target.Target - TVM compilation target (e.g. "cuda", "llvm", or "auto"). - target_host : str | tvm.target.Target | None - Host target used for cross-compilation, or None to infer/default. - verbose : bool - Enable verbose messages during compilation/build. + Index(es) of output tensor(s) to return (lazy mode only). + execution_backend : str | None + Backend for kernel execution ("auto", "dlpack", "tvm_ffi", etc.). + target : str | Target | None + TVM compilation target (e.g., "cuda", "llvm", "auto"). + target_host : str | Target | None + Host target for cross-compilation. + verbose : bool | None + Enable verbose compilation output. pass_configs : dict[str, Any] | None - Extra TVM pass configuration options forwarded to the compiler's - PassContext. + TVM pass configuration options. debug_root_path : str | None - If provided, compiled kernel source and the elaborated Python program - are written to this directory to ease debugging and inspection. + Directory to save compiled kernel source for debugging. compile_flags : list[str] | str | None - Additional flags passed to the compiler. A single string will be converted - to a single-element list. + Additional compiler flags. func_source : str - Original Python source string from which the PrimFunc or creator was - derived. Used for diagnostics and debug dumps. + Original Python source code of the decorated function. signature : inspect.Signature - Function signature of the original Python function (useful for tooling). - v2 : bool - Indicates whether the object wraps a "v2" PrimFunc creator (True) or a - plain callable / PrimFunc (False). v2-mode enables argument conversion - hooks and a distinct cache keying strategy. - func : Callable | PrimFunc | PrimFuncCreater - The underlying object: either a user function that returns a PrimFunc - (creator), a PrimFuncCreater, or an already-constructed PrimFunc. - For presentation/readability the function is stored last in the dataclass. - - Behavioral summary - ------------------ - - get_tir(*args, **kwargs) - Converts provided call-site arguments into a concrete PrimFunc. If the - wrapped object is a PrimFuncCreater or a user callable, it is invoked - with the given arguments. If the wrapped object is already a PrimFunc, - it is returned as-is. - - - compile(...) - A convenience wrapper that elaborates and immediately compiles a single - PrimFunc into a JITKernel using the module-level `compile` function. - When `debug_root_path` is set, the compiled C kernel and the source - Python program are saved for inspection. - - - par_compile(configs, ...) - Accepts an iterable of configs (either dicts mapping keyword args or - tuples mapping to positional args). Each config is elaborated to a - PrimFunc and the resulting set is compiled in parallel via the - module-level `par_compile` helper. Returns a list of JITKernel objects - in the same order as the provided configs. + Function signature of the original function. + mode : Literal["auto", "lazy", "eager"] + Execution mode. "auto" infers from function behavior. + func : JITFunc + The wrapped function object. """ out_idx: list[int] | int | None @@ -270,9 +273,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): compile_flags: list[str] | str | None func_source: str signature: inspect.Signature - lazy_jit: bool + mode: Literal["auto", "lazy", "eager"] # place func at the last element for better __repr__ - func: Callable[_P, _T] | PrimFunc[_KP, _T] | LazyJITFunc[_KP, _T] + func: JITFunc[_KP, _T] def __post_init__(self): if self.debug_root_path is not None and not path.isabs(self.debug_root_path): @@ -288,9 +291,8 @@ def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]: """ Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object. """ - if isinstance(self.func, LazyJITFunc): - tir = self.func.get_tir(*args, **kwargs) - elif isinstance(self.func, PrimFunc): + self.initialize_jit_mode(*args, **kwargs) + if isinstance(self.func, PrimFunc): tir = self.func elif callable(self.func): tir = self.func(*args, **kwargs) @@ -299,6 +301,28 @@ def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]: assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}" return tir + def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]: + """ + Infer the JIT execution mode based on function behavior. + + Returns "lazy" if the function explicitly returns a PrimFunc, + or "eager" if it uses the DSL builder pattern. + """ + if self.mode in ("lazy", "eager"): + return self.mode + # auto: infer by checking if function returns PrimFunc directly + if not isinstance(self.func, JITFunc): + return "lazy" + return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager" + + def initialize_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]: + if self.mode == "auto": + self.mode = self._infer_jit_mode(*args, **kwargs) + self.func.set_mode(self.mode) + if self.mode == "eager" and self.out_idx is not None: + raise ValueError("out_idx is only supported in lazy mode. In eager mode, use T.empty() to declare output tensors instead.") + return self.mode + def par_compile( self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, ignore_error: bool = False ) -> list[JITKernel[_KP, _T]]: @@ -322,6 +346,7 @@ def par_compile( List[JITKernel] A list of compiled JITKernel objects corresponding to the provided configs. """ + configs = list(configs) funcs = [] for cfg in tqdm(configs, desc="Elaborating"): @@ -345,9 +370,9 @@ def par_compile( ) def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: - func = self.get_tir(*args, **kwargs) + prim_func = self.get_tir(*args, **kwargs) kernel_result = compile( - func, + prim_func, out_idx=self.out_idx, execution_backend=self.execution_backend, target=self.target, @@ -368,7 +393,7 @@ def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: with open(path.join(self.debug_root_path, kernel_file), "w") as f: print(kernel_result.get_kernel_source(), file=f) with open(path.join(self.debug_root_path, program_file), "w") as f: - print(func.script(), file=f) + print(prim_func.script(), file=f) return kernel_result @@ -397,22 +422,24 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: } return compile_args - if self.lazy_jit: - kwargs.update(kwargs.pop("__tune_params", {})) - key, kernel_args = self.func.parse_args(*args, **kwargs) - kernel = self._kernel_cache.get(key, None) - if kernel is None: - kernel = self.compile(*args, **kwargs) - self._kernel_cache[key] = kernel - return kernel(*kernel_args.values()) + kwargs.update(kwargs.pop("__tune_params", {})) + + # infer mode early, before parse_args needs it + if self.mode == "auto": + self.mode = self._infer_jit_mode(*args, **kwargs) + self.func.set_mode(self.mode) + key, kernel_args = self.func.parse_args(*args, **kwargs) + kernel = self._kernel_cache.get(key, None) + if kernel is None: + kernel = self.compile(*args, **kwargs) + self._kernel_cache[key] = kernel + + # eager mode: execute kernel immediately and return result + # lazy mode: return kernel object for manual invocation + if self.mode == "eager": + return kernel(*kernel_args.values()) else: - key = self.parse_cache_key(*args, **kwargs) - tune_params = kwargs.pop("__tune_params", {}) - kernel = self._kernel_cache.get(key, None) - if kernel is None: - kernel = self.compile(*args, **kwargs, **tune_params) - self._kernel_cache[key] = kernel return kernel @@ -420,12 +447,12 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: @overload -def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: ... +def jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ... @overload def jit( - *, # Indicates subsequent arguments are keyword-only + *, out_idx: Any = None, target: str | Target | None = None, target_host: str | Target | None = None, @@ -434,13 +461,13 @@ def jit( pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, -) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ... +) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ... -def jit( # This is the new public interface +def jit( func: Callable[_P, _T] | PrimFunc | None = None, *, # Indicates subsequent arguments are keyword-only - out_idx: Any = None, + out_idx: list[int] | int | None = None, target: str | Target | None = None, target_host: str | Target | None = None, execution_backend: ExecutionBackend | None = None, @@ -448,115 +475,36 @@ def jit( # This is the new public interface pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, -): +) -> Callable[[Callable[_P, _T]], JITImpl[_KP, _KP, _T, _T]]: """ - Just-In-Time (JIT) compiler decorator for TileLang functions. + JIT compiler decorator for TileLang functions. - This decorator can be used without arguments (e.g., `@tilelang.jit`): - Applies JIT compilation with default settings. + Supports two execution modes (automatically inferred): + - **lazy**: Function returns PrimFunc explicitly. Returns compiled kernel object. + - **eager**: Function uses DSL builder pattern. Executes kernel immediately. Parameters ---------- - func_or_out_idx : Any, optional - If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. - If using `@tilelang.jit` directly on a function, this argument is implicitly - the function to be decorated (and `out_idx` will be `None`). - target : Union[str, Target], optional - Compilation target for TVM (e.g., "cuda", "llvm"). If None, reads from - TILELANG_TARGET environment variable (defaults to "auto"). - target_host : Union[str, Target], optional - Target host for cross-compilation. Defaults to None. - execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional - Backend for kernel execution and argument passing. If None, reads from - TILELANG_EXECUTION_BACKEND environment variable (defaults to "auto"). - verbose : bool, optional - Enables verbose logging during compilation. If None, reads from - TILELANG_VERBOSE environment variable (defaults to False). - pass_configs : Optional[Dict[str, Any]], optional - Configurations for TVM's pass context. Defaults to None. - debug_root_path : Optional[str], optional - Directory to save compiled kernel source for debugging. Defaults to None. - - Environment Variables - --------------------- - TILELANG_TARGET : str - Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto". - TILELANG_EXECUTION_BACKEND : str - Default execution backend. Defaults to "auto". - TILELANG_VERBOSE : str - Set to "1", "true", "yes", or "on" to enable verbose compilation by default. - - Returns - ------- - Callable - Either a JIT-compiled wrapper around the input function, or a configured decorator - instance that can then be applied to a function. - """ - - def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: - if isinstance(func, PrimFunc): - orig_func = func.orig_func - else: - orig_func = func - return JITImpl( - func=func, - out_idx=out_idx, - execution_backend=execution_backend, - target=target, - target_host=target_host, - verbose=verbose, - pass_configs=pass_configs, - debug_root_path=debug_root_path, - compile_flags=compile_flags, - func_source=inspect.getsource(orig_func), - signature=inspect.signature(orig_func), - lazy_jit=False, - ) - - if func is not None: - return decorator(func) - else: - return decorator - - -@overload -def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ... - - -@overload -def lazy_jit( - *, - out_idx: Any = None, - target: str | Target | None = None, - target_host: str | Target | None = None, - execution_backend: ExecutionBackend | None = None, - verbose: bool | None = None, - pass_configs: dict[str, Any] | None = None, - debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None, -) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ... - - -def lazy_jit( - func: Callable[_P, _T] | PrimFunc | None = None, - *, # Indicates subsequent arguments are keyword-only - target: str | Target | None = None, - target_host: str | Target | None = None, - execution_backend: ExecutionBackend | None = None, - verbose: bool | None = None, - pass_configs: dict[str, Any] | None = None, - debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None, -): - """ - Lazy JIT compiler decorator - returns the kernel object on first call, then executes it. - - Supports environment variable defaults for target, execution_backend, and verbose. - See `jit` documentation for parameter details and environment variables. + out_idx : list[int] | int | None + Output tensor index(es). Only supported in lazy mode. + target : str | Target | None + TVM compilation target (e.g., "cuda", "llvm", "auto"). + target_host : str | Target | None + Host target for cross-compilation. + execution_backend : ExecutionBackend | None + Backend for kernel execution. + verbose : bool | None + Enable verbose compilation output. + pass_configs : dict[str, Any] | None + TVM pass configuration options. + debug_root_path : str | None + Directory to save compiled kernel source for debugging. + compile_flags : list[str] | str | None + Additional compiler flags. """ compile_args = dict( - out_idx=None, + out_idx=out_idx, execution_backend=execution_backend, target=target, target_host=target_host, @@ -567,9 +515,17 @@ def lazy_jit( ) def decorator(func: Callable[_P, _T]): - pf: LazyJITFunc[_P, _T] = prim_func(func, lazy_jit=True) + mode = "auto" + pf: JITFunc[_P, _T] = prim_func(func, eager_jit=True) + func_source = inspect.getsource(pf.orig_func) + signature = inspect.signature(pf.orig_func) + return JITImpl( - func=pf, **compile_args, func_source=inspect.getsource(pf.orig_func), signature=inspect.signature(pf.orig_func), lazy_jit=True + func=pf, + **compile_args, + func_source=func_source, + signature=signature, + mode=mode, ) return decorator(func) if func is not None else decorator diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 47f873ec8..73ff779bf 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -19,7 +19,7 @@ from tilelang.jit.adapter.base import BaseKernelAdapter from tilelang.utils.language import retrieve_func_from_module from tilelang.engine.param import KernelParam -from tilelang.language.v2.dtypes import dtype +from tilelang.language.eager.dtypes import dtype class TVMFFIKernelAdapter(BaseKernelAdapter): diff --git a/tilelang/jit/exceptions.py b/tilelang/jit/exceptions.py new file mode 100644 index 000000000..844b5edb0 --- /dev/null +++ b/tilelang/jit/exceptions.py @@ -0,0 +1,24 @@ +"""Custom exceptions for TileLang JIT compilation.""" + + +class JITNoBuilderError(Exception): + """ + Exception raised when JIT-related operations require a Builder but none exists. + + In eager mode, TileLang constructs AST directly without an explicit prim_func, + so there must be a Builder available. This error is raised when eager-only + features like T.const() or T.Kernel() are called outside of a JIT/prim_func context. + """ + + pass + + +class EagerJITBuildError(Exception): + """ + Exception raised for errors when building TileLang eager JIT kernels. + + This error indicates that something went wrong during the eager-style + kernel construction process. + """ + + pass diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 62d6f57f1..06cf793a5 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -11,7 +11,7 @@ from . import overrides as _overrides # noqa: F401 # from .tir import prim_func, macro, # noqa: F401 -from .v2 import * # noqa: F401 +from .eager import * # noqa: F401 from .tir.ir import * # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401 from .proxy import ptr, make_tensor, Buffer, Tensor, StridedTensor, FragmentBuffer, SharedBuffer, LocalBuffer # noqa: F401 diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index b195acc74..6100316dd 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -28,9 +28,9 @@ from tvm.script.parser.tir import block_attr from tvm.tir.buffer import Buffer from tvm.tir.expr import FloatImm, IntImm -from .v2 import dtypes as _dtypes -from .v2.dtypes import dtype as tl_dtype -from .v2.builder import OutTensor +from .eager import dtypes as _dtypes +from .eager.dtypes import dtype as tl_dtype +from .eager.builder import OutTensor _Shapes = TypeVarTuple("_Shapes") _DType = TypeVar("_DType") diff --git a/tilelang/language/eager/__init__.py b/tilelang/language/eager/__init__.py new file mode 100644 index 000000000..cf760fed9 --- /dev/null +++ b/tilelang/language/eager/__init__.py @@ -0,0 +1,2 @@ +from .builder import prim_func, macro, PrimFunc, JITFunc, Ref, const # noqa: F401 +from .dtypes import * diff --git a/tilelang/language/v2/ast.py b/tilelang/language/eager/ast.py similarity index 100% rename from tilelang/language/v2/ast.py rename to tilelang/language/eager/ast.py diff --git a/tilelang/language/v2/builder.py b/tilelang/language/eager/builder.py similarity index 84% rename from tilelang/language/v2/builder.py rename to tilelang/language/eager/builder.py index 157f20f3e..b3efb645c 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/eager/builder.py @@ -18,7 +18,7 @@ from tvm.script.ir_builder import tir, IRBuilder from tvm.tir.expr import BufferLoad, CallEffectKind, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var -from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union +from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union, Literal, get_origin from collections.abc import Hashable from collections.abc import Sequence @@ -29,6 +29,7 @@ from typing_extensions import ParamSpec, Self from . import dtypes as dt from . import utils +from tilelang.jit.exceptions import JITNoBuilderError, EagerJITBuildError import threading import logging @@ -172,7 +173,7 @@ def __init__(self): self.out_idx = [] self.out_tensor_cnt = 0 self.constexpr_var = set() - self.lazy_jit = False + self.eager_jit = False self.current_file = "" self.current_line = 0 self.current_macro_name = "" @@ -187,12 +188,14 @@ def current(cls) -> Self: @contextmanager def prim_func(self, name): thread_local_storage.builder = self - with self.ir_builder, self.with_frame(tir.prim_func()): - tir.func_name(name) - yield - if len(self.out_idx) != self.out_tensor_cnt: - raise RuntimeError("Not all tensor allocated from `T.empty` are returned") - del thread_local_storage.builder + try: + with self.ir_builder, self.with_frame(tir.prim_func()): + tir.func_name(name) + yield + if len(self.out_idx) != self.out_tensor_cnt: + raise RuntimeError("Not all tensor allocated from `T.empty` are returned") + finally: + del thread_local_storage.builder @contextmanager def macro(self, name=None, annotations=None): @@ -725,7 +728,10 @@ def source(self) -> str: return self.ir_gen.source def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: - builder = Builder.current() or Builder() + builder = Builder.current() + if builder is None: + raise JITNoBuilderError("T.macro can only be used inside @tilelang.jit") + with builder.macro(self.name, self.annotations): res = self.ir_gen.gen(builder)(*args, **kwargs) return res @@ -843,9 +849,26 @@ def get_type_hints(func): def const(name: str, dtype: str = "int32") -> tuple[Var, ...]: + """ + Declare constexpr variables for dynamic tensor dimensions (eager mode only). + + In eager mode, use T.const() to declare shape dimensions that will be + inferred from actual tensor arguments at runtime. + + Example:: + + @tilelang.jit + def kernel(A, B): + M, N = T.const("M, N") + A: T.Tensor[[M, N], T.float32] + ... + """ builder = Builder.current() - assert builder is not None, "const can only be used inside `tilelang.lazy_jit` function" - assert builder.lazy_jit, "const can only be used inside `tilelang.lazy_jit` function" + # assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)" + # assert builder.eager_jit, "T.const() can only be used inside @tilelang.jit (eager mode)" + if builder is None or not builder.eager_jit: + raise JITNoBuilderError("T.const() can only be used inside @tilelang.jit (eager mode)") + if "," in name: names = re.split(r"\s*,\s*", name) return tuple(builder.constexpr(n, dtype) for n in names) @@ -858,8 +881,17 @@ def const(name: str, dtype: str = "int32") -> tuple[Var, ...]: @dataclass class TirTemplate(Generic[_P, _T]): + """ + Template for generating TIR PrimFunc with dynamic shape substitution. + + For lazy-style functions, the PrimFunc is used directly without substitution. + For eager-style functions, constexpr variables are substituted based on + actual tensor shapes at runtime. + """ + prim_func: PrimFunc[_P, _T] - matcher: dict[Var, tuple[tvm.tir.Var, str, int]] + matcher: dict[Var, tuple[tvm.tir.Var, str, int]] | None = None + is_lazy_style: bool = False # True if from lazy-style (returns PrimFunc directly) @classmethod def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate[_P, _T]: @@ -883,9 +915,16 @@ def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate f"Buffer shapes: {shapes}\n" f"Buffer strides: {strides}" ) - return cls(prim_func=prim_func, matcher=matcher) + return cls(prim_func=prim_func, matcher=matcher, is_lazy_style=False) + + @classmethod + def from_lazy_style(cls, prim_func: PrimFunc[_P, _T]) -> TirTemplate[_P, _T]: + """Create template from lazy-style function that returns PrimFunc directly.""" + return cls(prim_func=prim_func, is_lazy_style=True) def _parse_phase2_key(self, **kwargs): + if self.matcher is None: + return () result = [] for k, ty, i in self.matcher.values(): if ty == "shape": @@ -899,6 +938,8 @@ def _parse_phase2_key(self, **kwargs): return tuple(result) def get_tir(self, **kwargs): + if self.is_lazy_style: + return self.prim_func values = self._parse_phase2_key(**kwargs) subs = {name: value for name, value in zip(self.matcher, values)} result = substitute_primfunc(self.prim_func, subs) @@ -909,12 +950,29 @@ def get_tir(self, **kwargs): @dataclass -class LazyJITFunc(Generic[_P, _T]): +class JITFunc(Generic[_P, _T]): + """ + Internal wrapper for JIT-compiled functions. + + This class handles both lazy and eager execution styles: + + - **lazy style**: Function explicitly returns a PrimFunc. The original function + is called directly to obtain the TIR. + + - **eager style**: Function uses the DSL builder pattern with tensor type + annotations. The TIR is constructed by tracing the function body through + the Builder. + + The style is determined by `_is_lazy_style()` which checks if calling the + original function returns a PrimFunc directly. + """ + orig_func: Callable[_P, _T] arg_names: list[str] tensor_args: dict[str, Buffer | Var] tensor_args_defaults: dict[str, Any] ir_gen: IRGenerator[_P, _T] + mode: Literal["auto", "lazy", "eager"] = "auto" def __post_init__(self): # we don't want it to show up in the constructor @@ -931,30 +989,100 @@ def _parse_phase1_key(self, *args, **kwargs): p1_key = tuple(sorted(kwargs.items())) return p1_key, tensor_args, kwargs - def parse_args(self, *args, **kwargs): - p1_key, tensor_args, kwargs = self._parse_phase1_key(*args, **kwargs) - tir_temp = self.p1_cache.get(p1_key, None) - if tir_temp is None: + def _is_lazy_style(self, *args, **kwargs) -> bool: + """ + Check if the function uses lazy style (explicitly returns PrimFunc). + + Lazy style functions define an inner @T.prim_func and return it: + @jit + def foo(M, N): + @T.prim_func + def kernel(...): ... + return kernel # <- returns PrimFunc + + Eager style functions use the builder pattern with type annotations: + @jit + def foo(A, B): + A: T.Tensor[...] + with T.Kernel(...): ... + # no return + """ + try: + prim_func = self.orig_func(*args, **kwargs) + # lazy jit must return PrimFunc + if isinstance(prim_func, PrimFunc): + p1_key, _, _ = self._parse_phase1_key(*args, **kwargs) + self.p1_cache[p1_key] = TirTemplate.from_lazy_style(prim_func) + return True + return False + except (JITNoBuilderError, EagerJITBuildError): + # In eager mode, we construct AST directly without prim_func, + # so there's no Builder available when the function is called. + # When eager-only features like T.const() or T.Kernel() are used, + # they raise JITNoBuilderError because no Builder exists yet. + # This indicates the function is eager-style, not lazy-style. + return False + + def _build_tir_template(self, *args, **kwargs) -> TirTemplate[_P, _T]: + """Build TIR template based on the execution mode.""" + if self.mode == "lazy": + # lazy: function returns PrimFunc directly + return TirTemplate.from_lazy_style(self.orig_func(*args, **kwargs)) + elif self.mode == "eager": + # eager: trace function body through Builder to construct TIR builder = Builder() - builder.lazy_jit = True + builder.eager_jit = True with builder.prim_func(self.orig_func.__name__): self.ir_gen.gen(builder)(**self.tensor_args, **kwargs) pf = builder.get() pf.orig_func = self.orig_func if builder.out_idx: pf.out_idx_override = builder.out_idx - tir_temp = TirTemplate.create(pf, builder.constexpr_var) + return TirTemplate.create(pf, builder.constexpr_var) + else: + raise ValueError(f"Invalid jit mode: {self.mode}, expected 'lazy' or 'eager'") + + def parse_args(self, *args, **kwargs): + """Parse arguments and return cache key and tensor args.""" + p1_key, tensor_args, kwargs = self._parse_phase1_key(*args, **kwargs) + if not tensor_args: + return (p1_key, None), kwargs + tir_temp = self.p1_cache.get(p1_key, None) + if tir_temp is None: + # mode should be set by JITImpl before calling parse_args + tir_temp = self._build_tir_template(*args, **kwargs) self.p1_cache[p1_key] = tir_temp p2_key = tir_temp._parse_phase2_key(**tensor_args) return (p1_key, p2_key), tensor_args def get_tir(self, *args, **kwargs): (p1_key, _), tensor_args = self.parse_args(*args, **kwargs) + if p1_key not in self.p1_cache: + # in legacy gemm, we use lazy tir template to build the tir + tir_temp = self._build_tir_template(*args, **kwargs) + self.p1_cache[p1_key] = tir_temp + return tir_temp.get_tir(**tensor_args) return self.p1_cache[p1_key].get_tir(**tensor_args) def __call__(self, *args, **kwargs): return self.get_tir(*args, **kwargs) + def set_mode(self, mode: Literal["lazy", "eager"]): + """Set the JIT execution mode (internal use only).""" + self.mode = mode + + # Proxy function attributes for compatibility with autotuner and inspect. + # These attributes are needed by autotuner to extract closure variables + # and generate cache keys. + _PROXIED_ATTRS = frozenset({"__closure__", "__code__", "__name__", "__globals__", "__wrapped__"}) + + def __getattr__(self, name): + if name in JITFunc._PROXIED_ATTRS: + if name == "__wrapped__": + return self.orig_func + return getattr(self.orig_func, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def substitute_primfunc(prim_func, vmap): analyzer = tvm.arith.Analyzer() @@ -978,7 +1106,7 @@ def substitute_buffer(buf): ) -def prim_func(func: Callable[_P, _T] = None, *, lazy_jit=False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: +def prim_func(func: Callable[_P, _T] = None, *, eager_jit: bool = False) -> PrimFunc[_P, _T] | JITFunc[_P, _T]: def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]: sig = inspect.signature(func) ir_gen = mutate(func) @@ -992,15 +1120,19 @@ def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, elif param.name in func_annot: annot[param.name] = func_annot[param.name] for k in annot: - if not isinstance(annot[k], type) and callable(annot[k]): + # Call callable annotations (e.g., factory functions) to get the actual type. + # Skip typing generics like Optional[int], Union[...], List[...] which are + # callable but cannot be instantiated. + if not isinstance(annot[k], type) and callable(annot[k]) and get_origin(annot[k]) is None: annot[k] = annot[k]() - if lazy_jit: + + if eager_jit: arg_names = list(sig.parameters.keys()) tensor_args = {k: v for k, v in annot.items() if isinstance(v, (Buffer, Var))} tensor_args_defaults = { k: sig.parameters[k].default for k in tensor_args if sig.parameters[k].default is not sig.parameters[k].empty } - return LazyJITFunc(func, arg_names, tensor_args, tensor_args_defaults, ir_gen) + return JITFunc(func, arg_names, tensor_args, tensor_args_defaults, ir_gen) else: try: builder = Builder() diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/eager/dtypes.py similarity index 100% rename from tilelang/language/v2/dtypes.py rename to tilelang/language/eager/dtypes.py diff --git a/tilelang/language/v2/utils.py b/tilelang/language/eager/utils.py similarity index 100% rename from tilelang/language/v2/utils.py rename to tilelang/language/eager/utils.py diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 8679971e4..28ca8a5dc 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -7,6 +7,7 @@ from tvm.script.ir_builder.tir.frame import TIRFrame, BlockFrame from tvm.ffi import register_object from tilelang import _ffi_api +from tilelang.jit.exceptions import JITNoBuilderError import threading # Ensure single-dimension kernel bindings can be unpacked like iterables. @@ -279,6 +280,15 @@ def Kernel( with T.Kernel(loop_extent, is_cpu=True) as (i,): ... """ + # In eager mode, we construct AST directly without prim_func, + # so there must be a Builder available. If not, this function + # is being called outside of a JIT/prim_func context. + # lazy import to avoid circular import + from tilelang.language.eager.builder import Builder + + if Builder.current() is None: + raise JITNoBuilderError("T.Kernel() can only be used inside @tilelang.jit or @T.prim_func context. No Builder is available.") + attrs: dict = {} if not is_cpu and threads is None: diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py index 6064f68f4..224f0378c 100644 --- a/tilelang/language/loop.py +++ b/tilelang/language/loop.py @@ -5,7 +5,7 @@ from tvm import tir from tvm.tir import IntImm import tvm.script.ir_builder.tir as tb_tir -from .v2.builder import SerialForWithStep, UnrollForWithStep +from .eager.builder import SerialForWithStep, UnrollForWithStep from tilelang import _ffi_api from tvm.script.ir_builder.tir import frame diff --git a/tilelang/language/print_op.py b/tilelang/language/print_op.py index c7ef55f7d..c8c7be81f 100644 --- a/tilelang/language/print_op.py +++ b/tilelang/language/print_op.py @@ -3,7 +3,7 @@ It includes functionality to print variables, print values in buffers, conditionally execute debug prints and assert. """ -from tilelang.language.v2.builder import Builder +from tilelang.language.eager.builder import Builder from tvm import tir from typing import Any import tilelang.language as T diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 90a2d5ff3..9a11c0654 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -10,6 +10,7 @@ from tvm.tir import Var, PrimExpr from tvm.script.ir_builder.tir import buffer, handle, match_buffer from tilelang.utils import deprecated +from tilelang.jit.exceptions import JITNoBuilderError class BufferProxy: @@ -275,4 +276,8 @@ def ptr(dtype: str | None = None, storage_scope: str = "global", *, is_size_var: def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: + from tilelang.language.eager.builder import Builder + + if Builder.current() is None: + raise JITNoBuilderError("T.make_tensor() can only be used inside @tilelang.jit or @T.prim_func context. No Builder is available.") return Tensor.from_ptr(ptr, shape, dtype, strides) diff --git a/tilelang/language/v2/__init__.py b/tilelang/language/v2/__init__.py deleted file mode 100644 index 65fa646c7..000000000 --- a/tilelang/language/v2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .builder import prim_func, macro, PrimFunc, LazyJITFunc, Ref, const # noqa: F401 -from .dtypes import *