From 1ae458c1e29b657e46dc7f01bfed64f81974ff9d Mon Sep 17 00:00:00 2001 From: Lee Killough Date: Sat, 5 Oct 2019 10:32:54 -0400 Subject: [PATCH 1/5] Refactoring --- clients/benchmarks/client.cpp | 5 +- clients/gtest/blas1_gtest.cpp | 3 +- clients/include/near.hpp | 48 +- clients/include/rocblas_datatype2string.hpp | 8 +- clients/include/testing_gemm.hpp | 10 +- clients/include/unit.hpp | 35 +- library/include/rocblas-auxiliary.h | 4 +- library/include/tensile_host.hpp | 185 ++-- library/src/CMakeLists.txt | 10 +- library/src/blas3/Tensile/gemm.cpp | 779 ++++++++------ library/src/blas3/Tensile/gemm_batched.cpp | 458 ++++---- .../blas3/Tensile/gemm_strided_batched.cpp | 992 +++++++++++------- library/src/blas_ex/rocblas_gemm_ex.hpp | 481 +++++---- library/src/handle.cpp | 7 +- library/src/include/handle.h | 2 +- library/src/rocblas_auxiliary.cpp | 36 +- library/src/tensile_host.cpp | 441 ++++---- 17 files changed, 1875 insertions(+), 1629 deletions(-) diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index f4818c7d3..65be2031d 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -563,7 +563,7 @@ try options_description desc("rocblas-bench command line options"); desc.add_options() - // clang-format off + // clang-format off #ifdef USE_TENSILE_HOST ("lib", @@ -806,7 +806,8 @@ try throw std::invalid_argument("Invalid value for --function"); #ifdef USE_TENSILE_HOST - int copied_host = snprintf(arg.host_lib_path, sizeof(arg.host_lib_path), "%s", host_lib_path.c_str()); + int copied_host + = snprintf(arg.host_lib_path, sizeof(arg.host_lib_path), "%s", host_lib_path.c_str()); if(copied_host <= 0 || copied_host >= sizeof(arg.host_lib_path)) throw std::invalid_argument("Invalid value for --lib"); #endif diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index b2d2a9a42..f6f708bb7 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -288,6 +288,7 @@ TEST_P(NAME, blas1) \ { \ rocblas_blas1_dispatch(GetParam()); \ } \ +// clang-format on \ INSTANTIATE_TEST_CATEGORIES(NAME) @@ -324,6 +325,4 @@ BLAS1_TESTING(rotg, ARG2) BLAS1_TESTING(rotm, ARG1) BLAS1_TESTING(rotmg, ARG1) - // clang-format on - } // namespace diff --git a/clients/include/near.hpp b/clients/include/near.hpp index 66a69ab39..77bc4855c 100644 --- a/clients/include/near.hpp +++ b/clients/include/near.hpp @@ -41,33 +41,33 @@ static constexpr double sum_error_tolerance = 1 / 100000 #define NEAR_CHECK(M, N, batch_count, lda, strideA, hCPU, hGPU, err, NEAR_ASSERT) #define NEAR_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, err, NEAR_ASSERT) #else -// clang-format off -#define NEAR_CHECK(M, N, batch_count, lda, strideA, hCPU, hGPU, err, NEAR_ASSERT) \ - do \ - { \ - for(size_t k = 0; k < batch_count; k++) \ - for(size_t j = 0; j < N; j++) \ - for(size_t i = 0; i < M; i++) \ - NEAR_ASSERT(hCPU[i + j * lda + k * strideA], \ - hGPU[i + j * lda + k * strideA], \ - err); \ + +#define NEAR_CHECK(M, N, batch_count, lda, strideA, hCPU, hGPU, err, NEAR_ASSERT) \ + do \ + { \ + for(size_t k = 0; k < batch_count; k++) \ + for(size_t j = 0; j < N; j++) \ + for(size_t i = 0; i < M; i++) \ + NEAR_ASSERT( \ + hCPU[i + j * lda + k * strideA], hGPU[i + j * lda + k * strideA], err); \ } while(0) -#define NEAR_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, err, NEAR_ASSERT) \ - do \ - { \ - for(size_t k = 0; k < batch_count; k++) \ - for(size_t j = 0; j < N; j++) \ - for(size_t i = 0; i < M; i++) \ - if(rocblas_isnan(hCPU[k][i + j * lda])) { \ - ASSERT_TRUE(rocblas_isnan(hGPU[k][i + j * lda])); \ - } else { \ - NEAR_ASSERT(hCPU[k][i + j * lda], \ - hGPU[k][i + j * lda], \ - err); \ - } \ +#define NEAR_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, err, NEAR_ASSERT) \ + do \ + { \ + for(size_t k = 0; k < batch_count; k++) \ + for(size_t j = 0; j < N; j++) \ + for(size_t i = 0; i < M; i++) \ + if(rocblas_isnan(hCPU[k][i + j * lda])) \ + { \ + ASSERT_TRUE(rocblas_isnan(hGPU[k][i + j * lda])); \ + } \ + else \ + { \ + NEAR_ASSERT(hCPU[k][i + j * lda], hGPU[k][i + j * lda], err); \ + } \ } while(0) -// clang-format on + #endif #define NEAR_ASSERT_HALF(a, b, err) ASSERT_NEAR(half_to_float(a), half_to_float(b), err) diff --git a/clients/include/rocblas_datatype2string.hpp b/clients/include/rocblas_datatype2string.hpp index 7f0555ba7..cd7977fbb 100644 --- a/clients/include/rocblas_datatype2string.hpp +++ b/clients/include/rocblas_datatype2string.hpp @@ -195,20 +195,18 @@ constexpr rocblas_side char2rocblas_side(char value) } } -inline rocblas_initialization string2rocblas_initialization(const std::string& value) +// clang-format off +nline rocblas_initialization string2rocblas_initialization(const std::string& value) { - // clang-format off return value == "rand_int" ? rocblas_initialization_random_int : value == "trig_float" ? rocblas_initialization_trig_float : value == "hpl" ? rocblas_initialization_hpl : static_cast(-1); - // clang-format on } inline rocblas_datatype string2rocblas_datatype(const std::string& value) { - // clang-format off return value == "f16_r" || value == "h" ? rocblas_datatype_f16_r : value == "f32_r" || value == "s" ? rocblas_datatype_f32_r : @@ -227,7 +225,7 @@ inline rocblas_datatype string2rocblas_datatype(const std::string& value) value == "u8_c" ? rocblas_datatype_u8_c : value == "u32_c" ? rocblas_datatype_u32_c : static_cast(-1); - // clang-format on } +// clang-format on #endif diff --git a/clients/include/testing_gemm.hpp b/clients/include/testing_gemm.hpp index 46f25d0a5..7fb89a2e2 100644 --- a/clients/include/testing_gemm.hpp +++ b/clients/include/testing_gemm.hpp @@ -92,14 +92,14 @@ void testing_gemm(const Arguments& arg) T h_alpha = arg.get_alpha(); T h_beta = arg.get_beta(); - double gpu_time_used, cpu_time_used; - double rocblas_gflops, cblas_gflops; - double rocblas_error = 0.0; + double gpu_time_used, cpu_time_used; + double rocblas_gflops, cblas_gflops; + double rocblas_error = 0.0; #ifdef USE_TENSILE_HOST - const char* host_lib_path = arg.host_lib_path; + const char* host_lib_path = arg.host_lib_path; rocblas_local_handle handle(host_lib_path); #else - rocblas_local_handle handle; + rocblas_local_handle handle; #endif rocblas_int A_row = transA == rocblas_operation_none ? M : K; rocblas_int A_col = transA == rocblas_operation_none ? K : M; diff --git a/clients/include/unit.hpp b/clients/include/unit.hpp index 595a75054..be4bf4ae9 100644 --- a/clients/include/unit.hpp +++ b/clients/include/unit.hpp @@ -20,34 +20,37 @@ #define UNIT_CHECK(M, N, batch_count, lda, strideA, hCPU, hGPU, UNIT_ASSERT_EQ) #define UNIT_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, UNIT_ASSERT_EQ) #else -// clang-format off #define UNIT_CHECK(M, N, batch_count, lda, strideA, hCPU, hGPU, UNIT_ASSERT_EQ) \ do \ { \ for(size_t k = 0; k < batch_count; k++) \ for(size_t j = 0; j < N; j++) \ for(size_t i = 0; i < M; i++) \ - if (rocblas_isnan(hCPU[i + j * lda + k * strideA])) { \ + if(rocblas_isnan(hCPU[i + j * lda + k * strideA])) \ + { \ ASSERT_TRUE(rocblas_isnan(hGPU[i + j * lda + k * strideA])); \ - } else { \ + } \ + else \ + { \ UNIT_ASSERT_EQ(hCPU[i + j * lda + k * strideA], \ hGPU[i + j * lda + k * strideA]); \ } \ } while(0) -#define UNIT_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, UNIT_ASSERT_EQ) \ - do \ - { \ - for(size_t k = 0; k < batch_count; k++) \ - for(size_t j = 0; j < N; j++) \ - for(size_t i = 0; i < M; i++) \ - if (rocblas_isnan(hCPU[k][i + j * lda])) { \ - ASSERT_TRUE(rocblas_isnan(hGPU[k][i + j * lda])); \ - } else { \ - UNIT_ASSERT_EQ(hCPU[k][i + j * lda], \ - hGPU[k][i + j * lda]); \ - } \ +#define UNIT_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, UNIT_ASSERT_EQ) \ + do \ + { \ + for(size_t k = 0; k < batch_count; k++) \ + for(size_t j = 0; j < N; j++) \ + for(size_t i = 0; i < M; i++) \ + if(rocblas_isnan(hCPU[k][i + j * lda])) \ + { \ + ASSERT_TRUE(rocblas_isnan(hGPU[k][i + j * lda])); \ + } \ + else \ + { \ + UNIT_ASSERT_EQ(hCPU[k][i + j * lda], hGPU[k][i + j * lda]); \ + } \ } while(0) -// clang-format on #endif #define ASSERT_HALF_EQ(a, b) ASSERT_FLOAT_EQ(half_to_float(a), half_to_float(b)) diff --git a/library/include/rocblas-auxiliary.h b/library/include/rocblas-auxiliary.h index 3639f5084..f2907b678 100644 --- a/library/include/rocblas-auxiliary.h +++ b/library/include/rocblas-auxiliary.h @@ -19,9 +19,7 @@ extern "C" { /*! \brief create handle */ ROCBLAS_EXPORT rocblas_status rocblas_create_handle(rocblas_handle* handle); -#ifdef USE_TENSILE_HOST -ROCBLAS_EXPORT rocblas_status rocblas_create_host_handle(rocblas_handle* handle, const char*); -#endif + /*! \brief destroy handle */ ROCBLAS_EXPORT rocblas_status rocblas_destroy_handle(rocblas_handle handle); diff --git a/library/include/tensile_host.hpp b/library/include/tensile_host.hpp index 351503327..d085ba7ae 100644 --- a/library/include/tensile_host.hpp +++ b/library/include/tensile_host.hpp @@ -1,119 +1,142 @@ - - - - - #pragma once #ifndef __TENSILE_HOST_HPP__ #define __TENSILE_HOST_HPP__ #ifdef USE_TENSILE_HOST +#include "rocblas.h" -enum ContractionProblemType { - GEMM, GEMMStridedBatch +enum ContractionProblemType +{ + GEMM, + GEMMStridedBatch, }; template class RocblasContractionProblem { public: - ContractionProblemType problem_type; - rocblas_operation trans_a; - rocblas_operation trans_b; - unsigned long m; - unsigned long n; - unsigned long k; - const T* alpha; - const T* A; - const unsigned long ld_a; - unsigned long stride_a; - const T* B; - unsigned long ld_b; - unsigned long stride_b; - const T* beta; - T* C; - unsigned long ld_c; - unsigned long stride_c; - unsigned long batch_size; + ContractionProblemType problem_type; + rocblas_operation trans_a; + rocblas_operation trans_b; + unsigned long m; + unsigned long n; + unsigned long k; + const T* alpha; + const T* A; + const unsigned long ld_a; + unsigned long stride_a; + const T* B; + unsigned long ld_b; + unsigned long stride_b; + const T* beta; + T* C; + unsigned long ld_c; + unsigned long stride_c; + unsigned long batch_size; RocblasContractionProblem(ContractionProblemType problem_type, - rocblas_operation trans_a, - rocblas_operation trans_b, - unsigned long m, - unsigned long n, - unsigned long k, - const T* alpha, - const T* A, - unsigned long ld_a, - const T* B, - unsigned long ld_b, - const T* beta, - T* C, - unsigned long ld_c): problem_type(problem_type), - trans_a(trans_a),trans_b(trans_b), - m(m),n(n),k(k),alpha(alpha), - A(A),ld_a(ld_a),B(B),stride_a(1),ld_b(ld_b),stride_b(1),beta(beta), - C(C),ld_c(ld_c),stride_c(1),batch_size(1) + rocblas_operation trans_a, + rocblas_operation trans_b, + unsigned long m, + unsigned long n, + unsigned long k, + const T* alpha, + const T* A, + unsigned long ld_a, + const T* B, + unsigned long ld_b, + const T* beta, + T* C, + unsigned long ld_c) + : problem_type(problem_type) + , trans_a(trans_a) + , trans_b(trans_b) + , m(m) + , n(n) + , k(k) + , alpha(alpha) + , A(A) + , ld_a(ld_a) + , B(B) + , stride_a(1) + , ld_b(ld_b) + , stride_b(1) + , beta(beta) + , C(C) + , ld_c(ld_c) + , stride_c(1) + , batch_size(1) { } RocblasContractionProblem(ContractionProblemType problem_type, - rocblas_operation trans_a, - rocblas_operation trans_b, - unsigned long m, - unsigned long n, - unsigned long k, - const T* alpha, - const T* A, - unsigned long ld_a, - unsigned long stride_a, - const T* B, - unsigned long ld_b, - unsigned long stride_b, - const T* beta, - T* C, - unsigned long ld_c, - unsigned long stride_c, - unsigned long batch_size): problem_type(problem_type), - trans_a(trans_a),trans_b(trans_b), - m(m),n(n),k(k),alpha(alpha), - A(A),ld_a(ld_a),B(B),stride_a(stride_a),ld_b(ld_b),stride_b(stride_b),beta(beta), - C(C),ld_c(ld_c),stride_c(stride_c),batch_size(batch_size) + rocblas_operation trans_a, + rocblas_operation trans_b, + unsigned long m, + unsigned long n, + unsigned long k, + const T* alpha, + const T* A, + unsigned long ld_a, + unsigned long stride_a, + const T* B, + unsigned long ld_b, + unsigned long stride_b, + const T* beta, + T* C, + unsigned long ld_c, + unsigned long stride_c, + unsigned long batch_size) + : problem_type(problem_type) + , trans_a(trans_a) + , trans_b(trans_b) + , m(m) + , n(n) + , k(k) + , alpha(alpha) + , A(A) + , ld_a(ld_a) + , B(B) + , stride_a(stride_a) + , ld_b(ld_b) + , stride_b(stride_b) + , beta(beta) + , C(C) + , ld_c(ld_c) + , stride_c(stride_c) + , batch_size(batch_size) { } - }; - - class TensileHost { - public: - virtual void initializeHost(const char*) {} +public: + virtual void initializeHost(const char*) {} }; template class TensileHostCall { public: - rocblas_status runContractionProblem(RocblasContractionProblem *problem, TensileHost *host); + rocblas_status runContractionProblem(RocblasContractionProblem* problem, TensileHost* host); }; +TensileHost* createTensileHost(); -TensileHost *createTensileHost(); - -rocblas_status callTensileContraction_half(RocblasContractionProblem *problem, TensileHost *host); -rocblas_status callTensileContraction_float(RocblasContractionProblem *problem, TensileHost *host); -rocblas_status callTensileContraction_double(RocblasContractionProblem *problem, TensileHost *host); -rocblas_status callTensileContraction_float_complex(RocblasContractionProblem *problem, TensileHost *host); -rocblas_status callTensileContraction_double_complex(RocblasContractionProblem *problem, TensileHost *host); - +rocblas_status callTensileContraction_half(RocblasContractionProblem* problem, + TensileHost* host); +rocblas_status callTensileContraction_float(RocblasContractionProblem* problem, + TensileHost* host); +rocblas_status callTensileContraction_double(RocblasContractionProblem* problem, + TensileHost* host); +rocblas_status + callTensileContraction_float_complex(RocblasContractionProblem* problem, + TensileHost* host); +rocblas_status callTensileContraction_double_complex( + RocblasContractionProblem* problem, TensileHost* host); #endif #endif // __TENSILE_HOST_HPP__ - - - - diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index 68e052dd3..547dfa25d 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -26,12 +26,8 @@ set( package_targets rocblas ) set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) -# Set up Tensile Dependency +# Set up Tensile Dependency if( BUILD_WITH_TENSILE ) - # build gemm with Tensile Host library - if( BUILD_WITH_TENSILE_HOST ) - target_compile_definitions( Tensile PUBLIC USE_TENSILE_HOST ) - endif() # If we want to build a shared rocblas lib, force Tensile to build as a static lib to absorb into rocblas if( BUILD_SHARED_LIBS ) set( ROCBLAS_SHARED_LIBS ON ) @@ -54,6 +50,10 @@ if( BUILD_WITH_TENSILE ) # Create a unique name for Tensile compiled for rocBLAS set_target_properties( Tensile PROPERTIES OUTPUT_NAME tensile-rocblas CXX_EXTENSIONS NO ) + # build gemm with Tensile Host library + if( BUILD_WITH_TENSILE_HOST ) + target_compile_definitions( Tensile PUBLIC USE_TENSILE_HOST ) + endif() target_compile_features( Tensile PRIVATE cxx_static_assert cxx_nullptr cxx_auto_type ) # Remove this check when we no longer build with older rocm stack(ie < 1.8.2) if(TARGET hip::device) diff --git a/library/src/blas3/Tensile/gemm.cpp b/library/src/blas3/Tensile/gemm.cpp index 0f081b7b2..c56d83a8d 100644 --- a/library/src/blas3/Tensile/gemm.cpp +++ b/library/src/blas3/Tensile/gemm.cpp @@ -11,7 +11,6 @@ namespace { - template static constexpr char rocblas_gemm_name[] = "unknown"; template <> @@ -44,7 +43,6 @@ namespace T* C, rocblas_int ld_c) { - // clang-format off // Perform logging if(!handle) return rocblas_status_invalid_handle; @@ -54,8 +52,9 @@ namespace return rocblas_status_invalid_pointer; auto layer_mode = handle->layer_mode; - if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | - rocblas_layer_mode_log_profile)) + if(layer_mode + & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench + | rocblas_layer_mode_log_profile)) { auto trans_a_letter = rocblas_transpose_letter(trans_a); auto trans_b_letter = rocblas_transpose_letter(trans_b); @@ -64,74 +63,74 @@ namespace { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_name, - trans_a, - trans_b, - m, - n, - k, - *alpha, - A, - ld_a, - B, - ld_b, - *beta, - C, - ld_c); + rocblas_gemm_name, + trans_a, + trans_b, + m, + n, + k, + *alpha, + A, + ld_a, + B, + ld_b, + *beta, + C, + ld_c); if(layer_mode & rocblas_layer_mode_log_bench) { std::stringstream alphass; alphass << "--alpha " << std::real(*alpha); - if (std::imag(*alpha) != 0) + if(std::imag(*alpha) != 0) alphass << " --alphai " << std::imag(*alpha); std::stringstream betass; betass << "--beta " << std::real(*beta); - if (std::imag(*beta) != 0) + if(std::imag(*beta) != 0) betass << " --betai " << std::imag(*beta); log_bench(handle, - "./rocblas-bench -f gemm -r", - rocblas_precision_string, - "--transposeA", - trans_a_letter, - "--transposeB", - trans_b_letter, - "-m", - m, - "-n", - n, - "-k", - k, - alphass.str(), - "--lda", - ld_a, - "--ldb", - ld_b, - betass.str(), - "--ldc", - ld_c); + "./rocblas-bench -f gemm -r", + rocblas_precision_string, + "--transposeA", + trans_a_letter, + "--transposeB", + trans_b_letter, + "-m", + m, + "-n", + n, + "-k", + k, + alphass.str(), + "--lda", + ld_a, + "--ldb", + ld_b, + betass.str(), + "--ldc", + ld_c); } } else { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_name, - trans_a, - trans_b, - m, - n, - k, - alpha, - A, - ld_a, - B, - ld_b, - beta, - C, - ld_c); + rocblas_gemm_name, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c); } if(layer_mode & rocblas_layer_mode_log_profile) @@ -155,46 +154,65 @@ namespace ld_c); } - rocblas_status validArgs = validateArgs(handle, trans_a, trans_b, - m, n, k, alpha, - A, ld_a, 0, - B, ld_b, 0, beta, - C, ld_c, 0, 1); + rocblas_status validArgs = validateArgs( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, 0, B, ld_b, 0, beta, C, ld_c, 0, 1); if(validArgs != rocblas_status_success) return validArgs; - return rocblas_gemm_template(handle, trans_a, trans_b, m, n, k, alpha, 0, A, 0, ld_a, 0, B, 0, ld_b, 0, beta, 0, C, 0, ld_c, 0, 1); + return rocblas_gemm_template(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + 0, + A, + 0, + ld_a, + 0, + B, + 0, + ld_b, + 0, + beta, + 0, + C, + 0, + ld_c, + 0, + 1); } template rocblas_status rocblas_gemm_kernel_name_impl(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const T* alpha, - const T* A, - rocblas_int ld_a, - rocblas_stride stride_a, - const T* B, - rocblas_int ld_b, - rocblas_stride stride_b, - const T* beta, - T* C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const T* alpha, + const T* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const T* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const T* beta, + T* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { - // clang-format off if(!handle) return rocblas_status_invalid_handle; RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); auto layer_mode = handle->layer_mode; - if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | - rocblas_layer_mode_log_profile)) + if(layer_mode + & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench + | rocblas_layer_mode_log_profile)) { auto trans_a_letter = rocblas_transpose_letter(trans_a); auto trans_b_letter = rocblas_transpose_letter(trans_b); @@ -203,74 +221,74 @@ namespace { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_name, - trans_a, - trans_b, - m, - n, - k, - *alpha, - A, - ld_a, - B, - ld_b, - *beta, - C, - ld_c); + rocblas_gemm_name, + trans_a, + trans_b, + m, + n, + k, + *alpha, + A, + ld_a, + B, + ld_b, + *beta, + C, + ld_c); if(layer_mode & rocblas_layer_mode_log_bench) { std::stringstream alphass; alphass << "--alpha " << std::real(*alpha); - if (std::imag(*alpha) != 0) + if(std::imag(*alpha) != 0) alphass << " --alphai " << std::imag(*alpha); std::stringstream betass; betass << "--beta " << std::real(*beta); - if (std::imag(*beta) != 0) + if(std::imag(*beta) != 0) betass << " --betai " << std::imag(*beta); log_bench(handle, - "./rocblas-bench -f gemm -r", - rocblas_precision_string, - "--transposeA", - trans_a_letter, - "--transposeB", - trans_b_letter, - "-m", - m, - "-n", - n, - "-k", - k, - alphass.str(), - "--lda", - ld_a, - "--ldb", - ld_b, - betass.str(), - "--ldc", - ld_c); + "./rocblas-bench -f gemm -r", + rocblas_precision_string, + "--transposeA", + trans_a_letter, + "--transposeB", + trans_b_letter, + "-m", + m, + "-n", + n, + "-k", + k, + alphass.str(), + "--lda", + ld_a, + "--ldb", + ld_b, + betass.str(), + "--ldc", + ld_c); } } else { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_name, - trans_a, - trans_b, - m, - n, - k, - alpha, - A, - ld_a, - B, - ld_b, - beta, - C, - ld_c); + rocblas_gemm_name, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c); } if(layer_mode & rocblas_layer_mode_log_profile) @@ -294,297 +312,366 @@ namespace ld_c); } - rocblas_status validArgs = validateArgs(handle, trans_a, trans_b, - m, n, k, alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, beta, - C, ld_c, stride_c, b_c); + rocblas_status validArgs = validateArgs(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); if(validArgs != rocblas_status_success) return validArgs; - rocblas_gemm_kernel_name_template(trans_a, trans_b, m, n, k, ld_a, stride_a, ld_b, stride_b, ld_c, stride_c, b_c); + rocblas_gemm_kernel_name_template( + trans_a, trans_b, m, n, k, ld_a, stride_a, ld_b, stride_b, ld_c, stride_c, b_c); return validArgs; } } - extern "C" { /******************************************************************************* * GEMM APIs ******************************************************************************/ -rocblas_status rocblas_hgemm(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_half *alpha, - const rocblas_half *A, - rocblas_int ld_a, - const rocblas_half *B, - rocblas_int ld_b, - const rocblas_half *beta, - rocblas_half *C, - rocblas_int ld_c) +rocblas_status rocblas_hgemm(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_half* alpha, + const rocblas_half* A, + rocblas_int ld_a, + const rocblas_half* B, + rocblas_int ld_b, + const rocblas_half* beta, + rocblas_half* C, + rocblas_int ld_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMM,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a, - B,ld_b, - beta, - C,ld_c); - - return callTensileContraction_half( &problem, handle->host); + TensileHostCall hostCall; + RocblasContractionProblem problem(ContractionProblemType::GEMM, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c); + + return callTensileContraction_half(&problem, handle->host); #else - return rocblas_gemm_impl(handle, trans_a, trans_b, - m, n, k, alpha, A, ld_a, - B, ld_b, beta, C, ld_c); + return rocblas_gemm_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c); #endif } -rocblas_status rocblas_sgemm(rocblas_handle handle, +rocblas_status rocblas_sgemm(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const float *alpha, - const float *A, - rocblas_int ld_a, - const float *B, - rocblas_int ld_b, - const float *beta, - float *C, - rocblas_int ld_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const float* alpha, + const float* A, + rocblas_int ld_a, + const float* B, + rocblas_int ld_b, + const float* beta, + float* C, + rocblas_int ld_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMM,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a, - B,ld_b, - beta, - C,ld_c); - - return callTensileContraction_float( &problem, handle->host); + TensileHostCall hostCall; + RocblasContractionProblem problem(ContractionProblemType::GEMM, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c); + + return callTensileContraction_float(&problem, handle->host); #else - return rocblas_gemm_impl(handle, trans_a, trans_b, - m, n, k, alpha, A, ld_a, - B, ld_b, beta, C, ld_c); + return rocblas_gemm_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c); #endif } -rocblas_status rocblas_dgemm(rocblas_handle handle, +rocblas_status rocblas_dgemm(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const double *alpha, - const double *A, - rocblas_int ld_a, - const double *B, - rocblas_int ld_b, - const double *beta, - double *C, - rocblas_int ld_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const double* alpha, + const double* A, + rocblas_int ld_a, + const double* B, + rocblas_int ld_b, + const double* beta, + double* C, + rocblas_int ld_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMM,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a, - B,ld_b, - beta, - C,ld_c); - - return callTensileContraction_double( &problem, handle->host); + TensileHostCall hostCall; + RocblasContractionProblem problem(ContractionProblemType::GEMM, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c); + + return callTensileContraction_double(&problem, handle->host); #else - return rocblas_gemm_impl(handle, trans_a, trans_b, - m, n, k, alpha, A, ld_a, - B, ld_b, beta, C, ld_c); + return rocblas_gemm_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c); #endif } -rocblas_status rocblas_cgemm(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_float_complex *alpha, - const rocblas_float_complex *A, - rocblas_int ld_a, - const rocblas_float_complex *B, - rocblas_int ld_b, - const rocblas_float_complex *beta, - rocblas_float_complex *C, - rocblas_int ld_c) +rocblas_status rocblas_cgemm(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_float_complex* alpha, + const rocblas_float_complex* A, + rocblas_int ld_a, + const rocblas_float_complex* B, + rocblas_int ld_b, + const rocblas_float_complex* beta, + rocblas_float_complex* C, + rocblas_int ld_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMM,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a, - B,ld_b, - beta, - C,ld_c); - - return callTensileContraction_float_complex( &problem, handle->host); + TensileHostCall hostCall; + RocblasContractionProblem problem(ContractionProblemType::GEMM, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c); + + return callTensileContraction_float_complex(&problem, handle->host); #else - return rocblas_gemm_impl(handle, trans_a, trans_b, - m, n, k, alpha, A, ld_a, - B, ld_b, beta, C, ld_c); + return rocblas_gemm_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c); #endif } - -rocblas_status rocblas_zgemm(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_double_complex *alpha, - const rocblas_double_complex *A, - rocblas_int ld_a, - const rocblas_double_complex *B, - rocblas_int ld_b, - const rocblas_double_complex *beta, - rocblas_double_complex *C, - rocblas_int ld_c) +rocblas_status rocblas_zgemm(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_double_complex* alpha, + const rocblas_double_complex* A, + rocblas_int ld_a, + const rocblas_double_complex* B, + rocblas_int ld_b, + const rocblas_double_complex* beta, + rocblas_double_complex* C, + rocblas_int ld_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMM,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a, - B,ld_b, - beta, - C,ld_c); - - return callTensileContraction_double_complex( &problem, handle->host); + TensileHostCall hostCall; + RocblasContractionProblem problem(ContractionProblemType::GEMM, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c); + + return callTensileContraction_double_complex(&problem, handle->host); #else - return rocblas_gemm_impl(handle, trans_a, trans_b, - m, n, k, alpha, A, ld_a, - B, ld_b, beta, C, ld_c); + return rocblas_gemm_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c); #endif } /******************************************************************************* * GEMM Kernel name APIs ******************************************************************************/ -rocblas_status rocblas_hgemm_kernel_name(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_half *alpha, - const rocblas_half *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const rocblas_half *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const rocblas_half *beta, - rocblas_half *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) +rocblas_status rocblas_hgemm_kernel_name(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_half* alpha, + const rocblas_half* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const rocblas_half* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const rocblas_half* beta, + rocblas_half* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { - return rocblas_gemm_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_kernel_name_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); } -rocblas_status rocblas_sgemm_kernel_name(rocblas_handle handle, +rocblas_status rocblas_sgemm_kernel_name(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const float *alpha, - const float *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const float *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const float *beta, - float *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const float* alpha, + const float* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const float* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const float* beta, + float* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { - return rocblas_gemm_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_kernel_name_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); } -rocblas_status rocblas_dgemm_kernel_name(rocblas_handle handle, +rocblas_status rocblas_dgemm_kernel_name(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const double *alpha, - const double *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const double *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const double *beta, - double *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const double* alpha, + const double* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const double* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const double* beta, + double* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { - return rocblas_gemm_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_kernel_name_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); } - - } - diff --git a/library/src/blas3/Tensile/gemm_batched.cpp b/library/src/blas3/Tensile/gemm_batched.cpp index 44130ae39..c24f83c09 100644 --- a/library/src/blas3/Tensile/gemm_batched.cpp +++ b/library/src/blas3/Tensile/gemm_batched.cpp @@ -47,7 +47,6 @@ namespace rocblas_int ld_c, rocblas_int b_c) { - // clang-format off // Perform logging if(!handle) return rocblas_status_invalid_handle; @@ -57,8 +56,9 @@ namespace return rocblas_status_invalid_pointer; auto layer_mode = handle->layer_mode; - if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | - rocblas_layer_mode_log_profile)) + if(layer_mode + & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench + | rocblas_layer_mode_log_profile)) { auto trans_a_letter = rocblas_transpose_letter(trans_a); auto trans_b_letter = rocblas_transpose_letter(trans_b); @@ -67,68 +67,68 @@ namespace { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_batched_name, - trans_a, - trans_b, - m, - n, - k, - *alpha, - A, - ld_a, - B, - ld_b, - *beta, - C, - ld_c, - b_c); + rocblas_gemm_batched_name, + trans_a, + trans_b, + m, + n, + k, + *alpha, + A, + ld_a, + B, + ld_b, + *beta, + C, + ld_c, + b_c); if(layer_mode & rocblas_layer_mode_log_bench) log_bench(handle, - "./rocblas-bench -f gemm_batched -r", - rocblas_precision_string, - "--transposeA", - trans_a_letter, - "--transposeB", - trans_b_letter, - "-m", - m, - "-n", - n, - "-k", - k, - "--alpha", - *alpha, - "--lda", - ld_a, - "--ldb", - ld_b, - "--beta", - *beta, - "--ldc", - ld_c, - "--batch", - b_c); + "./rocblas-bench -f gemm_batched -r", + rocblas_precision_string, + "--transposeA", + trans_a_letter, + "--transposeB", + trans_b_letter, + "-m", + m, + "-n", + n, + "-k", + k, + "--alpha", + *alpha, + "--lda", + ld_a, + "--ldb", + ld_b, + "--beta", + *beta, + "--ldc", + ld_c, + "--batch", + b_c); } else { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_batched_name, - trans_a, - trans_b, - m, - n, - k, - alpha, - A, - ld_a, - B, - ld_b, - beta, - C, - ld_c, - b_c); + rocblas_gemm_batched_name, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c, + b_c); } if(layer_mode & rocblas_layer_mode_log_profile) @@ -154,48 +154,65 @@ namespace b_c); } - rocblas_status validArgs = validateArgs(handle, trans_a, trans_b, - m, n, k, alpha, - A, ld_a, - B, ld_b, beta, - C, ld_c, b_c); + rocblas_status validArgs = validateArgs( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); if(validArgs != rocblas_status_success) return validArgs; - return rocblas_gemm_template(handle, trans_a, trans_b, m, n, k, alpha, 0, A, 0, ld_a, 0, - B, 0, ld_b, 0, beta, 0, C, 0, ld_c, 0, b_c); + return rocblas_gemm_template(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + 0, + A, + 0, + ld_a, + 0, + B, + 0, + ld_b, + 0, + beta, + 0, + C, + 0, + ld_c, + 0, + b_c); } - /** * Kernel Name Function. */ template rocblas_status rocblas_gemm_batched_kernel_name_impl(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const T* alpha, - const T* A[], - rocblas_int ld_a, - const T* B[], - rocblas_int ld_b, - const T* beta, - T* C[], - rocblas_int ld_c, - rocblas_int b_c) + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const T* alpha, + const T* A[], + rocblas_int ld_a, + const T* B[], + rocblas_int ld_b, + const T* beta, + T* C[], + rocblas_int ld_c, + rocblas_int b_c) { - // clang-format off if(!handle) return rocblas_status_invalid_handle; RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); auto layer_mode = handle->layer_mode; - if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | - rocblas_layer_mode_log_profile)) + if(layer_mode + & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench + | rocblas_layer_mode_log_profile)) { auto trans_a_letter = rocblas_transpose_letter(trans_a); auto trans_b_letter = rocblas_transpose_letter(trans_b); @@ -204,68 +221,68 @@ namespace { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_batched_name, - trans_a, - trans_b, - m, - n, - k, - *alpha, - A, - ld_a, - B, - ld_b, - *beta, - C, - ld_c, - b_c); + rocblas_gemm_batched_name, + trans_a, + trans_b, + m, + n, + k, + *alpha, + A, + ld_a, + B, + ld_b, + *beta, + C, + ld_c, + b_c); if(layer_mode & rocblas_layer_mode_log_bench) log_bench(handle, - "./rocblas-bench -f gemm_batched -r", - rocblas_precision_string, - "--transposeA", - trans_a_letter, - "--transposeB", - trans_b_letter, - "-m", - m, - "-n", - n, - "-k", - k, - "--alpha", - *alpha, - "--lda", - ld_a, - "--ldb", - ld_b, - "--beta", - *beta, - "--ldc", - ld_c, - "--batch", - b_c); + "./rocblas-bench -f gemm_batched -r", + rocblas_precision_string, + "--transposeA", + trans_a_letter, + "--transposeB", + trans_b_letter, + "-m", + m, + "-n", + n, + "-k", + k, + "--alpha", + *alpha, + "--lda", + ld_a, + "--ldb", + ld_b, + "--beta", + *beta, + "--ldc", + ld_c, + "--batch", + b_c); } else { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_batched_name, - trans_a, - trans_b, - m, - n, - k, - alpha, - A, - ld_a, - B, - ld_b, - beta, - C, - ld_c, - b_c); + rocblas_gemm_batched_name, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + B, + ld_b, + beta, + C, + ld_c, + b_c); } if(layer_mode & rocblas_layer_mode_log_profile) @@ -291,31 +308,43 @@ namespace b_c); } - rocblas_stride stride_a; rocblas_stride stride_b; rocblas_stride stride_c; - infer_batch_strides(trans_a, trans_b, m, n, k, ld_a, - &stride_a, ld_b, &stride_b, ld_c, &stride_c); - - rocblas_status validArgs = validateArgs(handle, trans_a, trans_b, - m, n, k, alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, beta, - C, ld_c, stride_c, b_c); + infer_batch_strides( + trans_a, trans_b, m, n, k, ld_a, &stride_a, ld_b, &stride_b, ld_c, &stride_c); + + rocblas_status validArgs = validateArgs(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); if(validArgs != rocblas_status_success) return validArgs; - rocblas_gemm_kernel_name_template(trans_a, trans_b, m, n, k, ld_a, stride_a, ld_b, stride_b, ld_c, stride_c, b_c); + rocblas_gemm_kernel_name_template( + trans_a, trans_b, m, n, k, ld_a, stride_a, ld_b, stride_b, ld_c, stride_c, b_c); return validArgs; } } - /******************************************************************************* * Batched GEMM APIs ******************************************************************************/ @@ -337,10 +366,11 @@ rocblas_status rocblas_hgemm_batched(rocblas_handle handle, rocblas_int ld_c, rocblas_int b_c) { - return rocblas_gemm_batched_impl(handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); + return rocblas_gemm_batched_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); } -rocblas_status rocblas_sgemm_batched(rocblas_handle handle, +rocblas_status rocblas_sgemm_batched(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, rocblas_int m, @@ -356,10 +386,11 @@ rocblas_status rocblas_sgemm_batched(rocblas_handle handle, rocblas_int ld_c, rocblas_int b_c) { - return rocblas_gemm_batched_impl(handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); + return rocblas_gemm_batched_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); } -rocblas_status rocblas_dgemm_batched(rocblas_handle handle, +rocblas_status rocblas_dgemm_batched(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, rocblas_int m, @@ -375,10 +406,11 @@ rocblas_status rocblas_dgemm_batched(rocblas_handle handle, rocblas_int ld_c, rocblas_int b_c) { - return rocblas_gemm_batched_impl(handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); + return rocblas_gemm_batched_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); } -rocblas_status rocblas_cgemm_batched(rocblas_handle handle, +rocblas_status rocblas_cgemm_batched(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, rocblas_int m, @@ -394,10 +426,11 @@ rocblas_status rocblas_cgemm_batched(rocblas_handle handle, rocblas_int ld_c, rocblas_int b_c) { - return rocblas_gemm_batched_impl(handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); + return rocblas_gemm_batched_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); } -rocblas_status rocblas_zgemm_batched(rocblas_handle handle, +rocblas_status rocblas_zgemm_batched(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, rocblas_int m, @@ -413,91 +446,70 @@ rocblas_status rocblas_zgemm_batched(rocblas_handle handle, rocblas_int ld_c, rocblas_int b_c) { - return rocblas_gemm_batched_impl(handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); + return rocblas_gemm_batched_impl( + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); } - /******************************************************************************* * Batched GEMM Kernel name APIs ******************************************************************************/ -rocblas_status rocblas_hgemm_batched_kernel_name(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_half *alpha, - const rocblas_half *A[], - rocblas_int ld_a, - const rocblas_half *B[], - rocblas_int ld_b, - const rocblas_half *beta, - rocblas_half *C[], - rocblas_int ld_c, - rocblas_int b_c) +rocblas_status rocblas_hgemm_batched_kernel_name(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_half* alpha, + const rocblas_half* A[], + rocblas_int ld_a, + const rocblas_half* B[], + rocblas_int ld_b, + const rocblas_half* beta, + rocblas_half* C[], + rocblas_int ld_c, + rocblas_int b_c) { return rocblas_gemm_batched_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, - B, ld_b, - beta, - C, ld_c, b_c); + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); } -rocblas_status rocblas_sgemm_batched_kernel_name(rocblas_handle handle, +rocblas_status rocblas_sgemm_batched_kernel_name(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const float *alpha, - const float *A[], - rocblas_int ld_a, - const float *B[], - rocblas_int ld_b, - const float *beta, - float *C[], - rocblas_int ld_c, - rocblas_int b_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const float* alpha, + const float* A[], + rocblas_int ld_a, + const float* B[], + rocblas_int ld_b, + const float* beta, + float* C[], + rocblas_int ld_c, + rocblas_int b_c) { return rocblas_gemm_batched_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, - B, ld_b, - beta, - C, ld_c, b_c); + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); } -rocblas_status rocblas_dgemm_batched_kernel_name(rocblas_handle handle, +rocblas_status rocblas_dgemm_batched_kernel_name(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const double *alpha, - const double *A[], - rocblas_int ld_a, - const double *B[], - rocblas_int ld_b, - const double *beta, - double *C[], - rocblas_int ld_c, - rocblas_int b_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const double* alpha, + const double* A[], + rocblas_int ld_a, + const double* B[], + rocblas_int ld_b, + const double* beta, + double* C[], + rocblas_int ld_c, + rocblas_int b_c) { return rocblas_gemm_batched_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, - B, ld_b, - beta, - C, ld_c, b_c); + handle, trans_a, trans_b, m, n, k, alpha, A, ld_a, B, ld_b, beta, C, ld_c, b_c); +} } - - - -} \ No newline at end of file diff --git a/library/src/blas3/Tensile/gemm_strided_batched.cpp b/library/src/blas3/Tensile/gemm_strided_batched.cpp index 76bb4877d..32eb36e11 100644 --- a/library/src/blas3/Tensile/gemm_strided_batched.cpp +++ b/library/src/blas3/Tensile/gemm_strided_batched.cpp @@ -59,15 +59,15 @@ namespace rocblas_int b_c) { - // clang-format off if(!handle) return rocblas_status_invalid_handle; RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); auto layer_mode = handle->layer_mode; - if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | - rocblas_layer_mode_log_profile)) + if(layer_mode + & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench + | rocblas_layer_mode_log_profile)) { auto trans_a_letter = rocblas_transpose_letter(trans_a); auto trans_b_letter = rocblas_transpose_letter(trans_b); @@ -76,66 +76,66 @@ namespace { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_strided_batched_name, - trans_a, - trans_b, - m, - n, - k, - *alpha, - A, - ld_a, - stride_a, - B, - ld_b, - stride_b, - *beta, - C, - ld_c, - stride_c, - b_c); + rocblas_gemm_strided_batched_name, + trans_a, + trans_b, + m, + n, + k, + *alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + *beta, + C, + ld_c, + stride_c, + b_c); if(layer_mode & rocblas_layer_mode_log_bench) { std::stringstream alphass; alphass << "--alpha " << std::real(*alpha); - if (std::imag(*alpha) != 0) + if(std::imag(*alpha) != 0) alphass << " --alphai " << std::imag(*alpha); std::stringstream betass; betass << "--beta " << std::real(*beta); - if (std::imag(*beta) != 0) + if(std::imag(*beta) != 0) betass << " --betai " << std::imag(*beta); log_bench(handle, - "./rocblas-bench -f gemm_strided_batched -r", - rocblas_precision_string, - "--transposeA", - trans_a_letter, - "--transposeB", - trans_b_letter, - "-m", - m, - "-n", - n, - "-k", - k, - alphass.str(), - "--lda", - ld_a, - "--stride_a", - stride_a, - "--ldb", - ld_b, - "--stride_b", - stride_b, - betass.str(), - "--ldc", - ld_c, - "--stride_c", - stride_c, - "--batch", - b_c); + "./rocblas-bench -f gemm_strided_batched -r", + rocblas_precision_string, + "--transposeA", + trans_a_letter, + "--transposeB", + trans_b_letter, + "-m", + m, + "-n", + n, + "-k", + k, + alphass.str(), + "--lda", + ld_a, + "--stride_a", + stride_a, + "--ldb", + ld_b, + "--stride_b", + stride_b, + betass.str(), + "--ldc", + ld_c, + "--stride_c", + stride_c, + "--batch", + b_c); } } else @@ -143,24 +143,24 @@ namespace if(layer_mode & rocblas_layer_mode_log_trace) { log_trace(handle, - rocblas_gemm_strided_batched_name, - trans_a, - trans_b, - m, - n, - k, - alpha, - A, - ld_a, - stride_a, - B, - ld_b, - stride_b, - beta, - C, - ld_c, - stride_c, - b_c); + rocblas_gemm_strided_batched_name, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); } } @@ -195,18 +195,51 @@ namespace } } - rocblas_status validArgs = validateArgs(handle, trans_a, trans_b, - m, n, k, alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, beta, - C, ld_c, stride_c, b_c); + rocblas_status validArgs = validateArgs(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); if(validArgs != rocblas_status_success) return validArgs; - return rocblas_gemm_template(handle, trans_a, trans_b, m, n, k, alpha, 0, A, 0, ld_a, stride_a, B, 0, ld_b, stride_b, beta, 0, C, 0, ld_c, stride_c, b_c); - - // clang-format on + return rocblas_gemm_template(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + 0, + A, + 0, + ld_a, + stride_a, + B, + 0, + ld_b, + stride_b, + beta, + 0, + C, + 0, + ld_c, + stride_c, + b_c); } /******************************************************************************* @@ -232,15 +265,15 @@ namespace rocblas_stride stride_c, rocblas_int b_c) { - // clang-format off if(!handle) return rocblas_status_invalid_handle; RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); auto layer_mode = handle->layer_mode; - if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | - rocblas_layer_mode_log_profile)) + if(layer_mode + & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench + | rocblas_layer_mode_log_profile)) { auto trans_a_letter = rocblas_transpose_letter(trans_a); auto trans_b_letter = rocblas_transpose_letter(trans_b); @@ -249,80 +282,80 @@ namespace { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_strided_batched_name, - trans_a, - trans_b, - m, - n, - k, - *alpha, - A, - ld_a, - stride_a, - B, - ld_b, - stride_b, - *beta, - C, - ld_c, - stride_c, - b_c); + rocblas_gemm_strided_batched_name, + trans_a, + trans_b, + m, + n, + k, + *alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + *beta, + C, + ld_c, + stride_c, + b_c); if(layer_mode & rocblas_layer_mode_log_bench) log_bench(handle, - "./rocblas-bench -f gemm_strided_batched -r", - rocblas_precision_string, - "--transposeA", - trans_a_letter, - "--transposeB", - trans_b_letter, - "-m", - m, - "-n", - n, - "-k", - k, - "--alpha", - *alpha, - "--lda", - ld_a, - "--bsa", - stride_a, - "--ldb", - ld_b, - "--bsb", - stride_b, - "--beta", - *beta, - "--ldc", - ld_c, - "--bsc", - stride_c, - "--batch", - b_c); + "./rocblas-bench -f gemm_strided_batched -r", + rocblas_precision_string, + "--transposeA", + trans_a_letter, + "--transposeB", + trans_b_letter, + "-m", + m, + "-n", + n, + "-k", + k, + "--alpha", + *alpha, + "--lda", + ld_a, + "--bsa", + stride_a, + "--ldb", + ld_b, + "--bsb", + stride_b, + "--beta", + *beta, + "--ldc", + ld_c, + "--bsc", + stride_c, + "--batch", + b_c); } else { if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, - rocblas_gemm_strided_batched_name, - trans_a, - trans_b, - m, - n, - k, - alpha, - A, - ld_a, - stride_a, - B, - ld_b, - stride_b, - beta, - C, - ld_c, - stride_c, - b_c); + rocblas_gemm_strided_batched_name, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); } if(layer_mode & rocblas_layer_mode_log_profile) @@ -354,16 +387,30 @@ namespace b_c); } - rocblas_status validArgs = validateArgs(handle, trans_a, trans_b, - m, n, k, alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, beta, - C, ld_c, stride_c, b_c); + rocblas_status validArgs = validateArgs(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); if(validArgs != rocblas_status_success) return validArgs; - rocblas_gemm_kernel_name_template(trans_a, trans_b, m, n, k, ld_a, stride_a, ld_b, stride_b, ld_c, stride_c, b_c); + rocblas_gemm_kernel_name_template( + trans_a, trans_b, m, n, k, ld_a, stride_a, ld_b, stride_b, ld_c, stride_c, b_c); return validArgs; } @@ -376,316 +423,445 @@ extern "C" { * Strided_Batched GEMM APIs ******************************************************************************/ -rocblas_status rocblas_hgemm_strided_batched(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_half *alpha, - const rocblas_half *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const rocblas_half *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const rocblas_half *beta, - rocblas_half *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) +rocblas_status rocblas_hgemm_strided_batched(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_half* alpha, + const rocblas_half* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const rocblas_half* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const rocblas_half* beta, + rocblas_half* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a,stride_a, - B,ld_b,stride_b, - beta, - C,ld_c,stride_c, - b_c); - - return callTensileContraction_half( &problem, handle->host); + TensileHostCall hostCall; + RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); + + return callTensileContraction_half(&problem, handle->host); #else - return rocblas_gemm_strided_batched_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_strided_batched_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); #endif } -rocblas_status rocblas_sgemm_strided_batched(rocblas_handle handle, +rocblas_status rocblas_sgemm_strided_batched(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const float *alpha, - const float *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const float *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const float *beta, - float *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const float* alpha, + const float* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const float* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const float* beta, + float* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a,stride_a, - B,ld_b,stride_b, - beta, - C,ld_c,stride_c, - b_c); - - return callTensileContraction_float( &problem, handle->host); + TensileHostCall hostCall; + RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); + + return callTensileContraction_float(&problem, handle->host); #else - return rocblas_gemm_strided_batched_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_strided_batched_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); #endif } -rocblas_status rocblas_dgemm_strided_batched(rocblas_handle handle, +rocblas_status rocblas_dgemm_strided_batched(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const double *alpha, - const double *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const double *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const double *beta, - double *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const double* alpha, + const double* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const double* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const double* beta, + double* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a,stride_a, - B,ld_b,stride_b, - beta, - C,ld_c,stride_c, - b_c); - - return callTensileContraction_double( &problem, handle->host); + TensileHostCall hostCall; + RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); + + return callTensileContraction_double(&problem, handle->host); #else - return rocblas_gemm_strided_batched_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_strided_batched_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); #endif } -rocblas_status rocblas_cgemm_strided_batched(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_float_complex *alpha, - const rocblas_float_complex *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const rocblas_float_complex *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const rocblas_float_complex *beta, - rocblas_float_complex *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) +rocblas_status rocblas_cgemm_strided_batched(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_float_complex* alpha, + const rocblas_float_complex* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const rocblas_float_complex* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const rocblas_float_complex* beta, + rocblas_float_complex* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a,stride_a, - B,ld_b,stride_b, - beta, - C,ld_c,stride_c, - b_c); - - return callTensileContraction_float_complex( &problem, handle->host); -#else - return rocblas_gemm_strided_batched_impl( - handle, trans_a, trans_b, - m, n, k, + TensileHostCall hostCall; + RocblasContractionProblem problem( + ContractionProblemType::GEMMStridedBatch, + trans_a, + trans_b, + m, + n, + k, alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, beta, - C, ld_c, stride_c, b_c); + C, + ld_c, + stride_c, + b_c); + + return callTensileContraction_float_complex(&problem, handle->host); +#else + return rocblas_gemm_strided_batched_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); #endif } -rocblas_status rocblas_zgemm_strided_batched(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_double_complex *alpha, - const rocblas_double_complex *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const rocblas_double_complex *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const rocblas_double_complex *beta, - rocblas_double_complex *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) +rocblas_status rocblas_zgemm_strided_batched(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_double_complex* alpha, + const rocblas_double_complex* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const rocblas_double_complex* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const rocblas_double_complex* beta, + rocblas_double_complex* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { #ifdef USE_TENSILE_HOST - TensileHostCall hostCall; - RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch,trans_a,trans_b, - m,n,k, - alpha, - A,ld_a,stride_a, - B,ld_b,stride_b, - beta, - C,ld_c,stride_c, - b_c); + TensileHostCall hostCall; + RocblasContractionProblem problem( + ContractionProblemType::GEMMStridedBatch, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); - return callTensileContraction_double_complex( &problem, handle->host); + return callTensileContraction_double_complex(&problem, handle->host); #else - return rocblas_gemm_strided_batched_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_strided_batched_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); #endif } /******************************************************************************* * Strided Batched GEMM Kernel name APIs ******************************************************************************/ -rocblas_status rocblas_hgemm_strided_batched_kernel_name(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const rocblas_half *alpha, - const rocblas_half *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const rocblas_half *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const rocblas_half *beta, - rocblas_half *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) +rocblas_status rocblas_hgemm_strided_batched_kernel_name(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + rocblas_int m, + rocblas_int n, + rocblas_int k, + const rocblas_half* alpha, + const rocblas_half* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const rocblas_half* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const rocblas_half* beta, + rocblas_half* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { - return rocblas_gemm_strided_batched_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_strided_batched_kernel_name_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); } -rocblas_status rocblas_sgemm_strided_batched_kernel_name(rocblas_handle handle, +rocblas_status rocblas_sgemm_strided_batched_kernel_name(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const float *alpha, - const float *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const float *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const float *beta, - float *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const float* alpha, + const float* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const float* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const float* beta, + float* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { - return rocblas_gemm_strided_batched_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_strided_batched_kernel_name_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); } -rocblas_status rocblas_dgemm_strided_batched_kernel_name(rocblas_handle handle, +rocblas_status rocblas_dgemm_strided_batched_kernel_name(rocblas_handle handle, rocblas_operation trans_a, rocblas_operation trans_b, - rocblas_int m, - rocblas_int n, - rocblas_int k, - const double *alpha, - const double *A, - rocblas_int ld_a, - rocblas_stride stride_a, - const double *B, - rocblas_int ld_b, - rocblas_stride stride_b, - const double *beta, - double *C, - rocblas_int ld_c, - rocblas_stride stride_c, - rocblas_int b_c) + rocblas_int m, + rocblas_int n, + rocblas_int k, + const double* alpha, + const double* A, + rocblas_int ld_a, + rocblas_stride stride_a, + const double* B, + rocblas_int ld_b, + rocblas_stride stride_b, + const double* beta, + double* C, + rocblas_int ld_c, + rocblas_stride stride_c, + rocblas_int b_c) { - return rocblas_gemm_strided_batched_kernel_name_impl( - handle, trans_a, trans_b, - m, n, k, - alpha, - A, ld_a, stride_a, - B, ld_b, stride_b, - beta, - C, ld_c, stride_c, b_c); + return rocblas_gemm_strided_batched_kernel_name_impl(handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A, + ld_a, + stride_a, + B, + ld_b, + stride_b, + beta, + C, + ld_c, + stride_c, + b_c); } - - - } diff --git a/library/src/blas_ex/rocblas_gemm_ex.hpp b/library/src/blas_ex/rocblas_gemm_ex.hpp index bfc6b7233..a96c1d4bf 100644 --- a/library/src/blas_ex/rocblas_gemm_ex.hpp +++ b/library/src/blas_ex/rocblas_gemm_ex.hpp @@ -15,14 +15,13 @@ ///////////////// // Device Side // ///////////////// -// clang-format off void device_matrix_copy(const void* src, rocblas_int ld_src, - void* dst, + void* dst, rocblas_int ld_dst, rocblas_int n1, rocblas_int n2, - size_t elem_size) + size_t elem_size) { if(src != dst || ld_src != ld_dst) // no copy if src matrix == dst matrix { @@ -118,21 +117,18 @@ static void device_strided_batched_matrix_copy(const void* src, } //------------------------------------------------------------------------------ -#define TENSILE_IN_ARGS(Ti, To, Tc) \ - To* dataD, const To* dataC, const Ti* dataA, const Ti* dataB, \ - Tc alpha, Tc beta, \ - unsigned int strideD1J, unsigned int strideD2K, \ - unsigned int strideC1J, unsigned int strideC2K, \ - unsigned int strideA1L, unsigned int strideA2K, \ - unsigned int strideB1J, unsigned int strideB2K, \ - unsigned int sizeI, unsigned int sizeJ, unsigned int sizeK, unsigned int sizeL, hipStream_t stream, \ - unsigned int numInputEvents, void* dummy1, void* dummy2 - -#define TENSILE_OUT_ARGS \ - dataD, dataC, dataA, dataB, alpha, beta, \ - strideD1J, strideD2K, strideC1J, strideC2K, \ - strideA1L, strideA2K, strideB1J, strideB2K, \ - sizeI, sizeJ, sizeK, sizeL, stream, 0, nullptr, nullptr +#define TENSILE_IN_ARGS(Ti, To, Tc) \ + To *dataD, const To *dataC, const Ti *dataA, const Ti *dataB, Tc alpha, Tc beta, \ + unsigned int strideD1J, unsigned int strideD2K, unsigned int strideC1J, \ + unsigned int strideC2K, unsigned int strideA1L, unsigned int strideA2K, \ + unsigned int strideB1J, unsigned int strideB2K, unsigned int sizeI, unsigned int sizeJ, \ + unsigned int sizeK, unsigned int sizeL, hipStream_t stream, unsigned int numInputEvents, \ + void *dummy1, void *dummy2 + +#define TENSILE_OUT_ARGS \ + dataD, dataC, dataA, dataB, alpha, beta, strideD1J, strideD2K, strideC1J, strideC2K, \ + strideA1L, strideA2K, strideB1J, strideB2K, sizeI, sizeJ, sizeK, sizeL, stream, 0, \ + nullptr, nullptr // Ti is typename for input data, To is typename for output data, Tc is typename for compute template @@ -162,31 +158,31 @@ inline TensileStatus tensile_Cijk_Alik_Bjlk_B(TENSILE_IN_ARGS(Ti, To, Tc)) template inline TensileStatus tensile_Cijk_Ailk_BjlkC_B(TENSILE_IN_ARGS(Ti, To, Tc)) { - return tensile_Cijk_Ailk_Bjlk_B(TENSILE_OUT_ARGS); + return tensile_Cijk_Ailk_Bjlk_B(TENSILE_OUT_ARGS); } template inline TensileStatus tensile_Cijk_AlikC_Bljk_B(TENSILE_IN_ARGS(Ti, To, Tc)) { - return tensile_Cijk_Alik_Bljk_B(TENSILE_OUT_ARGS); + return tensile_Cijk_Alik_Bljk_B(TENSILE_OUT_ARGS); } template inline TensileStatus tensile_Cijk_Alik_BjlkC_B(TENSILE_IN_ARGS(Ti, To, Tc)) { - return tensile_Cijk_Alik_Bjlk_B(TENSILE_OUT_ARGS); + return tensile_Cijk_Alik_Bjlk_B(TENSILE_OUT_ARGS); } template inline TensileStatus tensile_Cijk_AlikC_Bjlk_B(TENSILE_IN_ARGS(Ti, To, Tc)) { - return tensile_Cijk_Alik_Bjlk_B(TENSILE_OUT_ARGS); + return tensile_Cijk_Alik_Bjlk_B(TENSILE_OUT_ARGS); } template inline TensileStatus tensile_Cijk_AlikC_BjlkC_B(TENSILE_IN_ARGS(Ti, To, Tc)) { - return tensile_Cijk_Alik_Bjlk_B(TENSILE_OUT_ARGS); + return tensile_Cijk_Alik_Bjlk_B(TENSILE_OUT_ARGS); } //----- typename_data = tensile_bfloat16 ----- typename_compute = float ----------------------- @@ -216,16 +212,14 @@ inline TensileStatus tensile_Cijk_Alik_Bjlk_B -inline TensileStatus tensile_Cijk_Ailk_Bljk_B(TENSILE_IN_ARGS(TensileHalf, - TensileHalf, - float)) +inline TensileStatus tensile_Cijk_Ailk_Bljk_B( + TENSILE_IN_ARGS(TensileHalf, TensileHalf, float)) { //TODO: alpha and beta need to have precision equal to compute type, not data type TensileHalf alpha_half(alpha); @@ -233,9 +227,8 @@ inline TensileStatus tensile_Cijk_Ailk_Bljk_B(T return tensile_Cijk_Ailk_Bljk_HBH(TENSILE_OUT_ARGS_HALF); } template <> -inline TensileStatus tensile_Cijk_Ailk_Bjlk_B(TENSILE_IN_ARGS(TensileHalf, - TensileHalf, - float)) +inline TensileStatus tensile_Cijk_Ailk_Bjlk_B( + TENSILE_IN_ARGS(TensileHalf, TensileHalf, float)) { //TODO: alpha and beta need to have precision equal to compute type, not data type TensileHalf alpha_half(alpha); @@ -243,9 +236,8 @@ inline TensileStatus tensile_Cijk_Ailk_Bjlk_B(T return tensile_Cijk_Ailk_Bjlk_HBH(TENSILE_OUT_ARGS_HALF); } template <> -inline TensileStatus tensile_Cijk_Alik_Bljk_B(TENSILE_IN_ARGS(TensileHalf, - TensileHalf, - float)) +inline TensileStatus tensile_Cijk_Alik_Bljk_B( + TENSILE_IN_ARGS(TensileHalf, TensileHalf, float)) { //TODO: alpha and beta need to have precision equal to compute type, not data type TensileHalf alpha_half(alpha); @@ -253,9 +245,8 @@ inline TensileStatus tensile_Cijk_Alik_Bljk_B(T return tensile_Cijk_Alik_Bljk_HBH(TENSILE_OUT_ARGS_HALF); } template <> -inline TensileStatus tensile_Cijk_Alik_Bjlk_B(TENSILE_IN_ARGS(TensileHalf, - TensileHalf, - float)) +inline TensileStatus tensile_Cijk_Alik_Bjlk_B( + TENSILE_IN_ARGS(TensileHalf, TensileHalf, float)) { //TODO: alpha and beta need to have precision equal to compute type, not data type TensileHalf alpha_half(alpha); @@ -292,22 +283,26 @@ inline TensileStatus tensile_Cijk_Alik_Bjlk_B -inline TensileStatus tensile_Cijk_Ailk_Bljk_B(TENSILE_IN_ARGS(float, float, float)) +inline TensileStatus + tensile_Cijk_Ailk_Bljk_B(TENSILE_IN_ARGS(float, float, float)) { return tensile_Cijk_Ailk_Bljk_SB(TENSILE_OUT_ARGS); } template <> -inline TensileStatus tensile_Cijk_Ailk_Bjlk_B(TENSILE_IN_ARGS(float, float, float)) +inline TensileStatus + tensile_Cijk_Ailk_Bjlk_B(TENSILE_IN_ARGS(float, float, float)) { return tensile_Cijk_Ailk_Bjlk_SB(TENSILE_OUT_ARGS); } template <> -inline TensileStatus tensile_Cijk_Alik_Bljk_B(TENSILE_IN_ARGS(float, float, float)) +inline TensileStatus + tensile_Cijk_Alik_Bljk_B(TENSILE_IN_ARGS(float, float, float)) { return tensile_Cijk_Alik_Bljk_SB(TENSILE_OUT_ARGS); } template <> -inline TensileStatus tensile_Cijk_Alik_Bjlk_B(TENSILE_IN_ARGS(float, float, float)) +inline TensileStatus + tensile_Cijk_Alik_Bjlk_B(TENSILE_IN_ARGS(float, float, float)) { return tensile_Cijk_Alik_Bjlk_SB(TENSILE_OUT_ARGS); } @@ -365,77 +360,93 @@ inline TensileStatus tensile_Cijk_Alik_Bjlk_B{}, - "TensileComplexFloat is not a standard layout type, and thus is " - "incompatible with C."); + "TensileComplexFloat is not a standard layout type, and thus is " + "incompatible with C."); static_assert(std::is_trivial{}, - "TensileComplexFloat is not a trivial type, and thus is " - "incompatible with C."); + "TensileComplexFloat is not a trivial type, and thus is " + "incompatible with C."); static_assert(sizeof(rocblas_float_complex) == sizeof(TensileComplexFloat), - "TensileComplexFloat does not match public rocblas_float_complex"); + "TensileComplexFloat does not match public rocblas_float_complex"); template <> -inline TensileStatus tensile_Cijk_Ailk_Bljk_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_Ailk_Bljk_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_Ailk_Bljk_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_Ailk_Bljk_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } template <> -inline TensileStatus tensile_Cijk_Ailk_Bjlk_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_Ailk_Bjlk_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_Ailk_Bjlk_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_Ailk_Bjlk_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } template <> -inline TensileStatus tensile_Cijk_Alik_Bljk_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_Alik_Bljk_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_Alik_Bljk_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_Alik_Bljk_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } template <> -inline TensileStatus tensile_Cijk_Alik_Bjlk_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_Alik_Bjlk_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_Alik_Bjlk_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_Alik_Bjlk_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } // Complex Conjugate template <> -inline TensileStatus tensile_Cijk_Ailk_BjlkC_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_Ailk_BjlkC_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_Ailk_BjlkC_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_Ailk_BjlkC_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } template <> -inline TensileStatus tensile_Cijk_AlikC_Bljk_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_AlikC_Bljk_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_AlikC_Bljk_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_AlikC_Bljk_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } template <> -inline TensileStatus tensile_Cijk_Alik_BjlkC_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_Alik_BjlkC_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_Alik_BjlkC_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_Alik_BjlkC_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } template <> -inline TensileStatus tensile_Cijk_AlikC_Bjlk_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_AlikC_Bjlk_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_AlikC_Bjlk_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_AlikC_Bjlk_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } template <> -inline TensileStatus tensile_Cijk_AlikC_BjlkC_B( - TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) +inline TensileStatus + tensile_Cijk_AlikC_BjlkC_B( + TENSILE_IN_ARGS(rocblas_float_complex, rocblas_float_complex, rocblas_float_complex)) { - return tensile_Cijk_AlikC_BjlkC_CB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); + return tensile_Cijk_AlikC_BjlkC_CB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexFloat, TensileComplexFloat, TensileComplexFloat)); } //----- typename_data = rocblas_double_complex ---------- typename_compute = rocblas_double_complex -------------------------- @@ -450,80 +461,108 @@ static_assert(std::is_trivial{}, static_assert(sizeof(rocblas_double_complex) == sizeof(TensileComplexDouble), "TensileComplexDouble does not match rocblas_double_complex"); template <> -inline TensileStatus tensile_Cijk_Ailk_Bljk_B( +inline TensileStatus tensile_Cijk_Ailk_Bljk_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_Ailk_Bljk_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_Ailk_Bljk_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } template <> -inline TensileStatus tensile_Cijk_Ailk_Bjlk_B( +inline TensileStatus tensile_Cijk_Ailk_Bjlk_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_Ailk_Bjlk_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_Ailk_Bjlk_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } template <> -inline TensileStatus tensile_Cijk_Alik_Bljk_B( +inline TensileStatus tensile_Cijk_Alik_Bljk_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_Alik_Bljk_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_Alik_Bljk_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } template <> -inline TensileStatus tensile_Cijk_Alik_Bjlk_B( +inline TensileStatus tensile_Cijk_Alik_Bjlk_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_Alik_Bjlk_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_Alik_Bjlk_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } // Complex Conjugate template <> -inline TensileStatus tensile_Cijk_Ailk_BjlkC_B( +inline TensileStatus tensile_Cijk_Ailk_BjlkC_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_Ailk_BjlkC_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_Ailk_BjlkC_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } template <> -inline TensileStatus tensile_Cijk_AlikC_Bljk_B( +inline TensileStatus tensile_Cijk_AlikC_Bljk_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_AlikC_Bljk_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_AlikC_Bljk_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } template <> -inline TensileStatus tensile_Cijk_Alik_BjlkC_B( +inline TensileStatus tensile_Cijk_Alik_BjlkC_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_Alik_BjlkC_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_Alik_BjlkC_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } template <> -inline TensileStatus tensile_Cijk_AlikC_Bjlk_B( +inline TensileStatus tensile_Cijk_AlikC_Bjlk_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_AlikC_Bjlk_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_AlikC_Bjlk_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } template <> -inline TensileStatus tensile_Cijk_AlikC_BjlkC_B( +inline TensileStatus tensile_Cijk_AlikC_BjlkC_B( TENSILE_IN_ARGS(rocblas_double_complex, rocblas_double_complex, rocblas_double_complex)) { - return tensile_Cijk_AlikC_BjlkC_ZB(TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); + return tensile_Cijk_AlikC_BjlkC_ZB( + TENSILE_COMPLEX_OUT_ARGS(TensileComplexDouble, TensileComplexDouble, TensileComplexDouble)); } template -inline TensileStatus call_tensile_ex(To* dataD, - const To* dataC, - const Ti* dataA, - const Ti* dataB, - Tc alpha, Tc beta, - unsigned int strideD1J, - unsigned int strideD2K, - unsigned int strideC1J, - unsigned int strideC2K, - unsigned int strideA1L, - unsigned int strideA2K, - unsigned int strideB1J, - unsigned int strideB2K, - unsigned int sizeI, - unsigned int sizeJ, - unsigned int sizeK, - unsigned int sizeL, - hipStream_t stream, +inline TensileStatus call_tensile_ex(To* dataD, + const To* dataC, + const Ti* dataA, + const Ti* dataB, + Tc alpha, + Tc beta, + unsigned int strideD1J, + unsigned int strideD2K, + unsigned int strideC1J, + unsigned int strideC2K, + unsigned int strideA1L, + unsigned int strideA2K, + unsigned int strideB1J, + unsigned int strideB2K, + unsigned int sizeI, + unsigned int sizeJ, + unsigned int sizeK, + unsigned int sizeL, + hipStream_t stream, transpose_mode transposeMode) { switch(transposeMode) @@ -557,7 +596,6 @@ inline TensileStatus call_tensile_ex(To* dataD, //------------------------------------------------------------------------------ - /////////////// // Host Side // /////////////// @@ -611,42 +649,42 @@ rocblas_status gemm_ex_handle_transpose(rocblas_handle handle, return get_rocblas_status_for_hip_status(errC); else if(get_rocblas_status_for_hip_status(errD) != rocblas_status_success) return get_rocblas_status_for_hip_status(errD); - stride_a = trans_a == rocblas_operation_none ? lda * k : lda * m; - stride_b = trans_b == rocblas_operation_none ? ldb * n : ldb * k; - stride_c = ldc * n; - stride_d = ldd * n; + stride_a = trans_a == rocblas_operation_none ? lda * k : lda * m; + stride_b = trans_b == rocblas_operation_none ? ldb * n : ldb * k; + stride_c = ldc * n; + stride_d = ldd * n; rocblas_status status = rocblas_status_internal_error; for(int bi = 0; bi < batch_count; bi++) { // Tensile does not support batched gemm_ex yet, must do naive version status = gemm_ex_handle_transpose(handle, - trans_a, - trans_b, - m, - n, - k, - alpha + bi * stride_alpha, - 0, // using single alpha ^ - hostA[bi], - offset_a, - lda, - stride_a, - hostB[bi], - offset_b, - ldb, - stride_b, - beta + bi * stride_beta, - 0, // see ^ - hostC[bi], - offset_c, - ldc, - stride_c, - hostD[bi], - offset_d, - ldd, - stride_d, - 1); + trans_a, + trans_b, + m, + n, + k, + alpha + bi * stride_alpha, + 0, // using single alpha ^ + hostA[bi], + offset_a, + lda, + stride_a, + hostB[bi], + offset_b, + ldb, + stride_b, + beta + bi * stride_beta, + 0, // see ^ + hostC[bi], + offset_c, + ldc, + stride_c, + hostD[bi], + offset_d, + ldd, + stride_d, + 1); if(status != rocblas_status_success) return status; } @@ -691,11 +729,12 @@ rocblas_status gemm_ex_handle_transpose(rocblas_handle handle, rocblas_status rb_status; static const bool arch_lt906 = handle->device_arch_id() < 906; - const To* c_in; - unsigned int ldi, stride_i; + const To* c_in; + unsigned int ldi, stride_i; - if(!arch_lt906 && (std::is_same{} || std::is_same{}) && - ((ldc >= ldd && stride_c >= stride_d && m == ldd) || (ldc == ldd && stride_c == stride_d))) + if(!arch_lt906 && (std::is_same{} || std::is_same{}) + && ((ldc >= ldd && stride_c >= stride_d && m == ldd) + || (ldc == ldd && stride_c == stride_d))) { c_in = c; ldi = ldc; @@ -714,76 +753,87 @@ rocblas_status gemm_ex_handle_transpose(rocblas_handle handle, { for(int bi = 0; bi < batch_count; bi++) { - t_status = call_tensile_ex((To*)(d + bi * stride_d), - (const To*)(c_in + bi * stride_c), - (const Ti*)(a + bi * stride_a), - (const Ti*)(b + bi * stride_b), - alpha[bi * stride_alpha], beta[bi * stride_beta], - unsigned(ldd), stride_d, - unsigned(ldi), stride_i, - unsigned(lda), stride_a, - unsigned(ldb), stride_b, - unsigned(m), - unsigned(n), - unsigned(1), - unsigned(k), - handle->rocblas_stream, GetTransposeMode(trans_a, trans_b)); + t_status = call_tensile_ex((To*)(d + bi * stride_d), + (const To*)(c_in + bi * stride_c), + (const Ti*)(a + bi * stride_a), + (const Ti*)(b + bi * stride_b), + alpha[bi * stride_alpha], + beta[bi * stride_beta], + unsigned(ldd), + stride_d, + unsigned(ldi), + stride_i, + unsigned(lda), + stride_a, + unsigned(ldb), + stride_b, + unsigned(m), + unsigned(n), + unsigned(1), + unsigned(k), + handle->rocblas_stream, + GetTransposeMode(trans_a, trans_b)); } } else { // single alpha/beta - t_status = call_tensile_ex((To*)d, - (const To*)c_in, - (const Ti*)a, - (const Ti*)b, - *alpha, *beta, - unsigned(ldd), stride_d, - unsigned(ldi), stride_i, - unsigned(lda), stride_a, - unsigned(ldb), stride_b, - unsigned(m), - unsigned(n), - unsigned(batch_count), - unsigned(k), - handle->rocblas_stream, GetTransposeMode(trans_a, trans_b)); + t_status = call_tensile_ex((To*)d, + (const To*)c_in, + (const Ti*)a, + (const Ti*)b, + *alpha, + *beta, + unsigned(ldd), + stride_d, + unsigned(ldi), + stride_i, + unsigned(lda), + stride_a, + unsigned(ldb), + stride_b, + unsigned(m), + unsigned(n), + unsigned(batch_count), + unsigned(k), + handle->rocblas_stream, + GetTransposeMode(trans_a, trans_b)); } - - rb_status = (t_status == tensileStatusSuccess) ? rocblas_status_success : rocblas_status_internal_error; + rb_status = (t_status == tensileStatusSuccess) ? rocblas_status_success + : rocblas_status_internal_error; return rb_status; } #if defined(USE_CHUNKING) template -rocblas_status gemm_ex_chunking(rocblas_handle handle, - rocblas_operation trans_a, - rocblas_operation trans_b, - unsigned int m, - unsigned int n, - unsigned int k, - Tc* alpha, - unsigned int stride_alpha - Ti a, - unsigned int offsetAin, - unsigned int lda, - unsigned int stride_a, - Ti b, - unsigned int offsetBin, - unsigned int ldb, - unsigned int stride_b, - Tc* beta, - unsigned int stride_beta, - To c, - unsigned int offsetCin, - unsigned int ldc, - unsigned int stride_c, - To2 d, - unsigned int offsetDin, - unsigned int ldd, - unsigned int stride_d, - unsigned int batch_count) +rocblas_status gemm_ex_chunking(rocblas_handle handle, + rocblas_operation trans_a, + rocblas_operation trans_b, + unsigned int m, + unsigned int n, + unsigned int k, + Tc* alpha, + unsigned int stride_alpha Ti a, + unsigned int offsetAin, + unsigned int lda, + unsigned int stride_a, + Ti b, + unsigned int offsetBin, + unsigned int ldb, + unsigned int stride_b, + Tc* beta, + unsigned int stride_beta, + To c, + unsigned int offsetCin, + unsigned int ldc, + unsigned int stride_c, + To2 d, + unsigned int offsetDin, + unsigned int ldd, + unsigned int stride_d, + unsigned int batch_count) { unsigned int int_limit = std::numeric_limits::max() / sizeof(To); unsigned int m_chunk_size = m; @@ -922,8 +972,9 @@ rocblas_status gemm_ex_typecasting(rocblas_handle handle, { // copy alpha and beta from device to host and convert type for(int b = 0; b < batch_count; b++) - hipMemcpy(&h_alpha[b], (Tc*)alpha + b * stride_alpha, sizeof(Tc), hipMemcpyDeviceToHost); - + hipMemcpy( + &h_alpha[b], (Tc*)alpha + b * stride_alpha, sizeof(Tc), hipMemcpyDeviceToHost); + for(int b = 0; b < batch_count; b++) hipMemcpy(&h_beta[b], (Tc*)beta + b * stride_beta, sizeof(Tc), hipMemcpyDeviceToHost); } @@ -939,7 +990,7 @@ rocblas_status gemm_ex_typecasting(rocblas_handle handle, if(BATCHED) { if(!isAligned(a, sizeof(Ti*)) || !isAligned(b, sizeof(Ti*)) || !isAligned(c, sizeof(To*)) - || !isAligned(d, sizeof(To*))) + || !isAligned(d, sizeof(To*))) return rocblas_status_invalid_size; // Pass alpha and beta as simple array (stride of 1) @@ -974,7 +1025,7 @@ rocblas_status gemm_ex_typecasting(rocblas_handle handle, else { if(!isAligned(a, sizeof(Ti)) || !isAligned(b, sizeof(Ti)) || !isAligned(c, sizeof(To)) - || !isAligned(d, sizeof(To))) + || !isAligned(d, sizeof(To))) return rocblas_status_invalid_size; return gemm_ex_chunking(handle, @@ -1005,8 +1056,6 @@ rocblas_status gemm_ex_typecasting(rocblas_handle handle, unsigned(stride_d), unsigned(batch_count)); } - - } #endif diff --git a/library/src/handle.cpp b/library/src/handle.cpp index a32bf5a3b..4010e0a62 100644 --- a/library/src/handle.cpp +++ b/library/src/handle.cpp @@ -12,9 +12,10 @@ _rocblas_handle::_rocblas_handle() { #ifdef USE_TENSILE_HOST host = createTensileHost(); - if (host == nullptr) + if(!host) throw rocblas_status_internal_error; #endif + // default device is active device THROW_IF_HIP_ERROR(hipGetDevice(&device)); THROW_IF_HIP_ERROR(hipGetDeviceProperties(&device_properties, device)); @@ -68,9 +69,9 @@ _rocblas_handle::~_rocblas_handle() } if(device_memory) (hipFree)(device_memory); + #ifdef USE_TENSILE_HOST - if(host != nullptr) - delete host; + delete host; #endif } diff --git a/library/src/include/handle.h b/library/src/include/handle.h index fa46479c7..5d8660ecd 100644 --- a/library/src/include/handle.h +++ b/library/src/include/handle.h @@ -40,7 +40,7 @@ struct _rocblas_handle public: #ifdef USE_TENSILE_HOST - TensileHost *host = nullptr; + TensileHost* host = nullptr; #endif _rocblas_handle(); ~_rocblas_handle(); diff --git a/library/src/rocblas_auxiliary.cpp b/library/src/rocblas_auxiliary.cpp index 26b08a0ca..cdb4fef0b 100644 --- a/library/src/rocblas_auxiliary.cpp +++ b/library/src/rocblas_auxiliary.cpp @@ -59,28 +59,6 @@ extern "C" rocblas_status rocblas_set_pointer_mode(rocblas_handle handle, rocbla return rocblas_status_success; } -#ifdef USE_TENSILE_HOST -extern "C" rocblas_status rocblas_create_host_handle(rocblas_handle* handle, const char* lib_path) -{ - //std::cout << lib_path << std::endl; - rocblas_status status = rocblas_create_handle(handle); - - if (status == rocblas_status_success) - { - try - { - (*handle)->host->initializeHost(lib_path); - } - catch (...) - { - return rocblas_status_internal_error; - } - } - - return status; -} -#endif - /******************************************************************************* * ! \brief create rocblas handle called before any rocblas library routines ******************************************************************************/ @@ -88,7 +66,8 @@ extern "C" rocblas_status rocblas_create_handle(rocblas_handle* handle) { // if handle not valid if(!handle) - return rocblas_status_invalid_pointer; + return rocblas_status_invalid_handle; + // allocate on heap try { @@ -99,10 +78,17 @@ extern "C" rocblas_status rocblas_create_handle(rocblas_handle* handle) if((*handle)->layer_mode & rocblas_layer_mode_log_trace) log_trace(*handle, "rocblas_create_handle"); + +#ifdef USE_TENSILE_HOST + const char* lib_path = getenv("ROCBLAS_TENSILE_LIBPATH"); + if(!lib_path) + lib_path = "/opt/rocm/"; // TODO: Set default path + (*handle)->host->initializeHost(lib_path); +#endif } - catch(rocblas_status status) + catch(...) { - return status; + return rocblas_status_internal_error; } return rocblas_status_success; } diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp index 910acecae..1817e5302 100644 --- a/library/src/tensile_host.cpp +++ b/library/src/tensile_host.cpp @@ -1,135 +1,104 @@ #ifdef USE_TENSILE_HOST -#include "rocblas.h" #include "tensile_host.hpp" - - -#include +#include "rocblas.h" #include #include #include -#include +#include +#include +#include #include +#include #include - -#include -#include - -#include -#include -//#include - #include +#include +#include template -Tensile::DataType tensile_datatype() -{ - throw std::runtime_error("undefined datatype"); - return Tensile::DataType::Float; -} +static constexpr auto tensile_datatype = nullptr; template <> -Tensile::DataType tensile_datatype() -{ - return Tensile::DataType::Half; -} +static constexpr auto tensile_datatype = Tensile::DataType::Half; template <> -Tensile::DataType tensile_datatype() -{ - return Tensile::DataType::Float; -} - +static constexpr auto tensile_datatype = Tensile::DataType::Float; + template <> -Tensile::DataType tensile_datatype() -{ - return Tensile::DataType::Double; -} +static constexpr auto tensile_datatype = Tensile::DataType::Double; template <> -Tensile::DataType tensile_datatype() -{ - return Tensile::DataType::ComplexFloat; -} +static constexpr auto tensile_datatype = Tensile::DataType::ComplexFloat; template <> -Tensile::DataType tensile_datatype() -{ - return Tensile::DataType::ComplexDouble; -} +static constexpr auto tensile_datatype = Tensile::DataType::ComplexDouble; template -Tensile::ContractionProblem create_gemm_contraction_problem_strided ( - rocblas_operation trans_a, - rocblas_operation trans_b, - unsigned long m, - unsigned long n, - unsigned long k, - const T* alpha, - const T* A, - unsigned long ld_a, - unsigned long stride_a, - const T* B, - unsigned long ld_b, - unsigned long stride_b, - const T* beta, - T* C, - unsigned long ld_c, - unsigned long stride_c, - unsigned long batchSize) +auto create_gemm_contraction_problem_strided(rocblas_operation trans_a, + rocblas_operation trans_b, + unsigned long m, + unsigned long n, + unsigned long k, + const T* alpha, + const T* A, + unsigned long ld_a, + unsigned long stride_a, + const T* B, + unsigned long ld_b, + unsigned long stride_b, + const T* beta, + T* C, + unsigned long ld_c, + unsigned long stride_c, + unsigned long batchSize) { - - bool transposeA = false; - if (trans_a == rocblas_operation_conjugate_transpose) - transposeA = true; - - bool transposeB = false; - if(trans_b == rocblas_operation_conjugate_transpose) - transposeB = true; - - Tensile::DataType dt = tensile_datatype(); - Tensile::ContractionProblem problem = Tensile::ContractionProblem::GEMM_Strides( - transposeA, transposeB, - dt, dt, dt, dt, - m, n, k, batchSize, - ld_a, stride_a, - ld_b, stride_b, - ld_c, stride_c, - ld_c, stride_c, - *beta); + bool transposeA = trans_a != rocblas_operation_none; + bool transposeB = trans_b != rocblas_operation_none; + + Tensile::DataType dt = tensile_datatype; + Tensile::ContractionProblem problem = Tensile::ContractionProblem::GEMM_Strides(transposeA, + transposeB, + dt, + dt, + dt, + dt, + m, + n, + k, + batchSize, + ld_a, + stride_a, + ld_b, + stride_b, + ld_c, + stride_c, + ld_c, + stride_c, + *beta); return problem; } - // construct the gemm contraction problem template -Tensile::ContractionProblem create_gemm_contraction_problem ( - rocblas_operation trans_a, +auto create_gemm_contraction_problem(rocblas_operation trans_a, rocblas_operation trans_b, - unsigned long m, - unsigned long n, - unsigned long k, + unsigned long m, + unsigned long n, + unsigned long k, const T* alpha, const T* A, - unsigned long ld_a, + unsigned long ld_a, const T* B, - unsigned long ld_b, + unsigned long ld_b, const T* beta, T* C, - unsigned long ld_c) + unsigned long ld_c) { + bool transposeA = trans_a != rocblas_operation_none; + bool transposeB = trans_b != rocblas_operation_none; - bool transposeA = false; - if (trans_a == rocblas_operation_conjugate_transpose) - transposeA = true; - - bool transposeB = false; - if(trans_b == rocblas_operation_conjugate_transpose) - transposeB = true; - - - Tensile::ContractionProblem::FreeIndex free; + Tensile::ContractionProblem::FreeIndex free; Tensile::ContractionProblem::BoundIndex bound; free.ca = free.da = 0; @@ -137,34 +106,34 @@ Tensile::ContractionProblem create_gemm_contraction_problem ( Tensile::TensorDescriptor a, b, c, d; - Tensile::DataType dt = tensile_datatype(); + Tensile::DataType dt = tensile_datatype; if(transposeA) { - a = Tensile::TensorDescriptor(dt, {k, m}, {1, ld_a}); - free.a = 1; + a = Tensile::TensorDescriptor(dt, {k, m}, {1, ld_a}); + free.a = 1; bound.a = 0; } else { - a = Tensile::TensorDescriptor(dt, {m, k}, {1, ld_a}); - free.a = 0; + a = Tensile::TensorDescriptor(dt, {m, k}, {1, ld_a}); + free.a = 0; bound.a = 1; } if(transposeB) { - b = Tensile::TensorDescriptor(dt, {n, k}, {1, ld_b}); - free.b = 0; + b = Tensile::TensorDescriptor(dt, {n, k}, {1, ld_b}); + free.b = 0; bound.b = 1; } else { - b = Tensile::TensorDescriptor(dt, {k, n}, {1, ld_b}); - free.b = 1; + b = Tensile::TensorDescriptor(dt, {k, n}, {1, ld_b}); + free.b = 1; bound.b = 0; } - Tensile::ContractionProblem::FreeIndices freeIndices{free}; + Tensile::ContractionProblem::FreeIndices freeIndices{free}; Tensile::ContractionProblem::BatchIndices batchIndices; Tensile::ContractionProblem::BoundIndices boundIndices{bound}; @@ -176,222 +145,166 @@ Tensile::ContractionProblem create_gemm_contraction_problem ( b.appendDim(batchCount); d.appendDim(batchCount); - batchIndices.push_back({2,2,2,2}); + batchIndices.push_back({2, 2, 2, 2}); if(*beta != 0.0) c = d; Tensile::TensorOps nop; - return Tensile::ContractionProblem(a, nop, b, nop, c, nop, d, nop, freeIndices, batchIndices, boundIndices, *beta); + return Tensile::ContractionProblem( + a, nop, b, nop, c, nop, d, nop, freeIndices, batchIndices, boundIndices, *beta); } - -template -Tensile::ContractionProblem ConstructTensileProblem(RocblasContractionProblem *problem) +template +auto ConstructTensileProblem(RocblasContractionProblem* problem) { Tensile::ContractionProblem tensile_problem; switch(problem->problem_type) { - case GEMM: - tensile_problem = create_gemm_contraction_problem ( - problem->trans_a,problem->trans_b, - problem->m,problem->n,problem->k, - problem->alpha, - problem->A,problem->ld_a, - problem->B,problem->ld_b, - problem->beta, - problem->C,problem->ld_c); - break; - case GEMMStridedBatch: - tensile_problem = create_gemm_contraction_problem_strided ( - problem->trans_a,problem->trans_b, - problem->m,problem->n,problem->k, - problem->alpha, - problem->A,problem->ld_a,problem->stride_a, - problem->B,problem->ld_b,problem->stride_b, - problem->beta, - problem->C,problem->ld_c,problem->stride_c, - problem->batch_size); - break; + case GEMM: + tensile_problem = create_gemm_contraction_problem(problem->trans_a, + problem->trans_b, + problem->m, + problem->n, + problem->k, + problem->alpha, + problem->A, + problem->ld_a, + problem->B, + problem->ld_b, + problem->beta, + problem->C, + problem->ld_c); + break; + case GEMMStridedBatch: + tensile_problem = create_gemm_contraction_problem_strided(problem->trans_a, + problem->trans_b, + problem->m, + problem->n, + problem->k, + problem->alpha, + problem->A, + problem->ld_a, + problem->stride_a, + problem->B, + problem->ld_b, + problem->stride_b, + problem->beta, + problem->C, + problem->ld_c, + problem->stride_c, + problem->batch_size); + break; } return tensile_problem; } - template -Tensile::TypedContractionInputs GetTensileInputs(RocblasContractionProblem *problem) +auto GetTensileInputs(RocblasContractionProblem* problem) { Tensile::TypedContractionInputs inputs; switch(problem->problem_type) { - case GEMM: - case GEMMStridedBatch: - inputs.a = problem->A; - inputs.b = problem->B; - inputs.c = problem->C; - inputs.d = problem->C; - inputs.alpha = *(problem->alpha); - inputs.beta = *(problem->beta); - break; + case GEMM: + case GEMMStridedBatch: + inputs.a = problem->A; + inputs.b = problem->B; + inputs.c = problem->C; + inputs.d = problem->C; + inputs.alpha = *problem->alpha; + inputs.beta = *problem->beta; + break; } return inputs; } -class TensileHostImpl : public TensileHost +struct TensileHostImpl : TensileHost { -public: - void initializeHost(const char* lib_path) - { - //std::string dir ("/home/wgilmart/dev/wbgilmartin/tasks/new_client_integration/iteration3/Tensile/build/1_BenchmarkProblems/Cijk_Ailk_Bljk_SB_00/00_Final/source/library/*hsaco"); - - std::string path (lib_path); - std::string dir = path + "/*co"; - - - glob_t glob_result; - glob(dir.c_str(),GLOB_TILDE,NULL,&glob_result); - //vector files; - for(unsigned int i=0;i>( - Tensile::LoadLibraryFile(filename)); - - - // namespace fs = std::filesystem; - - // for (const auto & entry : fs::directory_iterator(dir)) - // std::cout << entry.path() << std::endl; - - - //std::DIR *dp; - //if((dp = opendir(dir.c_str())) == NULL) { - // std::cout << "Error(" << errno << ") opening " << dir << std::endl; - // return errno; - //} - - //while ((dirp = readdir(dp)) != NULL) { - // std::cout << dirp->d_name << std::endl; - //files.push_back(string(dirp->d_name)); - //} - //closedir(dp); - - //std::string filename ("/home/wgilmart/dev/wbgilmartin/tasks/new_client_integration/iteration3/Tensile/build/1_BenchmarkProblems/Cijk_Ailk_Bljk_SB_00/00_Final/source/library/TensileLibrary.yaml"); - //library = std::dynamic_pointer_cast>( - // Tensile::LoadLibraryFile(filename)); - - //std::string cofilename ("/home/wgilmart/dev/wbgilmartin/tasks/new_client_integration/iteration3/Tensile/build/1_BenchmarkProblems/Cijk_Ailk_Bljk_SB_00/00_Final/source/library/TensileLibrary_gfx906.co"); - //adapter.loadCodeObjectFile(cofilename); - - //std::string k906filename ("/home/wgilmart/dev/wbgilmartin/tasks/new_client_integration/iteration3/Tensile/build/1_BenchmarkProblems/Cijk_Ailk_Bljk_SB_00/00_Final/source/library/Kernels.so-000-gfx906.hsaco"); - //adapter.loadCodeObjectFile(k906filename); - - //std::string k900filename ("/home/wgilmart/dev/wbgilmartin/tasks/new_client_integration/iteration3/Tensile/build/1_BenchmarkProblems/Cijk_Ailk_Bljk_SB_00/00_Final/source/library/Kernels.so-000-gfx900.hsaco"); - //adapter.loadCodeObjectFile(k900filename); - - //std::string k803filename ("/home/wgilmart/dev/wbgilmartin/tasks/new_client_integration/iteration3/Tensile/build/1_BenchmarkProblems/Cijk_Ailk_Bljk_SB_00/00_Final/source/library/Kernels.so-000-gfx803.hsaco"); - //adapter.loadCodeObjectFile(k803filename); - - hardware = Tensile::hip::GetCurrentDevice(); - } - -//private: - std::shared_ptr> library; - std::shared_ptr hardware; - Tensile::hip::SolutionAdapter adapter; -}; - -template -rocblas_status TensileHostCall::runContractionProblem(RocblasContractionProblem *problem, TensileHost *host) -{ - Tensile::ContractionProblem tensile_problem; - try + void initializeHost(const char* lib_path) { - tensile_problem = ConstructTensileProblem(problem); - } - catch(...) - { - return rocblas_status_internal_error; - } - Tensile::TypedContractionInputs inputs; - try - { - inputs = GetTensileInputs(problem); - } - catch(...) - { - return rocblas_status_internal_error; - } + std::string path(lib_path); + std::string dir = path + "/*co"; - TensileHostImpl * hosti = dynamic_cast(host); - if (hosti == nullptr) - { - return rocblas_status_internal_error; - } + glob_t glob_result; + glob(dir.c_str(), GLOB_TILDE, NULL, &glob_result); + for(unsigned int i = 0; i < glob_result.gl_pathc; ++i) + adapter.loadCodeObjectFile(glob_result.gl_pathv[i]); + globfree(&glob_result); - std::vector result; - try - { - auto solution = hosti->library->findBestSolution(tensile_problem, *(hosti->hardware)); - result = solution->solve(tensile_problem, inputs, *(hosti->hardware)); - } - catch(...) - { - return rocblas_status_internal_error; - } + std::string filename = path + "/TensileLibrary.yaml"; + library = std::dynamic_pointer_cast< + Tensile::MasterSolutionLibrary>( + Tensile::LoadLibraryFile(filename)); - try - { - hosti->adapter.launchKernels(result); + hardware = Tensile::hip::GetCurrentDevice(); } - catch(...) - { + + std::shared_ptr> library; + std::shared_ptr hardware; + Tensile::hip::SolutionAdapter adapter; + + ~TensileHostImpl() override = default; // Ensure virtual base class +}; + +template +rocblas_status TensileHostCall::runContractionProblem(RocblasContractionProblem* problem, + TensileHost* host) +try +{ + auto tensile_problem = ConstructTensileProblem(problem); + auto inputs = GetTensileInputs(problem); + auto hosti = dynamic_cast(host); + if(!hosti) return rocblas_status_internal_error; - } - return rocblas_status_success; + auto solution = hosti->library->findBestSolution(tensile_problem, *hosti->hardware); + auto result = solution->solve(tensile_problem, inputs, *hosti->hardware); + hosti->adapter.launchKernels(result); + return rocblas_status_success; +} +catch(...) +{ + return rocblas_status_internal_error; } -TensileHost *createTensileHost() -{ - - TensileHostImpl *host = new TensileHostImpl(); - return host; +TensileHost* createTensileHost() +{ + return new TensileHostImpl(); } template -rocblas_status callTensileContraction( RocblasContractionProblem *problem, TensileHost *host) +inline rocblas_status callTensileContraction(RocblasContractionProblem* problem, TensileHost* host) { TensileHostCall hostCaller; return hostCaller.runContractionProblem(problem, host); } -rocblas_status callTensileContraction_half(RocblasContractionProblem *problem, TensileHost *host) +rocblas_status callTensileContraction_half(RocblasContractionProblem* problem, + TensileHost* host) { return callTensileContraction(problem, host); } -rocblas_status callTensileContraction_float(RocblasContractionProblem *problem, TensileHost *host) +rocblas_status callTensileContraction_float(RocblasContractionProblem* problem, + TensileHost* host) { return callTensileContraction(problem, host); } -rocblas_status callTensileContraction_double(RocblasContractionProblem *problem, TensileHost *host) +rocblas_status callTensileContraction_double(RocblasContractionProblem* problem, + TensileHost* host) { return callTensileContraction(problem, host); } -rocblas_status callTensileContraction_float_complex(RocblasContractionProblem *problem, TensileHost *host) +rocblas_status + callTensileContraction_float_complex(RocblasContractionProblem* problem, + TensileHost* host) { return callTensileContraction(problem, host); } -rocblas_status callTensileContraction_double_complex(RocblasContractionProblem *problem, TensileHost *host) +rocblas_status callTensileContraction_double_complex( + RocblasContractionProblem* problem, TensileHost* host) { return callTensileContraction(problem, host); } From ab30f404a5edfae26d9f0f71c92651262d72f86d Mon Sep 17 00:00:00 2001 From: Lee Killough Date: Sat, 5 Oct 2019 10:33:50 -0400 Subject: [PATCH 2/5] Move tensile_host.hpp from public C API to private C++ implementation --- library/{ => src}/include/tensile_host.hpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename library/{ => src}/include/tensile_host.hpp (100%) diff --git a/library/include/tensile_host.hpp b/library/src/include/tensile_host.hpp similarity index 100% rename from library/include/tensile_host.hpp rename to library/src/include/tensile_host.hpp From b40411e732464d6029120ff82df2e56e0ef24177 Mon Sep 17 00:00:00 2001 From: Lee Killough Date: Sat, 5 Oct 2019 11:20:41 -0400 Subject: [PATCH 3/5] Work around missing complex; fix formatting --- clients/gtest/blas1_gtest.cpp | 59 +++++++++++---------- clients/include/rocblas_datatype2string.hpp | 2 +- library/src/tensile_host.cpp | 6 ++- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index f6f708bb7..a5b9536f4 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -288,41 +288,42 @@ TEST_P(NAME, blas1) \ { \ rocblas_blas1_dispatch(GetParam()); \ } \ -// clang-format on \ INSTANTIATE_TEST_CATEGORIES(NAME) + // clang-format on + #define ARG1(Ti, To, Tc) Ti #define ARG2(Ti, To, Tc) Ti, To #define ARG3(Ti, To, Tc) Ti, To, Tc -BLAS1_TESTING(asum, ARG1) -BLAS1_TESTING(asum_batched, ARG1) -BLAS1_TESTING(asum_strided_batched, ARG1) -BLAS1_TESTING(nrm2, ARG1) -BLAS1_TESTING(nrm2_batched, ARG1) -BLAS1_TESTING(nrm2_strided_batched, ARG1) -BLAS1_TESTING(iamax, ARG1) -BLAS1_TESTING(iamin, ARG1) -BLAS1_TESTING(axpy, ARG1) -BLAS1_TESTING(copy, ARG1) -BLAS1_TESTING(copy_batched, ARG1) -BLAS1_TESTING(copy_strided_batched, ARG1) -BLAS1_TESTING(dot, ARG1) -BLAS1_TESTING(dotc, ARG1) -BLAS1_TESTING(dot_batched, ARG1) -BLAS1_TESTING(dotc_batched, ARG1) -BLAS1_TESTING(dot_strided_batched, ARG1) -BLAS1_TESTING(dotc_strided_batched, ARG1) -BLAS1_TESTING(scal, ARG2) -BLAS1_TESTING(scal_batched, ARG2) -BLAS1_TESTING(scal_strided_batched, ARG2) -BLAS1_TESTING(swap, ARG1) -BLAS1_TESTING(swap_batched, ARG1) -BLAS1_TESTING(swap_strided_batched, ARG1) -BLAS1_TESTING(rot, ARG3) -BLAS1_TESTING(rotg, ARG2) -BLAS1_TESTING(rotm, ARG1) -BLAS1_TESTING(rotmg, ARG1) + BLAS1_TESTING(asum, ARG1) + BLAS1_TESTING(asum_batched, ARG1) + BLAS1_TESTING(asum_strided_batched, ARG1) + BLAS1_TESTING(nrm2, ARG1) + BLAS1_TESTING(nrm2_batched, ARG1) + BLAS1_TESTING(nrm2_strided_batched, ARG1) + BLAS1_TESTING(iamax, ARG1) + BLAS1_TESTING(iamin, ARG1) + BLAS1_TESTING(axpy, ARG1) + BLAS1_TESTING(copy, ARG1) + BLAS1_TESTING(copy_batched, ARG1) + BLAS1_TESTING(copy_strided_batched, ARG1) + BLAS1_TESTING(dot, ARG1) + BLAS1_TESTING(dotc, ARG1) + BLAS1_TESTING(dot_batched, ARG1) + BLAS1_TESTING(dotc_batched, ARG1) + BLAS1_TESTING(dot_strided_batched, ARG1) + BLAS1_TESTING(dotc_strided_batched, ARG1) + BLAS1_TESTING(scal, ARG2) + BLAS1_TESTING(scal_batched, ARG2) + BLAS1_TESTING(scal_strided_batched, ARG2) + BLAS1_TESTING(swap, ARG1) + BLAS1_TESTING(swap_batched, ARG1) + BLAS1_TESTING(swap_strided_batched, ARG1) + BLAS1_TESTING(rot, ARG3) + BLAS1_TESTING(rotg, ARG2) + BLAS1_TESTING(rotm, ARG1) + BLAS1_TESTING(rotmg, ARG1) } // namespace diff --git a/clients/include/rocblas_datatype2string.hpp b/clients/include/rocblas_datatype2string.hpp index cd7977fbb..bdd3d3578 100644 --- a/clients/include/rocblas_datatype2string.hpp +++ b/clients/include/rocblas_datatype2string.hpp @@ -196,7 +196,7 @@ constexpr rocblas_side char2rocblas_side(char value) } // clang-format off -nline rocblas_initialization string2rocblas_initialization(const std::string& value) +inline rocblas_initialization string2rocblas_initialization(const std::string& value) { return value == "rand_int" ? rocblas_initialization_random_int : diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp index 1817e5302..fbc53f1e6 100644 --- a/library/src/tensile_host.cpp +++ b/library/src/tensile_host.cpp @@ -74,7 +74,8 @@ auto create_gemm_contraction_problem_strided(rocblas_operation trans_a, stride_c, ld_c, stride_c, - *beta); +// TODO: Must fix this to work with complex + std::real(*beta)); return problem; } @@ -153,7 +154,8 @@ auto create_gemm_contraction_problem(rocblas_operation trans_a, Tensile::TensorOps nop; return Tensile::ContractionProblem( - a, nop, b, nop, c, nop, d, nop, freeIndices, batchIndices, boundIndices, *beta); +// TODO: Must fix this to work with complex + a, nop, b, nop, c, nop, d, nop, freeIndices, batchIndices, boundIndices, std::real(*beta)); } template From c6f375c2da160de7bd71095e971a044a32c57024 Mon Sep 17 00:00:00 2001 From: Lee Killough Date: Sat, 5 Oct 2019 12:02:41 -0400 Subject: [PATCH 4/5] Remove --lib option and argument; make use of legacy handle API instead of introducing a new host handle API --- clients/benchmarks/client.cpp | 19 +------------------ clients/include/rocblas_arguments.hpp | 4 ---- clients/include/testing_gemm.hpp | 12 ++++-------- clients/include/utility.hpp | 6 ------ 4 files changed, 5 insertions(+), 36 deletions(-) diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 65be2031d..13c83bda2 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -554,22 +554,12 @@ try std::string compute_type; std::string initialization; -#ifdef USE_TENSILE_HOST - std::string host_lib_path; -#endif - rocblas_int device_id; bool datafile = rocblas_parse_data(argc, argv); options_description desc("rocblas-bench command line options"); desc.add_options() - // clang-format off - -#ifdef USE_TENSILE_HOST - ("lib", - value(&host_lib_path), - "Host libriary path") -#endif + // clang-format off ("sizem,m", value(&arg.M)->default_value(128), "Specific matrix size: sizem is only applicable to BLAS-2 & BLAS-3: the number of " @@ -805,13 +795,6 @@ try if(copied <= 0 || copied >= sizeof(arg.function)) throw std::invalid_argument("Invalid value for --function"); -#ifdef USE_TENSILE_HOST - int copied_host - = snprintf(arg.host_lib_path, sizeof(arg.host_lib_path), "%s", host_lib_path.c_str()); - if(copied_host <= 0 || copied_host >= sizeof(arg.host_lib_path)) - throw std::invalid_argument("Invalid value for --lib"); -#endif - return run_bench_test(arg); } catch(const std::invalid_argument& exp) diff --git a/clients/include/rocblas_arguments.hpp b/clients/include/rocblas_arguments.hpp index 120e5b017..177787087 100644 --- a/clients/include/rocblas_arguments.hpp +++ b/clients/include/rocblas_arguments.hpp @@ -80,10 +80,6 @@ struct Arguments rocblas_initialization initialization; -#ifdef USE_TENSILE_HOST - char host_lib_path[4096]; -#endif - // Validate input format. // rocblas_gentest.py is expected to conform to this format. // rocblas_gentest.py uses rocblas_common.yaml to generate this format. diff --git a/clients/include/testing_gemm.hpp b/clients/include/testing_gemm.hpp index 7fb89a2e2..6e4459ae9 100644 --- a/clients/include/testing_gemm.hpp +++ b/clients/include/testing_gemm.hpp @@ -92,15 +92,11 @@ void testing_gemm(const Arguments& arg) T h_alpha = arg.get_alpha(); T h_beta = arg.get_beta(); - double gpu_time_used, cpu_time_used; - double rocblas_gflops, cblas_gflops; - double rocblas_error = 0.0; -#ifdef USE_TENSILE_HOST - const char* host_lib_path = arg.host_lib_path; - rocblas_local_handle handle(host_lib_path); -#else + double gpu_time_used, cpu_time_used; + double rocblas_gflops, cblas_gflops; + double rocblas_error = 0.0; rocblas_local_handle handle; -#endif + rocblas_int A_row = transA == rocblas_operation_none ? M : K; rocblas_int A_col = transA == rocblas_operation_none ? K : M; rocblas_int B_row = transB == rocblas_operation_none ? K : N; diff --git a/clients/include/utility.hpp b/clients/include/utility.hpp index f7a944730..99dab94df 100644 --- a/clients/include/utility.hpp +++ b/clients/include/utility.hpp @@ -28,12 +28,6 @@ class rocblas_local_handle { rocblas_create_handle(&handle); } -#ifdef USE_TENSILE_HOST - rocblas_local_handle(const char* lib_path) - { - rocblas_create_host_handle(&handle, lib_path); - } -#endif ~rocblas_local_handle() { rocblas_destroy_handle(handle); From d7d22db32641713a1e6325f1fa92adc806a0a41c Mon Sep 17 00:00:00 2001 From: Lee Killough Date: Mon, 7 Oct 2019 18:16:25 -0400 Subject: [PATCH 5/5] Map values to value categories currently represented as double --- library/src/tensile_host.cpp | 69 ++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp index fbc53f1e6..73480492b 100644 --- a/library/src/tensile_host.cpp +++ b/library/src/tensile_host.cpp @@ -33,6 +33,13 @@ static constexpr auto tensile_datatype = Tensile::DataTyp template <> static constexpr auto tensile_datatype = Tensile::DataType::ComplexDouble; +// return the value category for a value, such as whether it's 0 or 1 +template +constexpr auto value_category(const T* beta) +{ + return *beta == T(0) ? 0.0 : *beta == T(1) ? 1.0 : -12345.0; +} + template auto create_gemm_contraction_problem_strided(rocblas_operation trans_a, rocblas_operation trans_b, @@ -55,27 +62,27 @@ auto create_gemm_contraction_problem_strided(rocblas_operation trans_a, bool transposeA = trans_a != rocblas_operation_none; bool transposeB = trans_b != rocblas_operation_none; - Tensile::DataType dt = tensile_datatype; - Tensile::ContractionProblem problem = Tensile::ContractionProblem::GEMM_Strides(transposeA, - transposeB, - dt, - dt, - dt, - dt, - m, - n, - k, - batchSize, - ld_a, - stride_a, - ld_b, - stride_b, - ld_c, - stride_c, - ld_c, - stride_c, -// TODO: Must fix this to work with complex - std::real(*beta)); + Tensile::DataType dt = tensile_datatype; + Tensile::ContractionProblem problem + = Tensile::ContractionProblem::GEMM_Strides(transposeA, + transposeB, + dt, + dt, + dt, + dt, + m, + n, + k, + batchSize, + ld_a, + stride_a, + ld_b, + stride_b, + ld_c, + stride_c, + ld_c, + stride_c, + value_category(beta)); return problem; } @@ -148,14 +155,23 @@ auto create_gemm_contraction_problem(rocblas_operation trans_a, batchIndices.push_back({2, 2, 2, 2}); - if(*beta != 0.0) + if(value_category(beta) != 0) c = d; Tensile::TensorOps nop; - return Tensile::ContractionProblem( -// TODO: Must fix this to work with complex - a, nop, b, nop, c, nop, d, nop, freeIndices, batchIndices, boundIndices, std::real(*beta)); + return Tensile::ContractionProblem(a, + nop, + b, + nop, + c, + nop, + d, + nop, + freeIndices, + batchIndices, + boundIndices, + value_category(beta)); } template @@ -278,7 +294,8 @@ TensileHost* createTensileHost() } template -inline rocblas_status callTensileContraction(RocblasContractionProblem* problem, TensileHost* host) +inline rocblas_status callTensileContraction(RocblasContractionProblem* problem, + TensileHost* host) { TensileHostCall hostCaller; return hostCaller.runContractionProblem(problem, host);