Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/tl_templates/cuda/cuda_fp4.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const __nv_fp4x2_storage_t src) {
// half -> fp4_e2m1
TL_DEVICE __nv_fp4_storage_t __tl_cvt_half_to_fp4(const __half src) {
__half_raw raw = *reinterpret_cast<const __half_raw *>(&src);
return __nv_cvt_halfraw_to_fp4(raw, __NV_E2M1, cudaRoundZero);
return __nv_cvt_halfraw_to_fp4(raw, __NV_E2M1, cudaRoundNearest);
}

// half2 -> fp4_e2m1x2 (1 byte)
TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_half2_to_fp4x2(const half2 src) {
__half2_raw raw = *reinterpret_cast<const __half2_raw *>(&src);
return __nv_cvt_halfraw2_to_fp4x2(raw, __NV_E2M1, cudaRoundZero);
return __nv_cvt_halfraw2_to_fp4x2(raw, __NV_E2M1, cudaRoundNearest);
}

// ============================================================================
Expand All @@ -207,12 +207,12 @@ TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const __nv_fp4x2_storage_t src) {

// float -> fp4_e2m1
TL_DEVICE __nv_fp4_storage_t __tl_cvt_float_to_fp4(const float src) {
return __nv_cvt_float_to_fp4(src, __NV_E2M1, cudaRoundZero);
return __nv_cvt_float_to_fp4(src, __NV_E2M1, cudaRoundNearest);
}

// float2 -> fp4_e2m1x2 (1 byte)
TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_float2_to_fp4x2(const float2 src) {
return __nv_cvt_float2_to_fp4x2(src, __NV_E2M1, cudaRoundZero);
return __nv_cvt_float2_to_fp4x2(src, __NV_E2M1, cudaRoundNearest);
}

// ============================================================================
Expand All @@ -235,12 +235,12 @@ TL_DEVICE double2 __tl_cvt_fp4x2_to_double2(const __nv_fp4x2_storage_t src) {

// double -> fp4_e2m1
TL_DEVICE __nv_fp4_storage_t __tl_cvt_double_to_fp4(const double src) {
return __nv_cvt_double_to_fp4(src, __NV_E2M1, cudaRoundZero);
return __nv_cvt_double_to_fp4(src, __NV_E2M1, cudaRoundNearest);
}

// double2 -> fp4_e2m1x2
TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_double2_to_fp4x2(const double2 src) {
return __nv_cvt_double2_to_fp4x2(src, __NV_E2M1, cudaRoundZero);
return __nv_cvt_double2_to_fp4x2(src, __NV_E2M1, cudaRoundNearest);
}

// ============================================================================
Expand All @@ -262,14 +262,14 @@ __tl_cvt_fp4x2_to_bfloat162(const __nv_fp4x2_storage_t src) {
// bfloat16 -> fp4_e2m1
TL_DEVICE __nv_fp4_storage_t __tl_cvt_bfloat16_to_fp4(const __nv_bfloat16 src) {
__nv_bfloat16_raw raw = *reinterpret_cast<const __nv_bfloat16_raw *>(&src);
return __nv_cvt_bfloat16raw_to_fp4(raw, __NV_E2M1, cudaRoundZero);
return __nv_cvt_bfloat16raw_to_fp4(raw, __NV_E2M1, cudaRoundNearest);
}

// bfloat162 -> fp4_e2m1x2
TL_DEVICE __nv_fp4x2_storage_t
__tl_cvt_bfloat162_to_fp4x2(const __nv_bfloat162 src) {
__nv_bfloat162_raw raw = *reinterpret_cast<const __nv_bfloat162_raw *>(&src);
return __nv_cvt_bfloat16raw2_to_fp4x2(raw, __NV_E2M1, cudaRoundZero);
return __nv_cvt_bfloat16raw2_to_fp4x2(raw, __NV_E2M1, cudaRoundNearest);
}

#endif
Loading