Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9c3cecb
Refactor: Replace @tilelang.lazy_jit with @tilelang.jit in examples a…
LeiWang1999 Jan 6, 2026
f48647b
Enhancement: Improve layout annotation handling in T.copy and paralle…
LeiWang1999 Jan 6, 2026
bfecb9b
Merge branch 'main' of https://github.com/tile-ai/tilelang into refac…
LeiWang1999 Jan 6, 2026
d173672
Add out_idx parameter to @jit and validate it's only used in lazy mode
LeiWang1999 Jan 6, 2026
5740b4e
lintfix
LeiWang1999 Jan 6, 2026
431ddbd
Merge branch 'main' of https://github.com/tile-ai/tilelang into refac…
LeiWang1999 Jan 6, 2026
b62ec8e
Refactor: centralize mode inference in JITImpl.__call__ before parse_…
LeiWang1999 Jan 6, 2026
e97de85
test fix
LeiWang1999 Jan 6, 2026
8b7e1ee
Enhancement: Introduce JITNoBuilderError for eager mode checks in T.K…
LeiWang1999 Jan 7, 2026
5473823
lint fix
LeiWang1999 Jan 7, 2026
ce70005
lint fix
LeiWang1999 Jan 7, 2026
edb8500
Enhancement: Update example_triton_sparse_gqa_decode_varlen_indice.py…
LeiWang1999 Jan 7, 2026
49a0f51
Merge branch 'main' of https://github.com/tile-ai/tilelang into jit_m…
LeiWang1999 Jan 7, 2026
c41d1ed
lint fix
LeiWang1999 Jan 7, 2026
f704c6d
Enhancement: Refactor JIT implementation to centralize mode initializ…
LeiWang1999 Jan 7, 2026
e2bcc75
Merge branch 'main' of https://github.com/tile-ai/tilelang into jit_m…
LeiWang1999 Jan 7, 2026
b78f650
Refactor: Remove unused simplify_prim_func decorator from example_gem…
LeiWang1999 Jan 7, 2026
050f4fc
remove claude.md
LeiWang1999 Jan 7, 2026
1a81308
Refactor: Clean up code formatting in layout.cc, lower_tile_op.cc, an…
LeiWang1999 Jan 7, 2026
f7e00eb
Refactor: Replace logical conditions with tir_all in complex expressi…
LeiWang1999 Jan 7, 2026
14d588a
Enhancement: Add @tilelang.testing.requires_cuda decorator to test fu…
LeiWang1999 Jan 7, 2026
b3c4d77
lint fix
LeiWang1999 Jan 7, 2026
62beb83
Refactor: Simplify output shape calculation in layout.cc by directly …
LeiWang1999 Jan 8, 2026
4093759
Refactor: Update boolean condition in test_boolop to use Or and Not f…
LeiWang1999 Jan 8, 2026
d2d0ce6
Refactor: Update test execution flow in test_tilelang_fragment_loop_c…
LeiWang1999 Jan 8, 2026
e8fc7b0
Refactor JIT function handling: Replace LazyJITFunc with JITFunc acro…
LeiWang1999 Jan 9, 2026
3e6f2cc
Update JITImpl documentation: Clarify output tensor index attribute d…
LeiWang1999 Jan 9, 2026
ef07119
Merge branch 'main' of https://github.com/tile-ai/tilelang into jit_m…
LeiWang1999 Jan 9, 2026
d33550b
Merge branch 'jit_merge_0107' of https://github.com/LeiWang1999/tilel…
LeiWang1999 Jan 9, 2026
36f37f4
Refactor TileLang to transition from v2 to eager language module. Upd…
LeiWang1999 Jan 9, 2026
d86f847
Enhance JIT compilation support: Update JIT function handling to impr…
LeiWang1999 Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -329,21 +329,15 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
block_H = 64

Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence

print("cache_seqlens: ", cache_seqlens)

max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda")

Expand All @@ -357,13 +351,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1

# Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True)
# print("block_indices: ", block_indices)
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0]
print("actual_num_blocks: ", actual_num_blocks)
# print(block_indices.shape, actual_num_blocks.shape)

max_num_blocks = torch.max(max_valid_num_blocks).item()
print("max_num_blocks: ", max_num_blocks)

ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)

Expand Down Expand Up @@ -402,6 +390,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1
avg_time = elapsed_time / 1000
avg_flops = total_flops / avg_time
print(f"Average time: {avg_time:.6f} seconds")
print(f"Average FLOPS: {avg_flops:.2f} GFLOPS")

# Measure performance of reference implementation
import flash_attn # noqa: F401
Expand All @@ -415,7 +404,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1
avg_time_ref = elapsed_time_ref / 1000
avg_flops_ref = total_flops / avg_time_ref
print(f"Average time of ref: {avg_time_ref:.6f} seconds")

print(f"Average FLOPS of ref: {avg_flops_ref:.2f} GFLOPS")
print(f"Speedup: {avg_time_ref / avg_time:.2f}x")


Expand Down
2 changes: 0 additions & 2 deletions examples/gemm/example_gemm_intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func


def make_swizzle_layout(shared_buf):
Expand All @@ -25,7 +24,6 @@ def transform_func(i, j):


@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
Expand Down
2 changes: 0 additions & 2 deletions examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type

tilelang.testing.set_random_seed(0)
Expand All @@ -29,7 +28,6 @@ def transform_func(i, j):


@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
Expand Down
22 changes: 11 additions & 11 deletions examples/lazy_jit/lazyjit.en.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm(\n",
" A,\n",
" B,\n",
Expand Down Expand Up @@ -209,7 +209,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_dyn_K(A, B):\n",
" M, N, K = T.dynamic(\"M, N, K\")\n",
" A: T.Tensor[[M, K], T.float16]\n",
Expand Down Expand Up @@ -248,7 +248,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def as_contingious(A):\n",
" M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n",
" A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n",
Expand Down Expand Up @@ -307,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_ptr(\n",
" A,\n",
" B,\n",
Expand Down Expand Up @@ -359,7 +359,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_ptr_dyn(A, B, M, N, K):\n",
" M: T.int32\n",
" N: T.int32\n",
Expand Down Expand Up @@ -421,7 +421,7 @@
}
],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def example_wrong_kernel(A):\n",
" M = T.const(\"M\")\n",
" A: T.Tensor[[M * 2, M * 3], T.float32]\n",
Expand Down Expand Up @@ -470,7 +470,7 @@
}
],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def dyn_annot(\n",
" A: T.ptr, # 1. T.ptr type annotation\n",
" is_2d=False,\n",
Expand Down Expand Up @@ -515,7 +515,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def add_one(X, data: T.float32 = 1):\n",
" M, N = T.const(\"M, N\")\n",
" X: T.Tensor[[M, N], T.float32]\n",
Expand Down Expand Up @@ -577,7 +577,7 @@
"B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def dummy_kernel(A, B):\n",
" M, N = T.const(\"M, N\")\n",
" A: T.Tensor[[M, N], T.float16]\n",
Expand Down Expand Up @@ -797,7 +797,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def element_wise(A, fn):\n",
" N = T.dynamic(\"N\")\n",
" A: T.Tensor[[N], T.float32]\n",
Expand Down Expand Up @@ -857,7 +857,7 @@
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])"
Expand Down
22 changes: 11 additions & 11 deletions examples/lazy_jit/lazyjit.zh.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm(\n",
" A,\n",
" B,\n",
Expand Down Expand Up @@ -209,7 +209,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_dyn_K(A, B):\n",
" M, N, K = T.dynamic(\"M, N, K\")\n",
" A: T.Tensor[[M, K], T.float16]\n",
Expand Down Expand Up @@ -248,7 +248,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def as_contingious(A):\n",
" M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n",
" A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n",
Expand Down Expand Up @@ -307,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_ptr(\n",
" A,\n",
" B,\n",
Expand Down Expand Up @@ -359,7 +359,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_ptr_dyn(A, B, M, N, K):\n",
" M: T.int32\n",
" N: T.int32\n",
Expand Down Expand Up @@ -421,7 +421,7 @@
}
],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def example_wrong_kernel(A):\n",
" M = T.const(\"M\")\n",
" A: T.Tensor[[M * 2, M * 3], T.float32]\n",
Expand Down Expand Up @@ -470,7 +470,7 @@
}
],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def dyn_annot(\n",
" A: T.ptr, # 1. T.ptr type annotation\n",
" is_2d=False,\n",
Expand Down Expand Up @@ -515,7 +515,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def add_one(X, data: T.float32 = 1):\n",
" M, N = T.const(\"M, N\")\n",
" X: T.Tensor[[M, N], T.float32]\n",
Expand Down Expand Up @@ -577,7 +577,7 @@
"B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def dummy_kernel(A, B):\n",
" M, N = T.const(\"M, N\")\n",
" A: T.Tensor[[M, N], T.float16]\n",
Expand Down Expand Up @@ -797,7 +797,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def element_wise(A, fn):\n",
" N = T.dynamic(\"N\")\n",
" A: T.Tensor[[N], T.float32]\n",
Expand Down Expand Up @@ -857,7 +857,7 @@
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])"
Expand Down
2 changes: 1 addition & 1 deletion src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
continue;

// Check if buffer exists in use_list_
if (!use_list_.count(buffer)) {
if (!use_list_.count(buffer) && IsFragmentBuffer(buffer)) {
LOG(WARNING) << "Layout inference failed for buffer " << buffer
<< ". "
<< "The buffer cannot be inferred with current layout "
Expand Down
3 changes: 2 additions & 1 deletion src/transform/lower_tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout,
}
Array<PrimExpr> layout_shape = layout->OutputShape();
Array<PrimExpr> output_shape = layout_shape;

if (ptr_type->storage_scope == "shared" ||
ptr_type->storage_scope == "shared.dyn") {
int replicate_extent = 1;
Expand All @@ -67,6 +66,8 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout,
}
for (size_t i = 0; i < layout_shape.size(); i++) {
auto shape = layout_shape[i].as<IntImmNode>();
ICHECK(shape) << "Layout output shape must be constant integer, but got: "
<< layout_shape[i];
layout_extent *= shape->value;
}
replicate_extent = buffer_extent / layout_extent;
Expand Down
17 changes: 6 additions & 11 deletions testing/python/arith/test_arith_hard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tvm.arith import Analyzer
from tvm.ir.expr import Range
from tvm.tir.expr import Not, Or
from tvm.tir import all as tir_all


def implies(x, y):
Expand All @@ -21,30 +22,25 @@ def check_expr(expr):
if not result:
smtlib2 = analyzer.get_smtlib2(expr)
raise AssertionError(f"Failed to prove: {expr}\nSMT-LIB2:\n{smtlib2}")
# assert result, f"Failed to prove: {expr}"

@T.macro
def complex_expr_1():
return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b)
return implies(tir_all(a > 0, b > 0, c > 0), ((b - a) // c) * c + a <= b)

check_expr(complex_expr_1())

@T.macro
def complex_expr_2():
return implies(a < b and b < c and a * d < b * d, b * d < c * d)
return implies(tir_all(a < b, b < c, a * d < b * d), b * d < c * d)

check_expr(complex_expr_2())

@T.macro
def complex_expr_3():
return implies(a >= 0 and a < 128, a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64)
return implies(tir_all(a >= 0, a < 128), a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64)

check_expr(complex_expr_3())

@T.macro
def complex_expr_4():
return implies(
a >= 0 and a < 128,
tir_all(a >= 0, a < 128),
(a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 + 16 - (a // 64 + a % 8 // 4) // 2 * 64) // 512
== (a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 - (a // 64 + a % 8 // 4) // 2 * 64) // 512,
)
Expand All @@ -59,9 +55,8 @@ def test_smtlib2():
b = T.Var("b", T.int32)
c = T.Var("c", T.int32)

@T.macro
def complex_expr_1():
return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b)
return implies(tir_all(a > 0, b > 0, c > 0), ((b - a) // c) * c + a <= b)

e = complex_expr_1()
analyzer = Analyzer()
Expand Down
1 change: 1 addition & 0 deletions testing/python/issue/test_tilelang_issue_1549.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch


@tilelang.testing.requires_cuda
def test_issue_1549_strange_var_vectorization():
@tl.jit
def get_wrong_kernel(M: int = 4096):
Expand Down
1 change: 1 addition & 0 deletions testing/python/issue/test_tilelang_issue_1601.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tilelang.language as T


@tilelang.testing.requires_cuda
def test_issue_1601():
@tilelang.jit
def qwq():
Expand Down
4 changes: 1 addition & 3 deletions testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,
)

from tilelang.transform import simplify_prim_func
from tilelang.intrinsics.mma_macro_generator import (
INT4TensorCoreIntrinEmitter,
INT4TensorCoreIntrinEmitterWithLadderTransform,
)
from tilelang.transform import simplify_prim_func

tilelang.testing.set_random_seed(42)


# @simplify_prim_func
def tl_matmul(
M,
N,
Expand Down
Loading
Loading