diff --git a/README.md b/README.md index 4f26d94b4..c8b704924 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,10 @@ oneMKL interfaces is an open-source implementation of oneMKL Data Parallel C++ ( Intel(R) oneAPI Math Kernel Library for Intel GPU Intel GPU + + NVIDIA cuBLAS for NVIDIA GPU + NVIDIA GPU + @@ -81,13 +85,13 @@ cl::sycl::queue cpu_queue(cpu_dev); cl::sycl::queue gpu_queue(gpu_dev); onemkl::blas::gemm(cpu_queue, transA, transB, m, ...); -onemkl::blas::gemm(gpu_queue, transA, transB, m, ...); +onemkl::blas::gemm(gpu_queue, transA, transB, m, ...); ``` How to build an application with run-time dispatching: ```cmd $> clang++ -fsycl –I$ONEMKL/include app.cpp -$> clang++ -fsycl app.o –L$ONEMKL/lib –lonemkl_blas_mklcpu –lonemkl_blas_mklgpu +$> clang++ -fsycl app.o –L$ONEMKL/lib –lonemkl_blas_mklcpu –lonemkl_blas_cublas ``` ### Supported Configurations: @@ -100,6 +104,7 @@ Supported domains: BLAS :------| :-------| :------------------ Intel CPU | Intel(R) oneAPI Math Kernel Library | Dynamic, Static Intel GPU | Intel(R) oneAPI Math Kernel Library | Dynamic, Static + NVIDIA GPU | NVIDIA cuBLAS | Dynamic, Static --- @@ -114,18 +119,19 @@ Supported domains: BLAS - Intel(R) Xeon(R) Processor Family - Accelerators - Intel(R) Processor Graphics GEN9 + - NVIDIA(R) TITAN RTX(TM) (Not tested with other NVIDIA GPU families and products.) --- ### Supported Operating Systems #### Linux* -Operating System | CPU Host/Target | Integrated Graphics from Intel (Intel GPU) -:--- | :--- | :--- -Ubuntu | 18.04.3, 19.04 | 18.04.3, 19.10 -SUSE Linux Enterprise Server* | 15 | *Not supported* -Red Hat Enterprise Linux* (RHEL*) | 8 | *Not supported* -Linux* kernel | *N/A* | 4.11 or higher +Operating System | CPU Host/Target | Integrated Graphics from Intel (Intel GPU) | NVIDIA GPU +:--- | :--- | :--- | :--- +Ubuntu | 18.04.3, 19.04 | 18.04.3, 19.10 | 18.04.3 +SUSE Linux Enterprise Server* | 15 | *Not supported* | *Not supported* +Red Hat Enterprise Linux* (RHEL*) | 8 | *Not supported* | *Not supported* +Linux* kernel | *N/A* | 4.11 or higher | *N/A* --- @@ -174,7 +180,7 @@ Linux* kernel | *N/A* | 4.11 or higher - Linux* + Linux* Any GNU* GCC 5.1 or higher @@ -192,6 +198,11 @@ Linux* kernel | *N/A* | 4.11 or higher Intel(R) oneAPI Math Kernel Library + NVIDIA GPU + Intel project for LLVM* technology + + NVIDIA CUDA SDK + @@ -206,7 +217,9 @@ Python | 3.6 or higher | [PSF](https://docs.python.org/3.6/license.html) [GNU* FORTRAN Compiler](https://gcc.gnu.org/wiki/GFortran) | 7.4.0 or higher | [GNU General Public License, version 3](https://gcc.gnu.org/onlinedocs/gcc-7.5.0/gfortran/Copying.html) [Intel(R) oneAPI DPC++ Compiler](https://software.intel.com/en-us/oneapi/dpc-compiler) | 2021.1-beta05 | [End User License Agreement for the Intel(R) Software Development Products](https://software.intel.com/en-us/license/eula-for-intel-software-development-products) [Intel project for LLVM* technology binary for Intel CPU](https://github.com/intel/llvm/releases) | Daily builds (experimental) tested with [20200331](https://github.com/intel/llvm/releases/download/20200331/dpcpp-compiler.tar.gz) | [Apache License v2](https://github.com/intel/llvm/blob/sycl/sycl/LICENSE.TXT) +[Intel project for LLVM* technology source for NVIDIA GPU](https://github.com/intel/llvm/releases) | Daily source releases: tested with [20200421](https://github.com/intel/llvm/tree/20200421) | [Apache License v2](https://github.com/intel/llvm/blob/sycl/sycl/LICENSE.TXT) [Intel(R) oneAPI Math Kernel Library](https://software.intel.com/en-us/oneapi/onemkl) | 2021.1-beta05 | [Intel Simplified Software License](https://software.intel.com/en-us/license/intel-simplified-software-license) +[NVIDIA CUDA SDK](https://developer.nvidia.com/cublas) | 10.2 | [End User License Agreement](https://docs.nvidia.com/cuda/eula/index.html) [NETLIB LAPACK](https://github.com/Reference-LAPACK/lapack) | 3.7.1 | [BSD like license](http://www.netlib.org/lapack/LICENSE.txt) [Sphinx](https://www.sphinx-doc.org/en/master/) | 2.4.4 | [BSD License](https://github.com/sphinx-doc/sphinx/blob/3.x/LICENSE) @@ -248,6 +261,7 @@ You can specify build options using `-D=`. The following ta CMake Option | Supported Values | Default Value :----------- | :--------------- | :--- BUILD_SHARED_LIBS | True, False | True +ENABLE_CUBLAS_BACKEND | True, False | False ENABLE_MKLCPU_BACKEND | True, False | True ENABLE_MKLGPU_BACKEND | True, False | True ENABLE_MKLCPU_THREAD_TBB | True, False | True diff --git a/cmake/FindcuBLAS.cmake b/cmake/FindcuBLAS.cmake new file mode 100644 index 000000000..06fe6fe59 --- /dev/null +++ b/cmake/FindcuBLAS.cmake @@ -0,0 +1,55 @@ +#========================================================================== +# Copyright (C) Codeplay Software Limited +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# For your convenience, a copy of the License has been included in this +# repository. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#========================================================================= + +find_package(CUDA 10.0 REQUIRED) +find_path(CUBLAS_INCLUDE_DIR "cublas_v2.h" HINTS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +get_filename_component(SYCL_BINARY_DIR ${CMAKE_CXX_COMPILER} DIRECTORY) +# the OpenCL include file from cuda is opencl 1.1 and it is not compatible with DPC++ +# the OpenCL include headers 1.2 onward is required. This is used to bypass NVIDIA OpenCL headers +find_path(OPENCL_INCLUDE_DIR CL/cl.h OpenCL/cl.h +HINTS +${OPENCL_INCLUDE_DIR} +${SYCL_BINARY_DIR}/../include/sycl/ +) +find_library(CUBLAS_LIBRARY cublas) +find_library(CUDA_DRIVER_LIBRARY cuda) +# this is work around to avoid duplication half creation in both cuda and SYCL +add_compile_definitions(CUDA_NO_HALF) + +find_package(Threads REQUIRED) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(cuBLAS + REQUIRED_VARS + CUBLAS_INCLUDE_DIR + CUDA_INCLUDE_DIRS + CUBLAS_LIBRARY + CUDA_LIBRARIES + CUDA_DRIVER_LIBRARY + OPENCL_INCLUDE_DIR +) +if(NOT TARGET ONEMKL::cuBLAS::cuBLAS) + add_library(ONEMKL::cuBLAS::cuBLAS SHARED IMPORTED) + set_target_properties(ONEMKL::cuBLAS::cuBLAS PROPERTIES + IMPORTED_LOCATION ${CUBLAS_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES "${OPENCL_INCLUDE_DIR};${CUDA_INCLUDE_DIRS}" + INTERFACE_LINK_LIBRARIES "Threads::Threads;${CUDA_DRIVER_LIBRARY};${CUDA_LIBRARIES}" + ) + +endif() diff --git a/include/onemkl/blas/blas.hpp b/include/onemkl/blas/blas.hpp index 6dd01af6c..d705048ad 100644 --- a/include/onemkl/blas/blas.hpp +++ b/include/onemkl/blas/blas.hpp @@ -31,6 +31,7 @@ #include "onemkl/blas/predicates.hpp" #include "onemkl/blas/detail/blas_loader.hpp" +#include "onemkl/blas/detail/cublas/blas_ct.hpp" #include "onemkl/blas/detail/mklcpu/blas_ct.hpp" #include "onemkl/blas/detail/mklgpu/blas_ct.hpp" diff --git a/include/onemkl/blas/detail/cublas/blas_ct.hpp b/include/onemkl/blas/detail/cublas/blas_ct.hpp new file mode 100644 index 000000000..4ac19b7f8 --- /dev/null +++ b/include/onemkl/blas/detail/cublas/blas_ct.hpp @@ -0,0 +1,3022 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +// +// Generated based on onemkl/blas/blas.hpp +// + +#ifndef _DETAIL_CUBLAS_BLAS_HPP_ +#define _DETAIL_CUBLAS_BLAS_HPP_ + +#include +#include +#include + +#include "onemkl/detail/backends.hpp" +#include "onemkl/detail/libraries.hpp" +#include "onemkl/types.hpp" + +#include "onemkl_blas_cublas.hpp" + +namespace onemkl { +namespace blas { + +template +static inline void syr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a, std::int64_t lda); +template <> +void syr2(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a, std::int64_t lda) { + syr2_precondition(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::syr2(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); + syr2_postcondition(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void syr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a, std::int64_t lda); +template <> +void syr2(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a, std::int64_t lda) { + syr2_precondition(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::syr2(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); + syr2_postcondition(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void scal(cl::sycl::queue &queue, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void scal(cl::sycl::queue &queue, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx) { + scal_precondition(queue, n, alpha, x, incx); + onemkl::cublas::scal(queue, n, alpha, x, incx); + scal_postcondition(queue, n, alpha, x, incx); +} + +template +static inline void scal(cl::sycl::queue &queue, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void scal(cl::sycl::queue &queue, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx) { + scal_precondition(queue, n, alpha, x, incx); + onemkl::cublas::scal(queue, n, alpha, x, incx); + scal_postcondition(queue, n, alpha, x, incx); +} + +template +static inline void scal(cl::sycl::queue &queue, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void scal(cl::sycl::queue &queue, std::int64_t n, + std::complex alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + scal_precondition(queue, n, alpha, x, incx); + onemkl::cublas::scal(queue, n, alpha, x, incx); + scal_postcondition(queue, n, alpha, x, incx); +} + +template +static inline void scal(cl::sycl::queue &queue, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void scal(cl::sycl::queue &queue, std::int64_t n, + std::complex alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + scal_precondition(queue, n, alpha, x, incx); + onemkl::cublas::scal(queue, n, alpha, x, incx); + scal_postcondition(queue, n, alpha, x, incx); +} + +template +static inline void scal(cl::sycl::queue &queue, std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void scal(cl::sycl::queue &queue, std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + scal_precondition(queue, n, alpha, x, incx); + onemkl::cublas::scal(queue, n, alpha, x, incx); + scal_postcondition(queue, n, alpha, x, incx); +} + +template +static inline void scal(cl::sycl::queue &queue, std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void scal(cl::sycl::queue &queue, std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + scal_precondition(queue, n, alpha, x, incx); + onemkl::cublas::scal(queue, n, alpha, x, incx); + scal_postcondition(queue, n, alpha, x, incx); +} + +template +static inline void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void trmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx) { + trmv_precondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + onemkl::cublas::trmv(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + trmv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +template +static inline void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void trmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx) { + trmv_precondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + onemkl::cublas::trmv(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + trmv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +template +static inline void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx); +template <> +void trmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + std::int64_t lda, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + trmv_precondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + onemkl::cublas::trmv(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + trmv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +template +static inline void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx); +template <> +void trmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + std::int64_t lda, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + trmv_precondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + onemkl::cublas::trmv(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + trmv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +template +static inline void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void tpmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx) { + tpmv_precondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); + onemkl::cublas::tpmv(queue, upper_lower, trans, unit_diag, n, a, x, incx); + tpmv_postcondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +template +static inline void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void tpmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx) { + tpmv_precondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); + onemkl::cublas::tpmv(queue, upper_lower, trans, unit_diag, n, a, x, incx); + tpmv_postcondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +template +static inline void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void tpmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + tpmv_precondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); + onemkl::cublas::tpmv(queue, upper_lower, trans, unit_diag, n, a, x, incx); + tpmv_postcondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +template +static inline void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void tpmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + tpmv_precondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); + onemkl::cublas::tpmv(queue, upper_lower, trans, unit_diag, n, a, x, incx); + tpmv_postcondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +template +static inline void spr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &a); +template <> +void spr(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &a) { + spr_precondition(queue, upper_lower, n, alpha, x, incx, a); + onemkl::cublas::spr(queue, upper_lower, n, alpha, x, incx, a); + spr_postcondition(queue, upper_lower, n, alpha, x, incx, a); +} + +template +static inline void spr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &a); +template <> +void spr(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &a) { + spr_precondition(queue, upper_lower, n, alpha, x, incx, a); + onemkl::cublas::spr(queue, upper_lower, n, alpha, x, incx, a); + spr_postcondition(queue, upper_lower, n, alpha, x, incx, a); +} + +template +static inline void hpmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); +template <> +void hpmv(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, + std::int64_t incy) { + hpmv_precondition(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); + onemkl::cublas::hpmv(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); + hpmv_postcondition(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); +} + +template +static inline void hpmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); +template <> +void hpmv(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, + std::int64_t incy) { + hpmv_precondition(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); + onemkl::cublas::hpmv(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); + hpmv_postcondition(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); +} + +template +static inline void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, float beta, cl::sycl::buffer &c, + std::int64_t ldc); +template <> +void syrk(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, std::int64_t n, std::int64_t k, + float alpha, cl::sycl::buffer &a, + std::int64_t lda, float beta, + cl::sycl::buffer &c, std::int64_t ldc) { + syrk_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + onemkl::cublas::syrk(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + syrk_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +template +static inline void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, double beta, cl::sycl::buffer &c, + std::int64_t ldc); +template <> +void syrk(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, std::int64_t n, std::int64_t k, + double alpha, cl::sycl::buffer &a, + std::int64_t lda, double beta, + cl::sycl::buffer &c, std::int64_t ldc) { + syrk_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + onemkl::cublas::syrk(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + syrk_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +template +static inline void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void syrk( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc) { + syrk_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + onemkl::cublas::syrk(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + syrk_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +template +static inline void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void syrk( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc) { + syrk_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + onemkl::cublas::syrk(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + syrk_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +template +static inline void her2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, + std::int64_t incy, cl::sycl::buffer, 1> &a, + std::int64_t lda); +template <> +void her2( + cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda) { + her2_precondition(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::her2(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); + her2_postcondition(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void her2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, + std::int64_t incy, cl::sycl::buffer, 1> &a, + std::int64_t lda); +template <> +void her2( + cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda) { + her2_precondition(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::her2(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); + her2_postcondition(queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void hbmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void hbmv( + cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy) { + hbmv_precondition(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::hbmv(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); + hbmv_postcondition(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void hbmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void hbmv( + cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy) { + hbmv_precondition(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::hbmv(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); + hbmv_postcondition(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void rot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, float c, + float s); +template <> +void rot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy, float c, float s) { + rot_precondition(queue, n, x, incx, y, incy, c, s); + onemkl::cublas::rot(queue, n, x, incx, y, incy, c, s); + rot_postcondition(queue, n, x, incx, y, incy, c, s); +} + +template +static inline void rot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, double c, + double s); +template <> +void rot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy, double c, double s) { + rot_precondition(queue, n, x, incx, y, incy, c, s); + onemkl::cublas::rot(queue, n, x, incx, y, incy, c, s); + rot_postcondition(queue, n, x, incx, y, incy, c, s); +} + +template +static inline void rot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, float c, + float s); +template <> +void rot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + float c, float s) { + rot_precondition(queue, n, x, incx, y, incy, c, s); + onemkl::cublas::rot(queue, n, x, incx, y, incy, c, s); + rot_postcondition(queue, n, x, incx, y, incy, c, s); +} + +template +static inline void rot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, + double c, double s); +template <> +void rot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + double c, double s) { + rot_precondition(queue, n, x, incx, y, incy, c, s); + onemkl::cublas::rot(queue, n, x, incx, y, incy, c, s); + rot_postcondition(queue, n, x, incx, y, incy, c, s); +} + +template +static inline void axpy(cl::sycl::queue &queue, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void axpy(cl::sycl::queue &queue, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy) { + axpy_precondition(queue, n, alpha, x, incx, y, incy); + onemkl::cublas::axpy(queue, n, alpha, x, incx, y, incy); + axpy_postcondition(queue, n, alpha, x, incx, y, incy); +} + +template +static inline void axpy(cl::sycl::queue &queue, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void axpy(cl::sycl::queue &queue, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy) { + axpy_precondition(queue, n, alpha, x, incx, y, incy); + onemkl::cublas::axpy(queue, n, alpha, x, incx, y, incy); + axpy_postcondition(queue, n, alpha, x, incx, y, incy); +} + +template +static inline void axpy(cl::sycl::queue &queue, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void axpy(cl::sycl::queue &queue, std::int64_t n, + std::complex alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy) { + axpy_precondition(queue, n, alpha, x, incx, y, incy); + onemkl::cublas::axpy(queue, n, alpha, x, incx, y, incy); + axpy_postcondition(queue, n, alpha, x, incx, y, incy); +} + +template +static inline void axpy(cl::sycl::queue &queue, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void axpy(cl::sycl::queue &queue, std::int64_t n, + std::complex alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy) { + axpy_precondition(queue, n, alpha, x, incx, y, incy); + onemkl::cublas::axpy(queue, n, alpha, x, incx, y, incy); + axpy_postcondition(queue, n, alpha, x, incx, y, incy); +} + +template +static inline void gerc(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, + std::int64_t incy, cl::sycl::buffer, 1> &a, + std::int64_t lda); +template <> +void gerc( + cl::sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda) { + gerc_precondition(queue, m, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::gerc(queue, m, n, alpha, x, incx, y, incy, a, lda); + gerc_postcondition(queue, m, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void gerc(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, + std::int64_t incy, cl::sycl::buffer, 1> &a, + std::int64_t lda); +template <> +void gerc( + cl::sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda) { + gerc_precondition(queue, m, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::gerc(queue, m, n, alpha, x, incx, y, incy, a, lda); + gerc_postcondition(queue, m, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + float beta, cl::sycl::buffer &c, std::int64_t ldc); +template <> +void syr2k(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, std::int64_t n, std::int64_t k, + float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc) { + syr2k_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::syr2k(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + syr2k_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + double beta, cl::sycl::buffer &c, std::int64_t ldc); +template <> +void syr2k(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, std::int64_t n, std::int64_t k, + double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc) { + syr2k_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::syr2k(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + syr2k_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void syr2k( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + syr2k_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::syr2k(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + syr2k_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void syr2k( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + syr2k_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::syr2k(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + syr2k_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, float beta, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void gemv(cl::sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, + float beta, cl::sycl::buffer &y, + std::int64_t incy) { + gemv_precondition(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::gemv(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + gemv_postcondition(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void gemv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + double alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, double beta, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void gemv(cl::sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, + double beta, cl::sycl::buffer &y, + std::int64_t incy) { + gemv_precondition(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::gemv(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + gemv_postcondition(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void gemv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void gemv( + cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy) { + gemv_precondition(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::gemv(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + gemv_postcondition(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void gemv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void gemv( + cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy) { + gemv_precondition(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::gemv(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + gemv_postcondition(queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void her(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &a, std::int64_t lda); +template <> +void her(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &a, + std::int64_t lda) { + her_precondition(queue, upper_lower, n, alpha, x, incx, a, lda); + onemkl::cublas::her(queue, upper_lower, n, alpha, x, incx, a, lda); + her_postcondition(queue, upper_lower, n, alpha, x, incx, a, lda); +} + +template +static inline void her(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &a, std::int64_t lda); +template <> +void her(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &a, + std::int64_t lda) { + her_precondition(queue, upper_lower, n, alpha, x, incx, a, lda); + onemkl::cublas::her(queue, upper_lower, n, alpha, x, incx, a, lda); + her_postcondition(queue, upper_lower, n, alpha, x, incx, a, lda); +} + +template +static inline void hpr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &a); +template <> +void hpr(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &a) { + hpr_precondition(queue, upper_lower, n, alpha, x, incx, a); + onemkl::cublas::hpr(queue, upper_lower, n, alpha, x, incx, a); + hpr_postcondition(queue, upper_lower, n, alpha, x, incx, a); +} + +template +static inline void hpr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &a); +template <> +void hpr(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &a) { + hpr_precondition(queue, upper_lower, n, alpha, x, incx, a); + onemkl::cublas::hpr(queue, upper_lower, n, alpha, x, incx, a); + hpr_postcondition(queue, upper_lower, n, alpha, x, incx, a); +} + +template +static inline void iamin(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); +template <> +void iamin(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result) { + iamin_precondition(queue, n, x, incx, result); + onemkl::cublas::iamin(queue, n, x, incx, result); + iamin_postcondition(queue, n, x, incx, result); +} + +template +static inline void iamin(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); +template <> +void iamin(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result) { + iamin_precondition(queue, n, x, incx, result); + onemkl::cublas::iamin(queue, n, x, incx, result); + iamin_postcondition(queue, n, x, incx, result); +} + +template +static inline void iamin(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer &result); +template <> +void iamin(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer &result) { + iamin_precondition(queue, n, x, incx, result); + onemkl::cublas::iamin(queue, n, x, incx, result); + iamin_postcondition(queue, n, x, incx, result); +} + +template +static inline void iamin(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer &result); +template <> +void iamin(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer &result) { + iamin_precondition(queue, n, x, incx, result); + onemkl::cublas::iamin(queue, n, x, incx, result); + iamin_postcondition(queue, n, x, incx, result); +} + +template +static inline void gemm_batch(cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, + cl::sycl::buffer &m, + cl::sycl::buffer &n, + cl::sycl::buffer &k, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, + cl::sycl::buffer &beta, cl::sycl::buffer &c, + cl::sycl::buffer &ldc, std::int64_t group_count, + cl::sycl::buffer &group_size); +template <> +void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, cl::sycl::buffer &beta, + cl::sycl::buffer &c, cl::sycl::buffer &ldc, std::int64_t group_count, + cl::sycl::buffer &group_size) { + gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); + onemkl::cublas::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); + gemm_batch_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); +} + +template +static inline void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, cl::sycl::buffer &beta, + cl::sycl::buffer &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size); +template <> +void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, cl::sycl::buffer &beta, + cl::sycl::buffer &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size) { + gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); + onemkl::cublas::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); + gemm_batch_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); +} + +template +static inline void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer, 1> &alpha, cl::sycl::buffer, 1> &a, + cl::sycl::buffer &lda, cl::sycl::buffer, 1> &b, + cl::sycl::buffer &ldb, cl::sycl::buffer, 1> &beta, + cl::sycl::buffer, 1> &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size); +template <> +void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer, 1> &alpha, cl::sycl::buffer, 1> &a, + cl::sycl::buffer &lda, cl::sycl::buffer, 1> &b, + cl::sycl::buffer &ldb, cl::sycl::buffer, 1> &beta, + cl::sycl::buffer, 1> &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size) { + gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); + onemkl::cublas::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); + gemm_batch_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); +} + +template +static inline void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer, 1> &alpha, cl::sycl::buffer, 1> &a, + cl::sycl::buffer &lda, cl::sycl::buffer, 1> &b, + cl::sycl::buffer &ldb, cl::sycl::buffer, 1> &beta, + cl::sycl::buffer, 1> &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size); +template <> +void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer, 1> &alpha, cl::sycl::buffer, 1> &a, + cl::sycl::buffer &lda, cl::sycl::buffer, 1> &b, + cl::sycl::buffer &ldb, cl::sycl::buffer, 1> &beta, + cl::sycl::buffer, 1> &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size) { + gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); + onemkl::cublas::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); + gemm_batch_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size); +} + +template +static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + cl::sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); +template <> +void gemm_batch( + cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, cl::sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); + onemkl::cublas::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); + gemm_batch_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + +template +static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, double beta, + cl::sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); +template <> +void gemm_batch( + cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, double alpha, cl::sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + double beta, cl::sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); + onemkl::cublas::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); + gemm_batch_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + +template +static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, + std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); +template <> +void gemm_batch( + cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); + onemkl::cublas::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); + gemm_batch_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + +template +static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, + std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); +template <> +void gemm_batch( + cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); + onemkl::cublas::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); + gemm_batch_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + +template +static inline void spmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &a, cl::sycl::buffer &x, + std::int64_t incx, float beta, cl::sycl::buffer &y, + std::int64_t incy); +template <> +void spmv(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, float alpha, + cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx, + float beta, cl::sycl::buffer &y, + std::int64_t incy) { + spmv_precondition(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); + onemkl::cublas::spmv(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); + spmv_postcondition(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); +} + +template +static inline void spmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &a, cl::sycl::buffer &x, + std::int64_t incx, double beta, cl::sycl::buffer &y, + std::int64_t incy); +template <> +void spmv(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, double alpha, + cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx, + double beta, cl::sycl::buffer &y, + std::int64_t incy) { + spmv_precondition(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); + onemkl::cublas::spmv(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); + spmv_postcondition(queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); +} + +template +static inline void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemm_ext(cl::sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + float beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemm_ext_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm_ext(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_ext_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, + offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + int8_t ao, cl::sycl::buffer &b, std::int64_t ldb, + uint8_t bo, float beta, cl::sycl::buffer &c, + std::int64_t ldc, cl::sycl::buffer &co); +template <> +void gemm_ext( + cl::sycl::queue &queue, transpose transa, transpose transb, offset offsetc, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, + int8_t ao, cl::sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, + cl::sycl::buffer &c, std::int64_t ldc, cl::sycl::buffer &co) { + gemm_ext_precondition(queue, transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, + beta, c, ldc, co); + onemkl::cublas::gemm_ext(queue, transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, + beta, c, ldc, co); + gemm_ext_postcondition(queue, transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, + beta, c, ldc, co); +} + +template +static inline void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemm_ext(cl::sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + float beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemm_ext_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm_ext(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_ext_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemm_ext(cl::sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + double beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemm_ext_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm_ext(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_ext_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); +template <> +void gemm_ext( + cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc) { + gemm_ext_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm_ext(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_ext_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, + std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void gemm_ext( + cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc) { + gemm_ext_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm_ext(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_ext_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, half alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, half beta, + cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemm_ext(cl::sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, half alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + half beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemm_ext_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm_ext(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_ext_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void swap(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy); +template <> +void swap(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy) { + swap_precondition(queue, n, x, incx, y, incy); + onemkl::cublas::swap(queue, n, x, incx, y, incy); + swap_postcondition(queue, n, x, incx, y, incy); +} + +template +static inline void swap(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy); +template <> +void swap(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy) { + swap_precondition(queue, n, x, incx, y, incy); + onemkl::cublas::swap(queue, n, x, incx, y, incy); + swap_postcondition(queue, n, x, incx, y, incy); +} + +template +static inline void swap(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void swap(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy) { + swap_precondition(queue, n, x, incx, y, incy); + onemkl::cublas::swap(queue, n, x, incx, y, incy); + swap_postcondition(queue, n, x, incx, y, incy); +} + +template +static inline void swap(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void swap(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy) { + swap_precondition(queue, n, x, incx, y, incy); + onemkl::cublas::swap(queue, n, x, incx, y, incy); + swap_postcondition(queue, n, x, incx, y, incy); +} + +template +static inline void geru(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, + std::int64_t incy, cl::sycl::buffer, 1> &a, + std::int64_t lda); +template <> +void geru( + cl::sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda) { + geru_precondition(queue, m, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::geru(queue, m, n, alpha, x, incx, y, incy, a, lda); + geru_postcondition(queue, m, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void geru(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, + std::int64_t incy, cl::sycl::buffer, 1> &a, + std::int64_t lda); +template <> +void geru( + cl::sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda) { + geru_precondition(queue, m, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::geru(queue, m, n, alpha, x, incx, y, incy, a, lda); + geru_postcondition(queue, m, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void nrm2(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer &result); +template <> +void nrm2(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer &result) { + nrm2_precondition(queue, n, x, incx, result); + onemkl::cublas::nrm2(queue, n, x, incx, result); + nrm2_postcondition(queue, n, x, incx, result); +} + +template +static inline void nrm2(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer &result); +template <> +void nrm2(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer &result) { + nrm2_precondition(queue, n, x, incx, result); + onemkl::cublas::nrm2(queue, n, x, incx, result); + nrm2_postcondition(queue, n, x, incx, result); +} + +template +static inline void nrm2(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); +template <> +void nrm2(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result) { + nrm2_precondition(queue, n, x, incx, result); + onemkl::cublas::nrm2(queue, n, x, incx, result); + nrm2_postcondition(queue, n, x, incx, result); +} + +template +static inline void nrm2(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); +template <> +void nrm2(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result) { + nrm2_precondition(queue, n, x, incx, result); + onemkl::cublas::nrm2(queue, n, x, incx, result); + nrm2_postcondition(queue, n, x, incx, result); +} + +template +static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + float beta, cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemm(cl::sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + float beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemm(cl::sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + double beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void gemm( + cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc) { + gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void gemm( + cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc) { + gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, + cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemm(cl::sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, half alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + half beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::gemm(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void herk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, float alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, float beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void herk( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + float alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, float beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + herk_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + onemkl::cublas::herk(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + herk_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +template +static inline void herk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, double alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, double beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void herk( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + double alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, double beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + herk_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + onemkl::cublas::herk(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); + herk_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +template +static inline void ger(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a, std::int64_t lda); +template <> +void ger(cl::sycl::queue &queue, std::int64_t m, + std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a, std::int64_t lda) { + ger_precondition(queue, m, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::ger(queue, m, n, alpha, x, incx, y, incy, a, lda); + ger_postcondition(queue, m, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void ger(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a, std::int64_t lda); +template <> +void ger(cl::sycl::queue &queue, std::int64_t m, + std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a, std::int64_t lda) { + ger_precondition(queue, m, n, alpha, x, incx, y, incy, a, lda); + onemkl::cublas::ger(queue, m, n, alpha, x, incx, y, incy, a, lda); + ger_postcondition(queue, m, n, alpha, x, incx, y, incy, a, lda); +} + +template +static inline void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb); +template <> +void trsm(cl::sycl::queue &queue, side left_right, + uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb) { + trsm_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + onemkl::cublas::trsm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + trsm_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); +} + +template +static inline void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb); +template <> +void trsm(cl::sycl::queue &queue, side left_right, + uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb) { + trsm_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + onemkl::cublas::trsm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + trsm_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); +} + +template +static inline void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb); +template <> +void trsm( + cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb) { + trsm_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + onemkl::cublas::trsm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + trsm_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); +} + +template +static inline void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb); +template <> +void trsm( + cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb) { + trsm_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + onemkl::cublas::trsm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + trsm_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); +} + +template +static inline void dotu(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &result); +template <> +void dotu(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy, + cl::sycl::buffer, 1> &result) { + dotu_precondition(queue, n, x, incx, y, incy, result); + onemkl::cublas::dotu(queue, n, x, incx, y, incy, result); + dotu_postcondition(queue, n, x, incx, y, incy, result); +} + +template +static inline void dotu(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &result); +template <> +void dotu(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy, + cl::sycl::buffer, 1> &result) { + dotu_precondition(queue, n, x, incx, y, incy, result); + onemkl::cublas::dotu(queue, n, x, incx, y, incy, result); + dotu_postcondition(queue, n, x, incx, y, incy, result); +} + +template +static inline void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, + std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void hemm( + cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + hemm_precondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::hemm(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + hemm_postcondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, + std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void hemm( + cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + hemm_precondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::hemm(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + hemm_postcondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void hpr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, + std::int64_t incy, cl::sycl::buffer, 1> &a); +template <> +void hpr2(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy, + cl::sycl::buffer, 1> &a) { + hpr2_precondition(queue, upper_lower, n, alpha, x, incx, y, incy, a); + onemkl::cublas::hpr2(queue, upper_lower, n, alpha, x, incx, y, incy, a); + hpr2_postcondition(queue, upper_lower, n, alpha, x, incx, y, incy, a); +} + +template +static inline void hpr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, + std::int64_t incy, cl::sycl::buffer, 1> &a); +template <> +void hpr2(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy, + cl::sycl::buffer, 1> &a) { + hpr2_precondition(queue, upper_lower, n, alpha, x, incx, y, incy, a); + onemkl::cublas::hpr2(queue, upper_lower, n, alpha, x, incx, y, incy, a); + hpr2_postcondition(queue, upper_lower, n, alpha, x, incx, y, incy, a); +} + +template +static inline void gbmv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::int64_t kl, std::int64_t ku, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, float beta, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void gbmv(cl::sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::int64_t kl, + std::int64_t ku, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, + float beta, cl::sycl::buffer &y, + std::int64_t incy) { + gbmv_precondition(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::gbmv(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + gbmv_postcondition(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void gbmv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::int64_t kl, std::int64_t ku, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, double beta, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void gbmv(cl::sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::int64_t kl, + std::int64_t ku, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, + double beta, cl::sycl::buffer &y, + std::int64_t incy) { + gbmv_precondition(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::gbmv(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + gbmv_postcondition(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void gbmv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::int64_t kl, std::int64_t ku, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); +template <> +void gbmv( + cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, std::int64_t kl, + std::int64_t ku, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, std::int64_t incy) { + gbmv_precondition(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::gbmv(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + gbmv_postcondition(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void gbmv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::int64_t kl, std::int64_t ku, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); +template <> +void gbmv( + cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, std::int64_t kl, + std::int64_t ku, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, std::int64_t incy) { + gbmv_precondition(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::gbmv(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + gbmv_postcondition(queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, std::int64_t k, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &x, std::int64_t incx); +template <> +void tbmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx) { + tbmv_precondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + onemkl::cublas::tbmv(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + tbmv_postcondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +template +static inline void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, std::int64_t k, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &x, std::int64_t incx); +template <> +void tbmv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx) { + tbmv_precondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + onemkl::cublas::tbmv(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + tbmv_postcondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +template +static inline void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, std::int64_t k, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx); +template <> +void tbmv( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx) { + tbmv_precondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + onemkl::cublas::tbmv(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + tbmv_postcondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +template +static inline void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, std::int64_t k, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void tbmv( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx) { + tbmv_precondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + onemkl::cublas::tbmv(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + tbmv_postcondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +template +static inline void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, + std::int64_t n, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + float beta, cl::sycl::buffer &c, std::int64_t ldc); +template <> +void symm(cl::sycl::queue &queue, side left_right, + uplo upper_lower, std::int64_t m, std::int64_t n, + float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc) { + symm_precondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::symm(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + symm_postcondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, + std::int64_t n, double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + double beta, cl::sycl::buffer &c, std::int64_t ldc); +template <> +void symm(cl::sycl::queue &queue, side left_right, + uplo upper_lower, std::int64_t m, std::int64_t n, + double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc) { + symm_precondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::symm(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + symm_postcondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, + std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void symm( + cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + symm_precondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::symm(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + symm_postcondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, + std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void symm( + cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + symm_precondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::symm(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + symm_postcondition(queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void dotc(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &result); +template <> +void dotc(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy, + cl::sycl::buffer, 1> &result) { + dotc_precondition(queue, n, x, incx, y, incy, result); + onemkl::cublas::dotc(queue, n, x, incx, y, incy, result); + dotc_postcondition(queue, n, x, incx, y, incy, result); +} + +template +static inline void dotc(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &result); +template <> +void dotc(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy, + cl::sycl::buffer, 1> &result) { + dotc_precondition(queue, n, x, incx, y, incy, result); + onemkl::cublas::dotc(queue, n, x, incx, y, incy, result); + dotc_postcondition(queue, n, x, incx, y, incy, result); +} + +template +static inline void syr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &a, std::int64_t lda); +template <> +void syr(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &a, std::int64_t lda) { + syr_precondition(queue, upper_lower, n, alpha, x, incx, a, lda); + onemkl::cublas::syr(queue, upper_lower, n, alpha, x, incx, a, lda); + syr_postcondition(queue, upper_lower, n, alpha, x, incx, a, lda); +} + +template +static inline void syr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &a, std::int64_t lda); +template <> +void syr(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &a, std::int64_t lda) { + syr_precondition(queue, upper_lower, n, alpha, x, incx, a, lda); + onemkl::cublas::syr(queue, upper_lower, n, alpha, x, incx, a, lda); + syr_postcondition(queue, upper_lower, n, alpha, x, incx, a, lda); +} + +template +static inline void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb); +template <> +void trmm(cl::sycl::queue &queue, side left_right, + uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb) { + trmm_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + onemkl::cublas::trmm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + trmm_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); +} + +template +static inline void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb); +template <> +void trmm(cl::sycl::queue &queue, side left_right, + uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb) { + trmm_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + onemkl::cublas::trmm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + trmm_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); +} + +template +static inline void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb); +template <> +void trmm( + cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb) { + trmm_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + onemkl::cublas::trmm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + trmm_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); +} + +template +static inline void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb); +template <> +void trmm( + cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb) { + trmm_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + onemkl::cublas::trmm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); + trmm_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb); +} + +template +static inline void rotmg(cl::sycl::queue &queue, cl::sycl::buffer &d1, + cl::sycl::buffer &d2, cl::sycl::buffer &x1, float y1, + cl::sycl::buffer ¶m); +template <> +void rotmg(cl::sycl::queue &queue, + cl::sycl::buffer &d1, + cl::sycl::buffer &d2, + cl::sycl::buffer &x1, float y1, + cl::sycl::buffer ¶m) { + rotmg_precondition(queue, d1, d2, x1, y1, param); + onemkl::cublas::rotmg(queue, d1, d2, x1, y1, param); + rotmg_postcondition(queue, d1, d2, x1, y1, param); +} + +template +static inline void rotmg(cl::sycl::queue &queue, cl::sycl::buffer &d1, + cl::sycl::buffer &d2, cl::sycl::buffer &x1, + double y1, cl::sycl::buffer ¶m); +template <> +void rotmg(cl::sycl::queue &queue, + cl::sycl::buffer &d1, + cl::sycl::buffer &d2, + cl::sycl::buffer &x1, double y1, + cl::sycl::buffer ¶m) { + rotmg_precondition(queue, d1, d2, x1, y1, param); + onemkl::cublas::rotmg(queue, d1, d2, x1, y1, param); + rotmg_postcondition(queue, d1, d2, x1, y1, param); +} + +template +static inline void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void tpsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx) { + tpsv_precondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); + onemkl::cublas::tpsv(queue, upper_lower, trans, unit_diag, n, a, x, incx); + tpsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +template +static inline void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void tpsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, + cl::sycl::buffer &x, std::int64_t incx) { + tpsv_precondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); + onemkl::cublas::tpsv(queue, upper_lower, trans, unit_diag, n, a, x, incx); + tpsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +template +static inline void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void tpsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + tpsv_precondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); + onemkl::cublas::tpsv(queue, upper_lower, trans, unit_diag, n, a, x, incx); + tpsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +template +static inline void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void tpsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + tpsv_precondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); + onemkl::cublas::tpsv(queue, upper_lower, trans, unit_diag, n, a, x, incx); + tpsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +template +static inline void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void trsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx) { + trsv_precondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + onemkl::cublas::trsv(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + trsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +template +static inline void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx); +template <> +void trsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx) { + trsv_precondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + onemkl::cublas::trsv(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + trsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +template +static inline void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx); +template <> +void trsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + std::int64_t lda, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + trsv_precondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + onemkl::cublas::trsv(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + trsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +template +static inline void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx); +template <> +void trsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + std::int64_t lda, + cl::sycl::buffer, 1> &x, + std::int64_t incx) { + trsv_precondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + onemkl::cublas::trsv(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); + trsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +template +static inline void copy(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy); +template <> +void copy(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy) { + copy_precondition(queue, n, x, incx, y, incy); + onemkl::cublas::copy(queue, n, x, incx, y, incy); + copy_postcondition(queue, n, x, incx, y, incy); +} + +template +static inline void copy(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy); +template <> +void copy(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy) { + copy_precondition(queue, n, x, incx, y, incy); + onemkl::cublas::copy(queue, n, x, incx, y, incy); + copy_postcondition(queue, n, x, incx, y, incy); +} + +template +static inline void copy(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void copy(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy) { + copy_precondition(queue, n, x, incx, y, incy); + onemkl::cublas::copy(queue, n, x, incx, y, incy); + copy_postcondition(queue, n, x, incx, y, incy); +} + +template +static inline void copy(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void copy(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer, 1> &y, + std::int64_t incy) { + copy_precondition(queue, n, x, incx, y, incy); + onemkl::cublas::copy(queue, n, x, incx, y, incy); + copy_postcondition(queue, n, x, incx, y, incy); +} + +template +static inline void hemv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void hemv( + cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy) { + hemv_precondition(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::hemv(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); + hemv_postcondition(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void hemv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); +template <> +void hemv( + cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy) { + hemv_precondition(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::hemv(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); + hemv_postcondition(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemmt(cl::sycl::queue &queue, uplo upper_lower, + transpose transa, transpose transb, std::int64_t n, + std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + float beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemmt_precondition(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); + onemkl::cublas::gemmt(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); + gemmt_postcondition(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +} + +template +static inline void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc); +template <> +void gemmt(cl::sycl::queue &queue, uplo upper_lower, + transpose transa, transpose transb, std::int64_t n, + std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, + double beta, cl::sycl::buffer &c, + std::int64_t ldc) { + gemmt_precondition(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); + onemkl::cublas::gemmt(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); + gemmt_postcondition(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +} + +template +static inline void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); +template <> +void gemmt( + cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc) { + gemmt_precondition(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); + onemkl::cublas::gemmt(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); + gemmt_postcondition(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +} + +template +static inline void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); +template <> +void gemmt( + cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc) { + gemmt_precondition(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); + onemkl::cublas::gemmt(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); + gemmt_postcondition(queue, upper_lower, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +} + +template +static inline void sbmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, float beta, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void sbmv(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, + float beta, cl::sycl::buffer &y, + std::int64_t incy) { + sbmv_precondition(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::sbmv(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); + sbmv_postcondition(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void sbmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + double alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, double beta, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void sbmv(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, + double beta, cl::sycl::buffer &y, + std::int64_t incy) { + sbmv_precondition(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::sbmv(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); + sbmv_postcondition(queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void asum(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer &result); +template <> +void asum(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer &result) { + asum_precondition(queue, n, x, incx, result); + onemkl::cublas::asum(queue, n, x, incx, result); + asum_postcondition(queue, n, x, incx, result); +} + +template +static inline void asum(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer &result); +template <> +void asum(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer &result) { + asum_precondition(queue, n, x, incx, result); + onemkl::cublas::asum(queue, n, x, incx, result); + asum_postcondition(queue, n, x, incx, result); +} + +template +static inline void asum(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); +template <> +void asum(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result) { + asum_precondition(queue, n, x, incx, result); + onemkl::cublas::asum(queue, n, x, incx, result); + asum_postcondition(queue, n, x, incx, result); +} + +template +static inline void asum(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); +template <> +void asum(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result) { + asum_precondition(queue, n, x, incx, result); + onemkl::cublas::asum(queue, n, x, incx, result); + asum_postcondition(queue, n, x, incx, result); +} + +template +static inline void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, std::int64_t k, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &x, std::int64_t incx); +template <> +void tbsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx) { + tbsv_precondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + onemkl::cublas::tbsv(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + tbsv_postcondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +template +static inline void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, std::int64_t k, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &x, std::int64_t incx); +template <> +void tbsv(cl::sycl::queue &queue, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx) { + tbsv_precondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + onemkl::cublas::tbsv(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + tbsv_postcondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +template +static inline void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, std::int64_t k, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, + std::int64_t incx); +template <> +void tbsv( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx) { + tbsv_precondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + onemkl::cublas::tbsv(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + tbsv_postcondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +template +static inline void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t n, std::int64_t k, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); +template <> +void tbsv( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx) { + tbsv_precondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + onemkl::cublas::tbsv(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); + tbsv_postcondition(queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +template +static inline void spr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a); +template <> +void spr2(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a) { + spr2_precondition(queue, upper_lower, n, alpha, x, incx, y, incy, a); + onemkl::cublas::spr2(queue, upper_lower, n, alpha, x, incx, y, incy, a); + spr2_postcondition(queue, upper_lower, n, alpha, x, incx, y, incy, a); +} + +template +static inline void spr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a); +template <> +void spr2(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &a) { + spr2_precondition(queue, upper_lower, n, alpha, x, incx, y, incy, a); + onemkl::cublas::spr2(queue, upper_lower, n, alpha, x, incx, y, incy, a); + spr2_postcondition(queue, upper_lower, n, alpha, x, incx, y, incy, a); +} + +template +static inline void iamax(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); +template <> +void iamax(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result) { + iamax_precondition(queue, n, x, incx, result); + onemkl::cublas::iamax(queue, n, x, incx, result); + iamax_postcondition(queue, n, x, incx, result); +} + +template +static inline void iamax(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); +template <> +void iamax(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result) { + iamax_precondition(queue, n, x, incx, result); + onemkl::cublas::iamax(queue, n, x, incx, result); + iamax_postcondition(queue, n, x, incx, result); +} + +template +static inline void iamax(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer &result); +template <> +void iamax(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer &result) { + iamax_precondition(queue, n, x, incx, result); + onemkl::cublas::iamax(queue, n, x, incx, result); + iamax_postcondition(queue, n, x, incx, result); +} + +template +static inline void iamax(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer &result); +template <> +void iamax(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer, 1> &x, + std::int64_t incx, + cl::sycl::buffer &result) { + iamax_precondition(queue, n, x, incx, result); + onemkl::cublas::iamax(queue, n, x, incx, result); + iamax_postcondition(queue, n, x, incx, result); +} + +template +static inline void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, + cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, + cl::sycl::buffer &m, + cl::sycl::buffer &n, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, std::int64_t group_count, + cl::sycl::buffer &group_size); +template <> +void trsm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &alpha, + cl::sycl::buffer &a, cl::sycl::buffer &lda, + cl::sycl::buffer &b, cl::sycl::buffer &ldb, std::int64_t group_count, + cl::sycl::buffer &group_size) { + trsm_batch_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + b, ldb, group_count, group_size); + onemkl::cublas::trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, + lda, b, ldb, group_count, group_size); + trsm_batch_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + b, ldb, group_count, group_size); +} + +template +static inline void trsm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &alpha, + cl::sycl::buffer &a, cl::sycl::buffer &lda, + cl::sycl::buffer &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size); +template <> +void trsm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &alpha, + cl::sycl::buffer &a, cl::sycl::buffer &lda, + cl::sycl::buffer &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size) { + trsm_batch_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + b, ldb, group_count, group_size); + onemkl::cublas::trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, + lda, b, ldb, group_count, group_size); + trsm_batch_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + b, ldb, group_count, group_size); +} + +template +static inline void trsm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size); +template <> +void trsm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size) { + trsm_batch_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + b, ldb, group_count, group_size); + onemkl::cublas::trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, + lda, b, ldb, group_count, group_size); + trsm_batch_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + b, ldb, group_count, group_size); +} + +template +static inline void trsm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size); +template <> +void trsm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size) { + trsm_batch_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + b, ldb, group_count, group_size); + onemkl::cublas::trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, + lda, b, ldb, group_count, group_size); + trsm_batch_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + b, ldb, group_count, group_size); +} + +template +static inline void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); +template <> +void trsm_batch( + cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, float alpha, cl::sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + trsm_batch_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size); + onemkl::cublas::trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + trsm_batch_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size); +} + +template +static inline void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + double alpha, cl::sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); +template <> +void trsm_batch( + cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, double alpha, cl::sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + trsm_batch_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size); + onemkl::cublas::trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + trsm_batch_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size); +} + +template +static inline void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); +template <> +void trsm_batch( + cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + trsm_batch_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size); + onemkl::cublas::trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + trsm_batch_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size); +} + +template +static inline void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); +template <> +void trsm_batch( + cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, + std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + trsm_batch_precondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size); + onemkl::cublas::trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + trsm_batch_postcondition(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size); +} + +template +static inline void rotm(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer ¶m); +template <> +void rotm(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer ¶m) { + rotm_precondition(queue, n, x, incx, y, incy, param); + onemkl::cublas::rotm(queue, n, x, incx, y, incy, param); + rotm_postcondition(queue, n, x, incx, y, incy, param); +} + +template +static inline void rotm(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer ¶m); +template <> +void rotm(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer ¶m) { + rotm_precondition(queue, n, x, incx, y, incy, param); + onemkl::cublas::rotm(queue, n, x, incx, y, incy, param); + rotm_postcondition(queue, n, x, incx, y, incy, param); +} + +template +static inline void rotg(cl::sycl::queue &queue, cl::sycl::buffer &a, + cl::sycl::buffer &b, cl::sycl::buffer &c, + cl::sycl::buffer &s); +template <> +void rotg(cl::sycl::queue &queue, + cl::sycl::buffer &a, + cl::sycl::buffer &b, + cl::sycl::buffer &c, + cl::sycl::buffer &s) { + rotg_precondition(queue, a, b, c, s); + onemkl::cublas::rotg(queue, a, b, c, s); + rotg_postcondition(queue, a, b, c, s); +} + +template +static inline void rotg(cl::sycl::queue &queue, cl::sycl::buffer &a, + cl::sycl::buffer &b, cl::sycl::buffer &c, + cl::sycl::buffer &s); +template <> +void rotg(cl::sycl::queue &queue, + cl::sycl::buffer &a, + cl::sycl::buffer &b, + cl::sycl::buffer &c, + cl::sycl::buffer &s) { + rotg_precondition(queue, a, b, c, s); + onemkl::cublas::rotg(queue, a, b, c, s); + rotg_postcondition(queue, a, b, c, s); +} + +template +static inline void rotg(cl::sycl::queue &queue, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &c, + cl::sycl::buffer, 1> &s); +template <> +void rotg(cl::sycl::queue &queue, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &b, + cl::sycl::buffer &c, + cl::sycl::buffer, 1> &s) { + rotg_precondition(queue, a, b, c, s); + onemkl::cublas::rotg(queue, a, b, c, s); + rotg_postcondition(queue, a, b, c, s); +} + +template +static inline void rotg(cl::sycl::queue &queue, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &b, + cl::sycl::buffer &c, + cl::sycl::buffer, 1> &s); +template <> +void rotg(cl::sycl::queue &queue, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &b, + cl::sycl::buffer &c, + cl::sycl::buffer, 1> &s) { + rotg_precondition(queue, a, b, c, s); + onemkl::cublas::rotg(queue, a, b, c, s); + rotg_postcondition(queue, a, b, c, s); +} + +template +static inline void sdsdot(cl::sycl::queue &queue, std::int64_t n, float sb, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result); +template <> +void sdsdot(cl::sycl::queue &queue, std::int64_t n, float sb, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result) { + sdsdot_precondition(queue, n, sb, x, incx, y, incy, result); + onemkl::cublas::sdsdot(queue, n, sb, x, incx, y, incy, result); + sdsdot_postcondition(queue, n, sb, x, incx, y, incy, result); +} + +template +static inline void her2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, float beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); +template <> +void her2k( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, float beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + her2k_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::her2k(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + her2k_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void her2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + double beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); +template <> +void her2k( + cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, double beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + her2k_precondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + onemkl::cublas::her2k(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + her2k_postcondition(queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +template +static inline void dot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result); +template <> +void dot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result) { + dot_precondition(queue, n, x, incx, y, incy, result); + onemkl::cublas::dot(queue, n, x, incx, y, incy, result); + dot_postcondition(queue, n, x, incx, y, incy, result); +} + +template +static inline void dot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result); +template <> +void dot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result) { + dot_precondition(queue, n, x, incx, y, incy, result); + onemkl::cublas::dot(queue, n, x, incx, y, incy, result); + dot_postcondition(queue, n, x, incx, y, incy, result); +} + +template +static inline void dot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result); +template <> +void dot(cl::sycl::queue &queue, std::int64_t n, + cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result) { + dot_precondition(queue, n, x, incx, y, incy, result); + onemkl::cublas::dot(queue, n, x, incx, y, incy, result); + dot_postcondition(queue, n, x, incx, y, incy, result); +} + +template +static inline void symv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, float beta, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void symv(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, + float beta, cl::sycl::buffer &y, + std::int64_t incy) { + symv_precondition(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::symv(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); + symv_postcondition(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); +} + +template +static inline void symv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, double beta, + cl::sycl::buffer &y, std::int64_t incy); +template <> +void symv(cl::sycl::queue &queue, uplo upper_lower, + std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, + double beta, cl::sycl::buffer &y, + std::int64_t incy) { + symv_precondition(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); + onemkl::cublas::symv(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); + symv_postcondition(queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); +} + +} //namespace blas +} //namespace onemkl + +#endif //_DETAIL_CUBLAS_BLAS_HPP_ diff --git a/include/onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp b/include/onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp new file mode 100644 index 000000000..2336bbbd4 --- /dev/null +++ b/include/onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp @@ -0,0 +1,847 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#ifndef _ONEMKL_BLAS_CUBLAS_HPP_ +#define _ONEMKL_BLAS_CUBLAS_HPP_ +#include +#include +#include +#include +#include "onemkl/types.hpp" + +namespace onemkl { +using onemkl::diag; +using onemkl::offset; +using onemkl::side; +using onemkl::transpose; +using onemkl::uplo; +namespace cublas { +// Level 1 + +void asum(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer &result); + +void asum(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer &result); + +void asum(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result); + +void asum(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result); + +void axpy(cl::sycl::queue &queue, std::int64_t n, float alpha, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy); + +void axpy(cl::sycl::queue &queue, std::int64_t n, double alpha, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy); + +void axpy(cl::sycl::queue &queue, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy); + +void axpy(cl::sycl::queue &queue, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy); + +void copy(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy); + +void copy(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy); + +void copy(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy); + +void copy(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy); + +void dot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &result); + +void dot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &result); + +void dot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &result); + +void dotc(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &result); + +void dotc(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &result); + +void dotu(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &result); + +void dotu(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &result); + +void iamin(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result); + +void iamin(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); + +void iamin(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer &result); + +void iamin(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer &result); + +void iamax(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result); + +void iamax(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &result); + +void iamax(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer &result); + +void iamax(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer &result); + +void nrm2(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer &result); + +void nrm2(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer &result); + +void nrm2(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result); + +void nrm2(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &result); + +void rot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy, float c, + float s); + +void rot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy, + double c, double s); + +void rot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, float c, float s); + +void rot(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, double c, double s); + +void rotg(cl::sycl::queue &queue, cl::sycl::buffer &a, cl::sycl::buffer &b, + cl::sycl::buffer &c, cl::sycl::buffer &s); + +void rotg(cl::sycl::queue &queue, cl::sycl::buffer &a, cl::sycl::buffer &b, + cl::sycl::buffer &c, cl::sycl::buffer &s); + +void rotg(cl::sycl::queue &queue, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &c, + cl::sycl::buffer, 1> &s); + +void rotg(cl::sycl::queue &queue, cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &c, + cl::sycl::buffer, 1> &s); + +void rotm(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer ¶m); + +void rotm(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer ¶m); + +void rotmg(cl::sycl::queue &queue, cl::sycl::buffer &d1, cl::sycl::buffer &d2, + cl::sycl::buffer &x1, float y1, cl::sycl::buffer ¶m); + +void rotmg(cl::sycl::queue &queue, cl::sycl::buffer &d1, cl::sycl::buffer &d2, + cl::sycl::buffer &x1, double y1, cl::sycl::buffer ¶m); + +void scal(cl::sycl::queue &queue, std::int64_t n, float alpha, cl::sycl::buffer &x, + std::int64_t incx); + +void scal(cl::sycl::queue &queue, std::int64_t n, double alpha, cl::sycl::buffer &x, + std::int64_t incx); + +void scal(cl::sycl::queue &queue, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void scal(cl::sycl::queue &queue, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void scal(cl::sycl::queue &queue, std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void scal(cl::sycl::queue &queue, std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void sdsdot(cl::sycl::queue &queue, std::int64_t n, float sb, cl::sycl::buffer &x, + std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, + cl::sycl::buffer &result); + +void swap(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy); + +void swap(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer &x, std::int64_t incx, + cl::sycl::buffer &y, std::int64_t incy); + +void swap(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy); + +void swap(cl::sycl::queue &queue, std::int64_t n, cl::sycl::buffer, 1> &x, + std::int64_t incx, cl::sycl::buffer, 1> &y, std::int64_t incy); + +// Level 2 + +void gbmv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, std::int64_t kl, + std::int64_t ku, float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, float beta, + cl::sycl::buffer &y, std::int64_t incy); + +void gbmv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, std::int64_t kl, + std::int64_t ku, double alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx, double beta, + cl::sycl::buffer &y, std::int64_t incy); + +void gbmv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, std::int64_t kl, + std::int64_t ku, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, std::int64_t incy); + +void gbmv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, std::int64_t kl, + std::int64_t ku, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); + +void gemv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx, float beta, cl::sycl::buffer &y, std::int64_t incy); + +void gemv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx, double beta, cl::sycl::buffer &y, std::int64_t incy); + +void gemv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); + +void gemv(cl::sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); + +void ger(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, + std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda); + +void ger(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, + std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda); + +void gerc(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda); + +void gerc(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda); + +void geru(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda); + +void geru(cl::sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda); + +void hbmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); + +void hbmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); + +void hemv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + cl::sycl::buffer, 1> &y, std::int64_t incy); + +void hemv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); + +void her(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &a, std::int64_t lda); + +void her(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &a, std::int64_t lda); + +void her2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda); + +void her2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a, std::int64_t lda); + +void hpmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer, 1> &x, + std::int64_t incx, std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); + +void hpmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, cl::sycl::buffer, 1> &y, + std::int64_t incy); + +void hpr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &a); + +void hpr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &a); + +void hpr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a); + +void hpr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &x, std::int64_t incx, + cl::sycl::buffer, 1> &y, std::int64_t incy, + cl::sycl::buffer, 1> &a); + +void sbmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx, float beta, cl::sycl::buffer &y, std::int64_t incy); + +void sbmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, double alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx, double beta, cl::sycl::buffer &y, std::int64_t incy); + +void spmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &a, cl::sycl::buffer &x, std::int64_t incx, + float beta, cl::sycl::buffer &y, std::int64_t incy); + +void spmv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &a, cl::sycl::buffer &x, std::int64_t incx, + double beta, cl::sycl::buffer &y, std::int64_t incy); + +void spr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &a); + +void spr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &a); + +void spr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, + std::int64_t incy, cl::sycl::buffer &a); + +void spr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, + std::int64_t incy, cl::sycl::buffer &a); + +void symv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx, float beta, cl::sycl::buffer &y, std::int64_t incy); + +void symv(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx, double beta, cl::sycl::buffer &y, std::int64_t incy); + +void syr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &a, + std::int64_t lda); + +void syr(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &a, + std::int64_t lda); + +void syr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, + std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda); + +void syr2(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, + cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, + std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda); + +void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx); + +void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx); + +void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx); + +void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &x, std::int64_t incx); + +void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + std::int64_t k, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, cl::sycl::buffer &x, std::int64_t incx); + +void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, cl::sycl::buffer &x, std::int64_t incx); + +void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, cl::sycl::buffer, 1> &x, + std::int64_t incx); + +void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, cl::sycl::buffer &x, std::int64_t incx); + +void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, cl::sycl::buffer &x, std::int64_t incx); + +void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, cl::sycl::buffer, 1> &x, + std::int64_t incx); + +void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx); + +void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx); + +void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx); + +void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &x, + std::int64_t incx); + +void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t n, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &x, std::int64_t incx); + +// Level 3 + +void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); + +void herk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + float alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, float beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void herk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + double alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, double beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void her2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + float beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void her2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + double beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + double alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); + +void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + float alpha, cl::sycl::buffer &a, std::int64_t lda, float beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + double alpha, cl::sycl::buffer &a, std::int64_t lda, double beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); + +void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, double alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, + std::int64_t k, std::complex alpha, cl::sycl::buffer, 1> &a, + std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); + +void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb); + +void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb); + +void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb); + +void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb); + +void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb); + +void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb); + +void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb); + +void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb); +// Batch API + +void gemm_batch(cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, cl::sycl::buffer &beta, + cl::sycl::buffer &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size); + +void gemm_batch(cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, cl::sycl::buffer &beta, + cl::sycl::buffer &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size); + +void gemm_batch(cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &ldb, + cl::sycl::buffer, 1> &beta, + cl::sycl::buffer, 1> &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size); + +void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer, 1> &alpha, cl::sycl::buffer, 1> &a, + cl::sycl::buffer &lda, cl::sycl::buffer, 1> &b, + cl::sycl::buffer &ldb, cl::sycl::buffer, 1> &beta, + cl::sycl::buffer, 1> &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size); + +void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, cl::sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, cl::sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, cl::sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, double beta, + cl::sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + +void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + +void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + +void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &alpha, + cl::sycl::buffer &a, cl::sycl::buffer &lda, + cl::sycl::buffer &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size); + +void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &alpha, + cl::sycl::buffer &a, cl::sycl::buffer &lda, + cl::sycl::buffer &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size); + +void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, + cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size); + +void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, + cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, + cl::sycl::buffer &ldb, std::int64_t group_count, + cl::sycl::buffer &group_size); + +void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + cl::sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size); + +void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + cl::sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size); + +void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +// BLAS-like extensions + +void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, + std::int64_t n, std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc); + +void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, offset offsetc, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, int8_t ao, + cl::sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, + cl::sycl::buffer &c, std::int64_t ldc, cl::sycl::buffer &co); + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc); + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc); + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, + cl::sycl::buffer &c, std::int64_t ldc); + +} // namespace cublas +} // namespace onemkl + +#endif //_ONEMKL_BLAS_CUBLAS_HPP_ diff --git a/include/onemkl/detail/backends.hpp b/include/onemkl/detail/backends.hpp index 19cd80470..68299ef2a 100644 --- a/include/onemkl/detail/backends.hpp +++ b/include/onemkl/detail/backends.hpp @@ -25,12 +25,13 @@ namespace onemkl { -enum class backend { intelcpu, intelgpu, unsupported }; +enum class backend { intelcpu, intelgpu, nvidiagpu, unsupported }; typedef std::map backendmap; static backendmap backend_map = { { backend::intelcpu, "intelcpu" }, { backend::intelgpu, "intelgpu" }, + { backend::nvidiagpu, "nvidiagpu" }, { backend::unsupported, "unsupported" } }; } //namespace onemkl diff --git a/include/onemkl/detail/backends_selector.hpp b/include/onemkl/detail/backends_selector.hpp index 3cde4f886..4ce76a3f7 100644 --- a/include/onemkl/detail/backends_selector.hpp +++ b/include/onemkl/detail/backends_selector.hpp @@ -29,7 +29,8 @@ #define LIB_NAME(a) "lib" a ".so" #endif -#define INTEL_ID 32902 +#define INTEL_ID 32902 +#define NVIDIA_ID 4318 namespace onemkl { inline char *select_backend(cl::sycl::queue &queue) { @@ -45,6 +46,8 @@ inline char *select_backend(cl::sycl::queue &queue) { if (vendor_id == INTEL_ID) return (char *)LIB_NAME("onemkl_blas_mklgpu"); + else if (vendor_id == NVIDIA_ID) + return (char *)LIB_NAME("onemkl_blas_cublas"); return (char *)"unsupported"; } else { diff --git a/include/onemkl/detail/libraries.hpp b/include/onemkl/detail/libraries.hpp index cbb55782f..ccac02da6 100644 --- a/include/onemkl/detail/libraries.hpp +++ b/include/onemkl/detail/libraries.hpp @@ -30,7 +30,8 @@ enum class library { intelmkl, cublas }; typedef std::map librarymap; -static librarymap library_map = { { library::intelmkl, "intelmkl" } }; +static librarymap library_map = { { library::intelmkl, "intelmkl" }, + { library::cublas, "cublas" } }; } //namespace onemkl diff --git a/src/blas/backends/CMakeLists.txt b/src/blas/backends/CMakeLists.txt index 043153e79..82dadb80d 100644 --- a/src/blas/backends/CMakeLists.txt +++ b/src/blas/backends/CMakeLists.txt @@ -24,3 +24,7 @@ endif() if(ENABLE_MKLGPU_BACKEND) add_subdirectory(mklgpu) endif() + +if(ENABLE_CUBLAS_BACKEND AND UNIX) + add_subdirectory(cublas) +endif() diff --git a/src/blas/backends/cublas/CMakeLists.txt b/src/blas/backends/cublas/CMakeLists.txt new file mode 100644 index 000000000..d76fd535c --- /dev/null +++ b/src/blas/backends/cublas/CMakeLists.txt @@ -0,0 +1,59 @@ +#========================================================================== +# Copyright (C) Codeplay Software Limited +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# For your convenience, a copy of the License has been included in this +# repository. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#========================================================================= + +set(LIB_NAME onemkl_blas_cublas) +set(LIB_OBJ ${LIB_NAME}_obj) +find_package(cuBLAS REQUIRED) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + cublas_level1.cpp + cublas_level2.cpp + cublas_level3.cpp + cublas_batch.cpp + cublas_extensions.cpp + cublas_scope_handle.cpp + $<$: mkl_blas_cublas_wrappers.cpp> +) +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src +) +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ONEMKL::cuBLAS::cuBLAS) +target_compile_features(${LIB_OBJ} PUBLIC cxx_std_11) +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON) + +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp new file mode 100644 index 000000000..e04b86b62 --- /dev/null +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -0,0 +1,184 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#include +#include "onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp" + +namespace onemkl { +namespace cublas { + +void gemm_batch(cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, cl::sycl::buffer &beta, + cl::sycl::buffer &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_batch(cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer &alpha, cl::sycl::buffer &a, + cl::sycl::buffer &lda, cl::sycl::buffer &b, + cl::sycl::buffer &ldb, cl::sycl::buffer &beta, + cl::sycl::buffer &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_batch(cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &ldb, + cl::sycl::buffer, 1> &beta, + cl::sycl::buffer, 1> &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_batch( + cl::sycl::queue &queue, cl::sycl::buffer &transa, + cl::sycl::buffer &transb, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &k, + cl::sycl::buffer, 1> &alpha, cl::sycl::buffer, 1> &a, + cl::sycl::buffer &lda, cl::sycl::buffer, 1> &b, + cl::sycl::buffer &ldb, cl::sycl::buffer, 1> &beta, + cl::sycl::buffer, 1> &c, cl::sycl::buffer &ldc, + std::int64_t group_count, cl::sycl::buffer &group_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, cl::sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, cl::sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, cl::sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, double beta, + cl::sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &alpha, + cl::sycl::buffer &a, cl::sycl::buffer &lda, + cl::sycl::buffer &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, cl::sycl::buffer &alpha, + cl::sycl::buffer &a, cl::sycl::buffer &lda, + cl::sycl::buffer &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, + cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, cl::sycl::buffer &ldb, + std::int64_t group_count, cl::sycl::buffer &group_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void trsm_batch(cl::sycl::queue &queue, cl::sycl::buffer &left_right, + cl::sycl::buffer &upper_lower, cl::sycl::buffer &trans, + cl::sycl::buffer &unit_diag, cl::sycl::buffer &m, + cl::sycl::buffer &n, + cl::sycl::buffer, 1> &alpha, + cl::sycl::buffer, 1> &a, + cl::sycl::buffer &lda, + cl::sycl::buffer, 1> &b, + cl::sycl::buffer &ldb, std::int64_t group_count, + cl::sycl::buffer &group_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + cl::sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + cl::sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + cl::sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + cl::sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + throw std::runtime_error("Not implemented for cublas"); +} + +void trsm_batch(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, cl::sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + throw std::runtime_error("Not implemented for cublas"); +} +} // namespace cublas +} // namespace onemkl diff --git a/src/blas/backends/cublas/cublas_extensions.cpp b/src/blas/backends/cublas/cublas_extensions.cpp new file mode 100644 index 000000000..6d17b34e5 --- /dev/null +++ b/src/blas/backends/cublas/cublas_extensions.cpp @@ -0,0 +1,113 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#include +#include "onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp" + +namespace onemkl { +namespace cublas { + +// BLAS-like extensions + +void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, + std::int64_t n, std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + cl::sycl::buffer, 1> &c, std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemmt(cl::sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, offset offsetc, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + cl::sycl::buffer &a, std::int64_t lda, int8_t ao, + cl::sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, + cl::sycl::buffer &c, std::int64_t ldc, cl::sycl::buffer &co) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &c, std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, double alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, double beta, + cl::sycl::buffer &c, std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, std::complex alpha, + cl::sycl::buffer, 1> &a, std::int64_t lda, + cl::sycl::buffer, 1> &b, std::int64_t ldb, + std::complex beta, cl::sycl::buffer, 1> &c, + std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +void gemm_ext(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, + cl::sycl::buffer &c, std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +} // namespace cublas +} // namespace onemkl diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp new file mode 100644 index 000000000..1304acdba --- /dev/null +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -0,0 +1,255 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +/** + * @file mkl_blas_cublas.cpp : contains the implementation of all the routines + * for CUDA backend + */ +#ifndef _MKL_BLAS_CUBLAS_HELPER_HPP_ +#define _MKL_BLAS_CUBLAS_HELPER_HPP_ +#include +#include +#include +#include +#include "onemkl/types.hpp" +namespace onemkl { +namespace cublas { + +// The static assert to make sure that all index types used in +// src/oneMKL/backend/cublas/blas.hpp inteface are int64_t +template +struct is_int64 : std::false_type {}; + +template +struct is_int64 : std::is_same {}; + +template +struct is_int64 + : std::integral_constant::value && + is_int64::value> {}; + +template +struct Overflow { + static void inline check(T...) {} +}; + +template +struct Overflow { + static void inline check(Index index, T... next) { + if (std::abs(index) >= (1LL << 31)) { + throw std::runtime_error( + "Cublas index overflow. cublas does not support 64 bit integer as " + "data size. Thus, the data size should not be greater that maximum " + "supported size by 32 bit integer."); + } + Overflow::check(next...); + } +}; + +template +void overflow_check(Index index, Next... indices) { + static_assert(is_int64::value, "oneMKL index type must be 64 bit integer."); + Overflow::check(index, indices...); +} + +class cublas_error : virtual public std::runtime_error { +protected: + inline const char *cublas_error_map(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + + default: + return ""; + } + } + + int error_number; ///< Error number +public: + /** Constructor (C++ STL string, cublasStatus_t). + * @param msg The error message + * @param err_num error number + */ + explicit cublas_error(std::string message, cublasStatus_t result) + : std::runtime_error((message + std::string(cublas_error_map(result)))) { + error_number = static_cast(result); + } + + /** Destructor. + * Virtual to allow for subclassing. + */ + virtual ~cublas_error() throw() {} + + /** Returns error number. + * @return #error_number + */ + virtual int getErrorNumber() const throw() { + return error_number; + } +}; + +class cuda_error : virtual public std::runtime_error { +protected: + inline const char *cuda_error_map(CUresult result) { + switch (result) { + case CUDA_SUCCESS: + return "CUDA_SUCCESS"; + case CUDA_ERROR_NOT_PERMITTED: + return "CUDA_ERROR_NOT_PERMITTED"; + case CUDA_ERROR_INVALID_CONTEXT: + return "CUDA_ERROR_INVALID_CONTEXT"; + case CUDA_ERROR_INVALID_DEVICE: + return "CUDA_ERROR_INVALID_DEVICE"; + case CUDA_ERROR_INVALID_VALUE: + return "CUDA_ERROR_INVALID_VALUE"; + case CUDA_ERROR_OUT_OF_MEMORY: + return "CUDA_ERROR_OUT_OF_MEMORY"; + case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES: + return "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES"; + default: + return ""; + } + } + int error_number; ///< error number +public: + /** Constructor (C++ STL string, CUresult). + * @param msg The error message + * @param err_num Error number + */ + explicit cuda_error(std::string message, CUresult result) + : std::runtime_error((message + std::string(cuda_error_map(result)))) { + error_number = static_cast(result); + } + + /** Destructor. + * Virtual to allow for subclassing. + */ + virtual ~cuda_error() throw() {} + + /** Returns error number. + * @return #error_number + */ + virtual int getErrorNumber() const throw() { + return error_number; + } +}; + +#define CUDA_ERROR_FUNC(name, err, ...) \ + err = name(__VA_ARGS__); \ + if (err != CUDA_SUCCESS) { \ + throw cuda_error(std::string(#name) + std::string(" : "), err); \ + } + +#define CUBLAS_ERROR_FUNC(name, err, ...) \ + err = name(__VA_ARGS__); \ + if (err != CUBLAS_STATUS_SUCCESS) { \ + throw cublas_error(std::string(#name) + std::string(" : "), err); \ + } + +inline cublasOperation_t get_cublas_operation(onemkl::transpose trn) { + switch (trn) { + case onemkl::transpose::nontrans: + return CUBLAS_OP_N; + case onemkl::transpose::trans: + return CUBLAS_OP_T; + case onemkl::transpose::conjtrans: + return CUBLAS_OP_C; + default: + throw "Wrong transpose Operation."; + } +} + +inline cublasFillMode_t get_cublas_fill_mode(onemkl::uplo ul) { + switch (ul) { + case onemkl::uplo::upper: + return CUBLAS_FILL_MODE_UPPER; + case onemkl::uplo::lower: + return CUBLAS_FILL_MODE_LOWER; + default: + throw "Wrong fill mode."; + } +} + +inline cublasDiagType_t get_cublas_diag_type(onemkl::diag un) { + switch (un) { + case onemkl::diag::unit: + return CUBLAS_DIAG_UNIT; + case onemkl::diag::nonunit: + return CUBLAS_DIAG_NON_UNIT; + default: + throw "Wrong diag type."; + } +} + +inline cublasSideMode_t get_cublas_side_mode(onemkl::side lr) { + switch (lr) { + case onemkl::side::left: + return CUBLAS_SIDE_LEFT; + case onemkl::side::right: + return CUBLAS_SIDE_RIGHT; + default: + throw "Wrong side mode."; + } +} + +/*converting std::complex to cuComplex*/ +template +struct CudaEquivalentType { + using Type = T; +}; + +template <> +struct CudaEquivalentType> { + using Type = cuComplex; +}; +template <> +struct CudaEquivalentType> { + using Type = cuDoubleComplex; +}; + +} // namespace cublas +} // namespace onemkl +#endif // _MKL_BLAS_CUBLAS_HELPER_HPP_ diff --git a/src/blas/backends/cublas/cublas_level1.cpp b/src/blas/backends/cublas/cublas_level1.cpp new file mode 100644 index 000000000..e8bf9f937 --- /dev/null +++ b/src/blas/backends/cublas/cublas_level1.cpp @@ -0,0 +1,606 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#include "cublas_helper.hpp" +#include "cublas_scope_handle.hpp" +#include "onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp" + +#include + +namespace onemkl { +namespace cublas { +// Level 1 +template +inline void asum(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + const int64_t incx, cl::sycl::buffer &result) { + using cuDataType1 = typename CudaEquivalentType::Type; + using cuDataType2 = typename CudaEquivalentType::Type; + overflow_check(n, incx); + + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto res_acc = result.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto res_ = sc.get_mem(ih, res_acc); + cublasStatus_t err; + // ASUM does not support negative index + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, std::abs(incx), res_); + }); + }); +} + +#define ASUM_LAUNCHER(TYPE1, TYPE2, CUBLAS_ROUTINE) \ + void asum(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, \ + const int64_t incx, cl::sycl::buffer &result) { \ + asum(CUBLAS_ROUTINE, queue, n, x, incx, result); \ + } +ASUM_LAUNCHER(float, float, cublasSasum) +ASUM_LAUNCHER(double, double, cublasDasum) +ASUM_LAUNCHER(std::complex, float, cublasScasum) +ASUM_LAUNCHER(std::complex, double, cublasDzasum) +#undef ASUM_LAUNCHER + +template +inline void scal(Func func, cl::sycl::queue &queue, int64_t n, T1 a, cl::sycl::buffer &x, + int64_t incx) { + using cuDataType1 = typename CudaEquivalentType::Type; + using cuDataType2 = typename CudaEquivalentType::Type; + overflow_check(n, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + // SCAL does not support negative incx + CUBLAS_ERROR_FUNC(func, err, handle, n, (cuDataType1 *)&a, x_, std::abs(incx)); + }); + }); +} + +#define SCAL_LAUNCHER(TYPE1, TYPE2, CUBLAS_ROUTINE) \ + void scal(cl::sycl::queue &queue, int64_t n, TYPE1 a, cl::sycl::buffer &x, \ + int64_t incx) { \ + scal(CUBLAS_ROUTINE, queue, n, a, x, incx); \ + } +SCAL_LAUNCHER(float, float, cublasSscal) +SCAL_LAUNCHER(double, double, cublasDscal) +SCAL_LAUNCHER(std::complex, std::complex, cublasCscal) +SCAL_LAUNCHER(std::complex, std::complex, cublasZscal) +SCAL_LAUNCHER(float, std::complex, cublasCsscal) +SCAL_LAUNCHER(double, std::complex, cublasZdscal) +#undef SCAL_LAUNCHER + +template +inline void axpy(Func func, cl::sycl::queue &queue, int64_t n, T alpha, cl::sycl::buffer &x, + int64_t incx, cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, n, (cuDataType *)&alpha, x_, incx, y_, incy); + }); + }); +} + +#define AXPY_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void axpy(cl::sycl::queue &queue, int64_t n, TYPE alpha, cl::sycl::buffer &x, \ + int64_t incx, cl::sycl::buffer &y, int64_t incy) { \ + axpy(CUBLAS_ROUTINE, queue, n, alpha, x, incx, y, incy); \ + } + +AXPY_LAUNCHER(float, cublasSaxpy) +AXPY_LAUNCHER(double, cublasDaxpy) +AXPY_LAUNCHER(std::complex, cublasCaxpy) +AXPY_LAUNCHER(std::complex, cublasZaxpy) +#undef AXPY_LAUNCHER + +template +inline void rotg(Func func, cl::sycl::queue &queue, cl::sycl::buffer &a, + cl::sycl::buffer &b, cl::sycl::buffer &c, + cl::sycl::buffer &s) { + using cuDataType1 = typename CudaEquivalentType::Type; + using cuDataType2 = typename CudaEquivalentType::Type; + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + auto s_acc = s.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto a_ = sc.get_mem(ih, a_acc); + auto b_ = sc.get_mem(ih, b_acc); + auto c_ = sc.get_mem(ih, c_acc); + auto s_ = sc.get_mem(ih, s_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, a_, b_, c_, s_); + }); + }); +} + +#define ROTG_LAUNCHER(TYPE1, TYPE2, CUBLAS_ROUTINE) \ + void rotg(cl::sycl::queue &queue, cl::sycl::buffer &a, \ + cl::sycl::buffer &b, cl::sycl::buffer &c, \ + cl::sycl::buffer &s) { \ + rotg(CUBLAS_ROUTINE, queue, a, b, c, s); \ + } + +ROTG_LAUNCHER(float, float, cublasSrotg) +ROTG_LAUNCHER(double, double, cublasDrotg) +ROTG_LAUNCHER(std::complex, float, cublasCrotg) +ROTG_LAUNCHER(std::complex, double, cublasZrotg) +#undef ROTG_LAUNCHER + +template +inline void rotm(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + int64_t incx, cl::sycl::buffer &y, int64_t incy, + cl::sycl::buffer ¶m) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + auto param_acc = param.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + auto param_ = sc.get_mem(ih, param_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, incx, y_, incy, param_); + }); + }); +} + +#define ROTM_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void rotm(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, int64_t incx, \ + cl::sycl::buffer &y, int64_t incy, cl::sycl::buffer ¶m) { \ + rotm(CUBLAS_ROUTINE, queue, n, x, incx, y, incy, param); \ + } + +ROTM_LAUNCHER(float, cublasSrotm) +ROTM_LAUNCHER(double, cublasDrotm) +#undef ROTM_LAUNCHER + +template +inline void copy(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + int64_t incx, cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, incx, y_, incy); + }); + }); +} + +#define COPY_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void copy(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, int64_t incx, \ + cl::sycl::buffer &y, int64_t incy) { \ + copy(CUBLAS_ROUTINE, queue, n, x, incx, y, incy); \ + } + +COPY_LAUNCHER(float, cublasScopy) +COPY_LAUNCHER(double, cublasDcopy) +COPY_LAUNCHER(std::complex, cublasCcopy) +COPY_LAUNCHER(std::complex, cublasZcopy) +#undef COPY_LAUNCHER + +template +inline void dot(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + const int64_t incx, cl::sycl::buffer &y, int64_t incy, + cl::sycl::buffer &result) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + auto res_acc = result.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + auto res_ = sc.get_mem(ih, res_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, incx, y_, incy, res_); + }); + }); +} + +#define DOT_LAUNCHER(EXT, TYPE, CUBLAS_ROUTINE) \ + void dot##EXT(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, \ + const int64_t incx, cl::sycl::buffer &y, const int64_t incy, \ + cl::sycl::buffer &result) { \ + dot(CUBLAS_ROUTINE, queue, n, x, incx, y, incy, result); \ + } +DOT_LAUNCHER(, float, cublasSdot) +DOT_LAUNCHER(, double, cublasDdot) +DOT_LAUNCHER(c, std::complex, cublasCdotc) +DOT_LAUNCHER(c, std::complex, cublasZdotc) +DOT_LAUNCHER(u, std::complex, cublasCdotu) +DOT_LAUNCHER(u, std::complex, cublasZdotu) +#undef DOT_LAUNCHER + +template +inline void rot(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + const int64_t incx, cl::sycl::buffer &y, int64_t incy, T2 c, T3 s) { + using cuDataType1 = typename CudaEquivalentType::Type; + using cuDataType2 = typename CudaEquivalentType::Type; + using cuDataType3 = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + // cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, incx, y_, incy, (cuDataType2 *)&c, + (cuDataType3 *)&s); + }); + }); +} + +#define ROT_LAUNCHER(TYPE1, TYPE2, TYPE3, CUBLAS_ROUTINE) \ + void rot(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, const int64_t incx, \ + cl::sycl::buffer &y, int64_t incy, TYPE2 c, TYPE3 s) { \ + rot(CUBLAS_ROUTINE, queue, n, x, incx, y, incy, c, s); \ + } + +ROT_LAUNCHER(float, float, float, cublasSrot) +ROT_LAUNCHER(double, double, double, cublasDrot) +ROT_LAUNCHER(std::complex, float, float, cublasCsrot) +ROT_LAUNCHER(std::complex, double, double, cublasZdrot) +#undef ROT_LAUNCHER + +void sdsdot(cl::sycl::queue &queue, int64_t n, float sb, cl::sycl::buffer &x, + int64_t incx, cl::sycl::buffer &y, int64_t incy, + cl::sycl::buffer &result) { + overflow_check(n, incx, incy); + // cuBLAS does not support sdot so we need to mimic sdot. + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.get_access(cgh); + auto y_acc = y.get_access(cgh); + auto res_acc = result.get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + auto res_ = sc.get_mem(ih, res_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); + }); + }); + // Since SB is a host pointer we need to bring the result back to the host and + // add sb to it. + result.get_access()[0] += sb; +} + +void dot(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, int64_t incx, + cl::sycl::buffer &y, int64_t incy, cl::sycl::buffer &result) { + overflow_check(n, incx, incy); + // CuBLAS does not support sdot so we need to mimic sdot + // converting float* to double * is very costly operation as sycl reinterpret + // does not support conversion from two types which is not the same size. + // So in order, to avoid loosing performance we are converting the result to be + // the float. This change may cause failure as the result precision reduces. + // Alternatively we need to write a sycl kernel to elementwise copy the + // data between two buffer. This will be very slow as the two x and y buffer + // need to be converted to double for this reason. + cl::sycl::buffer float_res_buff{ cl::sycl::range<1>(1) }; + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.get_access(cgh); + auto y_acc = y.get_access(cgh); + auto float_res_acc = float_res_buff.get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + auto float_res_ = sc.get_mem(ih, float_res_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(cublasSdot, err, handle, n, x_, incx, y_, incy, float_res_); + }); + }); + /// Since cuBLAS does not have sdot support, we had to do the operation in float and + // convert it back into double. This can result in precision issue. + result.get_access()[0] = + (double)float_res_buff.get_access()[0]; +} + +template +inline void rotmg(Func func, cl::sycl::queue &queue, cl::sycl::buffer &d1, + cl::sycl::buffer &d2, cl::sycl::buffer &x1, T y1, + cl::sycl::buffer ¶m) { + using cuDataType = typename CudaEquivalentType::Type; + cl::sycl::buffer y1_buff(&y1, cl::sycl::range<1>(1)); + queue.submit([&](cl::sycl::handler &cgh) { + auto d1_acc = d1.template get_access(cgh); + auto d2_acc = d2.template get_access(cgh); + auto x1_acc = x1.template get_access(cgh); + auto y1_acc = y1_buff.template get_access(cgh); + auto param_acc = param.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto d1_ = sc.get_mem(ih, d1_acc); + auto d2_ = sc.get_mem(ih, d2_acc); + auto x1_ = sc.get_mem(ih, x1_acc); + auto y1_ = sc.get_mem(ih, y1_acc); + auto param_ = sc.get_mem(ih, param_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, d1_, d2_, x1_, y1_, param_); + }); + }); +} + +#define ROTMG_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void rotmg(cl::sycl::queue &queue, cl::sycl::buffer &d1, \ + cl::sycl::buffer &d2, cl::sycl::buffer &x1, TYPE y1, \ + cl::sycl::buffer ¶m) { \ + rotmg(CUBLAS_ROUTINE, queue, d1, d2, x1, y1, param); \ + } + +ROTMG_LAUNCHER(float, cublasSrotmg) +ROTMG_LAUNCHER(double, cublasDrotmg) +#undef ROTMG_LAUNCHER + +template +inline void iamax(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + const int64_t incx, cl::sycl::buffer &result) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx); + // cuBLAS does not support int64_t as return type for the data. So we need to + // mimic iamax. We are converting the result to be the int and then we convert + // it back to the actual data on the host. + // This change may cause failure as the result of integer overflow + // based on the size. Alternatively either we need to write a sycl kernel + // to elementwise copy the data between two buffer, or allow reinterpret cast + // to convert to different type with different typesize size. + cl::sycl::buffer int_res_buff{ cl::sycl::range<1>(1) }; + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto int_res_acc = int_res_buff.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto int_res_ = sc.get_mem(ih, int_res_acc); + cublasStatus_t err; + // For negative incx, iamax returns 0. This behaviour is similar to that of + // reference netlib BLAS. + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, incx, int_res_); + }); + }); + // This requires to bring the data to host, copy it, and return it back to + // the device + result.template get_access()[0] = + std::max((int64_t)int_res_buff.template get_access()[0] - 1, + int64_t{ 0 }); +} + +#define IAMAX_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void iamax(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, \ + const int64_t incx, cl::sycl::buffer &result) { \ + iamax(CUBLAS_ROUTINE, queue, n, x, incx, result); \ + } +IAMAX_LAUNCHER(float, cublasIsamax) +IAMAX_LAUNCHER(double, cublasIdamax) +IAMAX_LAUNCHER(std::complex, cublasIcamax) +IAMAX_LAUNCHER(std::complex, cublasIzamax) +#undef IAMAX_LAUNCHER + +template +inline void swap(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + int64_t incx, cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, incx, y_, incy); + }); + }); +} + +#define SWAP_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void swap(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, int64_t incx, \ + cl::sycl::buffer &y, int64_t incy) { \ + swap(CUBLAS_ROUTINE, queue, n, x, incx, y, incy); \ + } + +SWAP_LAUNCHER(float, cublasSswap) +SWAP_LAUNCHER(double, cublasDswap) +SWAP_LAUNCHER(std::complex, cublasCswap) +SWAP_LAUNCHER(std::complex, cublasZswap) +#undef SWAP_LAUNCHER + +template +inline void iamin(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + const int64_t incx, cl::sycl::buffer &result) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx); + // cuBLAS does not support int64_t as return type for the data. So we need to + // mimic iamin we are converting the result to be the int and then we convert + // it back to the actual data on the host. + // This change may cause failure as the result of integer overflow + // based on the size. Alternatively, either we need to write a sycl kernel + // to elementwise copy the data between two buffer, or allow reinterpret cast + // to convert to different type with different typesize size. + cl::sycl::buffer int_res_buff{ cl::sycl::range<1>(1) }; + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto int_res_acc = int_res_buff.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto int_res_ = sc.get_mem(ih, int_res_acc); + cublasStatus_t err; + // For negative incx, iamin returns 0. This behaviour is similar to that of + // implemented as a reference IAMIN. + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, incx, int_res_); + }); + }); + result.template get_access()[0] = + std::max((int64_t)int_res_buff.template get_access()[0] - 1, + int64_t{ 0 }); +} + +#define IAMIN_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void iamin(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, \ + const int64_t incx, cl::sycl::buffer &result) { \ + iamin(CUBLAS_ROUTINE, queue, n, x, incx, result); \ + } +IAMIN_LAUNCHER(float, cublasIsamin) +IAMIN_LAUNCHER(double, cublasIdamin) +IAMIN_LAUNCHER(std::complex, cublasIcamin) +IAMIN_LAUNCHER(std::complex, cublasIzamin) +#undef IAMIN_LAUNCHER + +template +inline void nrm2(Func func, cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, + const int64_t incx, cl::sycl::buffer &result) { + using cuDataType1 = typename CudaEquivalentType::Type; + using cuDataType2 = typename CudaEquivalentType::Type; + overflow_check(n, incx); + + queue.submit([&](cl::sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto res_acc = result.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + // By default the pointer mode is the CUBLAS_POINTER_MODE_HOST + // when the data is on buffer, it must be set to + // CUBLAS_POINTER_MODE_DEVICE mode otherwise it causes the segmentation + // fault. When it is set to device it is users responsibility to + // synchronise as the function is completely asynchronous. + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + auto x_ = sc.get_mem(ih, x_acc); + auto res_ = sc.get_mem(ih, res_acc); + cublasStatus_t err; + // NRM2 does not support negative index + CUBLAS_ERROR_FUNC(func, err, handle, n, x_, std::abs(incx), res_); + }); + }); +} + +#define NRM2_LAUNCHER(TYPE1, TYPE2, CUBLAS_ROUTINE) \ + void nrm2(cl::sycl::queue &queue, int64_t n, cl::sycl::buffer &x, \ + const int64_t incx, cl::sycl::buffer &result) { \ + nrm2(CUBLAS_ROUTINE, queue, n, x, incx, result); \ + } +NRM2_LAUNCHER(float, float, cublasSnrm2) +NRM2_LAUNCHER(double, double, cublasDnrm2) +NRM2_LAUNCHER(std::complex, float, cublasScnrm2) +NRM2_LAUNCHER(std::complex, double, cublasDznrm2) +#undef NRM2_LAUNCHER + +} // namespace cublas +} // namespace onemkl diff --git a/src/blas/backends/cublas/cublas_level2.cpp b/src/blas/backends/cublas/cublas_level2.cpp new file mode 100644 index 000000000..25fda3ba5 --- /dev/null +++ b/src/blas/backends/cublas/cublas_level2.cpp @@ -0,0 +1,844 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#include +#include "cublas_helper.hpp" +#include "cublas_scope_handle.hpp" +#include "onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp" + +namespace onemkl { +namespace cublas { +template +inline void gemv(Func func, cl::sycl::queue &queue, transpose trans, int64_t m, int64_t n, T alpha, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, int64_t incx, + T beta, cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, m, lda, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_operation(trans), m, n, + (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, + incy); + }); + }); +} + +#define GEMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void gemv(cl::sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, \ + int64_t incx, TYPE beta, cl::sycl::buffer &y, int64_t incy) { \ + gemv(CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); \ + } + +GEMV_LAUNCHER(float, cublasSgemv) +GEMV_LAUNCHER(double, cublasDgemv) +GEMV_LAUNCHER(std::complex, cublasCgemv) +GEMV_LAUNCHER(std::complex, cublasZgemv) +#undef GEMV_LAUNCHER + +template +inline void gbmv(Func func, cl::sycl::queue &queue, transpose trans, int64_t m, int64_t n, + int64_t kl, int64_t ku, T alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &x, int64_t incx, T beta, cl::sycl::buffer &y, + int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, m, lda, kl, ku, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_operation(trans), m, n, kl, ku, + (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, + incy); + }); + }); +} + +#define GBMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void gbmv(cl::sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, \ + int64_t ku, TYPE alpha, cl::sycl::buffer &a, int64_t lda, \ + cl::sycl::buffer &x, int64_t incx, TYPE beta, cl::sycl::buffer &y, \ + int64_t incy) { \ + gbmv(CUBLAS_ROUTINE, queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); \ + } + +GBMV_LAUNCHER(float, cublasSgbmv) +GBMV_LAUNCHER(double, cublasDgbmv) +GBMV_LAUNCHER(std::complex, cublasCgbmv) +GBMV_LAUNCHER(std::complex, cublasZgbmv) +#undef GBMV_LAUNCHER + +template +inline void ger(Func func, cl::sycl::queue &queue, int64_t m, int64_t n, T alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, int64_t incy, + cl::sycl::buffer &a, int64_t lda) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, m, lda, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, m, n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, + lda); + }); + }); +} + +#define GER_LAUNCHER(EXT, TYPE, CUBLAS_ROUTINE) \ + void ger##EXT(cl::sycl::queue &queue, int64_t m, int64_t n, TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, \ + int64_t incy, cl::sycl::buffer &a, int64_t lda) { \ + ger(CUBLAS_ROUTINE, queue, m, n, alpha, x, incx, y, incy, a, lda); \ + } + +GER_LAUNCHER(, float, cublasSger) +GER_LAUNCHER(, double, cublasDger) +GER_LAUNCHER(u, std::complex, cublasCgeru) +GER_LAUNCHER(u, std::complex, cublasZgeru) +GER_LAUNCHER(c, std::complex, cublasCgerc) +GER_LAUNCHER(c, std::complex, cublasZgerc) +#undef GER_LAUNCHER + +template +inline void hbmv(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, T alpha, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, int64_t incx, + T beta, cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, k, lda, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, k, + (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, + incy); + }); + }); +} + +#define HBMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void hbmv(cl::sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, TYPE alpha, \ + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, \ + int64_t incx, TYPE beta, cl::sycl::buffer &y, int64_t incy) { \ + hbmv(CUBLAS_ROUTINE, queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); \ + } + +HBMV_LAUNCHER(std::complex, cublasChbmv) +HBMV_LAUNCHER(std::complex, cublasZhbmv) +#undef HBMV_LAUNCHER + +template +inline void hemv(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, int64_t incx, + T beta, cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, lda, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, + incy); + }); + }); +} + +#define HEMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void hemv(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, \ + int64_t incx, TYPE beta, cl::sycl::buffer &y, int64_t incy) { \ + hemv(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); \ + } + +HEMV_LAUNCHER(std::complex, cublasChemv) +HEMV_LAUNCHER(std::complex, cublasZhemv) +#undef HEMV_LAUNCHER + +template +inline void her(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, ScalarType alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &a, + int64_t lda) { + using cuScalarType = typename CudaEquivalentType::Type; + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, lda, incx); + + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuScalarType *)&alpha, x_, incx, a_, lda); + }); + }); +} + +#define HER_LAUNCHER(SCALAR_TYPE, DATA_TYPE, CUBLAS_ROUTINE) \ + void her(cl::sycl::queue &queue, uplo upper_lower, int64_t n, SCALAR_TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &a, \ + int64_t lda) { \ + her(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, x, incx, a, lda); \ + } + +HER_LAUNCHER(float, std::complex, cublasCher) +HER_LAUNCHER(double, std::complex, cublasZher) + +#undef HER_LAUNCHER + +template +inline void her2(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, int64_t incy, + cl::sycl::buffer &a, int64_t lda) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, lda, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); + }); + }); +} + +#define HER2_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void her2(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, \ + int64_t incy, cl::sycl::buffer &a, int64_t lda) { \ + her2(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); \ + } + +HER2_LAUNCHER(std::complex, cublasCher2) +HER2_LAUNCHER(std::complex, cublasZher2) + +#undef HER2_LAUNCHER + +template +inline void hpmv(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &a, cl::sycl::buffer &x, int64_t incx, T beta, + cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); + }); + }); +} + +#define HPMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void hpmv(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &a, cl::sycl::buffer &x, int64_t incx, TYPE beta, \ + cl::sycl::buffer &y, int64_t incy) { \ + hpmv(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); \ + } + +HPMV_LAUNCHER(std::complex, cublasChpmv) +HPMV_LAUNCHER(std::complex, cublasZhpmv) + +#undef HPMV_LAUNCHER + +template +inline void hpr(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, ScalarType alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &a) { + using cuScalarType = typename CudaEquivalentType::Type; + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuScalarType *)&alpha, x_, incx, a_); + }); + }); +} + +#define HPR_LAUNCHER(SCALAR_TYPE, DATA_TYPE, CUBLAS_ROUTINE) \ + void hpr(cl::sycl::queue &queue, uplo upper_lower, int64_t n, SCALAR_TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &a) { \ + hpr(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, x, incx, a); \ + } + +HPR_LAUNCHER(float, std::complex, cublasChpr) +HPR_LAUNCHER(double, std::complex, cublasZhpr) + +#undef HPR_LAUNCHER + +template +inline void hpr2(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, int64_t incy, + cl::sycl::buffer &a) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, x_, incx, y_, incy, a_); + }); + }); +} + +#define HPR2_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void hpr2(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, \ + int64_t incy, cl::sycl::buffer &a) { \ + hpr2(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, x, incx, y, incy, a); \ + } + +HPR2_LAUNCHER(std::complex, cublasChpr2) +HPR2_LAUNCHER(std::complex, cublasZhpr2) + +#undef HPR2_LAUNCHER + +template +inline void sbmv(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, T alpha, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, int64_t incx, + T beta, cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, k, lda, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, k, + (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, + incy); + }); + }); +} + +#define SBMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void sbmv(cl::sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, TYPE alpha, \ + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, \ + int64_t incx, TYPE beta, cl::sycl::buffer &y, int64_t incy) { \ + sbmv(CUBLAS_ROUTINE, queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, incy); \ + } + +SBMV_LAUNCHER(float, cublasSsbmv) +SBMV_LAUNCHER(double, cublasDsbmv) + +#undef SBMV_LAUNCHER + +template +inline void symv(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, int64_t incx, + T beta, cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, lda, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, + incy); + }); + }); +} + +#define SYMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void symv(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, \ + int64_t incx, TYPE beta, cl::sycl::buffer &y, int64_t incy) { \ + symv(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); \ + } + +SYMV_LAUNCHER(float, cublasSsymv) +SYMV_LAUNCHER(double, cublasDsymv) + +#undef SYMV_LAUNCHER + +template +inline void syr(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &a, int64_t lda) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, lda, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, x_, incx, a_, lda); + }); + }); +} + +#define SYR_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void syr(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &a, \ + int64_t lda) { \ + syr(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, x, incx, a, lda); \ + } + +SYR_LAUNCHER(float, cublasSsyr) +SYR_LAUNCHER(double, cublasDsyr) +// Intel does not support the following two +SYR_LAUNCHER(std::complex, cublasCsyr) +SYR_LAUNCHER(std::complex, cublasZsyr) +#undef SYR_LAUNCHER + +template +inline void syr2(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, int64_t incy, + cl::sycl::buffer &a, int64_t lda) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, lda, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); + }); + }); +} + +#define SYR2_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void syr2(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, \ + int64_t incy, cl::sycl::buffer &a, int64_t lda) { \ + syr2(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); \ + } + +SYR2_LAUNCHER(float, cublasSsyr2) +SYR2_LAUNCHER(double, cublasDsyr2) +// Intel does not support the following two +SYR2_LAUNCHER(std::complex, cublasCsyr2) +SYR2_LAUNCHER(std::complex, cublasZsyr2) + +#undef SYR2_LAUNCHER + +template +inline void spmv(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &a, cl::sycl::buffer &x, int64_t incx, T beta, + cl::sycl::buffer &y, int64_t incy) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); + }); + }); +} + +#define SPMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void spmv(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &a, cl::sycl::buffer &x, int64_t incx, TYPE beta, \ + cl::sycl::buffer &y, int64_t incy) { \ + spmv(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); \ + } + +SPMV_LAUNCHER(float, cublasSspmv) +SPMV_LAUNCHER(double, cublasDspmv) + +#undef SPMV_LAUNCHER + +template +inline void spr(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &a) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, x_, incx, a_); + }); + }); +} + +#define SPR_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void spr(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &a) { \ + spr(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, x, incx, a); \ + } + +SPR_LAUNCHER(float, cublasSspr) +SPR_LAUNCHER(double, cublasDspr) + +#undef SPR_LAUNCHER + +template +inline void spr2(Func func, cl::sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, int64_t incy, + cl::sycl::buffer &a) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx, incy); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + auto y_ = sc.get_mem(ih, y_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), n, + (cuDataType *)&alpha, x_, incx, y_, incy, a_); + }); + }); +} + +#define SPR2_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void spr2(cl::sycl::queue &queue, uplo upper_lower, int64_t n, TYPE alpha, \ + cl::sycl::buffer &x, int64_t incx, cl::sycl::buffer &y, \ + int64_t incy, cl::sycl::buffer &a) { \ + spr2(CUBLAS_ROUTINE, queue, upper_lower, n, alpha, x, incx, y, incy, a); \ + } + +SPR2_LAUNCHER(float, cublasSspr2) +SPR2_LAUNCHER(double, cublasDspr2) + +#undef SPR2_LAUNCHER + +template +inline void tbmv(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, int64_t k, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &x, int64_t incx) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, k, lda, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, + a_, lda, x_, incx); + }); + }); +} + +#define TBMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void tbmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, \ + int64_t n, int64_t k, cl::sycl::buffer &a, int64_t lda, \ + cl::sycl::buffer &x, int64_t incx) { \ + tbmv(CUBLAS_ROUTINE, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); \ + } + +TBMV_LAUNCHER(float, cublasStbmv) +TBMV_LAUNCHER(double, cublasDtbmv) +TBMV_LAUNCHER(std::complex, cublasCtbmv) +TBMV_LAUNCHER(std::complex, cublasZtbmv) + +#undef TBMV_LAUNCHER + +template +inline void tbsv(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, int64_t k, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &x, int64_t incx) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, k, lda, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, + a_, lda, x_, incx); + }); + }); +} + +#define TBSV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void tbsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, \ + int64_t n, int64_t k, cl::sycl::buffer &a, int64_t lda, \ + cl::sycl::buffer &x, int64_t incx) { \ + tbsv(CUBLAS_ROUTINE, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); \ + } + +TBSV_LAUNCHER(float, cublasStbsv) +TBSV_LAUNCHER(double, cublasDtbsv) +TBSV_LAUNCHER(std::complex, cublasCtbsv) +TBSV_LAUNCHER(std::complex, cublasZtbsv) + +#undef TBSV_LAUNCHER + +template +inline void tpmv(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, cl::sycl::buffer &a, cl::sycl::buffer &x, + int64_t incx) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, + x_, incx); + }); + }); +} + +#define TPMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void tpmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, \ + int64_t n, cl::sycl::buffer &a, cl::sycl::buffer &x, \ + int64_t incx) { \ + tpmv(CUBLAS_ROUTINE, queue, upper_lower, trans, unit_diag, n, a, x, incx); \ + } + +TPMV_LAUNCHER(float, cublasStpmv) +TPMV_LAUNCHER(double, cublasDtpmv) +TPMV_LAUNCHER(std::complex, cublasCtpmv) +TPMV_LAUNCHER(std::complex, cublasZtpmv) + +#undef TPMV_LAUNCHER + +template +inline void tpsv(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, cl::sycl::buffer &a, cl::sycl::buffer &x, + int64_t incx) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, + x_, incx); + }); + }); +} + +#define TPSV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void tpsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, \ + int64_t n, cl::sycl::buffer &a, cl::sycl::buffer &x, \ + int64_t incx) { \ + tpsv(CUBLAS_ROUTINE, queue, upper_lower, trans, unit_diag, n, a, x, incx); \ + } + +TPSV_LAUNCHER(float, cublasStpsv) +TPSV_LAUNCHER(double, cublasDtpsv) +TPSV_LAUNCHER(std::complex, cublasCtpsv) +TPSV_LAUNCHER(std::complex, cublasZtpsv) + +#undef TPSV_LAUNCHER + +template +inline void trmv(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &x, int64_t incx) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, lda, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, + lda, x_, incx); + }); + }); +} + +#define TRMV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void trmv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, \ + int64_t n, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, \ + int64_t incx) { \ + trmv(CUBLAS_ROUTINE, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); \ + } + +TRMV_LAUNCHER(float, cublasStrmv) +TRMV_LAUNCHER(double, cublasDtrmv) +TRMV_LAUNCHER(std::complex, cublasCtrmv) +TRMV_LAUNCHER(std::complex, cublasZtrmv) + +#undef TRMV_LAUNCHER + +template +inline void trsv(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &x, int64_t incx) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, lda, incx); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto x_ = sc.get_mem(ih, x_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, + lda, x_, incx); + }); + }); +} + +#define TRSV_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void trsv(cl::sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, \ + int64_t n, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &x, \ + int64_t incx) { \ + trsv(CUBLAS_ROUTINE, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); \ + } + +TRSV_LAUNCHER(float, cublasStrsv) +TRSV_LAUNCHER(double, cublasDtrsv) +TRSV_LAUNCHER(std::complex, cublasCtrsv) +TRSV_LAUNCHER(std::complex, cublasZtrsv) + +#undef TRSV_LAUNCHER + +} // namespace cublas +} // namespace onemkl diff --git a/src/blas/backends/cublas/cublas_level3.cpp b/src/blas/backends/cublas/cublas_level3.cpp new file mode 100644 index 000000000..81f7b7a91 --- /dev/null +++ b/src/blas/backends/cublas/cublas_level3.cpp @@ -0,0 +1,379 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#include +#include "cublas_helper.hpp" +#include "cublas_scope_handle.hpp" +#include "onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp" + +namespace onemkl { +namespace cublas { +template +inline void gemm(Func func, cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, T alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, T beta, cl::sycl::buffer &c, + int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, k, lda, ldb, ldc); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto b_ = sc.get_mem(ih, b_acc); + auto c_ = sc.get_mem(ih, c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, a_, lda, + b_, ldb, (cuDataType *)&beta, c_, ldc); + }); + }); +} + +#define GEMM_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE alpha, cl::sycl::buffer &a, int64_t lda, \ + cl::sycl::buffer &b, int64_t ldb, TYPE beta, cl::sycl::buffer &c, \ + int64_t ldc) { \ + gemm(CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ + } + +GEMM_LAUNCHER(float, cublasSgemm) +GEMM_LAUNCHER(double, cublasDgemm) +GEMM_LAUNCHER(std::complex, cublasCgemm) +GEMM_LAUNCHER(std::complex, cublasZgemm) + +#undef GEMM_LAUNCHER + +void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, + cl::sycl::buffer &c, std::int64_t ldc) { + throw std::runtime_error("Not implemented for cublas"); +} + +template +inline void symm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, + int64_t n, T alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, T beta, cl::sycl::buffer &c, + int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto b_ = sc.get_mem(ih, b_acc); + auto c_ = sc.get_mem(ih, c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_side_mode(left_right), + get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, + lda, b_, ldb, (cuDataType *)&beta, c_, ldc); + }); + }); +} + +#define SYMM_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void symm(cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, \ + TYPE alpha, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, \ + int64_t ldb, TYPE beta, cl::sycl::buffer &c, int64_t ldc) { \ + symm(CUBLAS_ROUTINE, queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, \ + ldc); \ + } + +SYMM_LAUNCHER(float, cublasSsymm) +SYMM_LAUNCHER(double, cublasDsymm) +SYMM_LAUNCHER(std::complex, cublasCsymm) +SYMM_LAUNCHER(std::complex, cublasZsymm) + +#undef SYMM_LAUNCHER + +template +inline void hemm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, + int64_t n, T alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, T beta, cl::sycl::buffer &c, + int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto b_ = sc.get_mem(ih, b_acc); + auto c_ = sc.get_mem(ih, c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_side_mode(left_right), + get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, + lda, b_, ldb, (cuDataType *)&beta, c_, ldc); + }); + }); +} + +#define HEMM_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, \ + TYPE alpha, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, \ + int64_t ldb, TYPE beta, cl::sycl::buffer &c, int64_t ldc) { \ + hemm(CUBLAS_ROUTINE, queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, \ + ldc); \ + } +HEMM_LAUNCHER(std::complex, cublasChemm) +HEMM_LAUNCHER(std::complex, cublasZhemm) + +#undef HEMM_LAUNCHER + +template +inline void syrk(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, + int64_t k, T alpha, cl::sycl::buffer &a, int64_t lda, T beta, + cl::sycl::buffer &c, int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, k, lda, ldc); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto c_ = sc.get_mem(ih, c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, + (cuDataType *)&beta, c_, ldc); + }); + }); +} + +#define SYRK_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void syrk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, \ + TYPE alpha, cl::sycl::buffer &a, int64_t lda, TYPE beta, \ + cl::sycl::buffer &c, int64_t ldc) { \ + syrk(CUBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); \ + } + +SYRK_LAUNCHER(float, cublasSsyrk) +SYRK_LAUNCHER(double, cublasDsyrk) +SYRK_LAUNCHER(std::complex, cublasCsyrk) +SYRK_LAUNCHER(std::complex, cublasZsyrk) + +#undef SYRK_LAUNCHER + +template +inline void herk(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, + int64_t k, ScalarType alpha, cl::sycl::buffer &a, int64_t lda, + ScalarType beta, cl::sycl::buffer &c, int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + using cuScalarType = typename CudaEquivalentType::Type; + overflow_check(n, k, lda, ldc); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto c_ = sc.get_mem(ih, c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), n, k, (cuScalarType *)&alpha, a_, lda, + (cuScalarType *)&beta, c_, ldc); + }); + }); +} + +#define HERK_LAUNCHER(DATA_TYPE, SCALAR_TYPE, CUBLAS_ROUTINE) \ + void herk(cl::sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, \ + SCALAR_TYPE alpha, cl::sycl::buffer &a, int64_t lda, SCALAR_TYPE beta, \ + cl::sycl::buffer &c, int64_t ldc) { \ + herk(CUBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc); \ + } + +HERK_LAUNCHER(std::complex, float, cublasCherk) +HERK_LAUNCHER(std::complex, double, cublasZherk) + +#undef HERK_LAUNCHER + +template +inline void syr2k(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, + int64_t k, T alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, T beta, cl::sycl::buffer &c, + int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(n, k, lda, ldb, ldc); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto b_ = sc.get_mem(ih, b_acc); + auto c_ = sc.get_mem(ih, c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, + ldb, (cuDataType *)&beta, c_, ldc); + }); + }); +} + +#define SYR2K_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void syr2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, \ + TYPE alpha, cl::sycl::buffer &a, int64_t lda, \ + cl::sycl::buffer &b, int64_t ldb, TYPE beta, cl::sycl::buffer &c, \ + int64_t ldc) { \ + syr2k(CUBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc); \ + } +SYR2K_LAUNCHER(float, cublasSsyr2k) +SYR2K_LAUNCHER(double, cublasDsyr2k) +SYR2K_LAUNCHER(std::complex, cublasCsyr2k) +SYR2K_LAUNCHER(std::complex, cublasZsyr2k) + +#undef SYR2K_LAUNCHER + +template +inline void her2k(Func func, cl::sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, + int64_t k, DataType alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, ScalarType beta, + cl::sycl::buffer &c, int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + using cuScalarType = typename CudaEquivalentType::Type; + overflow_check(n, k, lda, ldb, ldc); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto b_ = sc.get_mem(ih, b_acc); + auto c_ = sc.get_mem(ih, c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(upper_lower), + get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, + ldb, (cuScalarType *)&beta, c_, ldc); + }); + }); +} + +#define HER2K_LAUNCHER(DATA_TYPE, SCALAR_TYPE, CUBLAS_ROUTINE) \ + void her2k(cl::sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, \ + DATA_TYPE alpha, cl::sycl::buffer &a, int64_t lda, \ + cl::sycl::buffer &b, int64_t ldb, SCALAR_TYPE beta, \ + cl::sycl::buffer &c, int64_t ldc) { \ + her2k(CUBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc); \ + } + +HER2K_LAUNCHER(std::complex, float, cublasCher2k) +HER2K_LAUNCHER(std::complex, double, cublasZher2k) + +#undef HER2K_LAUNCHER + +// NOTE: In cublas TRMM diverted from the netlib blas and for performance +// reason it requires the C matrix to be +// separated from the B matrix. It is possible to use B instead of C, but this +// will slow-down the code. +template +inline void trmm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto b_ = sc.get_mem(ih, b_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_side_mode(left_right), + get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), + get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, + b_, ldb, b_, ldb); + }); + }); +} + +#define TRMM_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void trmm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \ + diag unit_diag, int64_t m, int64_t n, TYPE alpha, cl::sycl::buffer &a, \ + int64_t lda, cl::sycl::buffer &b, int64_t ldb) { \ + trmm(CUBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, \ + lda, b, ldb); \ + } +TRMM_LAUNCHER(float, cublasStrmm) +TRMM_LAUNCHER(double, cublasDtrmm) +TRMM_LAUNCHER(std::complex, cublasCtrmm) +TRMM_LAUNCHER(std::complex, cublasZtrmm) + +#undef TRMM_LAUNCHER + +template +inline void trsm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb); + queue.submit([&](cl::sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + cgh.interop_task([=](cl::sycl::interop_handler ih) { + auto sc = CublasScopedContextHandler(queue); + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(ih, a_acc); + auto b_ = sc.get_mem(ih, b_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_side_mode(left_right), + get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), + get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, + b_, ldb); + }); + }); +} + +#define TRSM_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void trsm(cl::sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \ + diag unit_diag, int64_t m, int64_t n, TYPE alpha, cl::sycl::buffer &a, \ + int64_t lda, cl::sycl::buffer &b, int64_t ldb) { \ + trsm(CUBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, \ + lda, b, ldb); \ + } +TRSM_LAUNCHER(float, cublasStrsm) +TRSM_LAUNCHER(double, cublasDtrsm) +TRSM_LAUNCHER(std::complex, cublasCtrsm) +TRSM_LAUNCHER(std::complex, cublasZtrsm) + +#undef TRSM_LAUNCHER +} // namespace cublas +} // namespace onemkl diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp new file mode 100644 index 000000000..ddb6b21bc --- /dev/null +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -0,0 +1,145 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#include "cublas_scope_handle.hpp" +#include + +namespace onemkl { +namespace cublas { + +cublas_handle::~cublas_handle() noexcept(false) { + for (auto &handle_pair : cublas_handle_mapper_) { + cublasStatus_t err; + if (handle_pair.second != nullptr) { + auto handle = handle_pair.second->exchange(nullptr); + if (handle != nullptr) { + CUBLAS_ERROR_FUNC(cublasDestroy, err, handle); + handle = nullptr; + } + delete handle_pair.second; + handle_pair.second = nullptr; + } + } + cublas_handle_mapper_.clear(); +} +/** + * Inserts a new element in the map if its key is unique. This new element + * is constructed in place using args as the arguments for the construction + * of a value_type (which is an object of a pair type). The insertion only + * takes place if no other element in the container has a key equivalent to + * the one being emplaced (keys in a map container are unique). + */ +thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; + +CublasScopedContextHandler::CublasScopedContextHandler(cl::sycl::queue queue) { + placedContext_ = queue.get_context(); + auto device = queue.get_device(); + auto desired = cl::sycl::get_native(placedContext_); + auto cudaDevice = cl::sycl::get_native(device); + CUresult err; + CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); + CUcontext primary; + cuDevicePrimaryCtxRetain(&primary, cudaDevice); + bool isPrimary = primary == desired; + cuDevicePrimaryCtxRelease(cudaDevice); + if (original_ != desired) { + // Sets the desired context as the active one for the thread + CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); + // No context is installed and the suggested context is primary + // This is the most common case. We can activate the context in the + // thread and leave it there until all the PI context referring to the + // same underlying CUDA primary context are destroyed. This emulates + // the behaviour of the CUDA runtime api, and avoids costly context + // switches. No action is required on this side of the if. + needToRecover_ = !(original_ == nullptr && isPrimary); + } +} + +CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) { + if (needToRecover_) { + CUresult err; + CUDA_ERROR_FUNC(cuCtxSetCurrent, err, original_); + } +} + +void ContextCallback(void *userData) { + auto *ptr = static_cast **>(userData); + if (!ptr) { + return; + } + if (*ptr != nullptr) { + auto handle = (*ptr)->exchange(nullptr); + if (handle != nullptr) { + cublasStatus_t err1; + CUBLAS_ERROR_FUNC(cublasDestroy, err1, handle); + handle = nullptr; + } + delete *ptr; + *ptr = nullptr; + } +} + +cublasHandle_t CublasScopedContextHandler::get_handle(const cl::sycl::queue &queue) { + auto piPlacedContext_ = reinterpret_cast(placedContext_.get()); + CUstream streamId = get_stream(queue); + cublasStatus_t err; + auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_); + if (it != handle_helper.cublas_handle_mapper_.end()) { + if (it->second == nullptr) { + handle_helper.cublas_handle_mapper_.erase(it); + } + else { + auto handle = it->second->load(); + if (handle != nullptr) { + cudaStream_t currentStreamId; + CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); + if (currentStreamId != streamId) { + CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); + } + return handle; + } + else { + handle_helper.cublas_handle_mapper_.erase(it); + } + } + } + + cublasHandle_t handle; + + CUBLAS_ERROR_FUNC(cublasCreate, err, &handle); + CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); + + auto insert_iter = handle_helper.cublas_handle_mapper_.insert( + std::make_pair(piPlacedContext_, new std::atomic(handle))); + + auto ptr = &(insert_iter.first->second); + + sycl::detail::pi::contextSetExtendedDeleter(placedContext_, ContextCallback, ptr); + + return handle; +} + +CUstream CublasScopedContextHandler::get_stream(const cl::sycl::queue &queue) { + return cl::sycl::get_native(queue); +} +cl::sycl::context CublasScopedContextHandler::get_context(const cl::sycl::queue &queue) { + return queue.get_context(); +} + +} // namespace cublas +} // namespace onemkl diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp new file mode 100644 index 000000000..d5373b481 --- /dev/null +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -0,0 +1,96 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#ifndef _MKL_BLAS_CUBLAS_SCOPED_HANDLE_HPP_ +#define _MKL_BLAS_CUBLAS_SCOPED_HANDLE_HPP_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "cublas_helper.hpp" +namespace onemkl { +namespace cublas { + +struct cublas_handle { + using handle_container_t = std::unordered_map *>; + handle_container_t cublas_handle_mapper_{}; + ~cublas_handle() noexcept(false); +}; + +/** +* @brief NVIDIA advise for handle creation: +https://devtalk.nvidia.com/default/topic/838794/gpu-accelerated libraries/using-cublas-in-different-cuda-streams/ +According to NVIDIA: +1) It is required that different handles to be used for different devices: + http://docs.nvidia.com/cuda/cublas/index.html#cublas-context +2) It is recommended (but not required, if care is taken) that different handles be used for different host threads: +http://docs.nvidia.com/cuda/cublas/index.html#thread-safety2changeme +3) It is neither required nor recommended that different handles be used for different streams on the same device, + using the same host thread. + +However, the 3 above advises are for using cuda runtime API. The NVIDIA runtime API creates a default context for users. +The createHandle function in cuBLAS uses the context located on top of the stack for each thread. Then, the cuBLAS routine +uses this context for resource allocation/access. Calling a cuBLAS function with a handle created for context A and +memories/queue created for context B results in a segmentation fault. Thus we need to create one handle per context +and per thread. A context can have multiple streams, so the important thing here is to have one cublasHandle per driver +context and that cuBLAS handle can switch between multiple streams created for that context. Here, we are dealing with +CUDA driver API, therefore, the SYCL-CUDA backend controls the context. If a queue(equivalent of CUDA stream) is associated +with a context different from the one on top of the thread stack(can be any context which associated at any time by either +the runtime or user for any specific reason), the context associated with the queue must be moved on top of the stack +temporarily for the requested routine operations. However, after the cuBLAS routine execution, the original context must +be restored to prevent intervening with the original user/runtime execution set up. Here, the RAII type context switch +is used to guarantee to recover the original CUDA context. The cuBLAS handle allocates internal resources, therefore, +the handle must be destroyed when the context goes out of scope. This will bind the life of cuBLAS handle to the SYCL context. +**/ + +class CublasScopedContextHandler { + CUcontext original_; + cl::sycl::context placedContext_; + bool needToRecover_; + static thread_local cublas_handle handle_helper; + CUstream get_stream(const cl::sycl::queue &queue); + cl::sycl::context get_context(const cl::sycl::queue &queue); + +public: + CublasScopedContextHandler(cl::sycl::queue queue); + + ~CublasScopedContextHandler() noexcept(false); + /** + * @brief get_handle: creates the handle by implicitely impose the advice + * given by nvidia for creating a cublas_handle. (e.g. one cuStream per device + * per thread). + * @param queue sycl queue. + * @return cublasHandle_t a handle to construct cublas routines + */ + cublasHandle_t get_handle(const cl::sycl::queue &queue); + // This is a work-around function for reinterpret_casting the memory. This + // will be fixed when SYCL-2020 has been implemented for Pi backend. + template + inline T get_mem(cl::sycl::interop_handler ih, U acc) { + CUdeviceptr cudaPtr = ih.get_mem(acc); + return reinterpret_cast(cudaPtr); + } +}; + +} // namespace cublas +} // namespace onemkl +#endif //_MKL_BLAS_CUBLAS_SCOPED_HANDLE_HPP_ diff --git a/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp b/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp new file mode 100644 index 000000000..2a6ee6dab --- /dev/null +++ b/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp @@ -0,0 +1,204 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#include "blas/function_table.hpp" +#include "onemkl/blas/detail/cublas/onemkl_blas_cublas.hpp" + +#define WRAPPER_VERSION 1 + +extern "C" function_table_t mkl_blas_table = { + WRAPPER_VERSION, + onemkl::cublas::asum, + onemkl::cublas::asum, + onemkl::cublas::asum, + onemkl::cublas::asum, + onemkl::cublas::axpy, + onemkl::cublas::axpy, + onemkl::cublas::axpy, + onemkl::cublas::axpy, + onemkl::cublas::copy, + onemkl::cublas::copy, + onemkl::cublas::copy, + onemkl::cublas::copy, + onemkl::cublas::dot, + onemkl::cublas::dot, + onemkl::cublas::dot, + onemkl::cublas::dotc, + onemkl::cublas::dotc, + onemkl::cublas::dotu, + onemkl::cublas::dotu, + onemkl::cublas::iamin, + onemkl::cublas::iamin, + onemkl::cublas::iamin, + onemkl::cublas::iamin, + onemkl::cublas::iamax, + onemkl::cublas::iamax, + onemkl::cublas::iamax, + onemkl::cublas::iamax, + onemkl::cublas::nrm2, + onemkl::cublas::nrm2, + onemkl::cublas::nrm2, + onemkl::cublas::nrm2, + onemkl::cublas::rot, + onemkl::cublas::rot, + onemkl::cublas::rot, + onemkl::cublas::rot, + onemkl::cublas::rotg, + onemkl::cublas::rotg, + onemkl::cublas::rotg, + onemkl::cublas::rotg, + onemkl::cublas::rotm, + onemkl::cublas::rotm, + onemkl::cublas::rotmg, + onemkl::cublas::rotmg, + onemkl::cublas::scal, + onemkl::cublas::scal, + onemkl::cublas::scal, + onemkl::cublas::scal, + onemkl::cublas::scal, + onemkl::cublas::scal, + onemkl::cublas::sdsdot, + onemkl::cublas::swap, + onemkl::cublas::swap, + onemkl::cublas::swap, + onemkl::cublas::swap, + onemkl::cublas::gbmv, + onemkl::cublas::gbmv, + onemkl::cublas::gbmv, + onemkl::cublas::gbmv, + onemkl::cublas::gemv, + onemkl::cublas::gemv, + onemkl::cublas::gemv, + onemkl::cublas::gemv, + onemkl::cublas::ger, + onemkl::cublas::ger, + onemkl::cublas::gerc, + onemkl::cublas::gerc, + onemkl::cublas::geru, + onemkl::cublas::geru, + onemkl::cublas::hbmv, + onemkl::cublas::hbmv, + onemkl::cublas::hemv, + onemkl::cublas::hemv, + onemkl::cublas::her, + onemkl::cublas::her, + onemkl::cublas::her2, + onemkl::cublas::her2, + onemkl::cublas::hpmv, + onemkl::cublas::hpmv, + onemkl::cublas::hpr, + onemkl::cublas::hpr, + onemkl::cublas::hpr2, + onemkl::cublas::hpr2, + onemkl::cublas::sbmv, + onemkl::cublas::sbmv, + onemkl::cublas::spmv, + onemkl::cublas::spmv, + onemkl::cublas::spr, + onemkl::cublas::spr, + onemkl::cublas::spr2, + onemkl::cublas::spr2, + onemkl::cublas::symv, + onemkl::cublas::symv, + onemkl::cublas::syr, + onemkl::cublas::syr, + onemkl::cublas::syr2, + onemkl::cublas::syr2, + onemkl::cublas::tbmv, + onemkl::cublas::tbmv, + onemkl::cublas::tbmv, + onemkl::cublas::tbmv, + onemkl::cublas::tbsv, + onemkl::cublas::tbsv, + onemkl::cublas::tbsv, + onemkl::cublas::tbsv, + onemkl::cublas::tpmv, + onemkl::cublas::tpmv, + onemkl::cublas::tpmv, + onemkl::cublas::tpmv, + onemkl::cublas::tpsv, + onemkl::cublas::tpsv, + onemkl::cublas::tpsv, + onemkl::cublas::tpsv, + onemkl::cublas::trmv, + onemkl::cublas::trmv, + onemkl::cublas::trmv, + onemkl::cublas::trmv, + onemkl::cublas::trsv, + onemkl::cublas::trsv, + onemkl::cublas::trsv, + onemkl::cublas::trsv, + onemkl::cublas::gemm, + onemkl::cublas::gemm, + onemkl::cublas::gemm, + onemkl::cublas::gemm, + onemkl::cublas::gemm, + onemkl::cublas::hemm, + onemkl::cublas::hemm, + onemkl::cublas::herk, + onemkl::cublas::herk, + onemkl::cublas::her2k, + onemkl::cublas::her2k, + onemkl::cublas::symm, + onemkl::cublas::symm, + onemkl::cublas::symm, + onemkl::cublas::symm, + onemkl::cublas::syrk, + onemkl::cublas::syrk, + onemkl::cublas::syrk, + onemkl::cublas::syrk, + onemkl::cublas::syr2k, + onemkl::cublas::syr2k, + onemkl::cublas::syr2k, + onemkl::cublas::syr2k, + onemkl::cublas::trmm, + onemkl::cublas::trmm, + onemkl::cublas::trmm, + onemkl::cublas::trmm, + onemkl::cublas::trsm, + onemkl::cublas::trsm, + onemkl::cublas::trsm, + onemkl::cublas::trsm, + onemkl::cublas::gemm_batch, + onemkl::cublas::gemm_batch, + onemkl::cublas::gemm_batch, + onemkl::cublas::gemm_batch, + onemkl::cublas::gemm_batch, + onemkl::cublas::gemm_batch, + onemkl::cublas::gemm_batch, + onemkl::cublas::gemm_batch, + onemkl::cublas::trsm_batch, + onemkl::cublas::trsm_batch, + onemkl::cublas::trsm_batch, + onemkl::cublas::trsm_batch, + onemkl::cublas::trsm_batch, + onemkl::cublas::trsm_batch, + onemkl::cublas::trsm_batch, + onemkl::cublas::trsm_batch, + onemkl::cublas::gemmt, + onemkl::cublas::gemmt, + onemkl::cublas::gemmt, + onemkl::cublas::gemmt, + onemkl::cublas::gemm_ext, + onemkl::cublas::gemm_ext, + onemkl::cublas::gemm_ext, + onemkl::cublas::gemm_ext, + onemkl::cublas::gemm_ext, + onemkl::cublas::gemm_ext, + onemkl::cublas::gemm_ext, +}; diff --git a/src/config.hpp.in b/src/config.hpp.in index 589b5c95d..85c35924d 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -20,6 +20,7 @@ #ifndef ONEMKL_CONFIG_H #define ONEMKL_CONFIG_H +#cmakedefine ENABLE_CUBLAS_BACKEND #cmakedefine ENABLE_MKLCPU_BACKEND #cmakedefine ENABLE_MKLGPU_BACKEND #cmakedefine BUILD_SHARED_LIBS diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index 860d05c8d..348a3215a 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -70,6 +70,17 @@ if(ENABLE_MKLGPU_BACKEND) endif() endif() +if(ENABLE_CUBLAS_BACKEND) + add_dependencies(test_main_ct onemkl_blas_cublas) + if(BUILD_SHARED_LIBS) + list(APPEND ONEMKL_LIBRARIES onemkl_blas_cublas) + else() + list(APPEND ONEMKL_LIBRARIES -foffload-static-lib=${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libonemkl_blas_cublas.a) + find_package(cuBLAS REQUIRED) + list(APPEND ONEMKL_LIBRARIES ONEMKL::cuBLAS::cuBLAS) + endif() +endif() + target_link_libraries(test_main_ct PUBLIC gtest gtest_main diff --git a/tests/unit_tests/include/test_helper.hpp b/tests/unit_tests/include/test_helper.hpp index 4fd5df293..10d92505c 100644 --- a/tests/unit_tests/include/test_helper.hpp +++ b/tests/unit_tests/include/test_helper.hpp @@ -36,6 +36,13 @@ #define TEST_RUN_INTELGPU(q, func, args) #endif +#ifdef ENABLE_CUBLAS_BACKEND + #define TEST_RUN_NVIDIAGPU(q, func, args) \ + func args +#else + #define TEST_RUN_NVIDIAGPU(q, func, args) +#endif + #define TEST_RUN_CT(q, func, args) \ do { \ if (q.is_host() || q.get_device().is_cpu()) \ @@ -45,6 +52,8 @@ q.get_device().get_info()); \ if (vendor_id == INTEL_ID) \ TEST_RUN_INTELGPU(q, func, args); \ + else if (vendor_id == NVIDIA_ID) \ + TEST_RUN_NVIDIAGPU(q, func, args); \ } \ } while (0); @@ -61,6 +70,8 @@ class DeviceNamePrint { switch (vendor_id) { case INTEL_ID: return std::string("INTELGPU"); + case NVIDIA_ID: + return std::string("NVIDIAGPU"); } } if (dev.param.is_accelerator()) diff --git a/tests/unit_tests/main_test.cpp b/tests/unit_tests/main_test.cpp index 9c95299b7..6f14ae10e 100644 --- a/tests/unit_tests/main_test.cpp +++ b/tests/unit_tests/main_test.cpp @@ -98,6 +98,10 @@ int main(int argc, char** argv) { #ifndef ENABLE_MKLGPU_BACKEND if (dev.is_gpu() && vendor_id == INTEL_ID) continue; +#endif +#ifndef ENABLE_CUBLAS_BACKEND + if (dev.is_gpu() && vendor_id == NVIDIA_ID) + continue; #endif if (!dev.is_accelerator()) devices.push_back(dev);