@@ -53,9 +53,9 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
5353 4 ,
5454 ".b16" ,
5555 A_warp .data ,
56- 8 * tx ,
57- A_shared .data ,
58- 16 * (tx % 16 ) + 8 * (tx // 16 ),
56+ A_warp . elem_offset + 8 * tx ,
57+ A_shared .access_ptr ( "r" ) ,
58+ s1 * (tx % 16 ) + 8 * (tx // 16 ),
5959 dtype = "float16" ,
6060 )
6161 )
@@ -106,9 +106,9 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
106106 4 ,
107107 ".b16" ,
108108 B_warp .data ,
109- 8 * tx ,
110- B_shared .data ,
111- 16 * (tx % 16 ) + 8 * (tx // 16 ),
109+ B_warp . elem_offset + 8 * tx ,
110+ B_shared .access_ptr ( "r" ) ,
111+ s1 * (tx % 16 ) + 8 * (tx // 16 ),
112112 dtype = "float16" ,
113113 )
114114 )
@@ -313,9 +313,10 @@ def schedule(sch: tir.Schedule):
313313 k_factors = sch .sample_perfect_tile (k , n = 3 )
314314 num_ty = sch .get (i_factors [2 ]) * sch .get (j_factors [2 ])
315315 else :
316- i_factors = [1 , 16 , 4 , 2 , 2 ]
317- j_factors = [1 , 32 , 1 , 8 , 1 ]
318- k_factors = [64 , 4 , 1 ]
316+ i_factors = [4 , 8 , 2 , 4 , 1 ]
317+ j_factors = [1 , 64 , 2 , 1 , 2 ]
318+ k_factors = [128 , 2 , 1 ]
319+
319320 num_ty = i_factors [2 ] * j_factors [2 ]
320321
321322 i0 , i1 , i2 , i3 , i4 = sch .split (i , factors = i_factors )
@@ -487,12 +488,12 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
487488c = tvm .nd .array (np .zeros ((M , N ), dtype = "float32" ), dev )
488489f = tvm .build (sch .mod ["main" ], target = "cuda" , name = "dense" )
489490
490- # print(f.imported_modules[0].get_source())
491- # f(a, b, c)
492- # tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
493- # print("ok")
491+ print (f .imported_modules [0 ].get_source ())
492+ f (a , b , c )
493+ tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
494+ print ("ok" )
494495
495- # evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
496- # gflops = (N * M * K) * 2 / 1e9
497- # time_ms = evaluator(a, b, c).mean * 1e3
498- # print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
496+ evaluator = f .time_evaluator (f .entry_name , dev , number = 1000 )
497+ gflops = (N * M * K ) * 2 / 1e9
498+ time_ms = evaluator (a , b , c ).mean * 1e3
499+ print ("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms , gflops / (time_ms / 1e3 )))
0 commit comments