Skip to content

Commit f36286b

Browse files
committed
[Enhancement] Add buffer load copy functions and improve copy logic in tilelang
- Introduced new functions for buffer load copy with stride and parallel execution. - Enhanced the copy logic in `copy.py` to simplify nested if statements for BufferLoad nodes. - Added corresponding test cases for the new buffer load functionalities.
1 parent 91d5ef5 commit f36286b

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

testing/python/language/test_tilelang_language_copy.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,74 @@ def test_tilelang_copy_with_stride():
8686
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)
8787

8888

89+
def tilelang_copy_bufferload(num_tokens, dtype="float16"):
90+
91+
@T.prim_func
92+
def main(
93+
indices: T.Tensor((num_tokens,), "int32"),
94+
x: T.Tensor((num_tokens,), dtype),
95+
):
96+
with T.Kernel(num_tokens, threads=32) as pid:
97+
idx = T.alloc_local([1], "int32")
98+
T.copy(indices[pid], idx[0])
99+
x[idx[0]] = x[idx[0]] + 1
100+
101+
return main
102+
103+
104+
def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
105+
program = tilelang_copy_bufferload(num_tokens, dtype)
106+
# test compilation only
107+
tilelang.compile(
108+
program,
109+
out_idx=[1],
110+
pass_configs={
111+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
112+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
113+
})
114+
115+
116+
def test_tilelang_copy_bufferload():
117+
run_tilelang_copy_bufferload(num_tokens=128)
118+
119+
120+
def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"):
121+
122+
@T.prim_func
123+
def main(
124+
A: T.Tensor((M, N), dtype),
125+
B: T.Tensor((M, N), dtype),
126+
):
127+
# Initialize Kernel Context
128+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
129+
for i, j in T.Parallel(block_M, block_N):
130+
T.copy(A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j])
131+
132+
return main
133+
134+
135+
def run_tilelang_copy_buffer_load_with_parallel(M=1024,
136+
N=1024,
137+
block_M=128,
138+
block_N=128,
139+
dtype="float16"):
140+
program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype)
141+
kernel = tilelang.compile(
142+
program,
143+
out_idx=[1],
144+
target="cuda",
145+
pass_configs={
146+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
147+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
148+
})
149+
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
150+
b = kernel(a)
151+
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
152+
153+
154+
def test_tilelang_copy_buffer_load_with_parallel():
155+
run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128)
156+
157+
89158
if __name__ == "__main__":
90159
tilelang.testing.main()

tilelang/language/copy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def get_extent(data):
4545

4646
src_extent = get_extent(src)
4747
dst_extent = get_extent(dst)
48+
# Combine the nested if statements into a single if statement as suggested by SIM102
49+
if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and
50+
isinstance(dst, tir.BufferLoad)):
51+
# check if the case is like this:
52+
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
53+
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
54+
return tir.BufferStore(dst.buffer, src, dst.indices)
55+
4856
assert src_extent or dst_extent, "Can't deduce copy extents from args"
4957
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
5058
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)

0 commit comments

Comments
 (0)