diff --git a/testing/python/language/test_tilelang_language_copy.py b/testing/python/language/test_tilelang_language_copy.py index 953f1b0b4..1a09165ba 100644 --- a/testing/python/language/test_tilelang_language_copy.py +++ b/testing/python/language/test_tilelang_language_copy.py @@ -86,5 +86,74 @@ def test_tilelang_copy_with_stride(): run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128) +def tilelang_copy_bufferload(num_tokens, dtype="float16"): + + @T.prim_func + def main( + indices: T.Tensor((num_tokens,), "int32"), + x: T.Tensor((num_tokens,), dtype), + ): + with T.Kernel(num_tokens, threads=32) as pid: + idx = T.alloc_local([1], "int32") + T.copy(indices[pid], idx[0]) + x[idx[0]] = x[idx[0]] + 1 + + return main + + +def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"): + program = tilelang_copy_bufferload(num_tokens, dtype) + # test compilation only + tilelang.compile( + program, + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True + }) + + +def test_tilelang_copy_bufferload(): + run_tilelang_copy_bufferload(num_tokens=128) + + +def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"): + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + for i, j in T.Parallel(block_M, block_N): + T.copy(A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j]) + + return main + + +def run_tilelang_copy_buffer_load_with_parallel(M=1024, + N=1024, + block_M=128, + block_N=128, + dtype="float16"): + program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True + }) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) + + +def test_tilelang_copy_buffer_load_with_parallel(): + run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 125cbd18a..0be3e21ac 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -45,6 +45,14 @@ def get_extent(data): src_extent = get_extent(src) dst_extent = get_extent(dst) + # Combine the nested if statements into a single if statement as suggested by SIM102 + if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and + isinstance(dst, tir.BufferLoad)): + # check if the case is like this: + # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes + # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] + return tir.BufferStore(dst.buffer, src, dst.indices) + assert src_extent or dst_extent, "Can't deduce copy extents from args" src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)