diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 1a9fb06c2..1cf642138 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1139,6 +1139,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } } + if ((from_ty.is_float8_e4m3() || from_ty.is_float8_e5m2()) && + target_ty.is_float()) { + // FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion + // (fp8x2 -> float2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // fp8x2 -> float2 + PrintIndent(); + stream << "*reinterpret_cast(&(" << sret + << ")) = " + "__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_" + "t*>(&(" + << src << ")), " + << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // fp8x4 -> float4 + PrintIndent(); + stream << "*(float2*)(&" << sret << ") = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[0], " + << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "*((float2*)(&" << sret << ")+1) = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[1], " + << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + os << sret; + return; + } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { + // fp8x8 -> float8 + PrintIndent(); + stream << "*(float2*)(&" << sret << ") = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[0], " + << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "*((float2*)(&" << sret << ")+1) = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[1], " + << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "*((float2*)(&" << sret << ")+2) = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[2], " + << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "*((float2*)(&" << sret << ")+3) = " + << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src + << "))[3], " + << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + os << sret; + return; + } + } + // Fallback: elementwise cast for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; diff --git a/src/tl_templates/cuda/cuda_fp8.h b/src/tl_templates/cuda/cuda_fp8.h index 2efb8f111..83c380c69 100644 --- a/src/tl_templates/cuda/cuda_fp8.h +++ b/src/tl_templates/cuda/cuda_fp8.h @@ -33,7 +33,7 @@ struct __CUDA_ALIGN__(32) fp8_e4_32_t { fp8_e4_16_t x; fp8_e4_16_t y; - __device__ __forceinline__ fp8_e4_32_t &operator=(const ulonglong4 &rhs) { + TL_DEVICE fp8_e4_32_t &operator=(const ulonglong4 &rhs) { x.x = *(fp8_e4_8_t *)&rhs.x; x.y = *(fp8_e4_8_t *)&rhs.y; y.x = *(fp8_e4_8_t *)&rhs.z; @@ -68,7 +68,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t { fp8_e5_16_t x; fp8_e5_16_t y; - __device__ __forceinline__ fp8_e5_32_t &operator=(const ulonglong4 &rhs) { + TL_DEVICE fp8_e5_32_t &operator=(const ulonglong4 &rhs) { x.x = *(fp8_e5_8_t *)&rhs.x; x.y = *(fp8_e5_8_t *)&rhs.y; y.x = *(fp8_e5_8_t *)&rhs.z; @@ -78,7 +78,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t { }; // Pack two fp8_e4_t values. -__forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) { +TL_DEVICE fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) { fp8_e4_2_t result; result.x = x; result.y = y; @@ -86,9 +86,8 @@ __forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) { } // Pack four fp8_e4_t values. -__forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1, - fp8_e4_t x2, - fp8_e4_t x3) { +TL_DEVICE fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, + fp8_e4_t x3) { fp8_e4_4_t result; result.x = x0; result.y = x1; @@ -98,11 +97,9 @@ __forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1, } // Pack eight fp8_e4_t values. -__forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1, - fp8_e4_t x2, fp8_e4_t x3, - fp8_e4_t x4, fp8_e4_t x5, - fp8_e4_t x6, - fp8_e4_t x7) { +TL_DEVICE fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, + fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5, + fp8_e4_t x6, fp8_e4_t x7) { fp8_e4_8_t result; result.x = make_fp8_e4_4_t(x0, x1, x2, x3); result.y = make_fp8_e4_4_t(x4, x5, x6, x7); @@ -110,11 +107,12 @@ __forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1, } // Pack sixteen fp8_e4_t values. -__forceinline__ __device__ fp8_e4_16_t -make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, - fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7, - fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, - fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7) { +TL_DEVICE fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, + fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5, + fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0, + fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, + fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, + fp8_e4_t y7) { fp8_e4_16_t result; result.x = make_fp8_e4_8_t(x0, x1, x2, x3, x4, x5, x6, x7); result.y = make_fp8_e4_8_t(y0, y1, y2, y3, y4, y5, y6, y7); @@ -122,7 +120,7 @@ make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, } // Pack thirty-two fp8_e4_t values. -__forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t( +TL_DEVICE fp8_e4_32_t make_fp8_e4_32_t( fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t x8, fp8_e4_t x9, fp8_e4_t x10, fp8_e4_t x11, fp8_e4_t x12, fp8_e4_t x13, fp8_e4_t x14, @@ -139,7 +137,7 @@ __forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t( } // Pack two fp8_e5_t values. -__forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) { +TL_DEVICE fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) { fp8_e5_2_t result; result.x = x; result.y = y; @@ -147,9 +145,8 @@ __forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) { } // Pack four fp8_e5_t values. -__forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1, - fp8_e5_t x2, - fp8_e5_t x3) { +TL_DEVICE fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, + fp8_e5_t x3) { fp8_e5_4_t result; result.x = x0; result.y = x1; @@ -159,11 +156,9 @@ __forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1, } // Pack eight fp8_e5_t values. -__forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1, - fp8_e5_t x2, fp8_e5_t x3, - fp8_e5_t x4, fp8_e5_t x5, - fp8_e5_t x6, - fp8_e5_t x7) { +TL_DEVICE fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, + fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5, + fp8_e5_t x6, fp8_e5_t x7) { fp8_e5_8_t result; result.x = make_fp8_e5_4_t(x0, x1, x2, x3); result.y = make_fp8_e5_4_t(x4, x5, x6, x7); @@ -171,11 +166,12 @@ __forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1, } // Pack sixteen fp8_e5_t values. -__forceinline__ __device__ fp8_e5_16_t -make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3, - fp8_e5_t x4, fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7, - fp8_e5_t y0, fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3, - fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, fp8_e5_t y7) { +TL_DEVICE fp8_e5_16_t make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, + fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5, + fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t y0, + fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3, + fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, + fp8_e5_t y7) { fp8_e5_16_t result; result.x = make_fp8_e5_8_t(x0, x1, x2, x3, x4, x5, x6, x7); result.y = make_fp8_e5_8_t(y0, y1, y2, y3, y4, y5, y6, y7); @@ -183,7 +179,7 @@ make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3, } // Pack thirty-two fp8_e5_t values. -__forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t( +TL_DEVICE fp8_e5_32_t make_fp8_e5_32_t( fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t x8, fp8_e5_t x9, fp8_e5_t x10, fp8_e5_t x11, fp8_e5_t x12, fp8_e5_t x13, fp8_e5_t x14, @@ -198,3 +194,14 @@ __forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t( y12, y13, y14, y15); return result; } + +// e4m3x2 -> float2 +TL_DEVICE float2 +__tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x, + const __nv_fp8_interpretation_t fp8_interpretation) { + half2 tmp = __nv_cvt_fp8x2_to_halfraw2(x, fp8_interpretation); + float2 result; + result.x = (float)tmp.x; + result.y = (float)tmp.y; + return result; +} diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index e505bc6ea..dbbdb5cce 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -20,6 +20,7 @@ #include "../op/copy.h" #include "../op/parallel.h" #include "../op/region.h" +#include "../target/utils.h" #include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h" @@ -1170,9 +1171,15 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { // If a cast operation exists, vectorization may still be required bool has_cast_operations = false; PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { - if (const auto *store = obj.as()) { + if (const auto *cast = obj.as()) { // Check if this is a non-reducer store with Cast operation - if (store->value.as()) { + DataType src_type = cast->value.dtype(); + DataType dst_type = cast->dtype; + bool src_ok = src_type.is_float() || src_type.is_bfloat() || + src_type.is_float8_e4m3() || src_type.is_float8_e5m2(); + bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() || + dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2(); + if (src_ok && dst_ok && TargetIsCuda(Target::Current())) { has_cast_operations = true; } } diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index adb59a6bd..2fd1554a8 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -60,9 +60,10 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) - A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda() - B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda() - C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda() + A_float = torch.randn(M, dtype=torch.float32, device="cuda") + A = A_float.to(str2dtype[src_dtype_str]) + B = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda") + C = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda") kernel(A, B) kernel_parallel(A, C) @@ -101,6 +102,14 @@ def test_vectorized_cast(): run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2) run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4) + # fp8_e4m3 -> fp32 + run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2) + run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4) + + # fp8_e5m2 -> fp32 + run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2) + run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4) + if __name__ == "__main__": tilelang.testing.main()