Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
}

if ((from_ty.is_float8_e4m3() || from_ty.is_float8_e5m2()) &&
target_ty.is_float()) {
// FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion
// (fp8x2 -> float2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// fp8x2 -> float2
PrintIndent();
stream << "*reinterpret_cast<float2*>(&(" << sret
<< ")) = "
"__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_"
"t*>(&("
<< src << ")), "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// fp8x4 -> float4
PrintIndent();
stream << "*(float2*)(&" << sret << ") = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[0], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+1) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[1], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// fp8x8 -> float8
PrintIndent();
stream << "*(float2*)(&" << sret << ") = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[0], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+1) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[1], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+2) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[2], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+3) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[3], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
}
}

// Fallback: elementwise cast
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
std::ostringstream val;
Expand Down
71 changes: 39 additions & 32 deletions src/tl_templates/cuda/cuda_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct __CUDA_ALIGN__(32) fp8_e4_32_t {
fp8_e4_16_t x;
fp8_e4_16_t y;

__device__ __forceinline__ fp8_e4_32_t &operator=(const ulonglong4 &rhs) {
TL_DEVICE fp8_e4_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp8_e4_8_t *)&rhs.x;
x.y = *(fp8_e4_8_t *)&rhs.y;
y.x = *(fp8_e4_8_t *)&rhs.z;
Expand Down Expand Up @@ -68,7 +68,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
fp8_e5_16_t x;
fp8_e5_16_t y;

__device__ __forceinline__ fp8_e5_32_t &operator=(const ulonglong4 &rhs) {
TL_DEVICE fp8_e5_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp8_e5_8_t *)&rhs.x;
x.y = *(fp8_e5_8_t *)&rhs.y;
y.x = *(fp8_e5_8_t *)&rhs.z;
Expand All @@ -78,17 +78,16 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
};

// Pack two fp8_e4_t values.
__forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
TL_DEVICE fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
fp8_e4_2_t result;
result.x = x;
result.y = y;
return result;
}

// Pack four fp8_e4_t values.
__forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1,
fp8_e4_t x2,
fp8_e4_t x3) {
TL_DEVICE fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
fp8_e4_t x3) {
fp8_e4_4_t result;
result.x = x0;
result.y = x1;
Expand All @@ -98,31 +97,30 @@ __forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1,
}

// Pack eight fp8_e4_t values.
__forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1,
fp8_e4_t x2, fp8_e4_t x3,
fp8_e4_t x4, fp8_e4_t x5,
fp8_e4_t x6,
fp8_e4_t x7) {
TL_DEVICE fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5,
fp8_e4_t x6, fp8_e4_t x7) {
fp8_e4_8_t result;
result.x = make_fp8_e4_4_t(x0, x1, x2, x3);
result.y = make_fp8_e4_4_t(x4, x5, x6, x7);
return result;
}

// Pack sixteen fp8_e4_t values.
__forceinline__ __device__ fp8_e4_16_t
make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7,
fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7) {
TL_DEVICE fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5,
fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0,
fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6,
fp8_e4_t y7) {
fp8_e4_16_t result;
result.x = make_fp8_e4_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
result.y = make_fp8_e4_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
return result;
}

// Pack thirty-two fp8_e4_t values.
__forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t(
TL_DEVICE fp8_e4_32_t make_fp8_e4_32_t(
fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, fp8_e4_t x4,
fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t x8, fp8_e4_t x9,
fp8_e4_t x10, fp8_e4_t x11, fp8_e4_t x12, fp8_e4_t x13, fp8_e4_t x14,
Expand All @@ -139,17 +137,16 @@ __forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t(
}

// Pack two fp8_e5_t values.
__forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
TL_DEVICE fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
fp8_e5_2_t result;
result.x = x;
result.y = y;
return result;
}

// Pack four fp8_e5_t values.
__forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1,
fp8_e5_t x2,
fp8_e5_t x3) {
TL_DEVICE fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2,
fp8_e5_t x3) {
fp8_e5_4_t result;
result.x = x0;
result.y = x1;
Expand All @@ -159,31 +156,30 @@ __forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1,
}

// Pack eight fp8_e5_t values.
__forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1,
fp8_e5_t x2, fp8_e5_t x3,
fp8_e5_t x4, fp8_e5_t x5,
fp8_e5_t x6,
fp8_e5_t x7) {
TL_DEVICE fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2,
fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5,
fp8_e5_t x6, fp8_e5_t x7) {
fp8_e5_8_t result;
result.x = make_fp8_e5_4_t(x0, x1, x2, x3);
result.y = make_fp8_e5_4_t(x4, x5, x6, x7);
return result;
}

// Pack sixteen fp8_e5_t values.
__forceinline__ __device__ fp8_e5_16_t
make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
fp8_e5_t x4, fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7,
fp8_e5_t y0, fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3,
fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, fp8_e5_t y7) {
TL_DEVICE fp8_e5_16_t make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2,
fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5,
fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t y0,
fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3,
fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6,
fp8_e5_t y7) {
fp8_e5_16_t result;
result.x = make_fp8_e5_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
result.y = make_fp8_e5_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
return result;
}

// Pack thirty-two fp8_e5_t values.
__forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t(
TL_DEVICE fp8_e5_32_t make_fp8_e5_32_t(
fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3, fp8_e5_t x4,
fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t x8, fp8_e5_t x9,
fp8_e5_t x10, fp8_e5_t x11, fp8_e5_t x12, fp8_e5_t x13, fp8_e5_t x14,
Expand All @@ -198,3 +194,14 @@ __forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t(
y12, y13, y14, y15);
return result;
}

// e4m3x2 -> float2
TL_DEVICE float2
__tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x,
const __nv_fp8_interpretation_t fp8_interpretation) {
half2 tmp = __nv_cvt_fp8x2_to_halfraw2(x, fp8_interpretation);
float2 result;
result.x = (float)tmp.x;
result.y = (float)tmp.y;
return result;
}
11 changes: 9 additions & 2 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "../op/copy.h"
#include "../op/parallel.h"
#include "../op/region.h"
#include "../target/utils.h"

#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
Expand Down Expand Up @@ -1170,9 +1171,15 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
// If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (const auto *cast = obj.as<CastNode>()) {
// Check if this is a non-reducer store with Cast operation
if (store->value.as<CastNode>()) {
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
}
Expand Down
15 changes: 12 additions & 3 deletions testing/python/language/test_tilelang_language_vectorized_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)

A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
A_float = torch.randn(M, dtype=torch.float32, device="cuda")
A = A_float.to(str2dtype[src_dtype_str])
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda")
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda")

kernel(A, B)
kernel_parallel(A, C)
Expand Down Expand Up @@ -101,6 +102,14 @@ def test_vectorized_cast():
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2)
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4)

# fp8_e4m3 -> fp32
run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2)
run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4)

# fp8_e5m2 -> fp32
run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2)
run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4)


if __name__ == "__main__":
tilelang.testing.main()
Loading