Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5d6e821
remove debug print
LeiWang1999 Oct 7, 2025
81f1cae
Remove inline let expressions from the LowerAndLegalize function in p…
LeiWang1999 Oct 7, 2025
abc2f8c
add test
LeiWang1999 Oct 7, 2025
f1aa27e
Update sparse MLA examples to support SKV adjustment and correctness …
LeiWang1999 Oct 7, 2025
6efbef8
reduce test shape
LeiWang1999 Oct 7, 2025
9354899
Update documentation structure and refactor main function parameters …
LeiWang1999 Oct 7, 2025
f372812
Update buffer access checks in merge_shared_memory_allocations.cc
LeiWang1999 Oct 7, 2025
ede05d3
lint fix
LeiWang1999 Oct 7, 2025
cc3138a
Support pipeline with LetStmt
LeiWang1999 Oct 9, 2025
597d8b1
lint fix
LeiWang1999 Oct 9, 2025
36736f3
• Fix LowerTileOp let handling to avoid LetInline dependency
LeiWang1999 Oct 9, 2025
0da83a8
Merge branch 'main' of https://github.com/tile-ai/tilelang into issue…
LeiWang1999 Oct 10, 2025
d0648e5
fix for wgmma pipeline with let binding
LeiWang1999 Oct 10, 2025
6f115d3
lint fix
LeiWang1999 Oct 10, 2025
04d66d6
test fix
LeiWang1999 Oct 10, 2025
afc668d
reduce smem usage.
LeiWang1999 Oct 10, 2025
c637767
let binding enhancement
LeiWang1999 Oct 10, 2025
531a3ae
fix for dpgm
LeiWang1999 Oct 10, 2025
4fa28c9
fix simplify
LeiWang1999 Oct 10, 2025
22ae8c5
lint fix
LeiWang1999 Oct 10, 2025
9ee3540
use tilelang.Simplify instead of tir.Simplify
LeiWang1999 Oct 10, 2025
2311fc7
• Add TL_FORCE_LET_INLINE pass config and gate eager LetInline usage
LeiWang1999 Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions testing/python/issue/test_tilelang_issue_814.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import tilelang
import tilelang.testing
import tilelang.language as T
import torch


def test_tmp_var(N, block_N, dtype="float"):

@T.prim_func
def kernel(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:
for i in T.Parallel(block_N):
idx = bx * block_N + i
tmp = T.max(A[idx], 1)
B[idx] = tmp / 2
A[idx] = tmp * 2

return kernel


def run_tmp_var_test(N=1024, block_N=128):
func = test_tmp_var(N, block_N)
jit_kernel = tilelang.compile(
func,
out_idx=[0, 1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})

a = torch.randn(N, device="cuda", dtype=torch.float)
b = torch.empty(N, device="cuda", dtype=torch.float)

a_ref = a.clone()

jit_kernel(a, b)

# Reference computation
tmp_ref = torch.maximum(a_ref, torch.tensor(1.0, dtype=torch.float, device="cuda"))
b_ref = tmp_ref / 2
a_ref = tmp_ref * 2

# Validate correctness
tilelang.testing.torch_assert_close(a, a_ref, rtol=1e-2, atol=1e-2)
tilelang.testing.torch_assert_close(b, b_ref, rtol=1e-2, atol=1e-2)


def test_issue_814():
"""Test that temporary variables are correctly handled and not over-inlined"""
run_tmp_var_test(N=1024, block_N=128)


if __name__ == "__main__":
tilelang.testing.main()
2 changes: 0 additions & 2 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
"""
mod = tir.transform.BindTarget(target)(mod)

# Inline let expressions and statements
mod = tilelang.transform.LetInline()(mod)
# Add wrapper for single buf store
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
# Inject assumes to speedup tvm prover
Expand Down
1 change: 0 additions & 1 deletion tilelang/language/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List
for extent in extents:
new_extents.append(extent)
extents = new_extents
print("after extents", extents)
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)

Expand Down
Loading