77
88
99@tl .jit
10- def ptr_null_test (M , N , K , block_M , block_N , block_K , dtype = "float16" , accum_dtype = "float" ):
11-
12- @T .prim_func
13- def main (
14- a_ptr : T .ptr ,
15- b_ptr : T .ptr ,
16- c_ptr : T .ptr ,
17- bias_ptr : T .ptr ,
18- m : T .int32 ,
19- n : T .int32 ,
20- k : T .int32 ,
21- with_bias : T .bool ,
22- ):
23- A = T .make_tensor (a_ptr , (m , k ), dtype )
24- B = T .make_tensor (b_ptr , (k , n ), dtype )
25- C = T .make_tensor (c_ptr , (m , n ), accum_dtype )
26- Bias = T .make_tensor (bias_ptr , (n ), accum_dtype )
27-
28- # Initialize Kernel Context
29- with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = 128 ) as (bx , by ):
30- A_shared = T .alloc_shared ((block_M , block_K ), dtype )
31- B_shared = T .alloc_shared ((block_N , block_K ), dtype )
32- C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
33-
34- T .clear (C_local )
35-
36- for ko in T .Pipelined (T .ceildiv (k , block_K ), num_stages = 3 ):
37- # Copy tile of A
38- T .copy (A [by * block_M , ko * block_K ], A_shared )
39- T .copy (B [bx * block_N , ko * block_K ], B_shared )
40- T .gemm (A_shared , B_shared , C_local , transpose_B = True )
41-
42- if with_bias :
43- for i , j in T .Parallel (block_M , block_N ):
44- C_local [i , j ] += Bias [bx * block_N + j ]
45-
46- T .copy (C_local , C [by * block_M , bx * block_N ])
47-
48- return main
49-
50-
51- @tl .jit
52- def tensor_null_test (M , N , K , block_M , block_N , block_K , dtype = "float16" , accum_dtype = "float" ):
10+ def tensor_null_test (M , N , K , block_M , block_N , block_K , dtype = "float16" , accum_dtype = "float" , with_bias = False ):
5311
5412 @T .prim_func
5513 def main (
5614 A : T .Tensor ((M , K ), dtype ),
5715 B : T .Tensor ((K , N ), dtype ),
5816 C : T .Tensor ((M , N ), accum_dtype ),
5917 Bias : T .Tensor ((N ), accum_dtype ),
60- with_bias : T .bool ,
6118 ):
6219 # Initialize Kernel Context
6320 with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = 128 ) as (bx , by ):
@@ -83,29 +40,12 @@ def main(
8340
8441
8542def run_test (M , N , K , block_M , block_N , block_K , dtype = "float16" , accum_dtype = "float" ):
86- kernel = ptr_null_test (M , N , K , block_M , block_N , block_K , dtype , accum_dtype )
8743
8844 a = torch .randn (M , K , device = "cuda" , dtype = map_torch_type (dtype ))
8945 b = torch .randn (N , K , device = "cuda" , dtype = map_torch_type (dtype ))
9046 c = torch .zeros (M , N , device = "cuda" , dtype = map_torch_type (accum_dtype ))
91- d = torch .randn (N , device = "cuda" , dtype = map_torch_type (accum_dtype ))
92- kernel (a , b , c , None , M , N , K , False )
93-
94- ref_no_bias = (a @ b .T ).to (map_torch_type (accum_dtype ))
95- ref_with_bias = ref_no_bias + d
96-
97- torch .testing .assert_close (c , ref_no_bias , atol = 1e-2 , rtol = 1e-2 )
98-
99- kernel (a , b , c , d , M , N , K , True )
100-
101- torch .testing .assert_close (c , ref_with_bias , atol = 1e-2 , rtol = 1e-2 )
102-
103- kernel = tensor_null_test (M , N , K , block_M , block_N , block_K , dtype , accum_dtype )
104- kernel (a , b , c , None , False )
105- torch .testing .assert_close (c , ref_no_bias , atol = 1e-2 , rtol = 1e-2 )
106- kernel (a , b , c , d , True )
107- torch .testing .assert_close (c , ref_with_bias , atol = 1e-2 , rtol = 1e-2 )
108-
47+ kernel = tensor_null_test (M , N , K , block_M , block_N , block_K , dtype , accum_dtype , with_bias = False )
48+ kernel (a , b , c , None )
10949
11050def test_nullptr ():
11151 run_test (1024 , 1024 , 1024 , 128 , 128 , 32 )
0 commit comments