Skip to content

Commit bddcd20

Browse files
authored
Use the float flavors of the cmath functions in the extended floating point fallbacks (#2106)
Fixes #2078
1 parent a2a3824 commit bddcd20

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

libcudacxx/include/cuda/std/__cuda/cmath_nvbf16.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -37,47 +37,47 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD
3737
// trigonometric functions
3838
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sin(__nv_bfloat16 __v)
3939
{
40-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __float2bfloat16(::sin(__bfloat162float(__v)));))
40+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __float2bfloat16(::sinf(__bfloat162float(__v)));))
4141
}
4242

4343
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sinh(__nv_bfloat16 __v)
4444
{
45-
return __float2bfloat16(::sinh(__bfloat162float(__v)));
45+
return __float2bfloat16(::sinhf(__bfloat162float(__v)));
4646
}
4747

4848
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cos(__nv_bfloat16 __v)
4949
{
50-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __float2bfloat16(::cos(__bfloat162float(__v)));))
50+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __float2bfloat16(::cosf(__bfloat162float(__v)));))
5151
}
5252

5353
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cosh(__nv_bfloat16 __v)
5454
{
55-
return __float2bfloat16(::cosh(__bfloat162float(__v)));
55+
return __float2bfloat16(::coshf(__bfloat162float(__v)));
5656
}
5757

5858
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 exp(__nv_bfloat16 __v)
5959
{
60-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __float2bfloat16(::exp(__bfloat162float(__v)));))
60+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __float2bfloat16(::expf(__bfloat162float(__v)));))
6161
}
6262

6363
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 hypot(__nv_bfloat16 __x, __nv_bfloat16 __y)
6464
{
65-
return __float2bfloat16(::hypot(__bfloat162float(__x), __bfloat162float(__y)));
65+
return __float2bfloat16(::hypotf(__bfloat162float(__x), __bfloat162float(__y)));
6666
}
6767

6868
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 __y)
6969
{
70-
return __float2bfloat16(::atan2(__bfloat162float(__x), __bfloat162float(__y)));
70+
return __float2bfloat16(::atan2f(__bfloat162float(__x), __bfloat162float(__y)));
7171
}
7272

7373
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 log(__nv_bfloat16 __x)
7474
{
75-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __float2bfloat16(::log(__bfloat162float(__x)));))
75+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __float2bfloat16(::logf(__bfloat162float(__x)));))
7676
}
7777

7878
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sqrt(__nv_bfloat16 __x)
7979
{
80-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(::sqrt(__bfloat162float(__x)));))
80+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(::sqrtf(__bfloat162float(__x)));))
8181
}
8282

8383
// floating point helper

libcudacxx/include/cuda/std/__cuda/cmath_nvfp16.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half sin(__half __v)
3636
{
3737
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_53, (return ::hsin(__v);), ({
3838
float __vf = __half2float(__v);
39-
__vf = ::sin(__vf);
39+
__vf = ::sinf(__vf);
4040
__half_raw __ret_repr = ::__float2half_rn(__vf);
4141

4242
uint16_t __repr = __half_raw(__v).x;
@@ -61,7 +61,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half sin(__half __v)
6161

6262
inline _LIBCUDACXX_INLINE_VISIBILITY __half sinh(__half __v)
6363
{
64-
return __float2half(::sinh(__half2float(__v)));
64+
return __float2half(::sinhf(__half2float(__v)));
6565
}
6666

6767
// clang-format off
@@ -72,7 +72,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half cos(__half __v)
7272
), (
7373
{
7474
float __vf = __half2float(__v);
75-
__vf = ::cos(__vf);
75+
__vf = ::cosf(__vf);
7676
__half_raw __ret_repr = ::__float2half_rn(__vf);
7777

7878
uint16_t __repr = __half_raw(__v).x;
@@ -94,7 +94,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half cos(__half __v)
9494

9595
inline _LIBCUDACXX_INLINE_VISIBILITY __half cosh(__half __v)
9696
{
97-
return __float2half(::cosh(__half2float(__v)));
97+
return __float2half(::coshf(__half2float(__v)));
9898
}
9999

100100
// clang-format off
@@ -105,7 +105,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half exp(__half __v)
105105
), (
106106
{
107107
float __vf = __half2float(__v);
108-
__vf = ::exp(__vf);
108+
__vf = ::expf(__vf);
109109
__half_raw __ret_repr = ::__float2half_rn(__vf);
110110

111111
uint16_t __repr = __half_raw(__v).x;
@@ -127,12 +127,12 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half exp(__half __v)
127127

128128
inline _LIBCUDACXX_INLINE_VISIBILITY __half hypot(__half __x, __half __y)
129129
{
130-
return __float2half(::hypot(__half2float(__x), __half2float(__y)));
130+
return __float2half(::hypotf(__half2float(__x), __half2float(__y)));
131131
}
132132

133133
inline _LIBCUDACXX_INLINE_VISIBILITY __half atan2(__half __x, __half __y)
134134
{
135-
return __float2half(::atan2(__half2float(__x), __half2float(__y)));
135+
return __float2half(::atan2f(__half2float(__x), __half2float(__y)));
136136
}
137137

138138
// clang-format off
@@ -143,7 +143,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half log(__half __x)
143143
), (
144144
{
145145
float __vf = __half2float(__x);
146-
__vf = ::log(__vf);
146+
__vf = ::logf(__vf);
147147
__half_raw __ret_repr = ::__float2half_rn(__vf);
148148

149149
uint16_t __repr = __half_raw(__x).x;
@@ -164,7 +164,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half log(__half __x)
164164

165165
inline _LIBCUDACXX_INLINE_VISIBILITY __half sqrt(__half __x)
166166
{
167-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2half(::sqrt(__half2float(__x)));))
167+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2half(::sqrtf(__half2float(__x)));))
168168
}
169169

170170
// floating point helper

0 commit comments

Comments
 (0)