From 8744d58e94ed5e653c725b0a88496a243a705c24 Mon Sep 17 00:00:00 2001 From: Torre Zuk <42548444+TorreZuk@users.noreply.github.com> Date: Thu, 7 May 2026 16:07:14 -0600 Subject: [PATCH 1/2] adding OpenBLAS axpy replacements for WIN32 --- .../clients/common/cblas_interface.cpp | 69 +++++++++++++++++++ .../clients/include/cblas_interface.hpp | 2 + 2 files changed, 71 insertions(+) diff --git a/projects/rocblas/clients/common/cblas_interface.cpp b/projects/rocblas/clients/common/cblas_interface.cpp index ef31c586681..00e7bba0262 100644 --- a/projects/rocblas/clients/common/cblas_interface.cpp +++ b/projects/rocblas/clients/common/cblas_interface.cpp @@ -93,6 +93,75 @@ void ref_axpy( // cblas_saxpy(n, alpha, x_float, incx, y_float, incy); } +#if defined(WIN32) && !defined(BLIS_ENABLE_CBLAS) + +template +void ref_axpy(int64_t n, T alpha, T* x, int64_t incx, T* y, int64_t incy) +{ + // Handle negative increments + x += incx < 0 ? incx * (1 - n) : 0; + y += incy < 0 ? incy * (1 - n) : 0; + + for(int64_t i = 0; i < n; i++) + { + y[i * incy] += alpha * x[i * incx]; + } +} + +template void + ref_axpy(int64_t n, float alpha, float* x, int64_t incx, float* y, int64_t incy); + +template void + ref_axpy(int64_t n, double alpha, double* x, int64_t incx, double* y, int64_t incy); + +template void ref_axpy(int64_t n, + rocblas_float_complex alpha, + rocblas_float_complex* x, + int64_t incx, + rocblas_float_complex* y, + int64_t incy); + +template void ref_axpy(int64_t n, + rocblas_double_complex alpha, + rocblas_double_complex* x, + int64_t incx, + rocblas_double_complex* y, + int64_t incy); + +template <> +void ref_axpy(int64_t n, + rocblas_bfloat16 alpha, + rocblas_bfloat16* x, + int64_t incx, + rocblas_bfloat16* y, + int64_t incy) +{ + // Handle negative increments + int64_t abs_incx = incx < 0 ? -incx : incx; + int64_t abs_incy = incy < 0 ? -incy : incy; + + // Convert to float + host_vector x_float(n * abs_incx); + host_vector y_float(n * abs_incy); + + for(int64_t i = 0; i < n; i++) + { + x_float[i * abs_incx] = float(x[i * abs_incx]); + y_float[i * abs_incy] = float(y[i * abs_incy]); + } + + // Compute in float precision + ref_axpy(n, float(alpha), x_float, incx, y_float, incy); + + // Convert back to bfloat16 + for(int64_t i = 0; i < n; i++) + { + y[i * abs_incy] = rocblas_bfloat16(y_float[i * abs_incy]); + } +} + +#endif + template <> void ref_asum(int64_t n, const float* x, int64_t incx, float* result) { diff --git a/projects/rocblas/clients/include/cblas_interface.hpp b/projects/rocblas/clients/include/cblas_interface.hpp index 50820fa83f3..e10f8fa446f 100644 --- a/projects/rocblas/clients/include/cblas_interface.hpp +++ b/projects/rocblas/clients/include/cblas_interface.hpp @@ -94,6 +94,7 @@ inline void ref_asum(int64_t n, const rocblas_double_complex* x, int64_t incx, d template void ref_axpy(int64_t n, T alpha, T* x, int64_t incx, T* y, int64_t incy); +#if !(defined(WIN32) && !defined(BLIS_ENABLE_CBLAS)) // windows OpenBLAS bug override template <> inline void ref_axpy(int64_t n, float alpha, float* x, int64_t incx, float* y, int64_t incy) { @@ -127,6 +128,7 @@ inline void ref_axpy(int64_t n, { cblas_zaxpy(n, &alpha, x, incx, y, incy); } +#endif // copy template From 755ad379056d470535b19a66d4fab84bb3f20402 Mon Sep 17 00:00:00 2001 From: Torre Zuk Date: Thu, 7 May 2026 16:19:17 -0600 Subject: [PATCH 2/2] clang format --- projects/rocblas/clients/common/cblas_interface.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/projects/rocblas/clients/common/cblas_interface.cpp b/projects/rocblas/clients/common/cblas_interface.cpp index 00e7bba0262..e79e1723da1 100644 --- a/projects/rocblas/clients/common/cblas_interface.cpp +++ b/projects/rocblas/clients/common/cblas_interface.cpp @@ -129,12 +129,12 @@ template void ref_axpy(int64_t n, int64_t incy); template <> -void ref_axpy(int64_t n, - rocblas_bfloat16 alpha, - rocblas_bfloat16* x, - int64_t incx, - rocblas_bfloat16* y, - int64_t incy) +void ref_axpy(int64_t n, + rocblas_bfloat16 alpha, + rocblas_bfloat16* x, + int64_t incx, + rocblas_bfloat16* y, + int64_t incy) { // Handle negative increments int64_t abs_incx = incx < 0 ? -incx : incx;