Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions library/include/rocblas_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct rocblas_bfloat16
}

// zero extend lower 16 bits of bfloat16 to convert to IEEE float
explicit constexpr __host__ __device__ operator float() const
constexpr __host__ __device__ operator float() const
{
union
{
Expand All @@ -73,11 +73,6 @@ struct rocblas_bfloat16
return u.fp32;
}

explicit constexpr __host__ __device__ operator double() const
{
return double(float(*this));
}

private:
static constexpr __host__ __device__ uint16_t float_to_bfloat16(float f)
{
Expand Down Expand Up @@ -240,11 +235,6 @@ constexpr __host__ __device__ bool iszero(rocblas_bfloat16 a)
{
return !(a.data & 0x7fff);
}
constexpr __host__ __device__ rocblas_bfloat16 abs(rocblas_bfloat16 a)
{
a.data &= 0x7fff;
return a;
}
inline rocblas_bfloat16 sin(rocblas_bfloat16 a)
{
return rocblas_bfloat16(sinf(float(a)));
Expand All @@ -254,15 +244,6 @@ inline rocblas_bfloat16 cos(rocblas_bfloat16 a)
return rocblas_bfloat16(cosf(float(a)));
}

// Inject standard functions into namespace std
namespace std
{
__device__ __host__ inline rocblas_bfloat16 abs(const rocblas_bfloat16& z)
{
return rocblas_bfloat16(z.data & 0x7fff);
}
}

#endif // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__))

#endif // _ROCBLAS_BFLOAT16_H_
20 changes: 10 additions & 10 deletions library/src/blas1/rocblas_rotg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace
template <typename T, typename U, typename std::enable_if<!is_complex<T>, int>::type = 0>
__device__ __host__ void rotg_calc(T& a, T& b, U& c, T& s)
{
T scale = std::abs(a) + std::abs(b);
T scale = rocblas_abs(a) + rocblas_abs(b);
if(scale == 0.0)
{
c = 1.0;
Expand All @@ -24,14 +24,14 @@ namespace
T sa = a / scale;
T sb = b / scale;
T r = scale * sqrt(sa * sa + sb * sb);
T roe = (std::abs(a) > std::abs(b)) ? a : b;
T roe = rocblas_abs(a) > rocblas_abs(b) ? a : b;
r = copysign(r, roe);
c = a / r;
s = b / r;
T z = 1.0;
if(std::abs(a) > std::abs(b))
if(rocblas_abs(a) > rocblas_abs(b))
z = s;
if(std::abs(b) >= std::abs(a) && c != 0.0)
if(rocblas_abs(b) >= rocblas_abs(a) && c != 0.0)
z = 1.0 / c;
a = r;
b = z;
Expand All @@ -41,20 +41,20 @@ namespace
template <typename T, typename U, typename std::enable_if<is_complex<T>, int>::type = 0>
__device__ __host__ void rotg_calc(T& a, T& b, U& c, T& s)
{
if(std::abs(a) == 0.0)
if(!rocblas_abs(a))
{
c = 0;
s = {1, 0};
a = b;
}
else
{
auto scale = std::abs(a) + std::abs(b);
auto sa = std::abs(a / scale);
auto sb = std::abs(b / scale);
auto scale = rocblas_abs(a) + rocblas_abs(b);
auto sa = rocblas_abs(a / scale);
auto sb = rocblas_abs(b / scale);
auto norm = scale * sqrt(sa * sa + sb * sb);
auto alpha = a / std::abs(a);
c = std::abs(a) / norm;
auto alpha = a / rocblas_abs(a);
c = rocblas_abs(a) / norm;
s = alpha * conj(b) / norm;
a = alpha * norm;
}
Expand Down
6 changes: 3 additions & 3 deletions library/src/blas1/rocblas_rotmg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace
T p1 = d1 * x1;
T q2 = p2 * y1;
T q1 = p1 * x1;
if(fabs(q1) > fabs(q2))
if(rocblas_abs(q1) > rocblas_abs(q2))
{
h21 = -y1 / x1;
h12 = p2 / p1;
Expand Down Expand Up @@ -100,7 +100,7 @@ namespace

if(d2 != 0)
{
while((fabs(d2) <= rgamsq) || (fabs(d2) >= gamsq))
while((rocblas_abs(d2) <= rgamsq) || (rocblas_abs(d2) >= gamsq))
{
if(flag == 0)
{
Expand All @@ -113,7 +113,7 @@ namespace
h12 = 1;
flag = -1;
}
if(fabs(d2) <= rgamsq)
if(rocblas_abs(d2) <= rgamsq)
{
d2 *= gamsq;
h21 /= gam;
Expand Down
23 changes: 23 additions & 0 deletions library/src/include/utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define UTILITY_H
#include "definitions.h"
#include "rocblas.h"
#include <cmath>
#include <complex>
#include <hip/hip_runtime.h>
#include <type_traits>
Expand Down Expand Up @@ -240,4 +241,26 @@ constexpr auto get_rocblas_status_for_hip_status(hipError_t status)
return rocblas_status_internal_error;
}
}

// Absolute value
template <typename T, typename std::enable_if<!is_complex<T>, int>::type = 0>
__device__ __host__ inline auto rocblas_abs(T x)
{
return x < 0 ? -x : x;
}

// For complex, we have defined a __device__ __host__ compatible std::abs
template <typename T, typename std::enable_if<is_complex<T>, int>::type = 0>
__device__ __host__ inline auto rocblas_abs(T x)
{
return std::abs(x);
}

// rocblas_bfloat16 is handled specially
__device__ __host__ inline auto rocblas_abs(rocblas_bfloat16 x)
{
x.data &= 0x7fff;
return x;
}

#endif