@@ -184,16 +184,19 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
184184
185185@T .prim_func
186186def mma_store_impl (a : T .handle , c : T .handle ) -> None :
187+ s1 = T .var ("int32" )
188+ s0 = T .var ("int32" )
189+
187190 C_warp = T .match_buffer (a , [32 , 4 ], dtype = "float32" , scope = "warp" , offset_factor = 1 )
188- C = T .match_buffer (c , [16 , 8 ], dtype = "float32" , scope = "global" ,offset_factor = 1 )
191+ C = T .match_buffer (c , [16 , 8 ], dtype = "float32" , scope = "global" ,offset_factor = 1 , strides = [ s1 , s0 ] )
189192
190193 with T .block ("root" ):
191194 T .reads (C_warp [0 :32 , 0 :4 ])
192195 T .writes (C [0 :16 , 0 :8 ])
193196 tx = T .env_thread ("threadIdx.x" )
194197 T .launch_thread (tx , 32 )
195198
196- T .evaluate (T .mma_store ("m16n8" , C .data , C . elem_offset , C_warp .access_ptr ( "r" ), tx , dtype = "float32" ))
199+ T .evaluate (T .mma_store ("m16n8" , C .access_ptr ( "w" ), C_warp . data , C_warp .elem_offset , s1 , dtype = "float32" ))
197200
198201
199202tir .TensorIntrin .register ("mma.ldmatrix_a" , ldmatrix_a_desc , ldmatrix_a_impl )
@@ -388,7 +391,6 @@ def lambda_b(i, j):
388391 fused_2 = sch .fuse (f_0 , f_3 )
389392
390393 # print(sch.mod.script())
391-
392394 # return
393395
394396 sch .tensorize (fused_1 , "mma_store" )
@@ -423,20 +425,20 @@ def lambda_b(i, j):
423425 print (sch .mod .script ())
424426 target = "cuda"
425427 f = tvm .build (sch .mod ["main" ], target = target , name = "dense" )
426- print ( f . imported_modules [ 0 ]. get_source ())
427-
428- # dev = tvm.device("cuda", 0 )
429- # a_np = np.random.uniform(size=(N, K )).astype("float16")
430- # b_np = np.random.uniform(size=(K, M)) .astype("float16" )
431- # c_np = np.dot(a_np.astype("float32"), b_np.astype("float32") )
432- # a = tvm.nd.array(a_np , dev)
433- # b = tvm.nd.array(b_np , dev)
434- # c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev )
435- # f = tvm.build(sch.mod["main"], target="cuda", name="dense")
436-
437- # f(a, b, c)
438- # tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
439- # print("ok")
428+
429+ dev = tvm . device ( "cuda" , 0 )
430+ a_np = np . random . uniform ( size = ( N , K )). astype ( "float16" )
431+ b_np = np .random .uniform (size = (K , M )).astype ("float16" )
432+ c_np = np .dot ( a_np . astype ( "float32" ), b_np .astype ("float32" ) )
433+ a = tvm . nd . array ( a_np , dev )
434+ b = tvm .nd .array (b_np , dev )
435+ c = tvm .nd .array (np . zeros (( M , N ), dtype = "float32" ) , dev )
436+ f = tvm .build ( sch . mod [ "main" ], target = "cuda" , name = "dense" )
437+
438+ print ( f . imported_modules [ 0 ]. get_source ())
439+ f (a , b , c )
440+ tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
441+ print ("ok" )
440442
441443# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
442444# gflops = (N * M * K) * 2 / 1e9
0 commit comments