diff --git a/library/include/rocblas_bfloat16.h b/library/include/rocblas_bfloat16.h index 4ef90b08b..eb3987adb 100644 --- a/library/include/rocblas_bfloat16.h +++ b/library/include/rocblas_bfloat16.h @@ -63,7 +63,7 @@ struct rocblas_bfloat16 } // zero extend lower 16 bits of bfloat16 to convert to IEEE float - explicit constexpr __host__ __device__ operator float() const + constexpr __host__ __device__ operator float() const { union { @@ -73,11 +73,6 @@ struct rocblas_bfloat16 return u.fp32; } - explicit constexpr __host__ __device__ operator double() const - { - return double(float(*this)); - } - private: static constexpr __host__ __device__ uint16_t float_to_bfloat16(float f) { @@ -240,11 +235,6 @@ constexpr __host__ __device__ bool iszero(rocblas_bfloat16 a) { return !(a.data & 0x7fff); } -constexpr __host__ __device__ rocblas_bfloat16 abs(rocblas_bfloat16 a) -{ - a.data &= 0x7fff; - return a; -} inline rocblas_bfloat16 sin(rocblas_bfloat16 a) { return rocblas_bfloat16(sinf(float(a))); @@ -254,15 +244,6 @@ inline rocblas_bfloat16 cos(rocblas_bfloat16 a) return rocblas_bfloat16(cosf(float(a))); } -// Inject standard functions into namespace std -namespace std -{ - __device__ __host__ inline rocblas_bfloat16 abs(const rocblas_bfloat16& z) - { - return rocblas_bfloat16(z.data & 0x7fff); - } -} - #endif // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__)) #endif // _ROCBLAS_BFLOAT16_H_ diff --git a/library/src/blas1/rocblas_rotg.cpp b/library/src/blas1/rocblas_rotg.cpp index 3c1daf63e..d293a10cf 100644 --- a/library/src/blas1/rocblas_rotg.cpp +++ b/library/src/blas1/rocblas_rotg.cpp @@ -11,7 +11,7 @@ namespace template , int>::type = 0> __device__ __host__ void rotg_calc(T& a, T& b, U& c, T& s) { - T scale = std::abs(a) + std::abs(b); + T scale = rocblas_abs(a) + rocblas_abs(b); if(scale == 0.0) { c = 1.0; @@ -24,14 +24,14 @@ namespace T sa = a / scale; T sb = b / scale; T r = scale * sqrt(sa * sa + sb * sb); - T roe = (std::abs(a) > std::abs(b)) ? a : b; + T roe = rocblas_abs(a) > rocblas_abs(b) ? a : b; r = copysign(r, roe); c = a / r; s = b / r; T z = 1.0; - if(std::abs(a) > std::abs(b)) + if(rocblas_abs(a) > rocblas_abs(b)) z = s; - if(std::abs(b) >= std::abs(a) && c != 0.0) + if(rocblas_abs(b) >= rocblas_abs(a) && c != 0.0) z = 1.0 / c; a = r; b = z; @@ -41,7 +41,7 @@ namespace template , int>::type = 0> __device__ __host__ void rotg_calc(T& a, T& b, U& c, T& s) { - if(std::abs(a) == 0.0) + if(!rocblas_abs(a)) { c = 0; s = {1, 0}; @@ -49,12 +49,12 @@ namespace } else { - auto scale = std::abs(a) + std::abs(b); - auto sa = std::abs(a / scale); - auto sb = std::abs(b / scale); + auto scale = rocblas_abs(a) + rocblas_abs(b); + auto sa = rocblas_abs(a / scale); + auto sb = rocblas_abs(b / scale); auto norm = scale * sqrt(sa * sa + sb * sb); - auto alpha = a / std::abs(a); - c = std::abs(a) / norm; + auto alpha = a / rocblas_abs(a); + c = rocblas_abs(a) / norm; s = alpha * conj(b) / norm; a = alpha * norm; } diff --git a/library/src/blas1/rocblas_rotmg.cpp b/library/src/blas1/rocblas_rotmg.cpp index 148841103..f01aca7fa 100644 --- a/library/src/blas1/rocblas_rotmg.cpp +++ b/library/src/blas1/rocblas_rotmg.cpp @@ -34,7 +34,7 @@ namespace T p1 = d1 * x1; T q2 = p2 * y1; T q1 = p1 * x1; - if(fabs(q1) > fabs(q2)) + if(rocblas_abs(q1) > rocblas_abs(q2)) { h21 = -y1 / x1; h12 = p2 / p1; @@ -100,7 +100,7 @@ namespace if(d2 != 0) { - while((fabs(d2) <= rgamsq) || (fabs(d2) >= gamsq)) + while((rocblas_abs(d2) <= rgamsq) || (rocblas_abs(d2) >= gamsq)) { if(flag == 0) { @@ -113,7 +113,7 @@ namespace h12 = 1; flag = -1; } - if(fabs(d2) <= rgamsq) + if(rocblas_abs(d2) <= rgamsq) { d2 *= gamsq; h21 /= gam; diff --git a/library/src/include/utility.h b/library/src/include/utility.h index ac0e85559..ed7672166 100644 --- a/library/src/include/utility.h +++ b/library/src/include/utility.h @@ -6,6 +6,7 @@ #define UTILITY_H #include "definitions.h" #include "rocblas.h" +#include #include #include #include @@ -240,4 +241,26 @@ constexpr auto get_rocblas_status_for_hip_status(hipError_t status) return rocblas_status_internal_error; } } + +// Absolute value +template , int>::type = 0> +__device__ __host__ inline auto rocblas_abs(T x) +{ + return x < 0 ? -x : x; +} + +// For complex, we have defined a __device__ __host__ compatible std::abs +template , int>::type = 0> +__device__ __host__ inline auto rocblas_abs(T x) +{ + return std::abs(x); +} + +// rocblas_bfloat16 is handled specially +__device__ __host__ inline auto rocblas_abs(rocblas_bfloat16 x) +{ + x.data &= 0x7fff; + return x; +} + #endif