Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
os << sret;
};

// A list of casting functions that are supported by TileLang templates.
// To add a new type conversion, you should do the following things:
// 1. Add the new conversion function in tl_templates. (__tl_cvt_xx)
// 2. Add a new if statement like the one below.
// 3. In src/target/utils.cc, allow this vectorizable cast.

// Handle conversion from float16 to float32
if (from_ty.is_float16() && target_ty.is_float() && target_ty.bits() == 32) {
// Use __half22float2 for vectorized conversion (half2 -> float2)
Expand Down Expand Up @@ -1245,6 +1251,53 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
}

// Handle conversion from float8 (E8M0) to bfloat16
if (from_ty.is_float8_e8m0fnu() && target_ty.is_bfloat16()) {
// Use __tl_cvt_e8m0x2_to_bfloat162 for vectorized conversion (fp8_e8m0x2 ->
// bfloat162)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_e8m0x2_to_bfloat162",
"__nv_fp8x2_storage_t", "__nv_bfloat162", "", true,
false);
return;
}
}

// Handle conversion from bfloat16 to float8 (E8M0)
if (from_ty.is_bfloat16() && target_ty.is_float8_e8m0fnu()) {
// Use __tl_cvt_bfloat162_to_e8m0x2 for vectorized conversion (bfloat162 ->
// fp8_e8m0x2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_bfloat162_to_e8m0x2", "__nv_bfloat162",
"__nv_fp8x2_storage_t", "", false, true);
return;
}
}

// Handle conversion from float to float8 (E8M0)
if (from_ty.is_float() && from_ty.bits() == 32 &&
target_ty.is_float8_e8m0fnu()) {
// Use __tl_cvt_float2_to_e8m0x2 for vectorized conversion (float2 ->
// fp8_e8m0x2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_float2_to_e8m0x2", "float2",
"__nv_fp8x2_storage_t", "", false, true);
return;
}
}

// Handle conversion from double to float8 (E8M0)
if (from_ty.is_float() && from_ty.bits() == 64 &&
target_ty.is_float8_e8m0fnu()) {
// Use __tl_cvt_double2_to_e8m0x2 for vectorized conversion (double2 ->
// fp8_e8m0x2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_double2_to_e8m0x2", "double2",
"__nv_fp8x2_storage_t", "", false, true);
return;
}
}

// Handle conversion from float16 to float4 (E2M1)
if (from_ty.is_float16() && target_ty.is_float4_e2m1fn()) {
// Use __tl_cvt_half2_to_fp4x2 for vectorized conversion (half2 -> fp4x2)
Expand Down
15 changes: 15 additions & 0 deletions src/target/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ int TargetGetWarpSize(Target target) {
}

bool IsCudaVectorizableFP8(DataType dtype) {
// NOTE: E8M0 is a special type of FP8 which is not handled here
// We only handle FP8 types which can be represented with
// __nv_fp8_interpretation_t here
return dtype.is_float8_e4m3() || dtype.is_float8_e4m3fn() ||
dtype.is_float8_e5m2();
}
Expand Down Expand Up @@ -178,6 +181,18 @@ bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) {
if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float())
return true;

// float8 (E8M0) -> bfloat16
if (from_ty.is_float8_e8m0fnu() && target_ty.is_bfloat16())
return true;

// bfloat16 -> float8 (E8M0)
if (from_ty.is_bfloat16() && target_ty.is_float8_e8m0fnu())
return true;

// float32/double -> float8 (E8M0)
if (from_ty.is_float() && target_ty.is_float8_e8m0fnu())
return true;

// float4_e2m1fn -> float32
if (from_ty.is_float4_e2m1fn() && target_ty.is_float())
return true;
Expand Down
55 changes: 55 additions & 0 deletions src/tl_templates/cuda/cuda_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,58 @@ __tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x,
result.y = (float)tmp.y;
return result;
}

// ============================================================================
// FP8 E8M0 Related Conversions
// ============================================================================
#if TL_HAS_FP8_E8M0

// fp8_e8m0 -> bfloat16
TL_DEVICE __nv_bfloat16
__tl_cvt_e8m0_to_bfloat16(const __nv_fp8_storage_t src) {
__nv_bfloat16_raw raw = __nv_cvt_e8m0_to_bf16raw(src);
return *reinterpret_cast<const __nv_bfloat16 *>(&raw);
}

// fp8_e8m0x2 -> bfloat16x2
TL_DEVICE __nv_bfloat162
__tl_cvt_e8m0x2_to_bfloat162(const __nv_fp8x2_storage_t src) {
__nv_bfloat162_raw raw = __nv_cvt_e8m0x2_to_bf162raw(src);
return *reinterpret_cast<const __nv_bfloat162 *>(&raw);
}

// bfloat16 -> fp8_e8m0
TL_DEVICE
__nv_fp8_storage_t __tl_cvt_bfloat16_to_e8m0(const __nv_bfloat16 src) {
__nv_bfloat16_raw raw = *reinterpret_cast<const __nv_bfloat16_raw *>(&src);
return __nv_cvt_bfloat16raw_to_e8m0(raw, __NV_SATFINITE, cudaRoundNearest);
}

// bfloat162 -> fp8_e8m0x2
TL_DEVICE __nv_fp8x2_storage_t
__tl_cvt_bfloat162_to_e8m0x2(const __nv_bfloat162 src) {
__nv_bfloat162_raw raw = *reinterpret_cast<const __nv_bfloat162_raw *>(&src);
return __nv_cvt_bfloat162raw_to_e8m0x2(raw, __NV_SATFINITE, cudaRoundNearest);
}

// float -> fp8_e8m0
TL_DEVICE __nv_fp8_storage_t __tl_cvt_float_to_e8m0(const float src) {
return __nv_cvt_float_to_e8m0(src, __NV_SATFINITE, cudaRoundNearest);
}

// float2 -> fp8_e8m0x2
TL_DEVICE __nv_fp8x2_storage_t __tl_cvt_float2_to_e8m0x2(const float2 src) {
return __nv_cvt_float2_to_e8m0x2(src, __NV_SATFINITE, cudaRoundNearest);
}

// double -> fp8_e8m0
TL_DEVICE __nv_fp8_storage_t __tl_cvt_double_to_e8m0(const double src) {
return __nv_cvt_double_to_e8m0(src, __NV_SATFINITE, cudaRoundNearest);
}

// double2 -> fp8_e8m0x2
TL_DEVICE __nv_fp8x2_storage_t __tl_cvt_double2_to_e8m0x2(const double2 src) {
return __nv_cvt_double2_to_e8m0x2(src, __NV_SATFINITE, cudaRoundNearest);
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ def run_vectorized_cast(src_dtype: T.dtype, dst_dtype: T.dtype, check_str: str,

code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()
print(code)
assert check_str in code and check_str in code_parallel, f"Cast {src_dtype} to {dst_dtype} with {lanes=} is not vectorized!"

# Requires torch >= 2.8
if src_dtype == T.float8_e8m0fnu or dst_dtype == T.float8_e8m0fnu:
return

if src_dtype == T.float4_e2m1fn or dst_dtype == T.float4_e2m1fn:
return

Expand Down Expand Up @@ -106,6 +109,13 @@ def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
(T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
# E8M0 <-> BFloat16
(T.float8_e8m0fnu, T.bfloat16, "__tl_cvt_e8m0x2_to_bfloat162", 2),
(T.bfloat16, T.float8_e8m0fnu, "__tl_cvt_bfloat162_to_e8m0x2", 2),
# Float -> E8M0
(T.float32, T.float8_e8m0fnu, "__tl_cvt_float2_to_e8m0x2", 2),
# Double -> E8M0
(T.float64, T.float8_e8m0fnu, "__tl_cvt_double2_to_e8m0x2", 2),
],
)
def test_vectorized_cast_fp8(src_dtype, dst_dtype, check_str, lanes):
Expand Down
Loading