diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index 22cc0460c..93c844bb9 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -178,13 +178,13 @@ TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const __nv_fp4x2_storage_t src) { // half -> fp4_e2m1 TL_DEVICE __nv_fp4_storage_t __tl_cvt_half_to_fp4(const __half src) { __half_raw raw = *reinterpret_cast(&src); - return __nv_cvt_halfraw_to_fp4(raw, __NV_E2M1, cudaRoundZero); + return __nv_cvt_halfraw_to_fp4(raw, __NV_E2M1, cudaRoundNearest); } // half2 -> fp4_e2m1x2 (1 byte) TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_half2_to_fp4x2(const half2 src) { __half2_raw raw = *reinterpret_cast(&src); - return __nv_cvt_halfraw2_to_fp4x2(raw, __NV_E2M1, cudaRoundZero); + return __nv_cvt_halfraw2_to_fp4x2(raw, __NV_E2M1, cudaRoundNearest); } // ============================================================================ @@ -207,12 +207,12 @@ TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const __nv_fp4x2_storage_t src) { // float -> fp4_e2m1 TL_DEVICE __nv_fp4_storage_t __tl_cvt_float_to_fp4(const float src) { - return __nv_cvt_float_to_fp4(src, __NV_E2M1, cudaRoundZero); + return __nv_cvt_float_to_fp4(src, __NV_E2M1, cudaRoundNearest); } // float2 -> fp4_e2m1x2 (1 byte) TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_float2_to_fp4x2(const float2 src) { - return __nv_cvt_float2_to_fp4x2(src, __NV_E2M1, cudaRoundZero); + return __nv_cvt_float2_to_fp4x2(src, __NV_E2M1, cudaRoundNearest); } // ============================================================================ @@ -235,12 +235,12 @@ TL_DEVICE double2 __tl_cvt_fp4x2_to_double2(const __nv_fp4x2_storage_t src) { // double -> fp4_e2m1 TL_DEVICE __nv_fp4_storage_t __tl_cvt_double_to_fp4(const double src) { - return __nv_cvt_double_to_fp4(src, __NV_E2M1, cudaRoundZero); + return __nv_cvt_double_to_fp4(src, __NV_E2M1, cudaRoundNearest); } // double2 -> fp4_e2m1x2 TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_double2_to_fp4x2(const double2 src) { - return __nv_cvt_double2_to_fp4x2(src, __NV_E2M1, cudaRoundZero); + return __nv_cvt_double2_to_fp4x2(src, __NV_E2M1, cudaRoundNearest); } // ============================================================================ @@ -262,14 +262,14 @@ __tl_cvt_fp4x2_to_bfloat162(const __nv_fp4x2_storage_t src) { // bfloat16 -> fp4_e2m1 TL_DEVICE __nv_fp4_storage_t __tl_cvt_bfloat16_to_fp4(const __nv_bfloat16 src) { __nv_bfloat16_raw raw = *reinterpret_cast(&src); - return __nv_cvt_bfloat16raw_to_fp4(raw, __NV_E2M1, cudaRoundZero); + return __nv_cvt_bfloat16raw_to_fp4(raw, __NV_E2M1, cudaRoundNearest); } // bfloat162 -> fp4_e2m1x2 TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_bfloat162_to_fp4x2(const __nv_bfloat162 src) { __nv_bfloat162_raw raw = *reinterpret_cast(&src); - return __nv_cvt_bfloat16raw2_to_fp4x2(raw, __NV_E2M1, cudaRoundZero); + return __nv_cvt_bfloat16raw2_to_fp4x2(raw, __NV_E2M1, cudaRoundNearest); } #endif