diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd.py b/examples/gemm_fp8/example_tilelang_gemm_amd.py index 93f8c4980..16a9d5f32 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -2,6 +2,7 @@ import tilelang import tilelang.language as T from tilelang.utils.tensor import torch_assert_close +from tilelang.utils import determine_fp8_type, determine_torch_fp8_type import itertools @@ -17,8 +18,9 @@ def supply_prog(args): a_param, b_param = args M, K = a_param.shape N, _ = b_param.shape - a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + fp8_dtype = determine_torch_fp8_type() + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) return [a, b] @@ -53,7 +55,7 @@ def get_configs(): ) @tilelang.jit(out_idx=[-1]) def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): - dtype = T.float8_e4m3fnuz + dtype = determine_fp8_type() accum_dtype = T.float32 @T.prim_func @@ -104,8 +106,9 @@ def gemm_fp8_ss( def test_gemm_fp8(M, N, K): kernel = fp8_matmul(M, N, K) - a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + fp8_dtype = determine_torch_fp8_type() + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) c = kernel(a, b) ref_c = ref_program(a, b) torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py index 63a68e90f..fc7fb4400 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py @@ -7,6 +7,7 @@ from tilelang.tileop.base import GemmWarpPolicy from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter +from tilelang.utils import determine_fp8_type tilelang.testing.set_random_seed(0) @@ -45,12 +46,14 @@ def tl_matmul( num_stages, k_pack=2, num_threads=256, - in_dtype=T.float8_e4m3fnuz, + in_dtype=None, out_dtype=T.float32, accum_dtype=T.float32, a_transposed=False, b_transposed=True, ): + if in_dtype is None: + in_dtype = determine_fp8_type() b_preshuffle = True warp_size = 64 num_warps = num_threads // warp_size @@ -164,7 +167,7 @@ def shuffle_weight( def assert_tl_matmul_correctness(M, N, K, k_pack=1, a_transposed=False, b_transposed=True): - in_dtype = T.float8_e4m3fnuz + in_dtype = determine_fp8_type() out_dtype = T.float32 accum_dtype = T.float32 kernel = tl_matmul( diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index 086997975..3b575c78e 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -1,6 +1,7 @@ import torch import tilelang import tilelang.language as T +from tilelang.utils import determine_fp8_type def calc_diff(x, y): @@ -55,21 +56,24 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn) - test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2) + test_gemm_fp8(1024, 1024, 1024, determine_fp8_type()) + test_gemm_fp8(1024, 1024, 1024, determine_fp8_type("e5m2")) def run_regression_perf(): M, N, K = 4096, 4096, 4096 - dtype = "float8_e4m3" + dtype = determine_fp8_type() kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") - dtype = "float8_e5m2" - kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) - profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") - return (latency_e4m3 + latency_e5m2) / 2 + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + dtype = determine_fp8_type("e5m2") + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + latency_e4m3 = profiler_e4m3.do_bench() + return latency_e4m3 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index a702e8ae0..39c6fc333 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -1,6 +1,7 @@ import torch import tilelang import tilelang.language as T +from tilelang.utils import determine_fp8_type @tilelang.jit(out_idx=[-1]) @@ -73,21 +74,26 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn) - test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2) + test_gemm_fp8(1024, 1024, 8192, determine_fp8_type()) + test_gemm_fp8(1024, 1024, 8192, determine_fp8_type("e5m2")) def run_regression_perf(): M, N, K = 1024, 1024, 8192 - dtype = "float8_e4m3" + dtype = determine_fp8_type() kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") - dtype = "float8_e5m2" - kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) - profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") - return (latency_e4m3 + latency_e5m2) / 2 + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + else: + latency_e4m3 = profiler_e4m3.do_bench() + if torch.version.hip is None: + dtype = determine_fp8_type("e5m2") + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + return latency_e4m3 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 162092204..1015a7463 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -4,10 +4,10 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter, -) +from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter +from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter from tilelang.utils.tensor import map_torch_type +from tilelang.utils import determine_fp8_type tilelang.testing.set_random_seed(0) @@ -39,26 +39,17 @@ def tl_matmul( assert in_dtype in [ T.float16, T.float8_e4m3fn, + T.float8_e4m3fnuz, T.float8_e5m2, + T.float8_e5m2fnuz, T.int8, - ], "Currently only float16 and int8 are supported" + ], "Currently only float16, float8, and int8 are supported" assert out_dtype in [ T.float16, T.float32, T.int32, ], "Currently only float16, float32 and int32 are supported" - micro_size_x = micro_size_y = micro_size_k = 16 - - is_float8 = in_dtype in [ - T.float8_e4m3fn, - T.float8_e5m2, - T.float8_e4m3fn, - T.float8_e5m2fnuz, - ] - if out_dtype == T.int32 or is_float8: - micro_size_k = 32 - # This is a debug config block_row_warps = 2 block_col_warps = 2 @@ -78,6 +69,38 @@ def tl_matmul( B_shape = (N, K) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) + is_hip = torch.version.hip is not None + # MMA Wrapper to Auto Generate Code for MMA/MFMA + if is_hip: + mma_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + else: + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + micro_size_x = mma_emitter.M_DIM + micro_size_y = getattr(mma_emitter, "n_dim", getattr(mma_emitter, "N_DIM", micro_size_x)) + micro_size_k = mma_emitter.k_dim C_shared_shape = ( block_M // micro_size_x, block_N // micro_size_y, @@ -85,27 +108,12 @@ def tl_matmul( micro_size_y, ) - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) + threads = mma_emitter.threads + local_size_a = mma_emitter.local_size_a + local_size_b = mma_emitter.local_size_b + local_size_c = mma_emitter.local_size_out + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols @T.prim_func def gemm_fp8_intrinsic( @@ -158,7 +166,10 @@ def gemm_fp8_intrinsic( ) # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) + if is_hip: + mma_emitter.mfma(A_local, B_local, C_local, ki) + else: + mma_emitter.mma(A_local, B_local, C_local) # Perform STMatrix mma_emitter.stmatrix( @@ -192,7 +203,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + elif in_dtype in { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + }: A = torch.randn(M, K).to(in_dtype).cuda() B = torch.randn(N, K).to(in_dtype).cuda() else: @@ -218,18 +234,23 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def main(): - assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) - assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) + e4m3_dtype = determine_fp8_type() + assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32) + e5m2_dtype = determine_fp8_type("e5m2") + assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32) def run_regression_perf(): M, N, K = 4096, 4096, 4096 out_dtype, accum_dtype = "float32", "float32" - in_dtype = T.float8_e4m3fn + in_dtype = determine_fp8_type() kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) print(kernel_e4m3.get_kernel_source()) profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + else: + latency_e4m3 = profiler_e4m3.do_bench() return latency_e4m3 diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index 01e594108..5c477f9a0 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -942,6 +942,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { {"float8_e4m3fnx8", "long"}, {"float8_e5m2fnuzx4", "fp8_e5_4_t"}, {"float8_e5m2fnuzx8", "long"}, + {"float8_e5m2x4", "fp8_e5_4_t"}, + {"float8_e5m2x8", "long"}, {"float32x16", "float32x16"}}; std::string call_mfma_code = R"({ *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), diff --git a/src/tl_templates/hip/hip_fp8.h b/src/tl_templates/hip/hip_fp8.h index 19208a514..326785490 100644 --- a/src/tl_templates/hip/hip_fp8.h +++ b/src/tl_templates/hip/hip_fp8.h @@ -1,11 +1,15 @@ #pragma once #include +#include #define HIP_FP8_ENABLED 1 #define TILELANG_FP8_E4M3_VARIANT_FN 0 #define TILELANG_FP8_E4M3_VARIANT_FNUZ 1 +#define TILELANG_FP8_E5M2_VARIANT_FN 0 +#define TILELANG_FP8_E5M2_VARIANT_FNUZ 1 + #ifndef TILELANG_FP8_E4M3_VARIANT #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define TILELANG_FP8_E4M3_VARIANT TILELANG_FP8_E4M3_VARIANT_FNUZ @@ -14,55 +18,128 @@ #endif #endif +#ifndef TILELANG_FP8_E5M2_VARIANT +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define TILELANG_FP8_E5M2_VARIANT TILELANG_FP8_E5M2_VARIANT_FNUZ +#else +#define TILELANG_FP8_E5M2_VARIANT TILELANG_FP8_E5M2_VARIANT_FN +#endif +#endif + #if (TILELANG_FP8_E4M3_VARIANT == TILELANG_FP8_E4M3_VARIANT_FN) #if defined(__clang__) && defined(__HIPCC__) -#if __is_identifier(__hip_fp8_e4m3) +#if !__is_identifier(__hip_fp8_e4m3) #define TILELANG_HAVE_FP8_E4M3_FN 1 #endif #endif #endif #if defined(TILELANG_HAVE_FP8_E4M3_FN) -using fp8_e4_t = __hip_fp8_e4m3; -using fp8_e4_2_t = __hip_fp8x2_e4m3; -using fp8_e4_4_storage_t = __hip_fp8x4_e4m3; +using hip_fp8_e4_t = __hip_fp8_e4m3; +using hip_fp8x2_e4_t = __hip_fp8x2_e4m3; +using hip_fp8x4_e4_t = __hip_fp8x4_e4m3; #else // FNUZ path (MI300X and universal fallback) -using fp8_e4_t = __hip_fp8_e4m3_fnuz; -using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz; -using fp8_e4_4_storage_t = __hip_fp8x4_e4m3_fnuz; +using hip_fp8_e4_t = __hip_fp8_e4m3_fnuz; +using hip_fp8x2_e4_t = __hip_fp8x2_e4m3_fnuz; +using hip_fp8x4_e4_t = __hip_fp8x4_e4m3_fnuz; +#endif + +#if (TILELANG_FP8_E5M2_VARIANT == TILELANG_FP8_E5M2_VARIANT_FN) +#if defined(__clang__) && defined(__HIPCC__) +#if !__is_identifier(__hip_fp8_e5m2) +#define TILELANG_HAVE_FP8_E5M2_FN 1 +#endif #endif +#endif + +#if defined(TILELANG_HAVE_FP8_E5M2_FN) +using hip_fp8_e5_t = __hip_fp8_e5m2; +using hip_fp8x2_e5_t = __hip_fp8x2_e5m2; +using hip_fp8x4_e5_t = __hip_fp8x4_e5m2; +#else +using hip_fp8_e5_t = __hip_fp8_e5m2_fnuz; +using hip_fp8x2_e5_t = __hip_fp8x2_e5m2_fnuz; +using hip_fp8x4_e5_t = __hip_fp8x4_e5m2_fnuz; +#endif + +struct fp8_e4_t { + unsigned char data; + __device__ fp8_e4_t() {} + __device__ fp8_e4_t(hip_fp8_e4_t val) { + data = *reinterpret_cast(&val); + } + __device__ fp8_e4_t(float val) { + constexpr __hip_fp8_interpretation_t interp = +#if (TILELANG_FP8_E4M3_VARIANT == TILELANG_FP8_E4M3_VARIANT_FNUZ) + __HIP_E4M3_FNUZ; +#else + __HIP_E4M3; +#endif + data = __hip_cvt_float_to_fp8(val, __HIP_SATFINITE, interp); + } + __device__ operator hip_fp8_e4_t() const { + return *reinterpret_cast(&data); + } + __device__ operator float() const { + return static_cast(static_cast(*this)); + } +}; + +using fp8_e4_2_t = hip_fp8x2_e4_t; +using fp8_e4_4_storage_t = uint32_t; // Additional FP8 types for compatibility -using fp8_e5_t = __hip_fp8_e5m2_fnuz; -using fp8_e5_2_t = __hip_fp8x2_e5m2_fnuz; +using fp8_e5_2_t = hip_fp8x2_e5_t; + +struct fp8_e5_t { + unsigned char data; + __device__ fp8_e5_t() {} + __device__ fp8_e5_t(hip_fp8_e5_t val) { + data = *reinterpret_cast(&val); + } + __device__ fp8_e5_t(float val) { + constexpr __hip_fp8_interpretation_t interp = +#if (TILELANG_FP8_E5M2_VARIANT == TILELANG_FP8_E5M2_VARIANT_FNUZ) + __HIP_E5M2_FNUZ; +#else + __HIP_E5M2; +#endif + data = __hip_cvt_float_to_fp8(val, __HIP_SATFINITE, interp); + } + __device__ operator hip_fp8_e5_t() const { + return *reinterpret_cast(&data); + } + __device__ operator float() const { + return static_cast(static_cast(*this)); + } +}; // Note: E8M0 types are not supported in current HIP version // using fp8_e8_t = __hip_fp8_e8m0_fnuz; // using fp8_e8_2_t = __hip_fp8x2_e8m0_fnuz; // Simple wrapper that provides member access for generated code -struct fp8_e4_4_t { +struct __align__(4) fp8_e4_4_t { union { - // __hip_fp8x4_e4m3_fnuz data; fp8_e4_4_storage_t data; struct { - fp8_e4_t x, y, z, w; + fp8_e4_t x; + fp8_e4_t y; + fp8_e4_t z; + fp8_e4_t w; }; }; - // Default constructor - __device__ fp8_e4_4_t() = default; - - // Constructor from __hip_fp8x4_e4m3_fnuz + __device__ fp8_e4_4_t() {} __device__ fp8_e4_4_t(const fp8_e4_4_storage_t &val) : data(val) {} + __device__ fp8_e4_4_t(const hip_fp8x4_e4_t &val) { + data = *reinterpret_cast(&val); + } - // Constructor from float4 - __device__ fp8_e4_4_t(const float4 &val) : data(val) {} - - // Conversion operator to __hip_fp8x4_e4m3_fnuz - __device__ operator fp8_e4_4_storage_t() const { return data; } + __device__ operator hip_fp8x4_e4_t() const { + return *reinterpret_cast(&data); + } - // Assignment operator __device__ fp8_e4_4_t &operator=(const fp8_e4_4_storage_t &val) { data = val; return *this; @@ -80,16 +157,25 @@ struct __align__(16) fp8_e4_16_t { }; // FP8 E5M2 vector types -struct fp8_e5_4_t { +using fp8_e5_4_storage_t = uint32_t; + +struct __align__(4) fp8_e5_4_t { union { - __hip_fp8x4_e5m2_fnuz data; + fp8_e5_4_storage_t data; struct { - fp8_e5_t x, y, z, w; + fp8_e5_t x; + fp8_e5_t y; + fp8_e5_t z; + fp8_e5_t w; }; }; - __device__ fp8_e5_4_t() = delete; - __device__ fp8_e5_4_t(const __hip_fp8x4_e5m2_fnuz &val) : data(val) {} - __device__ operator __hip_fp8x4_e5m2_fnuz() const { return data; } + __device__ fp8_e5_4_t() {} + __device__ fp8_e5_4_t(const hip_fp8x4_e5_t &val) { + data = *reinterpret_cast(&val); + } + __device__ operator hip_fp8x4_e5_t() const { + return *reinterpret_cast(&data); + } }; struct __align__(8) fp8_e5_8_t { diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index 8770f7416..e2a217563 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -1,6 +1,7 @@ import tilelang.language as T from tilelang import tvm as tvm import tilelang.testing +from tilelang.utils import determine_fp8_type import pytest @@ -147,8 +148,8 @@ def test_gemm_ss_fp8_cuda(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeA @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (128, 128, 128, True, True, T.float8_e5m2fnuz, T.float8_e5m2fnuz, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.float8_e4m3fnuz, T.float8_e4m3fnuz, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, determine_fp8_type("e5m2"), determine_fp8_type("e5m2"), T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, determine_fp8_type(), determine_fp8_type(), T.float32, 128, 128, 32, 2, 128), ], ) def test_gemm_ss_fp8_rocm(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): @@ -459,8 +460,8 @@ def test_gemm_sr_fp8_cuda(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeA "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ # TODO: There is precision problem needs to repair - # (128, 128, 128, True, True, T.float8_e5m2fnuz, T.float8_e5m2fnuz, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.float8_e4m3fnuz, T.float8_e4m3fnuz, T.float32, 128, 128, 32, 2, 128), + # (128, 128, 128, True, True, determine_fp8_type("e5m2"), determine_fp8_type("e5m2"), T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, determine_fp8_type(), determine_fp8_type(), T.float32, 128, 128, 32, 2, 128), ], ) def test_gemm_sr_fp8_rocm(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): @@ -624,7 +625,7 @@ def test_gemm_rr_fp8_cuda(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeA [ # TODO: There is precision problem needs to repair # (128, 128, 128, True, True, T.float8_e5m2fnuz, T.float8_e5m2fnuz, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.float8_e4m3fnuz, T.float8_e4m3fnuz, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, determine_fp8_type(), determine_fp8_type(), T.float32, 128, 128, 32, 2, 128), ], ) def test_gemm_rr_fp8_rocm(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index fa211e16d..f60b9a924 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -49,6 +49,7 @@ class MatrixCoreIntrinEmitter: "int32": "int32", "float8_e4m3": "e4m3", "float8_e5m2": "e5m2", + "float8_e4m3fn": "e4m3fn", "float8_e4m3fnuz": "e4m3fnuz", "float8_e5m2fnuz": "e5m2fnuz", } @@ -108,7 +109,7 @@ def __init__( def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): - if a_dtype in ["float8_e4m3fnuz", "float8_e5m2fnuz", T.int8]: + if a_dtype in ["float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", T.int8]: self.k_dim = 32 return a_dtype = DataType(a_dtype) @@ -141,12 +142,17 @@ def _initialize_mfma_prefix(self, k_dim=16): "float32": "f32", "int8": "i8", "int32": "i32", + "float8_e4m3fn": "fp8", "float8_e4m3fnuz": "fp8", - "float8_e5m2fnuz": "fp8", + # ROCm treats E5M2 as BF8 in MFMA intrinsics. + "float8_e5m2": "bf8", + "float8_e5m2fnuz": "bf8", }[in_dtype] if in_dtype_abbrv == "fp8": self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8" + elif in_dtype_abbrv == "bf8": + self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_bf8_bf8" elif in_dtype_abbrv == "i8": self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8" elif in_dtype_abbrv == "bf16": diff --git a/tilelang/utils/__init__.py b/tilelang/utils/__init__.py index f2e8ec441..335f72f24 100644 --- a/tilelang/utils/__init__.py +++ b/tilelang/utils/__init__.py @@ -1,6 +1,10 @@ """The profiler and convert to torch utils""" -from .target import determine_target # noqa: F401 +from .target import ( # noqa: F401 + determine_target, + determine_fp8_type, + determine_torch_fp8_type, +) from .tensor import TensorSupplyType, torch_assert_close, map_torch_type # noqa: F401 from .language import ( is_global, # noqa: F401 diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index 93df938c1..0c1904b38 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -64,6 +64,37 @@ def check_metal_availability() -> bool: return arch == "arm64" +def determine_fp8_type(fp8_format: Literal["e4m3", "e5m2"] = "e4m3") -> str: + """ + Select the correct FP8 dtype string for the current platform. + - CUDA defaults to FP8 E4M3FN / E5M2. + - ROCm uses FNUZ except gfx950 (OCP), which prefers non-FNUZ when available. + """ + if fp8_format not in {"e4m3", "e5m2"}: + raise ValueError(f"Unsupported FP8 format: {fp8_format}") + if torch.version.hip is None: + return "float8_e4m3fn" if fp8_format == "e4m3" else "float8_e5m2" + if not torch.cuda.is_available(): + return "float8_e4m3fnuz" if fp8_format == "e4m3" else "float8_e5m2fnuz" + props = torch.cuda.get_device_properties(0) + gcn_arch = getattr(props, "gcnArchName", "") + if fp8_format == "e4m3": + if gcn_arch.startswith("gfx950"): + return "float8_e4m3fn" + return "float8_e4m3fnuz" + if gcn_arch.startswith("gfx950") and hasattr(torch, "float8_e5m2"): + return "float8_e5m2" + return "float8_e5m2fnuz" + + +def determine_torch_fp8_type(fp8_format: Literal["e4m3", "e5m2"] = "e4m3") -> torch.dtype: + dtype_name = determine_fp8_type(fp8_format) + torch_dtype = getattr(torch, dtype_name, None) + if torch_dtype is None: + raise RuntimeError(f"PyTorch does not expose dtype {dtype_name}") + return torch_dtype + + def normalize_cutedsl_target(target: str | Target) -> Target | None: if isinstance(target, Target): if target.kind.name == "cuda" and "cutedsl" in target.keys: