diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index 4007bebe3..65b2d5cff 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -1,3 +1,4 @@ +import pytest import torch import tilelang.testing from tilelang import tvm as tvm @@ -207,17 +208,33 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack", + [ + (128, 128, 128, "float16", "float16", "float32", False, True, 1), + (128, 256, 256, "float16", "float32", "float32", False, True, 1), + (128, 256, 256, "float16", "float32", "float32", False, True, 2), + (128, 128, 128, "int8", "int32", "int32", False, True, 1), + (128, 256, 256, "int8", "int32", "int32", False, True, 1), + (128, 256, 256, "int8", "int32", "int32", False, True, 2), + (128, 256, 256, "int8", "int32", "int32", False, False, 1), + (128, 256, 256, "int8", "int32", "int32", False, False, 2), + (128, 128, 128, "float8_e4m3fnuz", "float16", "float32", False, True, 1), + ], +) @tilelang.testing.requires_rocm -def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, "float16", "float16") - assert_tl_matmul_correctness(128, 256, 256, "float16", "float32") - assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2) - assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") - assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") - assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) - assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") - assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2) - assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16") +def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack): + assert_tl_matmul_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + k_pack=k_pack, + ) assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32") assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2) assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False) diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index 393a77b78..eb2c6cbca 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -1,3 +1,4 @@ +import pytest import torch import tilelang.testing from tilelang import tvm as tvm @@ -257,19 +258,46 @@ def assert_tl_matmul_correctness( torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load", + [ + (256, 256, 512, "int8", "int32", "int32", False, True, 1, True, False), + (256, 256, 512, "int8", "int32", "int32", False, False, 1, True, False), + (256, 256, 512, "int8", "int32", "int32", False, True, 2, True, False), + (256, 256, 512, "int8", "int32", "int32", False, False, 2, True, False), + (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 1, True, False), + (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 1, True, False), + (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 2, True, False), + (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 2, True, False), + ], +) @tilelang.testing.requires_rocm -def test_assert_tl_matmul(): - assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) - assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) - assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) - - assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) - assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2, b_preshuffle=True) - - assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True) - assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True) - assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True) - assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_transposed=False, b_preshuffle=True) +def test_assert_tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + a_transposed, + b_transposed, + k_pack, + b_preshuffle, + b_g2l_load, +): + assert_tl_matmul_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + k_pack=k_pack, + b_preshuffle=b_preshuffle, + b_g2l_load=b_g2l_load, + ) if __name__ == "__main__": diff --git a/testing/python/amd/test_tilelang_test_amd.py b/testing/python/amd/test_tilelang_test_amd.py index 0666fd479..c9c3bedbb 100644 --- a/testing/python/amd/test_tilelang_test_amd.py +++ b/testing/python/amd/test_tilelang_test_amd.py @@ -1,3 +1,4 @@ +import pytest from tilelang import tvm as tvm import tilelang as tl import tilelang.language as T @@ -95,31 +96,49 @@ def ref_program(A, B): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) +@pytest.mark.parametrize( + "trans_A, trans_B, k_pack", + [ + (False, False, 1), + (False, True, 1), + (True, True, 1), + (True, False, 1), + (False, True, 2), + ], +) @tilelang.testing.requires_rocm -def test_gemm_f16f32f32_nt(): - run_gemm(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32, k_pack=2) +def test_gemm_f16f32f32_nt(trans_A, trans_B, k_pack): + run_gemm(1024, 1024, 1024, trans_A, trans_B, "float16", "float32", "float32", 128, 128, 32, k_pack=k_pack) +@pytest.mark.parametrize( + "trans_A, trans_B, k_pack", + [ + (False, False, 1), + (False, True, 1), + (True, True, 1), + (True, False, 1), + (False, True, 2), + ], +) @tilelang.testing.requires_rocm -def test_gemm_bf16f32f32_nt(): - run_gemm(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2) +def test_gemm_bf16f32f32_nt(trans_A, trans_B, k_pack): + run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=k_pack) +@pytest.mark.parametrize( + "trans_A, trans_B, k_pack", + [ + (False, False, 1), + (False, True, 1), + (True, True, 1), + (True, False, 1), + (False, True, 2), + ], +) @tilelang.testing.requires_rocm -def test_gemm_bf16bf16f32(): - run_gemm(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) - run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2) +def test_gemm_bf16bf16f32(trans_A, trans_B, k_pack): + run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=k_pack) def matmul_rs( diff --git a/testing/python/fastmath/test_mathops_fastmath.py b/testing/python/fastmath/test_mathops_fastmath.py index 7809983e8..72eddd960 100644 --- a/testing/python/fastmath/test_mathops_fastmath.py +++ b/testing/python/fastmath/test_mathops_fastmath.py @@ -1,3 +1,4 @@ +import pytest import tilelang import tilelang.language as T import torch @@ -242,13 +243,9 @@ def main( print(f"✓ {mathop_name} numerical test passed") -@tilelang.testing.requires_cuda -def test_mathops_generate_no_fastmath(): - """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" - # Based on test results, our tl.* intrinsics actually generate - # no fastmath versions - # This appears to be the intended behavior - single_arg_mathops = [ +@pytest.mark.parametrize( + "name, func", + [ ("exp", T.exp), ("exp2", T.exp2), ("exp10", T.exp10), @@ -270,24 +267,26 @@ def test_mathops_generate_no_fastmath(): ("trunc", T.trunc), ("round", T.round), ("nearbyint", T.nearbyint), - ] - - for name, func in single_arg_mathops: - run_single_arg_mathop_test(name, func, dtype="float32") - print(f"✓ {name} test passed") + ], +) +@tilelang.testing.requires_cuda +def test_mathops_generate_no_fastmath(name, func): + """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" + run_single_arg_mathop_test(name, func, dtype="float32") + print(f"✓ {name} test passed") -@tilelang.testing.requires_cuda -def test_two_arg_mathops_fastmath(): - """Test all two-argument mathops""" - # Two argument mathops - two_arg_mathops = [ +@pytest.mark.parametrize( + "name, func", + [ ("pow", T.pow), ("fmod", T.fmod), - ] - - for name, func in two_arg_mathops: - run_two_arg_mathop_test(name, func, dtype="float32") + ], +) +@tilelang.testing.requires_cuda +def test_two_arg_mathops_fastmath(name, func): + """Test all two-argument mathops""" + run_two_arg_mathop_test(name, func, dtype="float32") @tilelang.testing.requires_cuda @@ -296,11 +295,9 @@ def test_abs_maps_to_fabs(): run_abs_test() -@tilelang.testing.requires_cuda -def test_fastmath_versions(): - """Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code""" - # Test fastmath versions - fastmath_mathops = [ +@pytest.mark.parametrize( + "name, func", + [ ("__exp", T.__exp), ("__exp10", T.__exp10), ("__log", T.__log), @@ -309,11 +306,13 @@ def test_fastmath_versions(): ("__tan", T.__tan), ("__cos", T.__cos), ("__sin", T.__sin), - ] - - for name, func in fastmath_mathops: - run_fastmath_mathop_test(name, func, dtype="float32") - print(f"✓ {name} test passed") + ], +) +@tilelang.testing.requires_cuda +def test_fastmath_versions(name, func): + """Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code""" + run_fastmath_mathop_test(name, func, dtype="float32") + print(f"✓ {name} test passed") if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index 2fd1554a8..a9ab86985 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -1,3 +1,4 @@ +import pytest import torch import tilelang.testing import tilelang.language as T @@ -77,38 +78,29 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" -def test_vectorized_cast(): - # fp32 -> fp16 - run_vectorized_cast("float32", "float16", "__float22half2_rn", 2) - run_vectorized_cast("float32", "float16", "__float22half2_rn", 4) - - # fp16 -> fp32 - run_vectorized_cast("float16", "float32", "__half22float2", 2) - run_vectorized_cast("float16", "float32", "__half22float2", 4) - - # fp32 -> fp8_e4m3 - run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2) - run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4) - - # fp32 -> fp8_e5m2 - run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2) - run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4) - - # fp32 -> bf16 - run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 2) - run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 4) - - # bf16 -> fp32 - run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2) - run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4) - - # fp8_e4m3 -> fp32 - run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2) - run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4) - - # fp8_e5m2 -> fp32 - run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2) - run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4) +@pytest.mark.parametrize( + "src_dtype, dst_dtype, check_str, lanes", + [ + ("float32", "float16", "__float22half2_rn", 2), + ("float32", "float16", "__float22half2_rn", 4), + ("float16", "float32", "__half22float2", 2), + ("float16", "float32", "__half22float2", 4), + ("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2), + ("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4), + ("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2), + ("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4), + ("float32", "bfloat16", "__float22bfloat162_rn", 2), + ("float32", "bfloat16", "__float22bfloat162_rn", 4), + ("bfloat16", "float32", "__bfloat1622float2", 2), + ("bfloat16", "float32", "__bfloat1622float2", 4), + ("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2), + ("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4), + ("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2), + ("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4), + ], +) +def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): + run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes) if __name__ == "__main__": diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index a13e4533e..de8a9f9dc 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -109,30 +109,27 @@ def ref_program(A, B): @pytest.mark.skip(reason="Temporarily disabling until GEMM SS issues are resolved") -def test_gemm_ss(): - # More test case can be found in kernel/test_tilelang_kernel_gemm.py - # GEMM tests for float16 - run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2) - run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2) - run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2) - run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2) - # n8 test - run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) - - # int8 test - run_gemm_ss(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) - - # float8 tests - run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) - - # tfloat32 test - run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2, 128), + (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2, 128), + (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2, 128), + (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), + (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), + ], +) +def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) def matmul_rs( @@ -247,30 +244,27 @@ def ref_program(A, B): @pytest.mark.skip(reason="Temporarily disabling until GEMM RS issues are resolved") -def test_gemm_rs(): - # GEMM tests for float16 - run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) - - # n8 tests - run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) - - # int8 tests - run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) - - # float8 tests - run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) - - # float32 tests - run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), + (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), + ], +) +def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) def matmul_sr( @@ -384,31 +378,27 @@ def ref_program(A, B): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) -def test_gemm_sr(): - # GEMM tests for float16 - run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) - - # n8 tests - run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) - - # int8 tests - run_gemm_sr(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_sr(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_sr(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_sr(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2) - - # float8 tests - run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) - - # float32 tests - # TODO(lei): fix in future - run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), + (128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), + ], +) +def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) def matmul_rr( @@ -526,31 +516,29 @@ def ref_program(A, B): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) -def test_gemm_rr(): - # GEMM tests for float16 - run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) - run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) - # n8 tests - run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2) - run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2) - - # int8 tests - run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) - run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) - - # float8 tests - run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) - - # float32 tests - run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) - run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128), + (128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2, 128), + (128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2, 128), + (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), + (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), + ], +) +def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) if __name__ == "__main__": diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index 4ced4f837..3a703a002 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -1,3 +1,4 @@ +import pytest import torch import tilelang import tilelang.testing @@ -303,50 +304,53 @@ def run_gemm_sp_sm80( @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(9, 0) -def test_gemm_sp_sm90(): - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256) - - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128) - - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128) - - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128) - - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True) - - run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True) - run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True) +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", + [ + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True), + (512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True), + (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True), + ], +) +def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): + run_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(8, 0) @tilelang.testing.requires_cuda_compute_version_le(8, 9) -def test_gemm_sp_sm80(): - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128) - - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True) - - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128) - - run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True) - run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True) - run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True) - - run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True) - run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True) - run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True) +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", + [ + (512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False), + (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128, False, False), + (512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True), + (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True), + (512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True), + (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True), + (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True), + (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True), + ], +) +def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): + run_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B) if __name__ == "__main__": diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index 276bce4d9..cd4123d99 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -1,3 +1,4 @@ +import pytest from tilelang import tvm as tvm from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.utils.tensor import torch_assert_close, map_torch_type @@ -153,33 +154,24 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): return A, B -def test_gemm_ss(): - # More test case can be found in kernel/test_tilelang_kernel_gemm.py - # GEMM tests for float16 - # TODO: support transposed A compressor - run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2) - run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2) - run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2) - run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2) - - # n8 test - run_gemm_ss(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) - - # int8 test - run_gemm_ss(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2) - run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) - - # float8 tests - run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) - run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) - - # tfloat32 test - # run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2, 128), + (512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2, 128), + (512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2, 128), + (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), + (128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + ], +) +def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) def matmul_rs( @@ -313,30 +305,23 @@ def _matmul(A, B): print("pass") -def test_gemm_rs(): - # GEMM tests for float16 - run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) - - # n8 tests - run_gemm_rs(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) - - # int8 tests - run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) - - # float8 tests - run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) - - # float32 tests - # run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), + (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), + (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + ], +) +def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) def matmul_sr( @@ -470,30 +455,23 @@ def _matmul(A, B): print("pass") -def test_gemm_sr(): - # GEMM tests for float16 - run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) - - # n8 tests - run_gemm_sr(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) - - # int8 tests - run_gemm_sr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2) - run_gemm_sr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2) - run_gemm_sr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_sr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) - - # float8 tests - run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) - - # float32 tests - # run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), + (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), + (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2, 128), + (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2, 128), + (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + ], +) +def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) def matmul_rr( @@ -631,31 +609,25 @@ def _matmul(A, B): print("pass") -def test_gemm_rr(): - # GEMM tests for float16 - run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) - run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) - # n8 tests - run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2) - run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2) - - # int8 tests - run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) - run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) - - # float8 tests - run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) - - # float32 tests - # run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) - # run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2) +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", + [ + (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128), + (128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2, 128), + (128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2, 128), + (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), + (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + ], +) +def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): + run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) if __name__ == "__main__":