diff --git a/projects/rocblas/clients/common/cblas_interface.cpp b/projects/rocblas/clients/common/cblas_interface.cpp index 84e070d756a..ebe2612bbcb 100644 --- a/projects/rocblas/clients/common/cblas_interface.cpp +++ b/projects/rocblas/clients/common/cblas_interface.cpp @@ -2242,26 +2242,52 @@ void ref_syrk_ex(rocblas_fill uplo, U* C, int64_t ldc) { - float alpha_float = alpha; - float beta_float = beta; + if constexpr(!std::is_same_v) + { + float alpha_float = alpha; + float beta_float = beta; - host_vector A_float, C_float; + host_vector A_float, C_float; - cast_to_buffer(transA, n, k, lda, A, A_float); - cast_to_buffer(rocblas_operation_none, n, n, ldc, C, C_float); + cast_to_buffer(transA, n, k, lda, A, A_float); + cast_to_buffer(rocblas_operation_none, n, n, ldc, C, C_float); - ref_syrk(uplo, - transA, - n, - k, - alpha_float, - (const float*)A_float.data(), - lda, - beta_float, - C_float.data(), - ldc); + ref_syrk(uplo, + transA, + n, + k, + alpha_float, + (const float*)A_float.data(), + lda, + beta_float, + C_float.data(), + ldc); + + cast_from_buffer(n, n, ldc, C_float, C); + } + else + { + double alpha_double = alpha; + double beta_double = beta; + + host_vector A_double, C_double; + + cast_to_buffer(transA, n, k, lda, A, A_double); + cast_to_buffer(rocblas_operation_none, n, n, ldc, C, C_double); - cast_from_buffer(n, n, ldc, C_float, C); + ref_syrk(uplo, + transA, + n, + k, + alpha_double, + (const double*)A_double.data(), + lda, + beta_double, + C_double.data(), + ldc); + + cast_from_buffer(n, n, ldc, C_double, C); + } } #define INSTANTIATE_SYRK_EX_TEMPLATE(T_, U_, Tc_) \ diff --git a/projects/rocblas/clients/include/blas_ex/testing_syrk_ex.hpp b/projects/rocblas/clients/include/blas_ex/testing_syrk_ex.hpp index 373da18e692..27fe010ea67 100644 --- a/projects/rocblas/clients/include/blas_ex/testing_syrk_ex.hpp +++ b/projects/rocblas/clients/include/blas_ex/testing_syrk_ex.hpp @@ -372,8 +372,8 @@ void testing_syrk_ex(const Arguments& arg) // reference is computed on floats double tol = rocblas_handle(handle)->getArchMajor() == 11 ? sum_error_tolerance_for_gfx11 - : sum_error_tolerance; - tol *= K * 4; + : 4 * sum_error_tolerance; + tol = tol * K + 2 * sum_error_tolerance; // add To conversion rounding error near_check_general(N, N, ldc, hC_gold, hC, tol); } else diff --git a/projects/rocblas/clients/include/near.hpp b/projects/rocblas/clients/include/near.hpp index 804fbb2e39f..5f389b2349f 100644 --- a/projects/rocblas/clients/include/near.hpp +++ b/projects/rocblas/clients/include/near.hpp @@ -86,6 +86,12 @@ template <> inline constexpr double sum_error_tolerance_for_gfx11 = 1 / 100.0; +template <> // syrk_ex use +inline constexpr double sum_error_tolerance_for_gfx11 = get_epsilon(); + +template <> // syrk_ex use +inline constexpr double sum_error_tolerance_for_gfx11 = get_epsilon(); + template <> inline constexpr double sum_error_tolerance_for_gfx11