Skip to content

Commit 2758fb4

Browse files
committed
Add Parallel vectorized cast test
1 parent cbce71f commit 2758fb4

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

testing/python/language/test_tilelang_language_vectorized_cast.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,36 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
1717

1818
@T.prim_func
1919
def main(
20-
A: T.Tensor[(M), dtype_A], # noqa: F821
21-
B: T.Tensor[(M), dtype_B], # noqa: F821
20+
A: T.Tensor[(M,), dtype_A], # noqa: F821
21+
B: T.Tensor[(M,), dtype_B], # noqa: F821
2222
):
2323
with T.Kernel(1, threads=128):
2424
T.copy(A, B)
2525

2626
return main
2727

2828

29+
@tilelang.jit
30+
def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
31+
assert M % 256 == 0
32+
33+
@T.prim_func
34+
def main(
35+
A: T.Tensor[(M,), dtype_A], # noqa: F821
36+
B: T.Tensor[(M,), dtype_B], # noqa: F821
37+
):
38+
with T.Kernel(1, threads=128):
39+
A_local = T.alloc_fragment((M,), dtype_A)
40+
B_local = T.alloc_fragment((M,), dtype_B)
41+
42+
T.copy(A, A_local)
43+
for i in T.Parallel(M):
44+
B_local[i] = A_local[i]
45+
T.copy(B_local, B)
46+
47+
return main
48+
49+
2950
def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2):
3051
"""Run the vectorized cast kernel and check the correctness.
3152
Args:
@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
3758

3859
M = 128 * lanes
3960
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
61+
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
4062

4163
A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
4264
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
65+
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
4366

4467
kernel(A, B)
68+
kernel_parallel(A, C)
4569

4670
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
71+
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C)
4772

4873
code = kernel.get_kernel_source()
74+
code_parallel = kernel_parallel.get_kernel_source()
4975

50-
assert check_str in code, \
76+
assert check_str in code and check_str in code_parallel, \
5177
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
5278

5379

0 commit comments

Comments
 (0)