diff --git a/library/src/blas3/Tensile/gemm.cpp b/library/src/blas3/Tensile/gemm.cpp index 1fac22fb2..6104c7688 100644 --- a/library/src/blas3/Tensile/gemm.cpp +++ b/library/src/blas3/Tensile/gemm.cpp @@ -163,7 +163,6 @@ namespace hipMemcpy(&beta_h, beta, sizeof(T), hipMemcpyDeviceToHost); } - TensileHostCall hostCall; RocblasContractionProblem problem(ContractionProblemType::GEMM, trans_a, trans_b, @@ -179,7 +178,7 @@ namespace C, ld_c); - return callTensileContraction(&problem, handle->host); + return handle->host->runContractionProblem(problem); #else diff --git a/library/src/blas3/Tensile/gemm_strided_batched.cpp b/library/src/blas3/Tensile/gemm_strided_batched.cpp index 516c658ea..1ea3ef1af 100644 --- a/library/src/blas3/Tensile/gemm_strided_batched.cpp +++ b/library/src/blas3/Tensile/gemm_strided_batched.cpp @@ -197,7 +197,6 @@ namespace hipMemcpy(&beta_h, beta, sizeof(T), hipMemcpyDeviceToHost); } - TensileHostCall hostCall; RocblasContractionProblem problem(ContractionProblemType::GEMMStridedBatch, trans_a, trans_b, @@ -217,7 +216,7 @@ namespace stride_c, b_c); - return callTensileContraction(&problem, handle->host); + return handle->host->runContractionProblem(problem); #else rocblas_status validArgs = validateArgs(handle, diff --git a/library/src/handle.cpp b/library/src/handle.cpp index 4010e0a62..b4be3aa9b 100644 --- a/library/src/handle.cpp +++ b/library/src/handle.cpp @@ -1,6 +1,9 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#if BUILD_WITH_TENSILE +#include "Tensile.h" +#endif #include "handle.h" #include #include @@ -10,10 +13,12 @@ ******************************************************************************/ _rocblas_handle::_rocblas_handle() { +#if BUILD_WITH_TENSILE + static int dummy = (tensileInitialize(), 0); #ifdef USE_TENSILE_HOST - host = createTensileHost(); - if(!host) - throw rocblas_status_internal_error; + static TensileHost* hostImpl = createTensileHost(); + host = hostImpl; +#endif #endif // default device is active device @@ -69,10 +74,6 @@ _rocblas_handle::~_rocblas_handle() } if(device_memory) (hipFree)(device_memory); - -#ifdef USE_TENSILE_HOST - delete host; -#endif } /******************************************************************************* diff --git a/library/src/include/tensile_host.hpp b/library/src/include/tensile_host.hpp index 4b728bed8..93de2f451 100644 --- a/library/src/include/tensile_host.hpp +++ b/library/src/include/tensile_host.hpp @@ -12,42 +12,41 @@ enum ContractionProblemType }; template -class RocblasContractionProblem +struct RocblasContractionProblem { -public: ContractionProblemType problem_type; rocblas_operation trans_a; rocblas_operation trans_b; - unsigned long m; - unsigned long n; - unsigned long k; + size_t m; + size_t n; + size_t k; const T alpha; const T* A; - const unsigned long ld_a; - unsigned long stride_a; + size_t ld_a; + size_t stride_a; const T* B; - unsigned long ld_b; - unsigned long stride_b; + size_t ld_b; + size_t stride_b; const T beta; T* C; - unsigned long ld_c; - unsigned long stride_c; - unsigned long batch_size; + size_t ld_c; + size_t stride_c; + size_t batch_size; RocblasContractionProblem(ContractionProblemType problem_type, rocblas_operation trans_a, rocblas_operation trans_b, - unsigned long m, - unsigned long n, - unsigned long k, + size_t m, + size_t n, + size_t k, const T alpha, const T* A, - unsigned long ld_a, + size_t ld_a, const T* B, - unsigned long ld_b, + size_t ld_b, const T beta, T* C, - unsigned long ld_c) + size_t ld_c) : problem_type(problem_type) , trans_a(trans_a) , trans_b(trans_b) @@ -72,21 +71,21 @@ class RocblasContractionProblem RocblasContractionProblem(ContractionProblemType problem_type, rocblas_operation trans_a, rocblas_operation trans_b, - unsigned long m, - unsigned long n, - unsigned long k, + size_t m, + size_t n, + size_t k, const T alpha, const T* A, - unsigned long ld_a, - unsigned long stride_a, + size_t ld_a, + size_t stride_a, const T* B, - unsigned long ld_b, - unsigned long stride_b, + size_t ld_b, + size_t stride_b, const T beta, T* C, - unsigned long ld_c, - unsigned long stride_c, - unsigned long batch_size) + size_t ld_c, + size_t stride_c, + size_t batch_size) : problem_type(problem_type) , trans_a(trans_a) , trans_b(trans_b) @@ -109,24 +108,17 @@ class RocblasContractionProblem } }; -class TensileHost +struct TensileHost { -public: - virtual void initializeHost(const char*) {} -}; + template + rocblas_status runContractionProblem(const RocblasContractionProblem& problem); -template -class TensileHostCall -{ -public: - rocblas_status runContractionProblem(RocblasContractionProblem* problem, TensileHost* host); +protected: + TensileHost() = default; // Prevent instantiating this class except as base class }; TensileHost* createTensileHost(); -template -rocblas_status callTensileContraction(RocblasContractionProblem* problem, TensileHost* host); - #endif #endif // __TENSILE_HOST_HPP__ diff --git a/library/src/rocblas_auxiliary.cpp b/library/src/rocblas_auxiliary.cpp index cdb4fef0b..48dc91fe3 100644 --- a/library/src/rocblas_auxiliary.cpp +++ b/library/src/rocblas_auxiliary.cpp @@ -2,9 +2,6 @@ * Copyright 2016-2019 Advanced Micro Devices, Inc. * * ************************************************************************ */ -#if BUILD_WITH_TENSILE -#include "Tensile.h" -#endif #include "handle.h" #include "logging.h" #include "rocblas-auxiliary.h" @@ -71,20 +68,10 @@ extern "C" rocblas_status rocblas_create_handle(rocblas_handle* handle) // allocate on heap try { -#if BUILD_WITH_TENSILE - static int dummy = (tensileInitialize(), 0); -#endif *handle = new _rocblas_handle(); if((*handle)->layer_mode & rocblas_layer_mode_log_trace) log_trace(*handle, "rocblas_create_handle"); - -#ifdef USE_TENSILE_HOST - const char* lib_path = getenv("ROCBLAS_TENSILE_LIBPATH"); - if(!lib_path) - lib_path = "/opt/rocm/"; // TODO: Set default path - (*handle)->host->initializeHost(lib_path); -#endif } catch(...) { diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp index adad321a0..c8dab7ac3 100644 --- a/library/src/tensile_host.cpp +++ b/library/src/tensile_host.cpp @@ -21,10 +21,6 @@ static constexpr auto tensile_datatype = nullptr; template <> static constexpr auto tensile_datatype = Tensile::DataType::Half; -template <> -static constexpr auto tensile_datatype = Tensile::DataType::Half; - - template <> static constexpr auto tensile_datatype = Tensile::DataType::Float; @@ -45,48 +41,47 @@ constexpr auto value_category(const T& beta) } template -auto create_gemm_contraction_problem_strided(rocblas_operation trans_a, - rocblas_operation trans_b, - unsigned long m, - unsigned long n, - unsigned long k, - const T alpha, - const T* A, - unsigned long ld_a, - unsigned long stride_a, - const T* B, - unsigned long ld_b, - unsigned long stride_b, - const T beta, - T* C, - unsigned long ld_c, - unsigned long stride_c, - unsigned long batchSize) +auto create_gemm_contraction_problem_strided_batched(rocblas_operation trans_a, + rocblas_operation trans_b, + size_t m, + size_t n, + size_t k, + T alpha, + const T* A, + size_t ld_a, + size_t stride_a, + const T* B, + size_t ld_b, + size_t stride_b, + T beta, + T* C, + size_t ld_c, + size_t stride_c, + size_t batchSize) { - bool transposeA = trans_a != rocblas_operation_none; - bool transposeB = trans_b != rocblas_operation_none; - - Tensile::DataType dt = tensile_datatype; - Tensile::ContractionProblem problem - = Tensile::ContractionProblem::GEMM_Strides(transposeA, - transposeB, - dt, - dt, - dt, - dt, - m, - n, - k, - batchSize, - ld_a, - stride_a, - ld_b, - stride_b, - ld_c, - stride_c, - ld_c, - stride_c, - value_category(beta)); + auto transposeA = trans_a != rocblas_operation_none; + auto transposeB = trans_b != rocblas_operation_none; + + auto dt = tensile_datatype; + auto problem = Tensile::ContractionProblem::GEMM_Strides(transposeA, + transposeB, + dt, + dt, + dt, + dt, + m, + n, + k, + batchSize, + ld_a, + stride_a, + ld_b, + stride_b, + ld_c, + stride_c, + ld_c, + stride_c, + value_category(beta)); return problem; } @@ -95,20 +90,20 @@ auto create_gemm_contraction_problem_strided(rocblas_operation trans_a, template auto create_gemm_contraction_problem(rocblas_operation trans_a, rocblas_operation trans_b, - unsigned long m, - unsigned long n, - unsigned long k, - const T alpha, + size_t m, + size_t n, + size_t k, + T alpha, const T* A, - unsigned long ld_a, + size_t ld_a, const T* B, - unsigned long ld_b, - const T beta, + size_t ld_b, + T beta, T* C, - unsigned long ld_c) + size_t ld_c) { - bool transposeA = trans_a != rocblas_operation_none; - bool transposeB = trans_b != rocblas_operation_none; + auto transposeA = trans_a != rocblas_operation_none; + auto transposeB = trans_b != rocblas_operation_none; Tensile::ContractionProblem::FreeIndices freeIndex(2); Tensile::ContractionProblem::BoundIndex boundIndex; @@ -118,31 +113,31 @@ auto create_gemm_contraction_problem(rocblas_operation trans_a, freeIndex[1].isA = false; freeIndex[1].i = freeIndex[1].c = freeIndex[1].d = 1; - Tensile::TensorDescriptor a, b, c, d; + Tensile::TensorDescriptor a, b; + auto dt = tensile_datatype; - Tensile::DataType dt = tensile_datatype; if(transposeA) { - a = Tensile::TensorDescriptor(dt, {k, m}, {1, ld_a}); + a = {dt, {k, m}, {1, ld_a}}; freeIndex[0].i = 1; boundIndex.a = 0; } else { - a = Tensile::TensorDescriptor(dt, {m, k}, {1, ld_a}); + a = {dt, {m, k}, {1, ld_a}}; freeIndex[0].i = 0; boundIndex.a = 1; } if(transposeB) { - b = Tensile::TensorDescriptor(dt, {n, k}, {1, ld_b}); + b = {dt, {n, k}, {1, ld_b}}; freeIndex[1].i = 0; boundIndex.b = 1; } else { - b = Tensile::TensorDescriptor(dt, {k, n}, {1, ld_b}); + b = {dt, {k, n}, {1, ld_b}}; freeIndex[1].i = 1; boundIndex.b = 0; } @@ -151,77 +146,76 @@ auto create_gemm_contraction_problem(rocblas_operation trans_a, Tensile::ContractionProblem::BatchIndices batchIndices; Tensile::ContractionProblem::BoundIndices boundIndices{boundIndex}; - unsigned int batchCount = 1; + batchIndices.push_back({2, 2, 2, 2}); + + Tensile::TensorDescriptor c{dt, {m, n}, {1, ld_c}}; - d = Tensile::TensorDescriptor(dt, {m, n}, {1, ld_c}); + auto batchCount = 1; a.appendDim(batchCount); b.appendDim(batchCount); - d.appendDim(batchCount); - - c = d; + c.appendDim(batchCount); - batchIndices.push_back({2, 2, 2, 2}); + Tensile::TensorOps aops; + if(is_complex && trans_a == rocblas_operation_conjugate_transpose) + aops = {Tensile::TensorOp::Type::ComplexConjugate}; - Tensile::TensorOps nop; + Tensile::TensorOps bops; + if(is_complex && trans_b == rocblas_operation_conjugate_transpose) + bops = {Tensile::TensorOp::Type::ComplexConjugate}; return Tensile::ContractionProblem(a, - nop, + aops, b, - nop, + bops, + c, + {}, c, - nop, - d, - nop, + {}, freeIndices, batchIndices, boundIndices, value_category(beta)); } -template -auto ConstructTensileProblem(RocblasContractionProblem* problem) +template +auto ConstructTensileProblem(const PROB& problem) { - Tensile::ContractionProblem tensile_problem; - switch(problem->problem_type) + switch(problem.problem_type) { case GEMM: - tensile_problem = create_gemm_contraction_problem(problem->trans_a, - problem->trans_b, - problem->m, - problem->n, - problem->k, - problem->alpha, - problem->A, - problem->ld_a, - problem->B, - problem->ld_b, - problem->beta, - problem->C, - problem->ld_c); - break; + return create_gemm_contraction_problem(problem.trans_a, + problem.trans_b, + problem.m, + problem.n, + problem.k, + problem.alpha, + problem.A, + problem.ld_a, + problem.B, + problem.ld_b, + problem.beta, + problem.C, + problem.ld_c); case GEMMStridedBatch: - tensile_problem = create_gemm_contraction_problem_strided(problem->trans_a, - problem->trans_b, - problem->m, - problem->n, - problem->k, - problem->alpha, - problem->A, - problem->ld_a, - problem->stride_a, - problem->B, - problem->ld_b, - problem->stride_b, - problem->beta, - problem->C, - problem->ld_c, - problem->stride_c, - problem->batch_size); - break; + return create_gemm_contraction_problem_strided_batched(problem.trans_a, + problem.trans_b, + problem.m, + problem.n, + problem.k, + problem.alpha, + problem.A, + problem.ld_a, + problem.stride_a, + problem.B, + problem.ld_b, + problem.stride_b, + problem.beta, + problem.C, + problem.ld_c, + problem.stride_c, + problem.batch_size); } - - return tensile_problem; } template @@ -249,20 +243,20 @@ struct rocblas_to_tensile_type }; template -auto GetTensileInputs(RocblasContractionProblem* problem) +auto GetTensileInputs(const RocblasContractionProblem& problem) { using tensile_type = typename rocblas_to_tensile_type::type; Tensile::TypedContractionInputs inputs; - switch(problem->problem_type) + switch(problem.problem_type) { case GEMM: case GEMMStridedBatch: - inputs.a = reinterpret_cast(problem->A); - inputs.b = reinterpret_cast(problem->B); - inputs.c = reinterpret_cast(problem->C); - inputs.d = reinterpret_cast(problem->C); - inputs.alpha = static_cast(problem->alpha); - inputs.beta = static_cast(problem->beta); + inputs.a = reinterpret_cast(problem.A); + inputs.b = reinterpret_cast(problem.B); + inputs.c = reinterpret_cast(problem.C); + inputs.d = reinterpret_cast(problem.C); + memcpy(&inputs.alpha, &problem.alpha, sizeof(T)); + memcpy(&inputs.beta, &problem.beta, sizeof(T)); break; } @@ -271,46 +265,50 @@ auto GetTensileInputs(RocblasContractionProblem* problem) struct TensileHostImpl : TensileHost { - void initializeHost(const char* lib_path) + TensileHostImpl() { + const char* lib_path = getenv("ROCBLAS_TENSILE_LIBPATH"); + if(!lib_path) + lib_path = "/opt/rocm/"; // TODO: Set default path + std::string path(lib_path); - std::string dir = path + "/*co"; + auto dir = path + "/*co"; glob_t glob_result; glob(dir.c_str(), GLOB_TILDE, NULL, &glob_result); - for(unsigned int i = 0; i < glob_result.gl_pathc; ++i) + for(size_t i = 0; i < glob_result.gl_pathc; ++i) adapter.loadCodeObjectFile(glob_result.gl_pathv[i]); globfree(&glob_result); - std::string filename = path + "/TensileLibrary.yaml"; - library = std::dynamic_pointer_cast< + library = std::dynamic_pointer_cast< Tensile::MasterSolutionLibrary>( - Tensile::LoadLibraryFile(filename)); + Tensile::LoadLibraryFile(path + "/TensileLibrary.yaml")); hardware = Tensile::hip::GetCurrentDevice(); } +private: std::shared_ptr> library; std::shared_ptr hardware; Tensile::hip::SolutionAdapter adapter; - - ~TensileHostImpl() override = default; // Ensure virtual base class + friend class TensileHost; }; +TensileHost* createTensileHost() +{ + return new TensileHostImpl; +} + template -rocblas_status TensileHostCall::runContractionProblem(RocblasContractionProblem* problem, - TensileHost* host) +rocblas_status TensileHost::runContractionProblem(const RocblasContractionProblem& problem) try { - auto tensile_problem = ConstructTensileProblem(problem); - auto inputs = GetTensileInputs(problem); - auto hosti = dynamic_cast(host); - if(!hosti) - return rocblas_status_internal_error; - - auto solution = hosti->library->findBestSolution(tensile_problem, *hosti->hardware); - auto result = solution->solve(tensile_problem, inputs, *hosti->hardware); - hosti->adapter.launchKernels(result); + auto host = static_cast(this); + auto tensile_problem = ConstructTensileProblem(problem); + auto inputs = GetTensileInputs(problem); + auto solution = host->library->findBestSolution(tensile_problem, *host->hardware); + auto result = solution->solve(tensile_problem, inputs, *host->hardware); + host->adapter.launchKernels(result); return rocblas_status_success; } catch(...) @@ -318,33 +316,19 @@ catch(...) return rocblas_status_internal_error; } -TensileHost* createTensileHost() -{ - return new TensileHostImpl(); -} - -template -rocblas_status callTensileContraction(RocblasContractionProblem* problem, TensileHost* host) -{ - TensileHostCall hostCaller; - return hostCaller.runContractionProblem(problem, host); -} - -template rocblas_status callTensileContraction(RocblasContractionProblem* problem, - TensileHost* host); - -template rocblas_status callTensileContraction(RocblasContractionProblem* problem, - TensileHost* host); - -template rocblas_status callTensileContraction(RocblasContractionProblem* problem, - TensileHost* host); +template rocblas_status + TensileHost::runContractionProblem(const RocblasContractionProblem& problem); template rocblas_status - callTensileContraction(RocblasContractionProblem* problem, - TensileHost* host); + TensileHost::runContractionProblem(const RocblasContractionProblem& problem); template rocblas_status - callTensileContraction(RocblasContractionProblem* problem, - TensileHost* host); + TensileHost::runContractionProblem(const RocblasContractionProblem& problem); + +template rocblas_status TensileHost::runContractionProblem( + const RocblasContractionProblem& problem); + +template rocblas_status TensileHost::runContractionProblem( + const RocblasContractionProblem& problem); #endif