diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 09f4b3c2a..3f7a54d4d 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -50,14 +50,19 @@ using namespace std::literals; #include "testing_trsv.hpp" // Template to dispatch testing_gemm_ex for performance tests -// When Ti == void or complex, the test is marked invalid +// When Ti == void or Ti == To == Tc == bfloat16, the test is marked invalid template struct perf_gemm_ex : rocblas_test_invalid { }; template -struct perf_gemm_ex{}>::type> +struct perf_gemm_ex{} + && !(std::is_same{} && std::is_same{} + && std::is_same{})>::type> { explicit operator bool() { @@ -70,17 +75,20 @@ struct perf_gemm_ex{ }; // Template to dispatch testing_gemm_strided_batched_ex for performance tests -// When Ti == void or complex, the test is marked invalid +// When Ti == void or Ti == To == Tc == bfloat16, the test is marked invalid template struct perf_gemm_strided_batched_ex : rocblas_test_invalid { }; template -struct perf_gemm_strided_batched_ex{}>::type> +struct perf_gemm_strided_batched_ex< + Ti, + To, + Tc, + typename std::enable_if{} + && !(std::is_same{} && std::is_same{} + && std::is_same{})>::type> { explicit operator bool() { @@ -163,6 +171,23 @@ struct perf_blas< } }; +template +struct perf_blas{}>::type> +{ + explicit operator bool() + { + return true; + } + void operator()(const Arguments& arg) + { + if(!strcmp(arg.function, "dot")) + testing_dot(arg); + else + throw std::invalid_argument("Invalid combination --function "s + arg.function + + " --a_type "s + rocblas_datatype2string(arg.a_type)); + } +}; + template struct perf_blas{}>::type> { @@ -174,6 +199,8 @@ struct perf_blas{}>: { if(!strcmp(arg.function, "axpy")) testing_axpy(arg); + else if(!strcmp(arg.function, "dot")) + testing_dot(arg); else if(!strcmp(arg.function, "gemm")) testing_gemm(arg); else if(!strcmp(arg.function, "gemm_strided_batched")) diff --git a/clients/common/cblas_interface.cpp b/clients/common/cblas_interface.cpp index 86a7fa3b7..74eb44ceb 100644 --- a/clients/common/cblas_interface.cpp +++ b/clients/common/cblas_interface.cpp @@ -39,6 +39,50 @@ void cblas_axpy(rocblas_int n, } } +template <> +void cblas_dot(rocblas_int n, + const rocblas_half* x, + rocblas_int incx, + const rocblas_half* y, + rocblas_int incy, + rocblas_half* result) +{ + size_t abs_incx = incx >= 0 ? incx : -incx; + size_t abs_incy = incy >= 0 ? incy : -incy; + host_vector x_float(n * abs_incx); + host_vector y_float(n * abs_incy); + + for(size_t i = 0; i < n; i++) + { + x_float[i * abs_incx] = half_to_float(x[i * abs_incx]); + y_float[i * abs_incy] = half_to_float(y[i * abs_incy]); + } + + *result = float_to_half(cblas_sdot(n, x_float, incx, y_float, incy)); +} + +template <> +void cblas_dot(rocblas_int n, + const rocblas_bfloat16* x, + rocblas_int incx, + const rocblas_bfloat16* y, + rocblas_int incy, + rocblas_bfloat16* result) +{ + size_t abs_incx = incx >= 0 ? incx : -incx; + size_t abs_incy = incy >= 0 ? incy : -incy; + host_vector x_float(n * abs_incx); + host_vector y_float(n * abs_incy); + + for(size_t i = 0; i < n; i++) + { + x_float[i * abs_incx] = float(x[i * abs_incx]); + y_float[i * abs_incy] = float(y[i * abs_incy]); + } + + *result = rocblas_bfloat16(cblas_sdot(n, x_float, incx, y_float, incy)); +} + /* * =========================================================================== * level 2 BLAS diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index 8eb96a432..d33df705d 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -83,7 +83,8 @@ namespace || std::is_same{})) || (BLAS1 == blas1::dot && std::is_same{} && std::is_same{} - && (std::is_same{} + && (std::is_same{} || std::is_same{} + || std::is_same{} || std::is_same{} || std::is_same{} || std::is_same{})) diff --git a/clients/gtest/blas1_gtest.yaml b/clients/gtest/blas1_gtest.yaml index 768cf0273..d6dcd1404 100644 --- a/clients/gtest/blas1_gtest.yaml +++ b/clients/gtest/blas1_gtest.yaml @@ -34,7 +34,7 @@ Tests: # - iamin: *single_double_precisions_complex_real # broken for now -- cause unknown - axpy: *half_single_precisions_complex_real - copy: *single_double_precisions_complex_real - - dot: *single_double_precisions_complex_real + - dot: *half_bfloat_single_double_complex_real_precisions - dotc: *single_double_precisions_complex - scal: *single_double_precisions_complex_real - scal: *single_double_complex_real_in_complex_out @@ -68,7 +68,7 @@ Tests: - iamin_bad_arg: *single_double_precisions_complex_real - axpy_bad_arg: *half_single_precisions_complex_real - copy_bad_arg: *single_double_precisions_complex_real - - dot_bad_arg: *single_double_precisions_complex_real + - dot_bad_arg: *half_bfloat_single_double_complex_real_precisions - dotc_bad_arg: *single_double_precisions_complex - scal_bad_arg: *single_double_precisions_complex_real - scal_bad_arg: *single_double_complex_real_in_complex_out diff --git a/clients/gtest/gemm_gtest.cpp b/clients/gtest/gemm_gtest.cpp index b32b7594a..5ad7f5390 100644 --- a/clients/gtest/gemm_gtest.cpp +++ b/clients/gtest/gemm_gtest.cpp @@ -113,7 +113,11 @@ namespace // When Ti = To = Tc != void, this test applies. // When converted to bool, this functor returns true. template - struct gemm_testing{}>::type> + struct gemm_testing{} + && !std::is_same{}>::type> { explicit operator bool() { @@ -162,7 +166,13 @@ namespace // When Ti != void, this test applies. // When converted to bool, this functor returns true. template - struct gemm_ex_testing{}>::type> + struct gemm_ex_testing< + Ti, + To, + Tc, + typename std::enable_if{} + && !(std::is_same{} && std::is_same{} + && std::is_same{})>::type> { explicit operator bool() { diff --git a/clients/include/rocblas.hpp b/clients/include/rocblas.hpp index ff378f9ff..dee81629b 100644 --- a/clients/include/rocblas.hpp +++ b/clients/include/rocblas.hpp @@ -100,6 +100,12 @@ static constexpr auto rocblas_dot = rocblas_sdot; template <> static constexpr auto rocblas_dot = rocblas_ddot; +template <> +static constexpr auto rocblas_dot = rocblas_hdot; + +template <> +static constexpr auto rocblas_dot = rocblas_bfdot; + template <> static constexpr auto rocblas_dot = rocblas_cdotu; diff --git a/clients/include/rocblas_common.yaml b/clients/include/rocblas_common.yaml index 22efdc51e..a9f8dd5d9 100644 --- a/clients/include/rocblas_common.yaml +++ b/clients/include/rocblas_common.yaml @@ -40,6 +40,8 @@ Real precisions: &real_precisions { a_type: f64_r, b_type: f64_r, c_type: f64_r, d_type: f64_r, compute_type: f64_r } - &int8_precision { a_type: i8_r, b_type: i8_r, c_type: i32_r, d_type: i32_r, compute_type: i32_r } + - &bf16_precision + { a_type: bf16_r, b_type: bf16_r, c_type: bf16_r, d_type: bf16_r, compute_type: bf16_r } - &hpa_bf16_precision { a_type: bf16_r, b_type: bf16_r, c_type: bf16_r, d_type: bf16_r, compute_type: f32_r } @@ -173,6 +175,19 @@ Single double joined: &single_double_complex_real_in_complex_out - *single_precision_complex_real_in_complex_out - *double_precision_complex_real_in_complex_out +############################################# +# Used for Dot (quick) # +############################################# +Half bfloat single double complex real: &half_bfloat_single_double_complex_real_precisions + - *half_precision + - *bf16_precision + - *single_precision + - *double_precision + - *half_precision_complex + - *single_precision_complex + - *double_precision_complex + + # The Arguments struct passed directly to C++. See rocblas_arguments.hpp. # The order of the entries is significant, so it can't simply be a dictionary. # The types on the RHS are eval'd for Python-recognized types including ctypes diff --git a/clients/include/testing_dot.hpp b/clients/include/testing_dot.hpp index c8583259d..853250826 100644 --- a/clients/include/testing_dot.hpp +++ b/clients/include/testing_dot.hpp @@ -156,8 +156,8 @@ void testing_dot(const Arguments& arg) std::cout << "cpu=" << cpu_result << ", gpu_host_ptr=" << rocblas_result_1 << ", gpu_device_ptr=" << rocblas_result_2 << "\n"; - rocblas_error_1 = std::abs((cpu_result - rocblas_result_1) / cpu_result); - rocblas_error_2 = std::abs((cpu_result - rocblas_result_2) / cpu_result); + rocblas_error_1 = double(std::abs((cpu_result - rocblas_result_1) / cpu_result)); + rocblas_error_2 = double(std::abs((cpu_result - rocblas_result_2) / cpu_result)); } } diff --git a/clients/include/type_dispatch.hpp b/clients/include/type_dispatch.hpp index 22c86f2ea..4717ddf64 100644 --- a/clients/include/type_dispatch.hpp +++ b/clients/include/type_dispatch.hpp @@ -22,12 +22,14 @@ auto rocblas_simple_dispatch(const Arguments& arg) { case rocblas_datatype_f16_r: return TEST{}(arg); + case rocblas_datatype_bf16_r: + return TEST{}(arg); case rocblas_datatype_f32_r: return TEST{}(arg); case rocblas_datatype_f64_r: return TEST{}(arg); - // case rocblas_datatype_f16_c: - // return TEST{}(arg); + // case rocblas_datatype_f16_c: + // return TEST{}(arg); case rocblas_datatype_f32_c: return TEST{}(arg); case rocblas_datatype_f64_c: diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 50cc13b7a..a9c37a0f3 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -209,6 +209,22 @@ ROCBLAS_EXPORT rocblas_status rocblas_ddot(rocblas_handle handle, rocblas_int incy, double* result); +ROCBLAS_EXPORT rocblas_status rocblas_hdot(rocblas_handle handle, + rocblas_int n, + const rocblas_half* x, + rocblas_int incx, + const rocblas_half* y, + rocblas_int incy, + rocblas_half* result); + +ROCBLAS_EXPORT rocblas_status rocblas_bfdot(rocblas_handle handle, + rocblas_int n, + const rocblas_bfloat16* x, + rocblas_int incx, + const rocblas_bfloat16* y, + rocblas_int incy, + rocblas_bfloat16* result); + ROCBLAS_EXPORT rocblas_status rocblas_cdotu(rocblas_handle handle, rocblas_int n, const rocblas_float_complex* x, diff --git a/library/include/rocblas_bfloat16.h b/library/include/rocblas_bfloat16.h index 86d007fd8..4ef90b08b 100644 --- a/library/include/rocblas_bfloat16.h +++ b/library/include/rocblas_bfloat16.h @@ -254,6 +254,15 @@ inline rocblas_bfloat16 cos(rocblas_bfloat16 a) return rocblas_bfloat16(cosf(float(a))); } +// Inject standard functions into namespace std +namespace std +{ + __device__ __host__ inline rocblas_bfloat16 abs(const rocblas_bfloat16& z) + { + return rocblas_bfloat16(z.data & 0x7fff); + } +} + #endif // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__)) #endif // _ROCBLAS_BFLOAT16_H_ diff --git a/library/src/blas1/reduction.h b/library/src/blas1/reduction.h index 32a62ab60..a9a4589f1 100644 --- a/library/src/blas1/reduction.h +++ b/library/src/blas1/reduction.h @@ -215,7 +215,7 @@ __global__ void rocblas_reduction_kernel_part2(rocblas_int nblocks, To* workspac // Store result on device or in workspace if(tx == 0) - *result = FINALIZE{}(tmp[0]); + *result = Tr(FINALIZE{}(tmp[0])); } // At least two kernels are needed to finish the reduction diff --git a/library/src/blas1/rocblas_dot.cpp b/library/src/blas1/rocblas_dot.cpp index 1dcb2c00f..8ef551fee 100644 --- a/library/src/blas1/rocblas_dot.cpp +++ b/library/src/blas1/rocblas_dot.cpp @@ -13,20 +13,20 @@ namespace // setting to 512 for gfx803. constexpr int NB = 512; - template + template __global__ void dot_kernel_part1( - rocblas_int n, const T* x, rocblas_int incx, const T* y, rocblas_int incy, T* workspace) + rocblas_int n, const T* x, rocblas_int incx, const T* y, rocblas_int incy, T2* workspace) { ptrdiff_t tx = hipThreadIdx_x; ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + tx; - __shared__ T tmp[NB]; + __shared__ T2 tmp[NB]; // bound if(tid < n) - tmp[tx] = y[tid * incy] * (CONJ ? conj(x[tid * incx]) : x[tid * incx]); + tmp[tx] = T2(y[tid * incy]) * T2(CONJ ? conj(x[tid * incx]) : x[tid * incx]); else - tmp[tx] = T(0); // pad with zero + tmp[tx] = T2(0); // pad with zero rocblas_sum_reduce(tx, tmp); @@ -36,7 +36,7 @@ namespace // assume workspace has already been allocated, recommened for repeated calling of dot product // routine - template + template rocblas_status rocblas_dot_workspace(rocblas_handle __restrict__ handle, rocblas_int n, const T* x, @@ -44,7 +44,7 @@ namespace const T* y, rocblas_int incy, T* result, - T* workspace, + T2* workspace, rocblas_int blocks) { // At least two kernels are needed to finish the reduction @@ -73,18 +73,33 @@ namespace incy, workspace); - hipLaunchKernelGGL(rocblas_reduction_kernel_part2, - 1, - threads, - 0, - handle->rocblas_stream, - blocks, - workspace, - handle->pointer_mode != rocblas_pointer_mode_device ? workspace - : result); - if(handle->pointer_mode != rocblas_pointer_mode_device) + if(handle->pointer_mode == rocblas_pointer_mode_device) + { + hipLaunchKernelGGL(rocblas_reduction_kernel_part2, + 1, + threads, + 0, + handle->rocblas_stream, + blocks, + workspace, + result); + } + else + { + hipLaunchKernelGGL(rocblas_reduction_kernel_part2, + 1, + threads, + 0, + handle->rocblas_stream, + blocks, + workspace, + workspace); + + T2 res_T2; RETURN_IF_HIP_ERROR( - hipMemcpy(result, workspace, sizeof(*result), hipMemcpyDeviceToHost)); + hipMemcpy(&res_T2, workspace, sizeof(res_T2), hipMemcpyDeviceToHost)); + *result = T(res_T2); + } return rocblas_status_success; } @@ -95,6 +110,10 @@ namespace constexpr char rocblas_dot_name[] = "rocblas_sdot"; template constexpr char rocblas_dot_name[] = "rocblas_ddot"; + template + constexpr char rocblas_dot_name[] = "rocblas_hdot"; + template + constexpr char rocblas_dot_name[] = "rocblas_bfdot"; template <> constexpr char rocblas_dot_name[] = "rocblas_cdotc"; template <> @@ -105,7 +124,7 @@ namespace constexpr char rocblas_dot_name[] = "rocblas_zdotu"; // allocate workspace inside this API - template + template rocblas_status rocblas_dot(rocblas_handle handle, rocblas_int n, const T* x, @@ -152,13 +171,13 @@ namespace auto blocks = (n - 1) / NB + 1; if(handle->is_device_memory_size_query()) - return handle->set_optimal_device_memory_size(sizeof(T) * blocks); + return handle->set_optimal_device_memory_size(sizeof(T2) * blocks); - auto mem = handle->device_malloc(sizeof(T) * blocks); + auto mem = handle->device_malloc(sizeof(T2) * blocks); if(!mem) return rocblas_status_memory_error; - return rocblas_dot_workspace(handle, n, x, incx, y, incy, result, (T*)mem, blocks); + return rocblas_dot_workspace(handle, n, x, incx, y, incy, result, (T2*)mem, blocks); } } // namespace @@ -193,6 +212,29 @@ rocblas_status rocblas_ddot(rocblas_handle handle, return rocblas_dot(handle, n, x, incx, y, incy, result); } +rocblas_status rocblas_hdot(rocblas_handle handle, + rocblas_int n, + const rocblas_half* x, + rocblas_int incx, + const rocblas_half* y, + rocblas_int incy, + rocblas_half* result) +{ + return rocblas_dot( + handle, n, (const _Float16*)x, incx, (const _Float16*)y, incy, (_Float16*)result); +} + +rocblas_status rocblas_bfdot(rocblas_handle handle, + rocblas_int n, + const rocblas_bfloat16* x, + rocblas_int incx, + const rocblas_bfloat16* y, + rocblas_int incy, + rocblas_bfloat16* result) +{ + return rocblas_dot(handle, n, x, incx, y, incy, result); +} + rocblas_status rocblas_cdotu(rocblas_handle handle, rocblas_int n, const rocblas_float_complex* x,