Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions testing/python/language/test_tilelang_language_assume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import tilelang
import tilelang.language as T
import tilelang.testing


def test_assume_remove_boundary_check():

@tilelang.jit
def kernel_with_assume():
N = T.dynamic('N')

@T.prim_func
def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32):
with T.Kernel(1, threads=32) as _:
for i in T.serial(r - l + 1):
T.assume(l + i >= 0 and l + i < N)
A[l + i] = 0

return main

jit_kernel = kernel_with_assume()
source = jit_kernel.get_kernel_source()

assert ("if (" not in source)


def test_assume_enable_vectorization():

@tilelang.jit
def kernel_vectorize(M):
N = T.dynamic('N')
vectorize_size = 4

@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding()

base_idx = tid * 4
T.assume(N % vectorize_size == 0)

for i in T.vectorized(vectorize_size):
T.assume(base_idx + i < N)
B[tid, base_idx + i] = A[tid, base_idx + i]

return main

jit_kernel = kernel_vectorize(128)
source = jit_kernel.get_kernel_source()

assert ("float4" in source) and ("if (" not in source)


def test_assume_complex_indexing():

@tilelang.jit
def kernel_complex():
M = T.dynamic('M')
N = T.dynamic('N')

@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding()
for j in T.serial(N):
i_src = T.min(j + 233, tid + 2)
j_src = j * T.ceildiv(j, i_src) * j - 1

T.assume(i_src >= 0 and i_src < M)
T.assume(j_src >= 0 and j_src < N)

B[tid, j] = A[i_src, j_src]

return main

jit_kernel = kernel_complex()
source = jit_kernel.get_kernel_source()

assert ("if (" not in source)


if __name__ == '__main__':
tilelang.testing.main()
Loading