@@ -17,15 +17,36 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
1717
1818 @T .prim_func
1919 def main (
20- A : T .Tensor [(M ), dtype_A ], # noqa: F821
21- B : T .Tensor [(M ), dtype_B ], # noqa: F821
20+ A : T .Tensor [(M , ), dtype_A ], # noqa: F821
21+ B : T .Tensor [(M , ), dtype_B ], # noqa: F821
2222 ):
2323 with T .Kernel (1 , threads = 128 ):
2424 T .copy (A , B )
2525
2626 return main
2727
2828
29+ @tilelang .jit
30+ def parallel_vectorized_cast_kernel (M : int , dtype_A : str , dtype_B : str ):
31+ assert M % 256 == 0
32+
33+ @T .prim_func
34+ def main (
35+ A : T .Tensor [(M ,), dtype_A ], # noqa: F821
36+ B : T .Tensor [(M ,), dtype_B ], # noqa: F821
37+ ):
38+ with T .Kernel (1 , threads = 128 ):
39+ A_local = T .alloc_fragment ((M ,), dtype_A )
40+ B_local = T .alloc_fragment ((M ,), dtype_B )
41+
42+ T .copy (A , A_local )
43+ for i in T .Parallel (M ):
44+ B_local [i ] = A_local [i ]
45+ T .copy (B_local , B )
46+
47+ return main
48+
49+
2950def run_vectorized_cast (src_dtype_str : str , dst_dtype_str : str , check_str : str , lanes : int = 2 ):
3051 """Run the vectorized cast kernel and check the correctness.
3152 Args:
@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
3758
3859 M = 128 * lanes
3960 kernel = vectorized_cast_kernel (M , src_dtype_str , dst_dtype_str )
61+ kernel_parallel = parallel_vectorized_cast_kernel (M , src_dtype_str , dst_dtype_str )
4062
4163 A = torch .randn (M , dtype = str2dtype [src_dtype_str ]).cuda ()
4264 B = torch .zeros (M , dtype = str2dtype [dst_dtype_str ]).cuda ()
65+ C = torch .zeros (M , dtype = str2dtype [dst_dtype_str ]).cuda ()
4366
4467 kernel (A , B )
68+ kernel_parallel (A , C )
4569
4670 torch .testing .assert_close (A .to (str2dtype [dst_dtype_str ]), B )
71+ torch .testing .assert_close (A .to (str2dtype [dst_dtype_str ]), C )
4772
4873 code = kernel .get_kernel_source ()
74+ code_parallel = kernel_parallel .get_kernel_source ()
4975
50- assert check_str in code , \
76+ assert check_str in code and check_str in code_parallel , \
5177 f"Cast { src_dtype_str } to { dst_dtype_str } with { lanes = } is not vectorized!"
5278
5379
0 commit comments