Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions testing/python/language/test_tilelang_language_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,74 @@ def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)


def tilelang_copy_bufferload(num_tokens, dtype="float16"):

@T.prim_func
def main(
indices: T.Tensor((num_tokens,), "int32"),
x: T.Tensor((num_tokens,), dtype),
):
with T.Kernel(num_tokens, threads=32) as pid:
idx = T.alloc_local([1], "int32")
T.copy(indices[pid], idx[0])
x[idx[0]] = x[idx[0]] + 1

return main


def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
program = tilelang_copy_bufferload(num_tokens, dtype)
# test compilation only
tilelang.compile(
program,
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})


def test_tilelang_copy_bufferload():
run_tilelang_copy_bufferload(num_tokens=128)


def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"):

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
T.copy(A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j])

return main


def run_tilelang_copy_buffer_load_with_parallel(M=1024,
N=1024,
block_M=128,
block_N=128,
dtype="float16"):
program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)


def test_tilelang_copy_buffer_load_with_parallel():
run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128)


if __name__ == "__main__":
tilelang.testing.main()
8 changes: 8 additions & 0 deletions tilelang/language/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def get_extent(data):

src_extent = get_extent(src)
dst_extent = get_extent(dst)
# Combine the nested if statements into a single if statement as suggested by SIM102
if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and
isinstance(dst, tir.BufferLoad)):
# check if the case is like this:
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
return tir.BufferStore(dst.buffer, src, dst.indices)

assert src_extent or dst_extent, "Can't deduce copy extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
Expand Down
Loading