From 68a5289c807b2eace77f4a43e76eccadfb6217e6 Mon Sep 17 00:00:00 2001 From: SiriusNEO Date: Mon, 26 Jan 2026 18:02:33 +0800 Subject: [PATCH 1/3] [Feature] Support E8M0 related vectorized cast --- examples/gemm/example_gemm_autotune.py | 3 +- src/target/codegen_cuda.cc | 53 ++++++++++++++++++ src/target/utils.cc | 12 ++++ src/tl_templates/cuda/copy.h | 2 + src/tl_templates/cuda/cuda_fp8.h | 55 +++++++++++++++++++ .../test_tilelang_language_vectorized_cast.py | 12 +++- 6 files changed, 135 insertions(+), 2 deletions(-) diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index 016d448a4..d8085e2f1 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -208,7 +208,7 @@ def gemm_autotune( def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False): - use_autotune = True + use_autotune = False if use_autotune: result = get_best_config(M, N, K, with_roller) print(result.config) @@ -219,6 +219,7 @@ def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False # benchmark profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + print(kernel.get_kernel_source()) tilelang_latency = profiler.do_bench() ref_latency = profiler.do_bench(ref_program) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 259343f6d..19349eb51 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1177,6 +1177,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { os << sret; }; + // A list of casting functions that are supported by TileLang templates. + // To add a new type conversion, you should do the following things: + // 1. Add the new conversion function in tl_templates. (__tl_cvt_xx) + // 2. Add a new if statement like the one below. + // 3. In src/target/utils.cc, allow this vectorizable cast. + // Handle conversion from float16 to float32 if (from_ty.is_float16() && target_ty.is_float() && target_ty.bits() == 32) { // Use __half22float2 for vectorized conversion (half2 -> float2) @@ -1245,6 +1251,53 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } } + // Handle conversion from float8 (E8M0) to bfloat16 + if (from_ty.is_float8_e8m0fnu() && target_ty.is_bfloat16()) { + // Use __tl_cvt_e8m0x2_to_bfloat162 for vectorized conversion (fp8_e8m0x2 -> + // bfloat162) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_e8m0x2_to_bfloat162", + "__nv_fp8x2_storage_t", "__nv_bfloat162", "", true, + false); + return; + } + } + + // Handle conversion from bfloat16 to float8 (E8M0) + if (from_ty.is_bfloat16() && target_ty.is_float8_e8m0fnu()) { + // Use __tl_cvt_bfloat162_to_e8m0x2 for vectorized conversion (bfloat162 -> + // fp8_e8m0x2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_bfloat162_to_e8m0x2", "__nv_bfloat162", + "__nv_fp8x2_storage_t", "", false, true); + return; + } + } + + // Handle conversion from float to float8 (E8M0) + if (from_ty.is_float() && from_ty.bits() == 32 && + target_ty.is_float8_e8m0fnu()) { + // Use __nv_cvt_float2_to_e8m0x2 for vectorized conversion (float2 -> + // fp8_e8m0x2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_float2_to_e8m0x2", "float2", + "__nv_fp8x2_storage_t", "", false, true); + return; + } + } + + // Handle conversion from double to float8 (E8M0) + if (from_ty.is_float() && from_ty.bits() == 64 && + target_ty.is_float8_e8m0fnu()) { + // Use __nv_cvt_double2_to_e8m0x2 for vectorized conversion (double2 -> + // fp8_e8m0x2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_double2_to_e8m0x2", "double2", + "__nv_fp8x2_storage_t", "", false, true); + return; + } + } + // Handle conversion from float16 to float4 (E2M1) if (from_ty.is_float16() && target_ty.is_float4_e2m1fn()) { // Use __tl_cvt_half2_to_fp4x2 for vectorized conversion (half2 -> fp4x2) diff --git a/src/target/utils.cc b/src/target/utils.cc index 8df9a7138..cdb300562 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -178,6 +178,18 @@ bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) { if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float()) return true; + // float8 (E8M0) -> bfloat16 + if (from_ty.is_float8_e8m0fnu() && target_ty.is_bfloat16()) + return true; + + // bfloat16 -> float8 (E4M3/E5M2) + if (from_ty.is_bfloat16() && target_ty.is_float8_e8m0fnu()) + return true; + + // float32/double -> float8 (E4M3/E5M2) + if (from_ty.is_float() && target_ty.is_float8_e8m0fnu()) + return true; + // float4_e2m1fn -> float32 if (from_ty.is_float4_e2m1fn() && target_ty.is_float()) return true; diff --git a/src/tl_templates/cuda/copy.h b/src/tl_templates/cuda/copy.h index 0fa7b9d91..203ba6b5e 100644 --- a/src/tl_templates/cuda/copy.h +++ b/src/tl_templates/cuda/copy.h @@ -13,6 +13,8 @@ namespace tl { +#define TL_DEVICE __forceinline__ __device__ + TL_DEVICE void cp_async_commit() { asm volatile("cp.async.commit_group;\n" ::); } diff --git a/src/tl_templates/cuda/cuda_fp8.h b/src/tl_templates/cuda/cuda_fp8.h index c80046296..c7156fc56 100644 --- a/src/tl_templates/cuda/cuda_fp8.h +++ b/src/tl_templates/cuda/cuda_fp8.h @@ -312,3 +312,58 @@ __tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x, result.y = (float)tmp.y; return result; } + +// ============================================================================ +// FP8 E8M0 Related Conversions +// ============================================================================ +#if defined(TL_HAS_FP8_E8M0) + +// fp8_e8m0 -> bfloat16 +TL_DEVICE __nv_bfloat16 +__tl_cvt_e8m0_to_bfloat16(const __nv_fp8_storage_t src) { + __nv_bfloat16_raw raw = __nv_cvt_e8m0_to_bf16raw(src); + return *reinterpret_cast(&raw); +} + +// fp8_e8m0x2 -> bfloat16x2 +TL_DEVICE __nv_bfloat162 +__tl_cvt_e8m0x2_to_bfloat162(const __nv_fp8x2_storage_t src) { + __nv_bfloat162_raw raw = __nv_cvt_e8m0x2_to_bf162raw(src); + return *reinterpret_cast(&raw); +} + +// bfloat16 -> fp8_e8m0 +TL_DEVICE +__nv_fp8_storage_t __tl_cvt_bfloat16_to_e8m0(const __nv_bfloat16 src) { + __nv_bfloat16_raw raw = *reinterpret_cast(&src); + return __nv_cvt_bfloat16raw_to_e8m0(raw, __NV_SATFINITE, cudaRoundNearest); +} + +// bfloat162 -> fp8_e8m0x2 +TL_DEVICE __nv_fp8x2_storage_t +__tl_cvt_bfloat162_to_e8m0x2(const __nv_bfloat162 src) { + __nv_bfloat162_raw raw = *reinterpret_cast(&src); + return __nv_cvt_bfloat162raw_to_e8m0x2(raw, __NV_SATFINITE, cudaRoundNearest); +} + +// float -> fp8_e8m0 +TL_DEVICE __nv_fp8_storage_t __tl_cvt_float_to_e8m0(const float src) { + return __nv_cvt_float_to_e8m0(src, __NV_SATFINITE, cudaRoundNearest); +} + +// float2 -> fp8_e8m0x2 +TL_DEVICE __nv_fp8x2_storage_t __tl_cvt_float2_to_e8m0x2(const float2 src) { + return __nv_cvt_float2_to_e8m0x2(src, __NV_SATFINITE, cudaRoundNearest); +} + +// double -> fp8_e8m0 +TL_DEVICE __nv_fp8_storage_t __tl_cvt_double_to_e8m0(const double src) { + return __nv_cvt_double_to_e8m0(src, __NV_SATFINITE, cudaRoundNearest); +} + +// double2 -> fp8_e8m0x2 +TL_DEVICE __nv_fp8x2_storage_t __tl_cvt_double2_to_e8m0x2(const double2 src) { + return __nv_cvt_double2_to_e8m0x2(src, __NV_SATFINITE, cudaRoundNearest); +} + +#endif diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index e4684f70c..ce678236c 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -55,9 +55,14 @@ def run_vectorized_cast(src_dtype: T.dtype, dst_dtype: T.dtype, check_str: str, code = kernel.get_kernel_source() code_parallel = kernel_parallel.get_kernel_source() - print(code) + # print(code) + # print(code_parallel) assert check_str in code and check_str in code_parallel, f"Cast {src_dtype} to {dst_dtype} with {lanes=} is not vectorized!" + # Requires torch >= 2.8 + if src_dtype == T.float8_e8m0fnu or dst_dtype == T.float8_e8m0fnu: + return + if src_dtype == T.float4_e2m1fn or dst_dtype == T.float4_e2m1fn: return @@ -106,6 +111,11 @@ def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4), + # E8M0 <-> FP16 + (T.float8_e8m0fnu, T.bfloat16, "__tl_cvt_e8m0x2_to_bfloat162", 2), + (T.bfloat16, T.float8_e8m0fnu, "__tl_cvt_bfloat162_to_e8m0x2", 2), + (T.float32, T.float8_e8m0fnu, "__tl_cvt_float2_to_e8m0x2", 2), + (T.float64, T.float8_e8m0fnu, "__tl_cvt_double2_to_e8m0x2", 2), ], ) def test_vectorized_cast_fp8(src_dtype, dst_dtype, check_str, lanes): From eeeb8aa8d17cd2c4a5c5857e645ed232c8439bf9 Mon Sep 17 00:00:00 2001 From: SiriusNEO Date: Mon, 26 Jan 2026 18:04:56 +0800 Subject: [PATCH 2/3] fix --- examples/gemm/example_gemm_autotune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index d8085e2f1..daecbafe4 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -208,7 +208,7 @@ def gemm_autotune( def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False): - use_autotune = False + use_autotune = True if use_autotune: result = get_best_config(M, N, K, with_roller) print(result.config) From 0269293ebb86e67b875555ca9726cc1bb6a12e19 Mon Sep 17 00:00:00 2001 From: SiriusNEO Date: Tue, 27 Jan 2026 10:58:25 +0800 Subject: [PATCH 3/3] address comments --- examples/gemm/example_gemm_autotune.py | 1 - src/target/codegen_cuda.cc | 4 ++-- src/target/utils.cc | 7 +++++-- src/tl_templates/cuda/copy.h | 2 -- src/tl_templates/cuda/cuda_fp8.h | 2 +- .../language/test_tilelang_language_vectorized_cast.py | 6 +++--- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index daecbafe4..016d448a4 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -219,7 +219,6 @@ def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False # benchmark profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) - print(kernel.get_kernel_source()) tilelang_latency = profiler.do_bench() ref_latency = profiler.do_bench(ref_program) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 19349eb51..8c8149bd8 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1277,7 +1277,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { // Handle conversion from float to float8 (E8M0) if (from_ty.is_float() && from_ty.bits() == 32 && target_ty.is_float8_e8m0fnu()) { - // Use __nv_cvt_float2_to_e8m0x2 for vectorized conversion (float2 -> + // Use __tl_cvt_float2_to_e8m0x2 for vectorized conversion (float2 -> // fp8_e8m0x2) if (lanes == 2 || lanes == 4 || lanes == 8) { PrintVectorizedCast("__tl_cvt_float2_to_e8m0x2", "float2", @@ -1289,7 +1289,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { // Handle conversion from double to float8 (E8M0) if (from_ty.is_float() && from_ty.bits() == 64 && target_ty.is_float8_e8m0fnu()) { - // Use __nv_cvt_double2_to_e8m0x2 for vectorized conversion (double2 -> + // Use __tl_cvt_double2_to_e8m0x2 for vectorized conversion (double2 -> // fp8_e8m0x2) if (lanes == 2 || lanes == 4 || lanes == 8) { PrintVectorizedCast("__tl_cvt_double2_to_e8m0x2", "double2", diff --git a/src/target/utils.cc b/src/target/utils.cc index cdb300562..cd9aa8f72 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -149,6 +149,9 @@ int TargetGetWarpSize(Target target) { } bool IsCudaVectorizableFP8(DataType dtype) { + // NOTE: E8M0 is a special type of FP8 which is not handled here + // We only handle FP8 types which can be represented with + // __nv_fp8_interpretation_t here return dtype.is_float8_e4m3() || dtype.is_float8_e4m3fn() || dtype.is_float8_e5m2(); } @@ -182,11 +185,11 @@ bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) { if (from_ty.is_float8_e8m0fnu() && target_ty.is_bfloat16()) return true; - // bfloat16 -> float8 (E4M3/E5M2) + // bfloat16 -> float8 (E8M0) if (from_ty.is_bfloat16() && target_ty.is_float8_e8m0fnu()) return true; - // float32/double -> float8 (E4M3/E5M2) + // float32/double -> float8 (E8M0) if (from_ty.is_float() && target_ty.is_float8_e8m0fnu()) return true; diff --git a/src/tl_templates/cuda/copy.h b/src/tl_templates/cuda/copy.h index 203ba6b5e..0fa7b9d91 100644 --- a/src/tl_templates/cuda/copy.h +++ b/src/tl_templates/cuda/copy.h @@ -13,8 +13,6 @@ namespace tl { -#define TL_DEVICE __forceinline__ __device__ - TL_DEVICE void cp_async_commit() { asm volatile("cp.async.commit_group;\n" ::); } diff --git a/src/tl_templates/cuda/cuda_fp8.h b/src/tl_templates/cuda/cuda_fp8.h index c7156fc56..bbd634e62 100644 --- a/src/tl_templates/cuda/cuda_fp8.h +++ b/src/tl_templates/cuda/cuda_fp8.h @@ -316,7 +316,7 @@ __tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x, // ============================================================================ // FP8 E8M0 Related Conversions // ============================================================================ -#if defined(TL_HAS_FP8_E8M0) +#if TL_HAS_FP8_E8M0 // fp8_e8m0 -> bfloat16 TL_DEVICE __nv_bfloat16 diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index ce678236c..a7b84804b 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -55,8 +55,6 @@ def run_vectorized_cast(src_dtype: T.dtype, dst_dtype: T.dtype, check_str: str, code = kernel.get_kernel_source() code_parallel = kernel_parallel.get_kernel_source() - # print(code) - # print(code_parallel) assert check_str in code and check_str in code_parallel, f"Cast {src_dtype} to {dst_dtype} with {lanes=} is not vectorized!" # Requires torch >= 2.8 @@ -111,10 +109,12 @@ def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4), - # E8M0 <-> FP16 + # E8M0 <-> BFloat16 (T.float8_e8m0fnu, T.bfloat16, "__tl_cvt_e8m0x2_to_bfloat162", 2), (T.bfloat16, T.float8_e8m0fnu, "__tl_cvt_bfloat162_to_e8m0x2", 2), + # Float -> E8M0 (T.float32, T.float8_e8m0fnu, "__tl_cvt_float2_to_e8m0x2", 2), + # Double -> E8M0 (T.float64, T.float8_e8m0fnu, "__tl_cvt_double2_to_e8m0x2", 2), ], )