@@ -82,10 +82,10 @@ def main(
8282 T .annotate_layout ({
8383 E :
8484 make_cutlass_metadata_layout (
85- E , mma_dtype = "float16" , arch = "9.0" , block_k = block_K ),
85+ E , mma_dtype = in_dtype , arch = "9.0" , block_k = block_K ),
8686 E_shared :
8787 make_cutlass_metadata_layout (
88- E_shared , mma_dtype = "float16" , arch = "9.0" , block_k = block_K ),
88+ E_shared , mma_dtype = in_dtype , arch = "9.0" , block_k = block_K ),
8989 })
9090 T .disable_warp_group_reg_alloc ()
9191 T .clear (C_frag )
@@ -216,14 +216,18 @@ def _matmul(A, B):
216216
217217 C = _matmul (A , B )
218218
219- torch_assert_close (
220- C_sp .to (torch .float32 ),
221- C .to (torch .float32 ),
222- rtol = 1e-3 ,
223- atol = 1e-3 ,
224- base_name = "tilelang_sp" ,
225- ref_name = "ref_dense" ,
226- )
219+ if 'float8' in in_dtype :
220+ diff = calc_diff (C_sp , C )
221+ assert diff < 1e-3 , f"{ diff = } "
222+ else :
223+ torch_assert_close (
224+ C_sp .to (torch .float32 ),
225+ C .to (torch .float32 ),
226+ rtol = 1e-3 ,
227+ atol = 1e-3 ,
228+ base_name = "tilelang_sp" ,
229+ ref_name = "ref_dense" ,
230+ )
227231 print ("pass" )
228232
229233
@@ -335,7 +339,7 @@ def test_gemm_sp_sm90():
335339 run_gemm_sp_sm90 (512 , 1024 , 768 , "float16" , "float32" , "float32" , 64 , 64 , 64 , 0 , 128 , True ,
336340 True )
337341
338- run_gemm_sp_sm90 (256 , 256 , 256 , "float8_e4m3" , "float16" , "float16" , 64 , 64 , 64 , 2 , 128 , False ,
342+ run_gemm_sp_sm90 (512 , 1024 , 768 , "float8_e4m3" , "float16" , "float16" , 64 , 64 , 64 , 2 , 128 , False ,
339343 True )
340344 run_gemm_sp_sm90 (512 , 1024 , 768 , "int8" , "int32" , "int32" , 64 , 64 , 64 , 2 , 128 , False , True )
341345
0 commit comments