Skip to content

Commit c28a687

Browse files
committed
[test] use cal_diff for assertion
1 parent 1022c9c commit c28a687

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def main(
8282
T.annotate_layout({
8383
E:
8484
make_cutlass_metadata_layout(
85-
E, mma_dtype="float16", arch="9.0", block_k=block_K),
85+
E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
8686
E_shared:
8787
make_cutlass_metadata_layout(
88-
E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
88+
E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
8989
})
9090
T.disable_warp_group_reg_alloc()
9191
T.clear(C_frag)
@@ -216,14 +216,18 @@ def _matmul(A, B):
216216

217217
C = _matmul(A, B)
218218

219-
torch_assert_close(
220-
C_sp.to(torch.float32),
221-
C.to(torch.float32),
222-
rtol=1e-3,
223-
atol=1e-3,
224-
base_name="tilelang_sp",
225-
ref_name="ref_dense",
226-
)
219+
if 'float8' in in_dtype:
220+
diff = calc_diff(C_sp, C)
221+
assert diff < 1e-3, f"{diff=}"
222+
else:
223+
torch_assert_close(
224+
C_sp.to(torch.float32),
225+
C.to(torch.float32),
226+
rtol=1e-3,
227+
atol=1e-3,
228+
base_name="tilelang_sp",
229+
ref_name="ref_dense",
230+
)
227231
print("pass")
228232

229233

@@ -335,7 +339,7 @@ def test_gemm_sp_sm90():
335339
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True,
336340
True)
337341

338-
run_gemm_sp_sm90(256, 256, 256, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False,
342+
run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False,
339343
True)
340344
run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
341345

0 commit comments

Comments
 (0)