From a60139f73816ecf29692a97b38cd3ad7963b5cb6 Mon Sep 17 00:00:00 2001 From: LJC00118 <317678865@qq.com> Date: Fri, 19 Dec 2025 15:39:53 +0800 Subject: [PATCH 1/6] Refactor CUDA vectorized cast generation and remove unsupported FP8 type --- src/target/codegen_cuda.cc | 316 +++++++----------------------- src/target/utils.cc | 32 +++ src/target/utils.h | 3 + src/transform/layout_inference.cc | 11 +- 4 files changed, 108 insertions(+), 254 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 657871d8f..1211ad7a4 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -15,6 +15,7 @@ #include "../op/builtin.h" #include "./ptx.h" +#include "./utils.h" #include "arith/pattern_match.h" namespace tvm { @@ -128,10 +129,9 @@ static std::string GetTileLangFP8Type(DataType type) { << "Only support scalar and vector types of width (2, 4, 8, 16, 32) " "for FP8"; } - if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() || - type.is_float8_e4m3()) { + if (type.is_float8_e4m3() || type.is_float8_e4m3fn()) { stream << "fp8_e4" << vec << "_t"; - } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz()) { + } else if (type.is_float8_e5m2()) { stream << "fp8_e5" << vec << "_t"; } else if (type.is_float8_e8m0fnu()) { stream << "fp8_e8" << vec << "_t"; @@ -970,273 +970,95 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { stream << ' ' << sret << ";\n"; std::string src = SSAGetID(PrintExpr(op->value), from_ty); - // Handle conversion between float16 and float32 + int lanes = from_ty.lanes(); + + auto generate_vector_conversion = + [&](const std::string &cast_func, const std::string &src_type, + const std::string &dst_type, const std::string &extra_args = "", + bool src_needs_reinterpret = false, + bool dst_needs_reinterpret = false) { + int num_chunks = lanes / 2; + std::string src_cast = src_needs_reinterpret + ? "reinterpret_cast<" + src_type + "*>" + : "(" + src_type + "*)"; + std::string dst_cast = dst_needs_reinterpret + ? "reinterpret_cast<" + dst_type + "*>" + : "(" + dst_type + "*)"; + + for (int i = 0; i < num_chunks; i++) { + PrintIndent(); + stream << "(" << dst_cast << "(&" << sret << "))[" << i + << "] = " << cast_func << "((" << src_cast << "(&" << src + << "))[" << i << "]" << extra_args << ");\n"; + } + os << sret; + }; + + // Handle conversion from float16 to float32 if (from_ty.is_float16() && target_ty.is_float()) { // Use __half22float2 for vectorized conversion (half2 -> float2) - if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { - // half2 -> float2 - PrintIndent(); - stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n"; - os << sret; - return; - } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { - // half4 -> float4 - PrintIndent(); - stream << "((float2*)(&" << sret << "))[0] = " - << "__half22float2(*(half2*)(&(" << src << ")));\n"; - PrintIndent(); - stream << "((float2*)(&" << sret << "))[1] = " - << "__half22float2(*((half2*)(&(" << src << "))+1));\n"; - os << sret; - return; - } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { - // half8 -> float8 - PrintIndent(); - stream << "((float2*)(&" << sret << "))[0] = " - << "__half22float2(*(half2*)(&(" << src << ")));\n"; - PrintIndent(); - stream << "((float2*)(&" << sret << "))[1] = " - << "__half22float2(*((half2*)(&(" << src << "))+1));\n"; - PrintIndent(); - stream << "((float2*)(&" << sret << "))[2] = " - << "__half22float2(*((half2*)(&(" << src << "))+2));\n"; - PrintIndent(); - stream << "((float2*)(&" << sret << "))[3] = " - << "__half22float2(*((half2*)(&(" << src << "))+3));\n"; - os << sret; + if (lanes == 2 || lanes == 4 || lanes == 8) { + generate_vector_conversion("__half22float2", "half2", "float2"); return; } - } else if (from_ty.is_float() && target_ty.is_float16()) { + } + + // Handle conversion from float32 to float16 + if (from_ty.is_float() && target_ty.is_float16()) { // Use __float22half2_rn for vectorized conversion (float2 -> half2) - if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { - // float2 -> half2 - PrintIndent(); - stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&(" - << src << ")));\n"; - os << sret; - return; - } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { - // float4 -> half4 - PrintIndent(); - stream << "((half2*)(&" << sret << "))[0] = " - << "__float22half2_rn(*(float2*)(&(" << src << ")));\n"; - PrintIndent(); - stream << "((half2*)(&" << sret << "))[1] = " - << "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n"; - os << sret; - return; - } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { - // float8 -> half8 - PrintIndent(); - stream << "((half2*)(&" << sret << "))[0] = " - << "__float22half2_rn(*(float2*)(&(" << src << ")));\n"; - PrintIndent(); - stream << "((half2*)(&" << sret << "))[1] = " - << "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n"; - PrintIndent(); - stream << "((half2*)(&" << sret << "))[2] = " - << "__float22half2_rn(*((float2*)(&(" << src << "))+2));\n"; - PrintIndent(); - stream << "((half2*)(&" << sret << "))[3] = " - << "__float22half2_rn(*((float2*)(&(" << src << "))+3));\n"; - os << sret; + if (lanes == 2 || lanes == 4 || lanes == 8) { + generate_vector_conversion("__float22half2_rn", "float2", "half2"); return; } } - // Handle conversion between bfloat16 and float32 + // Handle conversion from bfloat16 to float32 if (from_ty.is_bfloat16() && target_ty.is_float()) { // Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2) - if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { - // bfloat162 -> float2 - PrintIndent(); - stream << sret - << " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" - << src << ")));\n"; - os << sret; - return; - } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { - // bfloat162x2 -> float4 - PrintIndent(); - stream << "((float2*)(&" << sret << "))[0] = " - << "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" - << src << ")));\n"; - PrintIndent(); - stream << "((float2*)(&" << sret << "))[1] = " - << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" - << src << "))+1));\n"; - os << sret; - return; - } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { - // bfloat162x4 -> float8 - PrintIndent(); - stream << "((float2*)(&" << sret << "))[0] = " - << "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" - << src << ")));\n"; - PrintIndent(); - stream << "((float2*)(&" << sret << "))[1] = " - << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" - << src << "))+1));\n"; - PrintIndent(); - stream << "((float2*)(&" << sret << "))[2] = " - << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" - << src << "))+2));\n"; - PrintIndent(); - stream << "((float2*)(&" << sret << "))[3] = " - << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" - << src << "))+3));\n"; - os << sret; + if (lanes == 2 || lanes == 4 || lanes == 8) { + generate_vector_conversion("__bfloat1622float2", "__nv_bfloat162", + "float2", "", true, false); return; } - } else if (from_ty.is_float() && target_ty.is_bfloat16()) { + } + + // Handle conversion from float32 to bfloat16 + if (from_ty.is_float() && target_ty.is_bfloat16()) { // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162) - if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { - // float2 -> bfloat162 - PrintIndent(); - stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret - << ")) = __float22bfloat162_rn(*(float2*)(&(" << src << ")));\n"; - os << sret; - return; - } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { - // float4 -> bfloat162x2 - PrintIndent(); - stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = " - << "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n"; - PrintIndent(); - stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = " - << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n"; - os << sret; - return; - } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { - // float8 -> bfloat162x4 - PrintIndent(); - stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = " - << "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n"; - PrintIndent(); - stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = " - << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n"; - PrintIndent(); - stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[2] = " - << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+2));\n"; - PrintIndent(); - stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[3] = " - << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+3));\n"; - os << sret; + if (lanes == 2 || lanes == 4 || lanes == 8) { + generate_vector_conversion("__float22bfloat162_rn", "float2", + "__nv_bfloat162", "", false, true); return; } } // Handle conversion from float32 to float8 (E4M3/E5M2) - if (from_ty.is_float() && (target_ty.is_float8())) { - bool target_type_is_e4m3 = target_ty.is_float8_e4m3() || - target_ty.is_float8_e4m3fn() || - target_ty.is_float8_e4m3fnuz(); - // FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion - // (float2 -> fp8x2) - if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { - // float2 -> fp8x2 - PrintIndent(); - stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret - << ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast(&(" - << src << ")), __NV_SATFINITE, " - << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; - os << sret; - return; - } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { - // float4 -> fp8x4 - PrintIndent(); - stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = " - << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src - << ")), __NV_SATFINITE, " - << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; - PrintIndent(); - stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = " - << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src - << "))+1), __NV_SATFINITE, " - << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; - os << sret; - return; - } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { - // float8 -> fp8x8 - PrintIndent(); - stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = " - << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src - << ")), __NV_SATFINITE, " - << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; - PrintIndent(); - stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = " - << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src - << "))+1), __NV_SATFINITE, " - << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; - PrintIndent(); - stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = " - << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src - << "))+2), __NV_SATFINITE, " - << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; - PrintIndent(); - stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = " - << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src - << "))+3), __NV_SATFINITE, " - << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; - os << sret; + if (from_ty.is_float() && tl::IsCudaVectorizableFP8(target_ty)) { + bool target_type_is_e4m3 = + target_ty.is_float8_e4m3() || target_ty.is_float8_e4m3fn(); + std::string type_suffix = target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2"; + + // Use __nv_cvt_float2_to_fp8x2 for vectorized conversion (float2 -> fp8x2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + std::string extra_args = ", __NV_SATFINITE, " + type_suffix; + generate_vector_conversion("__nv_cvt_float2_to_fp8x2", "float2", + "__nv_fp8x2_storage_t", extra_args, false, + true); return; } } - if (from_ty.is_float8() && target_ty.is_float()) { - bool from_type_is_e4m3 = from_ty.is_float8_e4m3() || - from_ty.is_float8_e4m3fn() || - from_ty.is_float8_e4m3fnuz(); - // FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion - // (fp8x2 -> float2) - if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { - // fp8x2 -> float2 - PrintIndent(); - stream << "*reinterpret_cast(&(" << sret - << ")) = " - "__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_" - "t*>(&(" - << src << ")), " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; - os << sret; - return; - } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { - // fp8x4 -> float4 - PrintIndent(); - stream << "*(float2*)(&" << sret << ") = " - << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; - PrintIndent(); - stream << "*((float2*)(&" << sret << ")+1) = " - << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; - os << sret; - return; - } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { - // fp8x8 -> float8 - PrintIndent(); - stream << "*(float2*)(&" << sret << ") = " - << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; - PrintIndent(); - stream << "*((float2*)(&" << sret << ")+1) = " - << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; - PrintIndent(); - stream << "*((float2*)(&" << sret << ")+2) = " - << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[2], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; - PrintIndent(); - stream << "*((float2*)(&" << sret << ")+3) = " - << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[3], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; - os << sret; + // Handle conversion from float8 (E4M3/E5M2) to float32 + if (tl::IsCudaVectorizableFP8(from_ty) && target_ty.is_float()) { + bool from_type_is_e4m3 = + from_ty.is_float8_e4m3() || from_ty.is_float8_e4m3fn(); + std::string type_suffix = from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2"; + + // Use __tl_cvt_fp8x2_to_float2 for vectorized conversion (fp8x2 -> float2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + generate_vector_conversion("__tl_cvt_fp8x2_to_float2", + "__nv_fp8x2_storage_t", "float2", + ", " + type_suffix, true, false); return; } } diff --git a/src/target/utils.cc b/src/target/utils.cc index b69e3dd4c..66d32079d 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -134,6 +134,38 @@ int TargetGetWarpSize(Target target) { return res; } +bool IsCudaVectorizableFP8(DataType dtype) { + return dtype.is_float8_e4m3() || dtype.is_float8_e4m3fn() || + dtype.is_float8_e5m2(); +} + +bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) { + // float16 -> float32 + if (from_ty.is_float16() && target_ty.is_float()) + return true; + + // float32 -> float16 + if (from_ty.is_float() && target_ty.is_float16()) + return true; + + // bfloat16 -> float32 + if (from_ty.is_bfloat16() && target_ty.is_float()) + return true; + + // float32 -> bfloat16 + if (from_ty.is_float() && target_ty.is_bfloat16()) + return true; + + // float32 -> float8 (E4M3/E5M2) + if (from_ty.is_float() && IsCudaVectorizableFP8(target_ty)) + return true; + + // float8 (E4M3/E5M2) -> float32 + if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float()) + return true; + return false; +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() diff --git a/src/target/utils.h b/src/target/utils.h index bfd88281c..9de2d4d4f 100644 --- a/src/target/utils.h +++ b/src/target/utils.h @@ -30,6 +30,9 @@ bool TargetHasTmem(Target target); bool TargetHasBulkCopy(Target target); int TargetGetWarpSize(Target target); +bool IsCudaVectorizableFP8(DataType dtype); +bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty); + } // namespace tl } // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index b44824aff..de02519ee 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -1177,13 +1177,10 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { if (const auto *cast = obj.as()) { // Check if this is a non-reducer store with Cast operation - DataType src_type = cast->value.dtype(); - DataType dst_type = cast->dtype; - bool src_ok = - src_type.is_float() || src_type.is_bfloat() || src_type.is_float8(); - bool dst_ok = - dst_type.is_float() || dst_type.is_bfloat() || dst_type.is_float8(); - if (src_ok && dst_ok && TargetIsCuda(Target::Current())) { + DataType from_ty = cast->value.dtype(); + DataType target_ty = cast->dtype; + if (IsCudaVectorizableCast(from_ty, target_ty) && + TargetIsCuda(Target::Current())) { has_cast_operations = true; } } From bffedae0b0e4e3eaa379acd1df74aa86615f1ce6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Dec 2025 17:36:21 +0800 Subject: [PATCH 2/6] test fix --- testing/python/debug/test_tilelang_debug_print.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index 3483cffc0..ad94ed419 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -12,13 +12,11 @@ def program(Q: T.Tensor((M, N), dtype)): shared_buf = T.alloc_shared([M, N], dtype) T.print(shared_buf) - jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi") + jit_kernel = tilelang.compile(program) profiler = jit_kernel.get_profiler() profiler.run_once() - def test_debug_print_buffer(): - debug_print_buffer(dtype=T.bool) debug_print_buffer(dtype=T.int8) debug_print_buffer(dtype=T.int16) debug_print_buffer(dtype=T.int32) @@ -31,13 +29,18 @@ def test_debug_print_buffer(): debug_print_buffer(dtype=T.float32) debug_print_buffer(dtype=T.float64) debug_print_buffer(dtype=T.bfloat16) + +@tilelang.testing.requires_cuda +def test_debug_print_buffer_cuda_fp8(): debug_print_buffer(dtype=T.float8_e4m3fn) - debug_print_buffer(dtype=T.float8_e4m3fn) - debug_print_buffer(dtype=T.float8_e4m3fnuz) debug_print_buffer(dtype=T.float8_e5m2) - debug_print_buffer(dtype=T.float8_e5m2fnuz) +@tilelang.testing.requires_rocm +def test_debug_print_buffer_rocm_fp8(): + debug_print_buffer(dtype=T.float8_e4m3fnuz) + debug_print_buffer(dtype=T.float8_e5m2fnuz) + def debug_print_buffer_conditional(M=16, N=16): dtype = T.float16 From 7b76aafe39977e25a27ab65e8435b8116fd43b04 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Dec 2025 17:36:50 +0800 Subject: [PATCH 3/6] lint fix --- testing/python/debug/test_tilelang_debug_print.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index ad94ed419..735eb3e80 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -16,6 +16,7 @@ def program(Q: T.Tensor((M, N), dtype)): profiler = jit_kernel.get_profiler() profiler.run_once() + def test_debug_print_buffer(): debug_print_buffer(dtype=T.int8) debug_print_buffer(dtype=T.int16) @@ -30,6 +31,7 @@ def test_debug_print_buffer(): debug_print_buffer(dtype=T.float64) debug_print_buffer(dtype=T.bfloat16) + @tilelang.testing.requires_cuda def test_debug_print_buffer_cuda_fp8(): debug_print_buffer(dtype=T.float8_e4m3fn) @@ -41,6 +43,7 @@ def test_debug_print_buffer_rocm_fp8(): debug_print_buffer(dtype=T.float8_e4m3fnuz) debug_print_buffer(dtype=T.float8_e5m2fnuz) + def debug_print_buffer_conditional(M=16, N=16): dtype = T.float16 From 0e264c94e6d4a6de8382088255b8730bb76321b3 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Dec 2025 17:55:35 +0800 Subject: [PATCH 4/6] Refactor CUDA vectorized cast function naming for clarity --- src/target/codegen_cuda.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 1211ad7a4..576b25888 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -972,7 +972,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { int lanes = from_ty.lanes(); - auto generate_vector_conversion = + auto PrintVectorizedCast = [&](const std::string &cast_func, const std::string &src_type, const std::string &dst_type, const std::string &extra_args = "", bool src_needs_reinterpret = false, @@ -998,7 +998,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { if (from_ty.is_float16() && target_ty.is_float()) { // Use __half22float2 for vectorized conversion (half2 -> float2) if (lanes == 2 || lanes == 4 || lanes == 8) { - generate_vector_conversion("__half22float2", "half2", "float2"); + PrintVectorizedCast("__half22float2", "half2", "float2"); return; } } @@ -1007,7 +1007,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { if (from_ty.is_float() && target_ty.is_float16()) { // Use __float22half2_rn for vectorized conversion (float2 -> half2) if (lanes == 2 || lanes == 4 || lanes == 8) { - generate_vector_conversion("__float22half2_rn", "float2", "half2"); + PrintVectorizedCast("__float22half2_rn", "float2", "half2"); return; } } @@ -1016,7 +1016,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { if (from_ty.is_bfloat16() && target_ty.is_float()) { // Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2) if (lanes == 2 || lanes == 4 || lanes == 8) { - generate_vector_conversion("__bfloat1622float2", "__nv_bfloat162", + PrintVectorizedCast("__bfloat1622float2", "__nv_bfloat162", "float2", "", true, false); return; } @@ -1026,7 +1026,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { if (from_ty.is_float() && target_ty.is_bfloat16()) { // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162) if (lanes == 2 || lanes == 4 || lanes == 8) { - generate_vector_conversion("__float22bfloat162_rn", "float2", + PrintVectorizedCast("__float22bfloat162_rn", "float2", "__nv_bfloat162", "", false, true); return; } @@ -1041,7 +1041,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { // Use __nv_cvt_float2_to_fp8x2 for vectorized conversion (float2 -> fp8x2) if (lanes == 2 || lanes == 4 || lanes == 8) { std::string extra_args = ", __NV_SATFINITE, " + type_suffix; - generate_vector_conversion("__nv_cvt_float2_to_fp8x2", "float2", + PrintVectorizedCast("__nv_cvt_float2_to_fp8x2", "float2", "__nv_fp8x2_storage_t", extra_args, false, true); return; @@ -1056,7 +1056,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { // Use __tl_cvt_fp8x2_to_float2 for vectorized conversion (fp8x2 -> float2) if (lanes == 2 || lanes == 4 || lanes == 8) { - generate_vector_conversion("__tl_cvt_fp8x2_to_float2", + PrintVectorizedCast("__tl_cvt_fp8x2_to_float2", "__nv_fp8x2_storage_t", "float2", ", " + type_suffix, true, false); return; From e344863321c83fc17dd4cc50e29fb951d73a1b68 Mon Sep 17 00:00:00 2001 From: Zhiwen Mo Date: Mon, 22 Dec 2025 20:44:18 +0800 Subject: [PATCH 5/6] Add support for float4_e2m1fn type conversions in CUDA vectorized casts - Implemented conversions between float4_e2m1fn and float32, half2, and float2 in utils.cc and cuda_fp4.h. - Updated test_tilelang_language_vectorized_cast.py to validate new conversions and ensure correctness. - Enhanced dtype conversion in dtypes.py to handle float4_e2m1fn appropriately, logging a warning for unsupported types in PyTorch. --- src/target/codegen_cuda.cc | 56 ++++++++++++++++--- src/target/utils.cc | 9 +++ src/tl_templates/cuda/cuda_fp4.h | 48 ++++++++++++++++ .../test_tilelang_language_vectorized_cast.py | 47 ++++++++-------- tilelang/language/v2/dtypes.py | 8 ++- 5 files changed, 134 insertions(+), 34 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 576b25888..517e12094 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1016,8 +1016,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { if (from_ty.is_bfloat16() && target_ty.is_float()) { // Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2) if (lanes == 2 || lanes == 4 || lanes == 8) { - PrintVectorizedCast("__bfloat1622float2", "__nv_bfloat162", - "float2", "", true, false); + PrintVectorizedCast("__bfloat1622float2", "__nv_bfloat162", "float2", "", + true, false); return; } } @@ -1026,8 +1026,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { if (from_ty.is_float() && target_ty.is_bfloat16()) { // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162) if (lanes == 2 || lanes == 4 || lanes == 8) { - PrintVectorizedCast("__float22bfloat162_rn", "float2", - "__nv_bfloat162", "", false, true); + PrintVectorizedCast("__float22bfloat162_rn", "float2", "__nv_bfloat162", + "", false, true); return; } } @@ -1042,8 +1042,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { if (lanes == 2 || lanes == 4 || lanes == 8) { std::string extra_args = ", __NV_SATFINITE, " + type_suffix; PrintVectorizedCast("__nv_cvt_float2_to_fp8x2", "float2", - "__nv_fp8x2_storage_t", extra_args, false, - true); + "__nv_fp8x2_storage_t", extra_args, false, true); return; } } @@ -1056,9 +1055,48 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { // Use __tl_cvt_fp8x2_to_float2 for vectorized conversion (fp8x2 -> float2) if (lanes == 2 || lanes == 4 || lanes == 8) { - PrintVectorizedCast("__tl_cvt_fp8x2_to_float2", - "__nv_fp8x2_storage_t", "float2", - ", " + type_suffix, true, false); + PrintVectorizedCast("__tl_cvt_fp8x2_to_float2", "__nv_fp8x2_storage_t", + "float2", ", " + type_suffix, true, false); + return; + } + } + + // Handle conversion from float16 to float4 (E2M1) + if (from_ty.is_float16() && target_ty.is_float4_e2m1fn()) { + // Use __tl_cvt_half2_to_fp4x2 for vectorized conversion (half2 -> fp4x2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_half2_to_fp4x2", "half2", "uint8_t", "", + false, true); + return; + } + } + + // Handle conversion from float32 to float4 (E2M1) + if (from_ty.is_float() && target_ty.is_float4_e2m1fn()) { + // Use __tl_cvt_float2_to_fp4x2 for vectorized conversion (float2 -> fp4x2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_float2_to_fp4x2", "float2", "uint8_t", "", + false, true); + return; + } + } + + // Handle conversion from float4 (E2M1) to float16 + if (from_ty.is_float4_e2m1fn() && target_ty.is_float16()) { + // Use __tl_cvt_fp4x2_to_half2 for vectorized conversion (fp4x2 -> half2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_fp4x2_to_half2", "uint8_t", "half2", "", + true, false); + return; + } + } + + // Handle conversion from float4 (E2M1) to float32 + if (from_ty.is_float4_e2m1fn() && target_ty.is_float()) { + // Use __tl_cvt_fp4x2_to_float2 for vectorized conversion (fp4x2 -> float2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_fp4x2_to_float2", "uint8_t", "float2", "", + true, false); return; } } diff --git a/src/target/utils.cc b/src/target/utils.cc index 66d32079d..993590ffb 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -163,6 +163,15 @@ bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) { // float8 (E4M3/E5M2) -> float32 if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float()) return true; + + // float4_e2m1fn -> float32 + if (from_ty.is_float4_e2m1fn() && target_ty.is_float()) + return true; + + // float32 -> float4_e2m1fn + if (from_ty.is_float() && target_ty.is_float4_e2m1fn()) + return true; + return false; } diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index e3f56622f..b76246442 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -154,4 +154,52 @@ TL_DEVICE fp4_e2_32_t make_fp4_e2_32_t( return result; } +// fp4_e2m1x2 (1 byte) -> half2 +// Uses PTX cvt.rn.f16x2.e2m1x2 instruction +TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const uint8_t src) { + half2 out; + uint32_t *out_ptr = reinterpret_cast(&out); + uint16_t src_packed = static_cast(src); + asm volatile("{\n" + ".reg .b8 byte0, byte1;\n" + "mov.b16 {byte0, byte1}, %1;\n" + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" + "}\n" + : "=r"(*out_ptr) + : "h"(src_packed)); + return out; +} + +// fp4_e2m1x2 (1 byte) -> float2 +TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const uint8_t src) { + half2 tmp = __tl_cvt_fp4x2_to_half2(src); + float2 result; + result.x = __half2float(tmp.x); + result.y = __half2float(tmp.y); + return result; +} + +// half2 -> fp4_e2m1x2 (1 byte) +// Uses PTX cvt.rn.satfinite.e2m1x2.f16x2 instruction +TL_DEVICE uint8_t __tl_cvt_half2_to_fp4x2(const half2 src) { + uint16_t out; + uint32_t const *src_ptr = reinterpret_cast(&src); + asm volatile("{\n" + ".reg .b8 result_byte;\n" + "cvt.rn.satfinite.e2m1x2.f16x2 result_byte, %1;\n" + "mov.b16 %0, {result_byte, 0};\n" + "}\n" + : "=h"(out) + : "r"(*src_ptr)); + return static_cast(out); +} + +// float2 -> fp4_e2m1x2 (1 byte) +TL_DEVICE uint8_t __tl_cvt_float2_to_fp4x2(const float2 src) { + half2 tmp; + tmp.x = __float2half(src.x); + tmp.y = __float2half(src.y); + return __tl_cvt_half2_to_fp4x2(tmp); +} + #endif diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index 1a0a0942a..f4f28fd30 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -3,14 +3,6 @@ import tilelang.testing import tilelang.language as T -str2dtype = { - T.float32: torch.float32, - T.float16: torch.float16, - T.bfloat16: torch.bfloat16, - T.float8_e4m3fn: torch.float8_e4m3fn, - T.float8_e5m2: torch.float8_e5m2, -} - @tilelang.jit def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): @@ -48,34 +40,39 @@ def main( return main -def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2): +def run_vectorized_cast(src_dtype: T.dtype, dst_dtype: T.dtype, check_str: str, lanes: int = 2): """Run the vectorized cast kernel and check the correctness. Args: - src_dtype_str: The source data type string. - dst_dtype_str: The destination data type string. + src_dtype: The source data type. + dst_dtype: The destination data type. check_str: Used to ensure vectorized cast is used. lanes: The number of lanes of the source and destination data types. """ M = 128 * lanes - kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) - kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) + kernel = vectorized_cast_kernel(M, src_dtype, dst_dtype) + kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype, dst_dtype) + + code = kernel.get_kernel_source() + code_parallel = kernel_parallel.get_kernel_source() + print(code) + assert check_str in code and check_str in code_parallel, f"Cast {src_dtype} to {dst_dtype} with {lanes=} is not vectorized!" + + if src_dtype == T.float4_e2m1fn or dst_dtype == T.float4_e2m1fn: + return A_float = torch.randn(M, dtype=torch.float32, device="cuda") - A = A_float.to(str2dtype[src_dtype_str]) - B = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda") - C = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda") + A = A_float.to(src_dtype.as_torch()) + + A = A_float.to(src_dtype.as_torch()) + B = torch.zeros(M, dtype=dst_dtype.as_torch(), device="cuda") + C = torch.zeros(M, dtype=dst_dtype.as_torch(), device="cuda") kernel(A, B) kernel_parallel(A, C) - torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B) - torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C) - - code = kernel.get_kernel_source() - code_parallel = kernel_parallel.get_kernel_source() - - assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" + torch.testing.assert_close(A.to(dst_dtype.as_torch()), B) + torch.testing.assert_close(A.to(dst_dtype.as_torch()), C) @pytest.mark.parametrize( @@ -97,6 +94,10 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4), + (T.float4_e2m1fn, T.float16, "__tl_cvt_fp4x2_to_half2", 2), + (T.float16, T.float4_e2m1fn, "__tl_cvt_half2_to_fp4x2", 2), + (T.float4_e2m1fn, T.float32, "__tl_cvt_fp4x2_to_float2", 2), + (T.float32, T.float4_e2m1fn, "__tl_cvt_float2_to_fp4x2", 2), ], ) def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index a42ba5a67..1649da6e4 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -5,6 +5,7 @@ from tvm import tir import tvm.script.ir_builder.tir._ffi_api as tb_ffi import numpy as np +from tilelang import logger _T = TypeVar("_T") @@ -175,7 +176,7 @@ def __dtype_as_torch__(self: dtype) -> torch.dtype: elif dtype_str == "float8_e5m2": assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torch. Please upgrade torch >= 2.1.0" return torch.float8_e5m2 - elif dtype_str == "e4m3fnuz_float8": + elif dtype_str == "float8_e4m3fnuz": assert hasattr(torch, "float8_e4m3fnuz"), ( "torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0" ) @@ -189,7 +190,10 @@ def __dtype_as_torch__(self: dtype) -> torch.dtype: assert hasattr(torch, "float4_e2m1fnx2"), ( "torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0" ) - return torch.float4_e2m1fnx2 + return torch.float4_e2m1fn_x2 + elif dtype_str == "float4_e2m1fn": + logger.info("torch doesn't support float4_e2m1fn, using float4_e2m1fnx2 as storage dtype.") + return torch.float4_e2m1fn_x2 elif dtype_str in _STR_TO_TORCH_DTYPE: return _STR_TO_TORCH_DTYPE[dtype_str] From e68e073c5fb634a285c4fbb76785115d900c01cd Mon Sep 17 00:00:00 2001 From: Zhiwen Mo Date: Mon, 22 Dec 2025 21:20:34 +0800 Subject: [PATCH 6/6] Enhance vectorized cast tests for new data types - Added tests for vectorized casting of float8 and float4 data types, ensuring compatibility with CUDA compute versions. - Refactored existing test functions to improve clarity and organization, separating tests for different data types. - Updated parameterization to include additional test cases for new conversions. --- .../test_tilelang_language_vectorized_cast.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index f4f28fd30..991b2a8eb 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -90,17 +90,39 @@ def run_vectorized_cast(src_dtype: T.dtype, dst_dtype: T.dtype, check_str: str, (T.float32, T.bfloat16, "__float22bfloat162_rn", 4), (T.bfloat16, T.float32, "__bfloat1622float2", 2), (T.bfloat16, T.float32, "__bfloat1622float2", 4), + ], +) +def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): + run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(8, 9) +@pytest.mark.parametrize( + "src_dtype, dst_dtype, check_str, lanes", + [ (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 2), (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4), + ], +) +def test_vectorized_cast_fp8(src_dtype, dst_dtype, check_str, lanes): + run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(10, 0) +@pytest.mark.parametrize( + "src_dtype, dst_dtype, check_str, lanes", + [ (T.float4_e2m1fn, T.float16, "__tl_cvt_fp4x2_to_half2", 2), (T.float16, T.float4_e2m1fn, "__tl_cvt_half2_to_fp4x2", 2), (T.float4_e2m1fn, T.float32, "__tl_cvt_fp4x2_to_float2", 2), (T.float32, T.float4_e2m1fn, "__tl_cvt_float2_to_fp4x2", 2), ], ) -def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): +def test_vectorized_cast_fp4(src_dtype, dst_dtype, check_str, lanes): run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes)