diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 13c83bda2..71a263b8d 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -21,9 +21,17 @@ #include "testing_nrm2_batched.hpp" #include "testing_nrm2_strided_batched.hpp" #include "testing_rot.hpp" +#include "testing_rot_batched.hpp" +#include "testing_rot_strided_batched.hpp" #include "testing_rotg.hpp" +#include "testing_rotg_batched.hpp" +#include "testing_rotg_strided_batched.hpp" #include "testing_rotm.hpp" +#include "testing_rotm_batched.hpp" +#include "testing_rotm_strided_batched.hpp" #include "testing_rotmg.hpp" +#include "testing_rotmg_batched.hpp" +#include "testing_rotmg_strided_batched.hpp" #include "testing_scal.hpp" #include "testing_scal_batched.hpp" #include "testing_scal_strided_batched.hpp" @@ -174,14 +182,18 @@ struct perf_blas< testing_set_get_vector(arg); else if(!strcmp(arg.function, "set_get_matrix")) testing_set_get_matrix(arg); - else if(!strcmp(arg.function, "rot")) - testing_rot(arg); - else if(!strcmp(arg.function, "rotg")) - testing_rotg(arg); else if(!strcmp(arg.function, "rotm")) testing_rotm(arg); + else if(!strcmp(arg.function, "rotm_batched")) + testing_rotm_batched(arg); + else if(!strcmp(arg.function, "rotm_strided_batched")) + testing_rotm_strided_batched(arg); else if(!strcmp(arg.function, "rotmg")) testing_rotmg(arg); + else if(!strcmp(arg.function, "rotmg_batched")) + testing_rotmg_batched(arg); + else if(!strcmp(arg.function, "rotmg_strided_batched")) + testing_rotmg_strided_batched(arg); else if(!strcmp(arg.function, "gemv")) testing_gemv(arg); else if(!strcmp(arg.function, "gemv_batched")) @@ -326,6 +338,47 @@ struct perf_blas +struct perf_blas_rot : rocblas_test_invalid +{ +}; + +template +struct perf_blas_rot< + Ti, + To, + Tc, + typename std::enable_if<( + (std::is_same{} && std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{} + && std::is_same{}) + || (std::is_same{} && std::is_same{} + && std::is_same{}) + || (std::is_same{} && std::is_same{} + && std::is_same{}) + || (std::is_same{} && std::is_same{} + && std::is_same{}))>::type> +{ + explicit operator bool() + { + return true; + } + + void operator()(const Arguments& arg) + { + if(!strcmp(arg.function, "rot")) + testing_rot(arg); + else if(!strcmp(arg.function, "rot_batched")) + testing_rot_batched(arg); + else if(!strcmp(arg.function, "rot_strided_batched")) + testing_rot_strided_batched(arg); + else + throw std::invalid_argument("Invalid combination --function "s + arg.function + + " --a_type "s + rocblas_datatype2string(arg.a_type)); + } +}; + template struct perf_blas_scal : rocblas_test_invalid { @@ -361,6 +414,40 @@ struct perf_blas_scal< } }; +template +struct perf_blas_rotg : rocblas_test_invalid +{ +}; + +template +struct perf_blas_rotg< + Ta, + Tb, + typename std::enable_if< + (std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{})>::type> +{ + explicit operator bool() + { + return true; + } + void operator()(const Arguments& arg) + { + if(!strcmp(arg.function, "rotg")) + testing_rotg(arg); + else if(!strcmp(arg.function, "rotg_batched")) + testing_rotg_batched(arg); + else if(!strcmp(arg.function, "rotg_strided_batched")) + testing_rotg_strided_batched(arg); + else + throw std::invalid_argument("Invalid combination --function "s + arg.function + + " --a_type " + rocblas_datatype2string(arg.a_type) + + " --b_type " + rocblas_datatype2string(arg.b_type)); + } +}; + int run_bench_test(Arguments& arg) { // disable unit_check in client benchmark, it is only used in gtest unit test @@ -523,6 +610,12 @@ int run_bench_test(Arguments& arg) if(!strcmp(function, "scal") || !strcmp(function, "scal_batched") || !strcmp(function, "scal_strided_batched")) rocblas_blas1_dispatch(arg); + else if(!strcmp(function, "rotg") || !strcmp(function, "rotg_batched") + || !strcmp(function, "rotg_strided_batched")) + rocblas_blas1_dispatch(arg); + else if(!strcmp(function, "rot") || !strcmp(function, "rot_batched") + || !strcmp(function, "rot_strided_batched")) + rocblas_blas1_dispatch(arg); else rocblas_simple_dispatch(arg); } diff --git a/clients/common/rocblas_gentest.py b/clients/common/rocblas_gentest.py index 383eef4e5..b159ce65a 100755 --- a/clients/common/rocblas_gentest.py +++ b/clients/common/rocblas_gentest.py @@ -199,13 +199,17 @@ def setdefaults(test): if test['function'] in ('asum_strided_batched', 'nrm2_strided_batched', 'scal_strided_batched', 'swap_strided_batched', 'copy_strided_batched', 'dot_strided_batched', - 'dotc_strided_batched'): + 'dotc_strided_batched', 'rot_strided_batched', + 'rotm_strided_batched'): if all([x in test for x in ('N', 'incx', 'stride_scale')]): test.setdefault('stride_x', int(test['N'] * abs(test['incx']) * test['stride_scale'])) if all([x in test for x in ('N', 'incy', 'stride_scale')]): test.setdefault('stride_y', int(test['N'] * abs(test['incy']) * test['stride_scale'])) + # we are using stride_c for param in rotm + if all([x in test for x in ('stride_scale')]): + test.setdefault('stride_c', int(test['stride_scale']) * 5) if test['function'] in ('ger_strided_batched'): if all([x in test for x in ('M', 'incx', 'stride_scale')]): @@ -215,6 +219,26 @@ def setdefaults(test): test.setdefault('stride_y', int(test['N'] * abs(test['incy']) * test['stride_scale'])) + # we are using stride_c for arg c and stride_d for arg s in rotg + # these are are single values for each batch + if test['function'] in ('rotg_strided_batched'): + if 'stride_scale' in test: + test.setdefault('stride_a', int(test['stride_scale'])) + test.setdefault('stride_b', int(test['stride_scale'])) + test.setdefault('stride_c', int(test['stride_scale'])) + test.setdefault('stride_d', int(test['stride_scale'])) + + # we are using stride_a for d1, stride_b for d2, and stride_c for param in + # rotmg. These are are single values for each batch, except param which is + # a 5 element array + if test['function'] in ('rotmg_strided_batched'): + if 'stride_scale' in test: + test.setdefault('stride_a', int(test['stride_scale'])) + test.setdefault('stride_b', int(test['stride_scale'])) + test.setdefault('stride_c', int(test['stride_scale']) * 5) + test.setdefault('stride_x', int(test['stride_scale'])) + test.setdefault('stride_y', int(test['stride_scale'])) + test.setdefault('stride_x', 0) test.setdefault('stride_y', 0) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index b2d2a9a42..f2a1e7ca1 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -18,9 +18,17 @@ #include "testing_nrm2_batched.hpp" #include "testing_nrm2_strided_batched.hpp" #include "testing_rot.hpp" +#include "testing_rot_batched.hpp" +#include "testing_rot_strided_batched.hpp" #include "testing_rotg.hpp" +#include "testing_rotg_batched.hpp" +#include "testing_rotg_strided_batched.hpp" #include "testing_rotm.hpp" +#include "testing_rotm_batched.hpp" +#include "testing_rotm_strided_batched.hpp" #include "testing_rotmg.hpp" +#include "testing_rotmg_batched.hpp" +#include "testing_rotmg_strided_batched.hpp" #include "testing_scal.hpp" #include "testing_scal_batched.hpp" #include "testing_scal_strided_batched.hpp" @@ -59,9 +67,17 @@ namespace swap_batched, swap_strided_batched, rot, + rot_batched, + rot_strided_batched, rotg, + rotg_batched, + rotg_strided_batched, rotm, + rotm_batched, + rotm_strided_batched, rotmg, + rotmg_batched, + rotmg_strided_batched, }; // ---------------------------------------------------------------------------- @@ -93,32 +109,45 @@ namespace { bool is_scal = (BLAS1 == blas1::scal || BLAS1 == blas1::scal_batched || BLAS1 == blas1::scal_strided_batched); + bool is_rot = (BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched); + bool is_rotg = (BLAS1 == blas1::rotg || BLAS1 == blas1::rotg_batched + || BLAS1 == blas1::rotg_strided_batched); + bool is_rotmg = (BLAS1 == blas1::rotmg || BLAS1 == blas1::rotmg_batched + || BLAS1 == blas1::rotmg_strided_batched); bool is_batched = (BLAS1 == blas1::nrm2_batched || BLAS1 == blas1::asum_batched || BLAS1 == blas1::scal_batched || BLAS1 == blas1::swap_batched || BLAS1 == blas1::copy_batched || BLAS1 == blas1::dot_batched - || BLAS1 == blas1::dotc_batched); + || BLAS1 == blas1::dotc_batched || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotg_batched + || BLAS1 == blas1::rotmg_batched); bool is_strided = (BLAS1 == blas1::nrm2_strided_batched || BLAS1 == blas1::asum_strided_batched || BLAS1 == blas1::scal_strided_batched || BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::copy_strided_batched || BLAS1 == blas1::dot_strided_batched - || BLAS1 == blas1::dotc_strided_batched); + || BLAS1 == blas1::dotc_strided_batched + || BLAS1 == blas1::rot_strided_batched + || BLAS1 == blas1::rotm_strided_batched + || BLAS1 == blas1::rotg_strided_batched + || BLAS1 == blas1::rotmg_strided_batched); - if((is_scal || BLAS1 == blas1::rot || BLAS1 == blas1::rotg) - && arg.a_type != arg.b_type) + if((is_scal || is_rotg || is_rot) && arg.a_type != arg.b_type) name << '_' << rocblas_datatype2string(arg.b_type); - if(BLAS1 == blas1::rot && arg.compute_type != arg.a_type) + if(is_rot && arg.compute_type != arg.a_type) name << '_' << rocblas_datatype2string(arg.compute_type); - name << '_' << arg.N; + if(!is_rotg && !is_rotmg) + name << '_' << arg.N; if(BLAS1 == blas1::axpy || is_scal) name << '_' << arg.alpha << "_" << arg.alphai; - name << '_' << arg.incx; + if(!is_rotg && !is_rotmg) + name << '_' << arg.incx; - if(is_strided) + if(is_strided && !is_rotg) { name << '_' << arg.stride_x; } @@ -129,17 +158,31 @@ namespace || BLAS1 == blas1::dotc_batched || BLAS1 == blas1::dot_strided_batched || BLAS1 == blas1::dotc_strided_batched || BLAS1 == blas1::swap || BLAS1 == blas1::swap_batched || BLAS1 == blas1::swap_strided_batched - || BLAS1 == blas1::rot || BLAS1 == blas1::rotm) + || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotm + || BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotm_strided_batched) { name << '_' << arg.incy; } if(BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::copy_strided_batched - || BLAS1 == blas1::dot_strided_batched || BLAS1 == blas1::dotc_strided_batched) + || BLAS1 == blas1::dot_strided_batched || BLAS1 == blas1::dotc_strided_batched + || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotm_strided_batched) { name << '_' << arg.stride_y; } + if(BLAS1 == blas1::rotg_strided_batched) + { + name << '_' << arg.stride_a << '_' << arg.stride_b << '_' << arg.stride_c << '_' + << arg.stride_d; + } + + if(BLAS1 == blas1::rotm_strided_batched || BLAS1 == blas1::rotmg_strided_batched) + { + name << '_' << arg.stride_c; + } + if(is_batched || is_strided) { name << "_" << arg.batch_count; @@ -220,7 +263,8 @@ namespace || std::is_same{} || std::is_same{})) - || (BLAS1 == blas1::rot + || ((BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched) && ((std::is_same{} && std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{} && std::is_same{}) @@ -233,16 +277,22 @@ namespace || (std::is_same{} && std::is_same{} && std::is_same{}))) - || (BLAS1 == blas1::rotg && std::is_same{} + || ((BLAS1 == blas1::rotg || BLAS1 == blas1::rotg_batched + || BLAS1 == blas1::rotg_strided_batched) + && std::is_same{} && ((std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}))) - || (BLAS1 == blas1::rotm && std::is_same{} && std::is_same{} + || ((BLAS1 == blas1::rotm || BLAS1 == blas1::rotm_batched + || BLAS1 == blas1::rotm_strided_batched) + && std::is_same{} && std::is_same{} && (std::is_same{} || std::is_same{})) - || (BLAS1 == blas1::rotmg && std::is_same{} && std::is_same{} + || ((BLAS1 == blas1::rotmg || BLAS1 == blas1::rotmg_batched + || BLAS1 == blas1::rotmg_strided_batched) + && std::is_same{} && std::is_same{} && (std::is_same{} || std::is_same{}))>; // Creates tests for one of the BLAS 1 functions @@ -320,9 +370,17 @@ BLAS1_TESTING(swap, ARG1) BLAS1_TESTING(swap_batched, ARG1) BLAS1_TESTING(swap_strided_batched, ARG1) BLAS1_TESTING(rot, ARG3) +BLAS1_TESTING(rot_batched, ARG3) +BLAS1_TESTING(rot_strided_batched, ARG3) BLAS1_TESTING(rotg, ARG2) +BLAS1_TESTING(rotg_batched, ARG2) +BLAS1_TESTING(rotg_strided_batched, ARG2) BLAS1_TESTING(rotm, ARG1) +BLAS1_TESTING(rotm_batched, ARG1) +BLAS1_TESTING(rotm_strided_batched, ARG1) BLAS1_TESTING(rotmg, ARG1) +BLAS1_TESTING(rotmg_batched, ARG1) +BLAS1_TESTING(rotmg_strided_batched, ARG1) // clang-format on diff --git a/clients/gtest/blas1_gtest.yaml b/clients/gtest/blas1_gtest.yaml index 6b474eb31..65d92fa6d 100644 --- a/clients/gtest/blas1_gtest.yaml +++ b/clients/gtest/blas1_gtest.yaml @@ -46,6 +46,21 @@ Tests: - rotg: *rotg_precisions - rotmg: *single_double_precisions_complex_real + - name: blas1_batched + category: quick + batch_count: [-1, 0, 5] + function: + - rotg_batched: *rotg_precisions + - rotmg_batched: *single_double_precisions_complex_real + + - name: blas1_strided_batched + category: quick + batch_count: [-1, 0, 5] + stride_scale: [ 1.5 ] + function: + - rotg_strided_batched: *rotg_precisions + - rotmg_strided_batched: *single_double_precisions_complex_real + # All functions with alpha and incx and incy # quick @@ -313,6 +328,8 @@ Tests: function: - swap_batched: *single_double_precisions_complex_real - copy_batched: *single_double_precisions_complex_real + - rot_batched: *rot_precisions + - rotm_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: quick @@ -324,6 +341,8 @@ Tests: function: - swap_strided_batched: *single_double_precisions_complex_real - copy_strided_batched: *single_double_precisions_complex_real + - rot_strided_batched: *rot_precisions + - rotm_strided_batched: *single_double_precisions_complex_real # pre_checkin - name: blas1 @@ -346,6 +365,8 @@ Tests: function: - swap_batched: *single_double_precisions_complex_real - copy_batched: *single_double_precisions_complex_real + - rot_batched: *rot_precisions + - rotm_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: pre_checkin @@ -356,6 +377,8 @@ Tests: function: - swap_strided_batched: *single_double_precisions_complex_real - copy_strided_batched: *single_double_precisions_complex_real + - rot_strided_batched: *rot_precisions + - rotm_strided_batched: *single_double_precisions_complex_real # nightly - name: blas1 @@ -378,6 +401,8 @@ Tests: function: - swap_batched: *single_double_precisions_complex_real - copy_batched: *single_double_precisions_complex_real + - rot_batched: *rot_precisions + - rotm_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: nightly @@ -388,6 +413,8 @@ Tests: function: - swap_strided_batched: *single_double_precisions_complex_real - copy_strided_batched: *single_double_precisions_complex_real + - rot_strided_batched: *rot_precisions + - rotm_strided_batched: *single_double_precisions_complex_real # all functions bad arg # for bad_arg no arguments should be used by test code @@ -421,6 +448,14 @@ Tests: - rotg_bad_arg: *rotg_precisions - rotm_bad_arg: *single_double_precisions_complex_real - rotmg_bad_arg: *single_double_precisions_complex_real + - rot_batched_bad_arg: *rot_precisions + - rotg_batched_bad_arg: *rotg_precisions + - rotm_batched_bad_arg: *single_double_precisions_complex_real + - rotmg_batched_bad_arg: *single_double_precisions_complex_real + - rot_strided_batched_bad_arg: *rot_precisions + - rotg_strided_batched_bad_arg: *rotg_precisions + - rotm_strided_batched_bad_arg: *single_double_precisions_complex_real + - rotmg_strided_batched_bad_arg: *single_double_precisions_complex_real ... diff --git a/clients/include/rocblas.hpp b/clients/include/rocblas.hpp index 49ee0c56f..ae96ec3c5 100644 --- a/clients/include/rocblas.hpp +++ b/clients/include/rocblas.hpp @@ -610,6 +610,83 @@ static constexpr auto template <> static constexpr auto rocblas_rot = rocblas_zdrot; +// rot_batched +template +rocblas_status (*rocblas_rot_batched)(rocblas_handle handle, + rocblas_int n, + T* const x[], + rocblas_int incx, + T* const y[], + rocblas_int incy, + const U* c, + const V* s, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rot_batched = rocblas_srot_batched; + +template <> +static constexpr auto rocblas_rot_batched = rocblas_drot_batched; + +template <> +static constexpr auto + rocblas_rot_batched = rocblas_crot_batched; + +template <> +static constexpr auto + rocblas_rot_batched = rocblas_csrot_batched; + +template <> +static constexpr auto rocblas_rot_batched = rocblas_zrot_batched; + +template <> +static constexpr auto + rocblas_rot_batched = rocblas_zdrot_batched; + +// rot_strided_batched +template +rocblas_status (*rocblas_rot_strided_batched)(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + rocblas_stride stride_x, + T* y, + rocblas_int incy, + rocblas_stride stride_y, + const U* c, + const V* s, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rot_strided_batched = rocblas_srot_strided_batched; + +template <> +static constexpr auto rocblas_rot_strided_batched = rocblas_drot_strided_batched; + +template <> +static constexpr auto + rocblas_rot_strided_batched = rocblas_crot_strided_batched; + +template <> +static constexpr auto rocblas_rot_strided_batched = rocblas_csrot_strided_batched; + +template <> +static constexpr auto + rocblas_rot_strided_batched = rocblas_zrot_strided_batched; + +template <> +static constexpr auto rocblas_rot_strided_batched = rocblas_zdrot_strided_batched; + // rotg template rocblas_status (*rocblas_rotg)(rocblas_handle handle, T* a, T* b, U* c, T* s); @@ -626,6 +703,54 @@ static constexpr auto rocblas_rotg = rocblas_crotg template <> static constexpr auto rocblas_rotg = rocblas_zrotg; +// rotg_batched +template +rocblas_status (*rocblas_rotg_batched)(rocblas_handle handle, + T* const a[], + T* const b[], + U* const c[], + T* const s[], + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rotg_batched = rocblas_srotg_batched; + +template <> +static constexpr auto rocblas_rotg_batched = rocblas_drotg_batched; + +template <> +static constexpr auto rocblas_rotg_batched = rocblas_crotg_batched; + +template <> +static constexpr auto rocblas_rotg_batched = rocblas_zrotg_batched; + +//rotg_strided_batched +template +rocblas_status (*rocblas_rotg_strided_batched)(rocblas_handle handle, + T* a, + rocblas_stride stride_a, + T* b, + rocblas_stride stride_b, + U* c, + rocblas_stride stride_c, + T* s, + rocblas_stride stride_s, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rotg_strided_batched = rocblas_srotg_strided_batched; + +template <> +static constexpr auto rocblas_rotg_strided_batched = rocblas_drotg_strided_batched; + +template <> +static constexpr auto + rocblas_rotg_strided_batched = rocblas_crotg_strided_batched; + +template <> +static constexpr auto + rocblas_rotg_strided_batched = rocblas_zrotg_strided_batched; + //rotm template rocblas_status (*rocblas_rotm)(rocblas_handle handle, @@ -642,6 +767,41 @@ static constexpr auto rocblas_rotm = rocblas_srotm; template <> static constexpr auto rocblas_rotm = rocblas_drotm; +// rotm_batched +template +rocblas_status (*rocblas_rotm_batched)(rocblas_handle handle, + rocblas_int n, + T* const x[], + rocblas_int incx, + T* const y[], + rocblas_int incy, + const T* const param[], + rocblas_int batch_count); +template <> +static constexpr auto rocblas_rotm_batched = rocblas_srotm_batched; + +template <> +static constexpr auto rocblas_rotm_batched = rocblas_drotm_batched; + +// rotm_strided_batched +template +rocblas_status (*rocblas_rotm_strided_batched)(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + rocblas_stride stride_x, + T* y, + rocblas_int incy, + rocblas_stride stride_y, + const T* param, + rocblas_stride stride_param, + rocblas_int batch_count); +template <> +static constexpr auto rocblas_rotm_strided_batched = rocblas_srotm_strided_batched; + +template <> +static constexpr auto rocblas_rotm_strided_batched = rocblas_drotm_strided_batched; + //rotmg template rocblas_status (*rocblas_rotmg)(rocblas_handle handle, T* d1, T* d2, T* x1, const T* y1, T* param); @@ -652,6 +812,43 @@ static constexpr auto rocblas_rotmg = rocblas_srotmg; template <> static constexpr auto rocblas_rotmg = rocblas_drotmg; +//rotmg_batched +template +rocblas_status (*rocblas_rotmg_batched)(rocblas_handle handle, + T* const d1[], + T* const d2[], + T* const x1[], + const T* const y1[], + T* const param[], + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rotmg_batched = rocblas_srotmg_batched; + +template <> +static constexpr auto rocblas_rotmg_batched = rocblas_drotmg_batched; + +//rotmg_strided_batched +template +rocblas_status (*rocblas_rotmg_strided_batched)(rocblas_handle handle, + T* d1, + rocblas_stride stride_d1, + T* d2, + rocblas_stride stride_d2, + T* x1, + rocblas_stride stride_x1, + const T* y1, + rocblas_stride stride_y1, + T* param, + rocblas_stride stride_param, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rotmg_strided_batched = rocblas_srotmg_strided_batched; + +template <> +static constexpr auto rocblas_rotmg_strided_batched = rocblas_drotmg_strided_batched; + /* * =========================================================================== * level 2 BLAS diff --git a/clients/include/rocblas_common.yaml b/clients/include/rocblas_common.yaml index 49bca04c0..620099283 100644 --- a/clients/include/rocblas_common.yaml +++ b/clients/include/rocblas_common.yaml @@ -317,10 +317,6 @@ Defaults: uplo: '*' diag: '*' batch_count: -1 - stride_a: 0 - stride_b: 0 - stride_c: 0 - stride_d: 0 norm_check: 0 unit_check: 1 timing: 0 diff --git a/clients/include/testing_rot.hpp b/clients/include/testing_rot.hpp index f238dde65..d9db8e480 100644 --- a/clients/include/testing_rot.hpp +++ b/clients/include/testing_rot.hpp @@ -96,8 +96,14 @@ void testing_rot(const Arguments& arg) rocblas_seedrand(); rocblas_init(hx, 1, N, incx); rocblas_init(hy, 1, N, incy); - rocblas_init(hc, 1, 1, 1); - rocblas_init(hs, 1, 1, 1); + + // Random alpha (0 - 10) + host_vector alpha(1); + rocblas_init(alpha, 1, 1, 1); + + // cos and sin of alpha (in rads) + hc[0] = cos(alpha[0]); + hs[0] = sin(alpha[0]); // CPU BLAS reference data host_vector cx = hx; diff --git a/clients/include/testing_rot_batched.hpp b/clients/include/testing_rot_batched.hpp new file mode 100644 index 000000000..fef7ad3ac --- /dev/null +++ b/clients/include/testing_rot_batched.hpp @@ -0,0 +1,272 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rot_batched_bad_arg(const Arguments& arg) +{ + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_int incy = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; + + rocblas_local_handle handle; + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(nullptr, N, dx, incx, dy, incy, dc, ds, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, nullptr, incx, dy, incy, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, dx, incx, nullptr, incy, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, nullptr, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rot_batched(const Arguments& arg) +{ + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host_x = 0.0, norm_error_host_y = 0.0, norm_error_device_x = 0.0, + norm_error_device_y = 0.0; + + // check to prevent undefined memory allocation error + if(N <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + { + device_vector dx(1); + device_vector dy(1); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count)), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count))); + return; + } + + size_t size_x = N * size_t(incx); + size_t size_y = N * size_t(incy); + + device_vector dx(batch_count); + device_vector dy(batch_count); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Initial Data on CPU + host_vector hx[batch_count]; //(size_x); + host_vector hy[batch_count]; //(size_y); + host_vector hc(1); + host_vector hs(1); + + device_batch_vector bx(batch_count, size_x); + device_batch_vector by(batch_count, size_y); + + for(int i = 0; i < batch_count; i++) + { + hx[i] = host_vector(size_x); + hy[i] = host_vector(size_y); + } + + int last = batch_count - 1; + if((!bx[last] && size_x) || (!by[last] && size_y)) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + rocblas_seedrand(); + for(int b = 0; b < batch_count; b++) + { + rocblas_init(hx[b], 1, N, incx); + rocblas_init(hy[b], 1, N, incy); + } + + // Random alpha (0 - 10) + host_vector alpha(1); + rocblas_init(alpha, 1, 1, 1); + + // cos and sin of alpha (in rads) + hc[0] = cos(alpha[0]); + hs[0] = sin(alpha[0]); + + // CPU BLAS reference data + host_vector cx[batch_count]; + host_vector cy[batch_count]; + for(int b = 0; b < batch_count; b++) + { + cx[b] = hx[b]; + cy[b] = hy[b]; + } + // cblas_rotg(cx, cy, hc, hs); + // cx[0] = hx[0]; + // cy[0] = hy[0]; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rot(N, cx[b], incx, cy[b], incy, hc, hs); + } + cpu_time_used = get_time_us() - cpu_time_used; + + if(arg.unit_check || arg.norm_check) + { + // Test rocblas_pointer_mode_host + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, hc, hs, batch_count))); + + host_vector rx[batch_count]; + host_vector ry[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rx[b] = host_vector(size_x); + ry[b] = host_vector(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + unit_check_general(1, N, batch_count, incx, cx, rx); + unit_check_general(1, N, batch_count, incy, cy, ry); + } + if(arg.norm_check) + { + norm_error_host_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + norm_error_host_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + } + } + + // Test rocblas_pointer_mode_device + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(V), hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count))); + + host_vector rx[batch_count]; + host_vector ry[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rx[b] = host_vector(size_x); + ry[b] = host_vector(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + unit_check_general(1, N, batch_count, incx, cx, rx); + unit_check_general(1, N, batch_count, incy, cy, ry); + } + if(arg.norm_check) + { + norm_error_device_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + norm_error_device_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rot_batched(handle, N, dx, incx, dy, incy, hc, hs, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rot_batched(handle, N, dx, incx, dy, incy, hc, hs, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "N,incx,incy,rocblas(us),cpu(us)"; + if(arg.norm_check) + std::cout + << ",norm_error_host_x,norm_error_host_y,norm_error_device_x,norm_error_device_y"; + std::cout << std::endl; + std::cout << N << "," << incx << "," << incy << "," << gpu_time_used << "," + << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host_x << ',' << norm_error_host_y << "," + << norm_error_device_x << "," << norm_error_device_y; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rot_strided_batched.hpp b/clients/include/testing_rot_strided_batched.hpp new file mode 100644 index 000000000..9d9ff2ae5 --- /dev/null +++ b/clients/include/testing_rot_strided_batched.hpp @@ -0,0 +1,240 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rot_strided_batched_bad_arg(const Arguments& arg) +{ + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_stride stride_x = 1; + rocblas_int incy = 1; + rocblas_stride stride_y = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; + + rocblas_local_handle handle; + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + nullptr, N, dx, incx, stride_x, dy, incy, stride_y, dc, ds, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + handle, N, nullptr, incx, stride_x, dy, incy, stride_y, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, nullptr, incy, stride_y, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, nullptr, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, dc, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rot_strided_batched(const Arguments& arg) +{ + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int stride_x = arg.stride_x; + rocblas_int stride_y = arg.stride_y; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host_x = 0.0, norm_error_host_y = 0.0, norm_error_device_x = 0.0, + norm_error_device_y = 0.0; + + // check to prevent undefined memory allocation error + if(N <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + { + static const size_t safe_size = 100; // arbitrarily set to 100 + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rot_strided_batched)(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dc, + ds, + batch_count), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, dc, ds, batch_count))); + return; + } + + size_t size_x = N * size_t(incx) + size_t(stride_x) * size_t(batch_count - 1); + size_t size_y = N * size_t(incy) + size_t(stride_y) * size_t(batch_count - 1); + + device_vector dx(size_x); + device_vector dy(size_y); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Initial Data on CPU + host_vector hx(size_x); + host_vector hy(size_y); + host_vector hc(1); + host_vector hs(1); + rocblas_seedrand(); + rocblas_init(hx, 1, N, incx, stride_x, batch_count); + rocblas_init(hy, 1, N, incy, stride_y, batch_count); + + // Random alpha (0 - 10) + host_vector alpha(1); + rocblas_init(alpha, 1, 1, 1); + + // cos and sin of alpha (in rads) + hc[0] = cos(alpha[0]); + hs[0] = sin(alpha[0]); + + // CPU BLAS reference data + host_vector cx = hx; + host_vector cy = hy; + // cblas_rotg(cx, cy, hc, hs); + // cx[0] = hx[0]; + // cy[0] = hy[0]; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rot(N, cx + b * stride_x, incx, cy + b * stride_y, incy, hc, hs); + } + cpu_time_used = get_time_us() - cpu_time_used; + + if(arg.unit_check || arg.norm_check) + { + // Test rocblas_pointer_mode_host + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR((rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hc, hs, batch_count))); + host_vector rx(size_x); + host_vector ry(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + if(arg.unit_check) + { + unit_check_general(1, N, batch_count, incx, stride_x, cx, rx); + unit_check_general(1, N, batch_count, incy, stride_y, cy, ry); + } + if(arg.norm_check) + { + norm_error_host_x + = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + norm_error_host_y + = norm_check_general('F', 1, N, incy, stride_x, batch_count, cy, ry); + } + } + + // Test rocblas_pointer_mode_device + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(V), hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR((rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, dc, ds, batch_count))); + host_vector rx(size_x); + host_vector ry(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + if(arg.unit_check) + { + unit_check_general(1, N, batch_count, incx, stride_x, cx, rx); + unit_check_general(1, N, batch_count, incy, stride_y, cy, ry); + } + if(arg.norm_check) + { + norm_error_device_x + = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + norm_error_device_y + = norm_check_general('F', 1, N, incy, stride_y, batch_count, cy, ry); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hc, hs, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hc, hs, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "N,incx,incy,rocblas(us),cpu(us)"; + if(arg.norm_check) + std::cout + << ",norm_error_host_x,norm_error_host_y,norm_error_device_x,norm_error_device_y"; + std::cout << std::endl; + std::cout << N << "," << incx << "," << incy << "," << gpu_time_used << "," + << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host_x << ',' << norm_error_host_y << "," + << norm_error_device_x << "," << norm_error_device_y; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rotg.hpp b/clients/include/testing_rotg.hpp index b64ce34bd..7c15b9604 100644 --- a/clients/include/testing_rotg.hpp +++ b/clients/include/testing_rotg.hpp @@ -120,10 +120,10 @@ void testing_rotg(const Arguments& arg) if(arg.unit_check) { - // unit_check_general(1, 1, 1, ca, ha); - // unit_check_general(1, 1, 1, cb, hb); - // unit_check_general(1, 1, 1, cc, hc); - // unit_check_general(1, 1, 1, cs, hs); + unit_check_general(1, 1, 1, ca, ha); + unit_check_general(1, 1, 1, cb, hb); + unit_check_general(1, 1, 1, cc, hc); + unit_check_general(1, 1, 1, cs, hs); } if(arg.norm_check) diff --git a/clients/include/testing_rotg_batched.hpp b/clients/include/testing_rotg_batched.hpp new file mode 100644 index 000000000..8a30dbd32 --- /dev/null +++ b/clients/include/testing_rotg_batched.hpp @@ -0,0 +1,272 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotg_batched_bad_arg(const Arguments& arg) +{ + rocblas_int batch_count = 5; + static const size_t safe_size = 1; + + rocblas_local_handle handle; + device_vector da(batch_count); + device_vector db(batch_count); + device_vector dc(batch_count); + device_vector ds(batch_count); + + if(!da || !db || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(nullptr, da, db, dc, ds, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, nullptr, db, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, da, nullptr, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, da, db, nullptr, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, da, db, dc, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotg_batched(const Arguments& arg) +{ + const int TEST_COUNT = 100; + rocblas_int batch_count = arg.batch_count; + rocblas_local_handle handle; + + double gpu_time_used, cpu_time_used; + double norm_error_host = 0.0, norm_error_device = 0.0; + + // check to prevent undefined memory allocation error + if(batch_count <= 0) + { + size_t safe_size = 1; + device_vector da(safe_size); + device_vector db(safe_size); + device_vector dc(safe_size); + device_vector ds(safe_size); + + if(!da || !db || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, da, db, dc, ds, batch_count)), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rotg_batched(handle, da, db, dc, ds, batch_count))); + return; + } + + // Initial Data on CPU + host_vector ha[batch_count]; + host_vector hb[batch_count]; + host_vector hc[batch_count]; + host_vector hs[batch_count]; + + device_batch_vector ba(batch_count, 1); + device_batch_vector bb(batch_count, 1); + + for(int b = 0; b < batch_count; b++) + { + ha[b] = host_vector(1); + hb[b] = host_vector(1); + hc[b] = host_vector(1); + hs[b] = host_vector(1); + } + + for(int i = 0; i < TEST_COUNT; i++) + { + host_vector ca[batch_count]; + host_vector cb[batch_count]; + host_vector cc[batch_count]; + host_vector cs[batch_count]; + + rocblas_seedrand(); + for(int b = 0; b < batch_count; b++) + { + rocblas_init(ha[b], 1, 1, 1); + rocblas_init(hb[b], 1, 1, 1); + rocblas_init(hc[b], 1, 1, 1); + rocblas_init(hs[b], 1, 1, 1); + ca[b] = ha[b]; + cb[b] = hb[b]; + cc[b] = hc[b]; + cs[b] = hs[b]; + } + + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotg(ca[b], cb[b], cc[b], cs[b]); + } + cpu_time_used = get_time_us() - cpu_time_used; + + // Test rocblas_pointer_mode_host + { + host_vector ra[batch_count]; + host_vector rb[batch_count]; + host_vector rc[batch_count]; + host_vector rs[batch_count]; + T* ra_in[batch_count]; + T* rb_in[batch_count]; + U* rc_in[batch_count]; + T* rs_in[batch_count]; + for(int b = 0; b < batch_count; b++) + { + ra_in[b] = ra[b] = ha[b]; + rb_in[b] = rb[b] = hb[b]; + rc_in[b] = rc[b] = hc[b]; + rs_in[b] = rs[b] = hs[b]; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + + CHECK_ROCBLAS_ERROR( + (rocblas_rotg_batched(handle, ra_in, rb_in, rc_in, rs_in, batch_count))); + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, ra, ca); + unit_check_general(1, 1, batch_count, 1, rb, cb); + unit_check_general(1, 1, batch_count, 1, rc, cc); + unit_check_general(1, 1, batch_count, 1, rs, cs); + } + + if(arg.norm_check) + { + norm_error_host = norm_check_general('F', 1, 1, batch_count, 1, ra, ca); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rb, cb); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rc, cc); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rs, cs); + } + } + + // Test rocblas_pointer_mode_device + { + device_vector da(batch_count); + device_vector db(batch_count); + device_vector dc(batch_count); + device_vector ds(batch_count); + device_batch_vector ba(batch_count, 1); + device_batch_vector bb(batch_count, 1); + device_batch_vector bc(batch_count, 1); + device_batch_vector bs(batch_count, 1); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(ba[b], ha[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bb[b], hb[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bc[b], hc[b], sizeof(U), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bs[b], hs[b], sizeof(T), hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(da, ba, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, bb, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, bc, sizeof(U*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, bs, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_ROCBLAS_ERROR((rocblas_rotg_batched(handle, da, db, dc, ds, batch_count))); + + host_vector ra[batch_count]; + host_vector rb[batch_count]; + host_vector rc[batch_count]; + host_vector rs[batch_count]; + for(int b = 0; b < batch_count; b++) + { + ra[b] = host_vector(1); + rb[b] = host_vector(1); + rc[b] = host_vector(1); + rs[b] = host_vector(1); + CHECK_HIP_ERROR(hipMemcpy(ra[b], ba[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rb[b], bb[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rc[b], bc[b], sizeof(U), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rs[b], bs[b], sizeof(T), hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, ra, ca); + unit_check_general(1, 1, batch_count, 1, rb, cb); + unit_check_general(1, 1, batch_count, 1, rc, cc); + unit_check_general(1, 1, batch_count, 1, rs, cs); + } + + if(arg.norm_check) + { + norm_error_device = norm_check_general('F', 1, 1, batch_count, 1, ra, ca); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rb, cb); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rc, cc); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rs, cs); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + // Device mode will be much quicker + // (TODO: or is there another reason we are typically using host_mode for timing?) + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + device_vector da(batch_count); + device_vector db(batch_count); + device_vector dc(batch_count); + device_vector ds(batch_count); + device_batch_vector ba(batch_count, 1); + device_batch_vector bb(batch_count, 1); + device_batch_vector bc(batch_count, 1); + device_batch_vector bs(batch_count, 1); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(ba[b], ha[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bb[b], hb[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bc[b], hc[b], sizeof(U), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bs[b], hs[b], sizeof(T), hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(da, ba, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, bb, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, bc, sizeof(U*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, bs, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotg_batched(handle, da, db, dc, ds, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotg_batched(handle, da, db, dc, ds, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "rocblas-us,CPU-us"; + if(arg.norm_check) + std::cout << ",norm_error_host_ptr,norm_error_device"; + std::cout << std::endl; + + std::cout << gpu_time_used << "," << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host << ',' << norm_error_device; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rotg_strided_batched.hpp b/clients/include/testing_rotg_strided_batched.hpp new file mode 100644 index 000000000..baeec8112 --- /dev/null +++ b/clients/include/testing_rotg_strided_batched.hpp @@ -0,0 +1,255 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotg_strided_batched_bad_arg(const Arguments& arg) +{ + static const size_t safe_size = 1; + rocblas_int batch_count = 5; + rocblas_stride stride_a = 10; + rocblas_stride stride_b = 10; + rocblas_stride stride_c = 10; + rocblas_stride stride_s = 10; + + rocblas_local_handle handle; + device_vector da(batch_count * stride_a); + device_vector db(batch_count * stride_b); + device_vector dc(batch_count * stride_c); + device_vector ds(batch_count * stride_s); + if(!da || !db || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + nullptr, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + handle, nullptr, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + handle, da, stride_a, nullptr, stride_b, dc, stride_c, ds, stride_s, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, nullptr, stride_c, ds, stride_s, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, nullptr, stride_s, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotg_strided_batched(const Arguments& arg) +{ + const int TEST_COUNT = 100; + rocblas_int stride_a = arg.stride_a; + rocblas_int stride_b = arg.stride_b; + rocblas_int stride_c = arg.stride_c; + rocblas_int stride_s = arg.stride_d; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host = 0.0, norm_error_device = 0.0; + + // check to prevent undefined memory allocation error + if(batch_count <= 0) + { + static const size_t safe_size = 1; // arbitrarily set to 100 + device_vector da(safe_size); + device_vector db(safe_size); + device_vector dc(safe_size); + device_vector ds(safe_size); + if(!da || !db || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rotg_strided_batched)(handle, + da, + stride_a, + db, + stride_b, + dc, + stride_c, + ds, + stride_s, + batch_count), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count))); + return; + } + + size_t size_a = size_t(stride_a) * size_t(batch_count); + size_t size_b = size_t(stride_b) * size_t(batch_count); + size_t size_c = size_t(stride_c) * size_t(batch_count); + size_t size_s = size_t(stride_s) * size_t(batch_count); + + host_vector ha(size_a); + host_vector hb(size_b); + host_vector hc(size_c); + host_vector hs(size_s); + + for(int i = 0; i < TEST_COUNT; i++) + { + // Initial data on CPU + rocblas_seedrand(); + rocblas_init(ha, 1, 1, 1, stride_a, batch_count); + rocblas_init(hb, 1, 1, 1, stride_b, batch_count); + rocblas_init(hc, 1, 1, 1, stride_c, batch_count); + rocblas_init(hs, 1, 1, 1, stride_s, batch_count); + + // CPU_BLAS + host_vector ca = ha; + host_vector cb = hb; + host_vector cc = hc; + host_vector cs = hs; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotg( + ca + b * stride_a, cb + b * stride_b, cc + b * stride_c, cs + b * stride_s); + } + cpu_time_used = get_time_us() - cpu_time_used; + + // Test rocblas_pointer_mode_host + { + host_vector ra = ha; + host_vector rb = hb; + host_vector rc = hc; + host_vector rs = hs; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_ROCBLAS_ERROR((rocblas_rotg_strided_batched( + handle, ra, stride_a, rb, stride_b, rc, stride_c, rs, stride_s, batch_count))); + + if(arg.unit_check) + { + double rel_error = std::numeric_limits::epsilon() * 100; + near_check_general(1, 1, batch_count, 1, stride_a, ca, ra, rel_error); + near_check_general(1, 1, batch_count, 1, stride_b, cb, rb, rel_error); + near_check_general(1, 1, batch_count, 1, stride_c, cc, rc, rel_error); + near_check_general(1, 1, batch_count, 1, stride_s, cs, rs, rel_error); + } + + if(arg.norm_check) + { + norm_error_host + = norm_check_general('F', 1, 1, 1, stride_a, batch_count, ca, ra); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_b, batch_count, cb, rb); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_c, batch_count, cc, rc); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_s, batch_count, cs, rs); + } + } + + // Test rocblas_pointer_mode_device + { + device_vector da(size_a); + device_vector db(size_b); + device_vector dc(size_c); + device_vector ds(size_s); + CHECK_HIP_ERROR(hipMemcpy(da, ha, sizeof(T) * size_a, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, hb, sizeof(T) * size_b, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U) * size_c, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(T) * size_s, hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_ROCBLAS_ERROR((rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count))); + host_vector ra(size_a); + host_vector rb(size_b); + host_vector rc(size_c); + host_vector rs(size_s); + CHECK_HIP_ERROR(hipMemcpy(ra, da, sizeof(T) * size_a, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rb, db, sizeof(T) * size_b, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rc, dc, sizeof(U) * size_c, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rs, ds, sizeof(T) * size_s, hipMemcpyDeviceToHost)); + + if(arg.unit_check) + { + double rel_error = std::numeric_limits::epsilon() * 100; + near_check_general(1, 1, batch_count, 1, stride_a, ca, ra, rel_error); + near_check_general(1, 1, batch_count, 1, stride_b, cb, rb, rel_error); + near_check_general(1, 1, batch_count, 1, stride_c, cc, rc, rel_error); + near_check_general(1, 1, batch_count, 1, stride_s, cs, rs, rel_error); + } + + if(arg.norm_check) + { + norm_error_host + = norm_check_general('F', 1, 1, 1, stride_a, batch_count, ca, ra); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_b, batch_count, cb, rb); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_c, batch_count, cc, rc); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_s, batch_count, cs, rs); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + // Device mode will be quicker + // (TODO: or is there another reason we are typically using host_mode for timing?) + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + device_vector da(size_a); + device_vector db(size_b); + device_vector dc(size_c); + device_vector ds(size_s); + CHECK_HIP_ERROR(hipMemcpy(da, ha, sizeof(T) * size_a, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, hb, sizeof(T) * size_b, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U) * size_c, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(T) * size_s, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "rocblas-us,CPU-us"; + if(arg.norm_check) + std::cout << ",norm_error_host_ptr,norm_error_device"; + std::cout << std::endl; + + std::cout << gpu_time_used << "," << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host << ',' << norm_error_device; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rotm_batched.hpp b/clients/include/testing_rotm_batched.hpp new file mode 100644 index 000000000..9ffedcd98 --- /dev/null +++ b/clients/include/testing_rotm_batched.hpp @@ -0,0 +1,289 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotm_batched_bad_arg(const Arguments& arg) +{ + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_int incy = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; + + rocblas_local_handle handle; + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dparam(safe_size); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(nullptr, N, dx, incx, dy, incy, dparam, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(handle, N, nullptr, incx, dy, incy, dparam, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(handle, N, dx, incx, nullptr, incy, dparam, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotm_batched(const Arguments& arg) +{ + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host_x = 0.0, norm_error_host_y = 0.0, norm_error_device_x = 0.0, + norm_error_device_y = 0.0; + + // check to prevent undefined memory allocation error + if(N <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + { + static const size_t safe_size = 100; // arbitrarily set to 100 + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dparam(safe_size); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count)), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count))); + return; + } + + size_t size_x = N * size_t(incx); + size_t size_y = N * size_t(incy); + + device_vector dx(batch_count); + device_vector dy(batch_count); + device_vector dparam(batch_count); + + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Initial Data on CPU + host_vector hx[batch_count]; + host_vector hy[batch_count]; + host_vector hdata[batch_count]; //(4); + host_vector hparam[batch_count]; //(5); + + device_batch_vector bx(batch_count, size_x); + device_batch_vector by(batch_count, size_y); + device_batch_vector bdata(batch_count, 4); + device_batch_vector bparam(batch_count, 5); + + for(int b = 0; b < batch_count; b++) + { + hx[b] = host_vector(size_x); + hy[b] = host_vector(size_y); + hdata[b] = host_vector(4); + hparam[b] = host_vector(5); + } + + int last = batch_count - 1; + if((!bx[last] && size_x) || (!by[last] && size_y) || !bdata[last] || !bparam[last]) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + rocblas_seedrand(); + for(int b = 0; b < batch_count; b++) + { + rocblas_init(hx[b], 1, N, incx); + rocblas_init(hy[b], 1, N, incy); + rocblas_init(hdata[b], 1, 4, 1); + + // CPU BLAS reference data + cblas_rotmg(&hdata[b][0], &hdata[b][1], &hdata[b][2], &hdata[b][3], hparam[b]); + } + + constexpr int FLAG_COUNT = 4; + const T FLAGS[FLAG_COUNT] = {-1, 0, 1, -2}; + + for(int i = 0; i < FLAG_COUNT; i++) + { + for(int b = 0; b < batch_count; b++) + hparam[b][0] = FLAGS[i]; + + host_vector cx[batch_count]; + host_vector cy[batch_count]; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cx[b] = hx[b]; + cy[b] = hy[b]; + + cblas_rotm(N, cx[b], incx, cy[b], incy, hparam[b]); + } + cpu_time_used = get_time_us() - cpu_time_used; + + if(arg.unit_check || arg.norm_check) + { + // Test rocblas_pointer_mode_host + // TODO: THIS IS NO LONGER SUPPORTED + // { + // CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + // for(int b = 0; b < batch_count; b++) + // { + // CHECK_HIP_ERROR( + // hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR( + // hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + // } + // CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + // CHECK_ROCBLAS_ERROR( + // (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, hparam, batch_count))); + + // host_vector rx[batch_count]; + // host_vector ry[batch_count]; + // for(int b = 0; b < batch_count; b++) + // { + // rx[b] = host_vector(size_x); + // ry[b] = host_vector(size_y); + // CHECK_HIP_ERROR( + // hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + // CHECK_HIP_ERROR( + // hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + // } + + // if(arg.unit_check) + // { + // T rel_error = std::numeric_limits::epsilon() * 1000; + // near_check_general(1, N, batch_count, incx, cx, rx, rel_error); + // near_check_general(1, N, batch_count, incy, cy, ry, rel_error); + // } + // if(arg.norm_check) + // { + // norm_error_host_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + // norm_error_host_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + // } + // } + + // Test rocblas_pointer_mode_device + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR( + hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(bparam[b], hparam[b], sizeof(T) * 5, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparam, bparam, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count))); + + host_vector rx[batch_count]; + host_vector ry[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rx[b] = host_vector(size_x); + ry[b] = host_vector(size_y); + CHECK_HIP_ERROR( + hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + T rel_error = std::numeric_limits::epsilon() * 1000; + near_check_general(1, N, batch_count, incx, cx, rx, rel_error); + near_check_general(1, N, batch_count, incy, cy, ry, rel_error); + } + if(arg.norm_check) + { + norm_error_device_x + = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + norm_error_device_y + = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(bparam[b], hparam[b], sizeof(T) * 5, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparam, bparam, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "N,incx,incy,rocblas(us),cpu(us)"; + if(arg.norm_check) + std::cout << ",norm_error_host_x,norm_error_host_y,norm_error_device_x,norm_error_" + "device_y"; + std::cout << std::endl; + std::cout << N << "," << incx << "," << incy << "," << gpu_time_used << "," + << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host_x << ',' << norm_error_host_y << "," + << norm_error_device_x << "," << norm_error_device_y; + std::cout << std::endl; + } + } +} diff --git a/clients/include/testing_rotm_strided_batched.hpp b/clients/include/testing_rotm_strided_batched.hpp new file mode 100644 index 000000000..7f67e3bef --- /dev/null +++ b/clients/include/testing_rotm_strided_batched.hpp @@ -0,0 +1,300 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotm_strided_batched_bad_arg(const Arguments& arg) +{ + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_stride stride_x = 1; + rocblas_int incy = 1; + rocblas_stride stride_y = 1; + rocblas_stride stride_param = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; + + rocblas_local_handle handle; + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dparam(safe_size); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_strided_batched( + nullptr, N, dx, incx, stride_x, dy, incy, stride_y, dparam, stride_param, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS((rocblas_rotm_strided_batched(handle, + N, + nullptr, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + nullptr, + incy, + stride_y, + dparam, + stride_param, + batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, nullptr, stride_param, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotm_strided_batched(const Arguments& arg) +{ + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int stride_x = arg.stride_x; + rocblas_int stride_y = arg.stride_y; + rocblas_int stride_param = arg.stride_c; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host_x = 0.0, norm_error_host_y = 0.0, norm_error_device_x = 0.0, + norm_error_device_y = 0.0; + + // check to prevent undefined memory allocation error + if(N <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + { + static const size_t safe_size = 100; // arbitrarily set to 100 + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dparam(safe_size); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rotm_strided_batched)(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count))); + return; + } + + size_t size_x = N * size_t(incx) + size_t(stride_x) * size_t(batch_count - 1); + size_t size_y = N * size_t(incy) + size_t(stride_y) * size_t(batch_count - 1); + size_t size_param = 5 + size_t(stride_param) * size_t(batch_count - 1); + + device_vector dx(size_x); + device_vector dy(size_y); + device_vector dparam(size_param); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Initial Data on CPU + host_vector hx(size_x); + host_vector hy(size_y); + host_vector hdata(4 * batch_count); + host_vector hparam(size_param); + rocblas_seedrand(); + rocblas_init(hx, 1, N, incx, stride_x, batch_count); + rocblas_init(hy, 1, N, incy, stride_y, batch_count); + rocblas_init(hdata, 1, 4, 1, 4, batch_count); + + // CPU BLAS reference data + for(int b = 0; b < batch_count; b++) + cblas_rotmg(hdata + b * 4, + hdata + b * 4 + 1, + hdata + b * 4 + 2, + hdata + b * 4 + 3, + hparam + b * stride_param); + + constexpr int FLAG_COUNT = 4; + const T FLAGS[FLAG_COUNT] = {-1, 0, 1, -2}; + + for(int i = 0; i < FLAG_COUNT; i++) + { + for(int b = 0; b < batch_count; b++) + (hparam + b * stride_param)[0] = FLAGS[i]; + + host_vector cx = hx; + host_vector cy = hy; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotm( + N, cx + b * stride_x, incx, cy + b * stride_y, incy, hparam + b * stride_param); + } + cpu_time_used = get_time_us() - cpu_time_used; + + if(arg.unit_check || arg.norm_check) + { + // Test rocblas_pointer_mode_host + // TODO: THIS IS NO LONGER SUPPORTED + // { + // CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + // CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + // CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched( + // handle, N, dx, incx, stride_x, dy, incy, stride_y, hparam, batch_count))); + // host_vector rx(size_x); + // host_vector ry(size_y); + // CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + // CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + // if(arg.unit_check) + // { + // T rel_error = std::numeric_limits::epsilon() * 1000; + // near_check_general(1, N, batch_count, incx, stride_x, cx, rx, rel_error); + // near_check_general(1, N, batch_count, incy, stride_y, cy, ry, rel_error); + // } + // if(arg.norm_check) + // { + // norm_error_host_x + // = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + // norm_error_host_y + // = norm_check_general('F', 1, N, incy, stride_x, batch_count, cy, ry); + // } + // } + + // Test rocblas_pointer_mode_device + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparam, hparam, sizeof(T) * size_param, hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count))); + host_vector rx(size_x); + host_vector ry(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + if(arg.unit_check) + { + T rel_error = std::numeric_limits::epsilon() * 1000; + near_check_general(1, N, batch_count, incx, stride_x, cx, rx, rel_error); + near_check_general(1, N, batch_count, incy, stride_y, cy, ry, rel_error); + } + if(arg.norm_check) + { + norm_error_device_x + = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + norm_error_device_y + = norm_check_general('F', 1, N, incy, stride_y, batch_count, cy, ry); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparam, hparam, sizeof(T) * size_param, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "N,incx,incy,rocblas(us),cpu(us)"; + if(arg.norm_check) + std::cout << ",norm_error_host_x,norm_error_host_y,norm_error_device_x,norm_error_" + "device_y"; + std::cout << std::endl; + std::cout << N << "," << incx << "," << incy << "," << gpu_time_used << "," + << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host_x << ',' << norm_error_host_y << "," + << norm_error_device_x << "," << norm_error_device_y; + std::cout << std::endl; + } + } +} diff --git a/clients/include/testing_rotmg_batched.hpp b/clients/include/testing_rotmg_batched.hpp new file mode 100644 index 000000000..728362813 --- /dev/null +++ b/clients/include/testing_rotmg_batched.hpp @@ -0,0 +1,319 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotmg_batched_bad_arg(const Arguments& arg) +{ + rocblas_int batch_count = 5; + static const size_t safe_size = 5; + + rocblas_local_handle handle; + device_vector d1(batch_count); + device_vector d2(batch_count); + device_vector x1(batch_count); + device_vector y1(batch_count); + device_vector param(batch_count); + + if(!d1 || !d2 || !x1 || !y1 || !param) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_batched(nullptr, d1, d2, x1, y1, param, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, nullptr, d2, x1, y1, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, d1, nullptr, x1, y1, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, d1, d2, nullptr, y1, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, d1, d2, x1, nullptr, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_batched(handle, d1, d2, x1, y1, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotmg_batched(const Arguments& arg) +{ + const int TEST_COUNT = 100; + rocblas_int batch_count = arg.batch_count; + rocblas_local_handle handle; + + double gpu_time_used, cpu_time_used; + double norm_error_host = 0.0, norm_error_device = 0.0; + + // check to prevent undefined memory allocation error + if(batch_count <= 0) + { + size_t safe_size = 1; + device_vector d1(safe_size); + device_vector d2(safe_size); + device_vector x1(safe_size); + device_vector y1(safe_size); + device_vector params(safe_size); + + if(!d1 || !d2 || !x1 || !y1 || !params) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, d1, d2, x1, y1, params, batch_count)), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR( + (rocblas_rotmg_batched(handle, d1, d2, x1, y1, params, batch_count))); + + return; + } + + // Initial Data on CPU + host_vector hd1[batch_count]; + host_vector hd2[batch_count]; + host_vector hx1[batch_count]; + host_vector hy1[batch_count]; + host_vector hparams[batch_count]; + + device_batch_vector bd1(batch_count, 1); + device_batch_vector bd2(batch_count, 1); + device_batch_vector bx1(batch_count, 1); + device_batch_vector by1(batch_count, 1); + device_batch_vector bparams(batch_count, 5); + + for(int b = 0; b < batch_count; b++) + { + hd1[b] = host_vector(1); + hd2[b] = host_vector(1); + hx1[b] = host_vector(1); + hy1[b] = host_vector(1); + hparams[b] = host_vector(5); + } + + for(int i = 0; i < TEST_COUNT; i++) + { + host_vector cd1[batch_count]; + host_vector cd2[batch_count]; + host_vector cx1[batch_count]; + host_vector cy1[batch_count]; + host_vector cparams[batch_count]; + + rocblas_seedrand(); + + for(int b = 0; b < batch_count; b++) + { + rocblas_init(hd1[b], 1, 1, 1); + rocblas_init(hd2[b], 1, 1, 1); + rocblas_init(hx1[b], 1, 1, 1); + rocblas_init(hy1[b], 1, 1, 1); + rocblas_init(hparams[b], 1, 5, 1); + cd1[b] = hd1[b]; + cd2[b] = hd2[b]; + cx1[b] = hx1[b]; + cy1[b] = hy1[b]; + cparams[b] = hparams[b]; + } + + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotmg(cd1[b], cd2[b], cx1[b], cy1[b], cparams[b]); + } + cpu_time_used = get_time_us() - cpu_time_used; + + // Test rocblas_pointer_mode_host + { + host_vector rd1[batch_count]; + host_vector rd2[batch_count]; + host_vector rx1[batch_count]; + host_vector ry1[batch_count]; + host_vector rparams[batch_count]; + T* rd1_in[batch_count]; + T* rd2_in[batch_count]; + T* rx1_in[batch_count]; + T* ry1_in[batch_count]; + T* rparams_in[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rd1_in[b] = rd1[b] = hd1[b]; + rd2_in[b] = rd2[b] = hd2[b]; + rx1_in[b] = rx1[b] = hx1[b]; + ry1_in[b] = ry1[b] = hy1[b]; + rparams_in[b] = rparams[b] = hparams[b]; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + + CHECK_ROCBLAS_ERROR((rocblas_rotmg_batched( + handle, rd1_in, rd2_in, rx1_in, ry1_in, rparams_in, batch_count))); + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, rd1, cd1); + unit_check_general(1, 1, batch_count, 1, rd2, cd2); + unit_check_general(1, 1, batch_count, 1, rx1, cx1); + unit_check_general(1, 1, batch_count, 1, ry1, cy1); + unit_check_general(1, 5, batch_count, 1, rparams, cparams); + } + + if(arg.norm_check) + { + norm_error_host = norm_check_general('F', 1, 1, batch_count, 1, rd1, cd1); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rd2, cd2); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rx1, cx1); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, ry1, cy1); + norm_error_host + += norm_check_general('F', 1, 5, batch_count, 1, rparams, cparams); + } + } + + // Test rocblas_pointer_mode_device + { + device_vector dd1(batch_count); + device_vector dd2(batch_count); + device_vector dx1(batch_count); + device_vector dy1(batch_count); + device_vector dparams(batch_count); + device_batch_vector bd1(batch_count, 1); + device_batch_vector bd2(batch_count, 1); + device_batch_vector bx1(batch_count, 1); + device_batch_vector by1(batch_count, 1); + device_batch_vector bparams(batch_count, 5); + + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bd1[b], hd1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bd2[b], hd2[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bx1[b], hx1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by1[b], hy1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(bparams[b], hparams[b], sizeof(T) * 5, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dd1, bd1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dd2, bd2, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dx1, bx1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy1, by1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparams, bparams, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_ROCBLAS_ERROR( + (rocblas_rotmg_batched(handle, dd1, dd2, dx1, dy1, dparams, batch_count))); + + host_vector rd1[batch_count]; + host_vector rd2[batch_count]; + host_vector rx1[batch_count]; + host_vector ry1[batch_count]; + host_vector rparams[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rd1[b] = host_vector(1); + rd2[b] = host_vector(1); + rx1[b] = host_vector(1); + ry1[b] = host_vector(1); + rparams[b] = host_vector(5); + CHECK_HIP_ERROR(hipMemcpy(rd1[b], bd1[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rd2[b], bd2[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rx1[b], bx1[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry1[b], by1[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(rparams[b], bparams[b], sizeof(T) * 5, hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, rd1, cd1); + unit_check_general(1, 1, batch_count, 1, rd2, cd2); + unit_check_general(1, 1, batch_count, 1, rx1, cx1); + unit_check_general(1, 1, batch_count, 1, ry1, cy1); + unit_check_general(1, 5, batch_count, 1, rparams, cparams); + } + + if(arg.norm_check) + { + norm_error_device = norm_check_general('F', 1, 1, batch_count, 1, rd1, cx1); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rd2, cd2); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rx1, cx1); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, ry1, cy1); + norm_error_device + += norm_check_general('F', 1, 5, batch_count, 1, rparams, cparams); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + device_vector dd1(batch_count); + device_vector dd2(batch_count); + device_vector dx1(batch_count); + device_vector dy1(batch_count); + device_vector dparams(batch_count); + device_batch_vector bd1(batch_count, 1); + device_batch_vector bd2(batch_count, 1); + device_batch_vector bx1(batch_count, 1); + device_batch_vector by1(batch_count, 1); + device_batch_vector bparams(batch_count, 5); + + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bd1[b], hd1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bd2[b], hd2[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bx1[b], hx1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by1[b], hy1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(bparams[b], hparams[b], sizeof(T) * 5, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dd1, bd1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dd2, bd2, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dx1, bx1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy1, by1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparams, bparams, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotmg_batched(handle, dd1, dd2, dx1, dy1, dparams, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotmg_batched(handle, dd1, dd2, dx1, dy1, dparams, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "rocblas-us,CPU-us"; + if(arg.norm_check) + std::cout << ",norm_error_host_ptr,norm_error_device"; + std::cout << std::endl; + + std::cout << gpu_time_used << "," << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host << ',' << norm_error_device; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rotmg_strided_batched.hpp b/clients/include/testing_rotmg_strided_batched.hpp new file mode 100644 index 000000000..127660796 --- /dev/null +++ b/clients/include/testing_rotmg_strided_batched.hpp @@ -0,0 +1,333 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotmg_strided_batched_bad_arg(const Arguments& arg) +{ + rocblas_int batch_count = 5; + static const size_t safe_size = 5; + + rocblas_local_handle handle; + device_vector d1(safe_size); + device_vector d2(safe_size); + device_vector x1(safe_size); + device_vector y1(safe_size); + device_vector param(safe_size); + + if(!d1 || !d2 || !x1 || !y1 || !param) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + nullptr, d1, 0, d2, 0, x1, 0, y1, 0, param, 0, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, nullptr, 0, d2, 0, x1, 0, y1, 0, param, 0, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, d1, 0, nullptr, 0, x1, 0, y1, 0, param, 0, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, d1, 0, d2, 0, nullptr, 0, y1, 0, param, 0, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, d1, 0, d2, 0, x1, 0, nullptr, 0, param, 0, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, d1, 0, d2, 0, x1, 0, y1, 0, nullptr, 0, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotmg_strided_batched(const Arguments& arg) +{ + const int TEST_COUNT = 100; + rocblas_int batch_count = arg.batch_count; + rocblas_int stride_d1 = arg.stride_a; + rocblas_int stride_d2 = arg.stride_b; + rocblas_int stride_x1 = arg.stride_x; + rocblas_int stride_y1 = arg.stride_y; + rocblas_int stride_param = arg.stride_c; + rocblas_local_handle handle; + + double gpu_time_used, cpu_time_used; + double norm_error_host = 0.0, norm_error_device = 0.0; + + // check to prevent undefined memory allocation error + if(batch_count <= 0) + { + size_t safe_size = 1; + device_vector d1(safe_size); + device_vector d2(safe_size); + device_vector x1(safe_size); + device_vector y1(safe_size); + device_vector params(safe_size); + + if(!d1 || !d2 || !x1 || !y1 || !params) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched(handle, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + params, + stride_param, + batch_count)), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rotmg_strided_batched(handle, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + params, + stride_param, + batch_count))); + + return; + } + + size_t size_d1 = batch_count * stride_d1; + size_t size_d2 = batch_count * stride_d2; + size_t size_x1 = batch_count * stride_x1; + size_t size_y1 = batch_count * stride_y1; + size_t size_param = batch_count * stride_param; + + // Initial Data on CPU + host_vector hd1(size_d1); + host_vector hd2(size_d2); + host_vector hx1(size_x1); + host_vector hy1(size_y1); + host_vector hparams(size_param); + + for(int i = 0; i < TEST_COUNT; i++) + { + rocblas_seedrand(); + rocblas_init(hparams, 1, 5, 1, stride_param, batch_count); + rocblas_init(hd1, 1, 1, 1, stride_d1, batch_count); + rocblas_init(hd2, 1, 1, 1, stride_d2, batch_count); + rocblas_init(hx1, 1, 1, 1, stride_x1, batch_count); + rocblas_init(hy1, 1, 1, 1, stride_y1, batch_count); + + host_vector cparams = hparams; + host_vector cd1 = hd1; + host_vector cd2 = hd2; + host_vector cx1 = hx1; + host_vector cy1 = hy1; + + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotmg(cd1 + b * stride_d1, + cd2 + b * stride_d2, + cx1 + b * stride_x1, + cy1 + b * stride_y1, + cparams + b * stride_param); + } + cpu_time_used = get_time_us() - cpu_time_used; + + // Test rocblas_pointer_mode_host + { + host_vector rd1 = hd1; + host_vector rd2 = hd2; + host_vector rx1 = hx1; + host_vector ry1 = hy1; + host_vector rparams = hparams; + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + + CHECK_ROCBLAS_ERROR((rocblas_rotmg_strided_batched(handle, + rd1, + stride_d1, + rd2, + stride_d2, + rx1, + stride_x1, + ry1, + stride_y1, + rparams, + stride_param, + batch_count))); + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, stride_d1, rd1, cd1); + unit_check_general(1, 1, batch_count, 1, stride_d2, rd2, cd2); + unit_check_general(1, 1, batch_count, 1, stride_x1, rx1, cx1); + unit_check_general(1, 1, batch_count, 1, stride_y1, ry1, cy1); + unit_check_general(1, 5, batch_count, 1, stride_param, rparams, cparams); + } + + if(arg.norm_check) + { + norm_error_host + = norm_check_general('F', 1, 1, 1, stride_d1, batch_count, rd1, cd1); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_d2, batch_count, rd2, cd2); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_x1, batch_count, rx1, cx1); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_y1, batch_count, ry1, cy1); + norm_error_host += norm_check_general( + 'F', 1, 5, 1, stride_param, batch_count, rparams, cparams); + } + } + + // Test rocblas_pointer_mode_device + { + device_vector dd1(size_d1); + device_vector dd2(size_d2); + device_vector dx1(size_x1); + device_vector dy1(size_y1); + device_vector dparams(size_param); + + CHECK_HIP_ERROR(hipMemcpy(dd1, hd1, sizeof(T) * size_d1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dd2, hd2, sizeof(T) * size_d2, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dx1, hx1, sizeof(T) * size_x1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy1, hy1, sizeof(T) * size_y1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparams, hparams, sizeof(T) * size_param, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_ROCBLAS_ERROR((rocblas_rotmg_strided_batched(handle, + dd1, + stride_d1, + dd2, + stride_d2, + dx1, + stride_x1, + dy1, + stride_y1, + dparams, + stride_param, + batch_count))); + + host_vector rd1(size_d1); + host_vector rd2(size_d2); + host_vector rx1(size_x1); + host_vector ry1(size_y1); + host_vector rparams(size_param); + + CHECK_HIP_ERROR(hipMemcpy(rd1, dd1, sizeof(T) * size_d1, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rd2, dd2, sizeof(T) * size_d2, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rx1, dx1, sizeof(T) * size_x1, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry1, dy1, sizeof(T) * size_y1, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(rparams, dparams, sizeof(T) * size_param, hipMemcpyDeviceToHost)); + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, stride_d1, rd1, cd1); + unit_check_general(1, 1, batch_count, 1, stride_d2, rd2, cd2); + unit_check_general(1, 1, batch_count, 1, stride_x1, rx1, cx1); + unit_check_general(1, 1, batch_count, 1, stride_y1, ry1, cy1); + unit_check_general(1, 5, batch_count, 1, stride_param, rparams, cparams); + } + + if(arg.norm_check) + { + norm_error_device + = norm_check_general('F', 1, 1, 1, stride_d1, batch_count, rd1, cd1); + norm_error_device + += norm_check_general('F', 1, 1, 1, stride_d2, batch_count, rd2, cd2); + norm_error_device + += norm_check_general('F', 1, 1, 1, stride_x1, batch_count, rx1, cx1); + norm_error_device + += norm_check_general('F', 1, 1, 1, stride_y1, batch_count, ry1, cy1); + norm_error_host += norm_check_general( + 'F', 1, 5, 1, stride_param, batch_count, rparams, cparams); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + device_vector dd1(size_d1); + device_vector dd2(size_d2); + device_vector dx1(size_x1); + device_vector dy1(size_y1); + device_vector dparams(size_param); + + CHECK_HIP_ERROR(hipMemcpy(dd1, hd1, sizeof(T) * size_d1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dd2, hd2, sizeof(T) * size_d2, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dx1, hx1, sizeof(T) * size_x1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy1, hy1, sizeof(T) * size_y1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * size_param, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotmg_strided_batched(handle, + dd1, + stride_d1, + dd2, + stride_d2, + dx1, + stride_x1, + dy1, + stride_y1, + dparams, + stride_param, + batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotmg_strided_batched(handle, + dd1, + stride_d1, + dd2, + stride_d2, + dx1, + stride_x1, + dy1, + stride_y1, + dparams, + stride_param, + batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "rocblas-us,CPU-us"; + if(arg.norm_check) + std::cout << ",norm_error_host_ptr,norm_error_device"; + std::cout << std::endl; + + std::cout << gpu_time_used << "," << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host << ',' << norm_error_device; + std::cout << std::endl; + } +} diff --git a/clients/include/type_dispatch.hpp b/clients/include/type_dispatch.hpp index 4717ddf64..d70b3af6d 100644 --- a/clients/include/type_dispatch.hpp +++ b/clients/include/type_dispatch.hpp @@ -49,7 +49,7 @@ auto rocblas_blas1_dispatch(const Arguments& arg) if(Tb == Ti) return rocblas_simple_dispatch(arg); else - { // for csscal and zdscal only + { // for csscal and zdscal and complex rotg only if(Ti == rocblas_datatype_f32_c && Tb == rocblas_datatype_f32_r) return TEST{}(arg); else if(Ti == rocblas_datatype_f64_c && Tb == rocblas_datatype_f64_r) @@ -60,6 +60,10 @@ auto rocblas_blas1_dispatch(const Arguments& arg) return TEST{}(arg); else if(Ti == rocblas_datatype_f64_c && Tb == rocblas_datatype_f64_r) return TEST{}(arg); + else if(Ti == rocblas_datatype_f32_r && Tb == rocblas_datatype_f32_r) + return TEST{}(arg); + else if(Ti == rocblas_datatype_f64_r && Tb == rocblas_datatype_f64_r) + return TEST{}(arg); // else if(Ti == rocblas_datatype_f16_c && To == rocblas_datatype_f16_r) // return TEST{}(arg); diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index be1dedd1c..f6b208612 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -1434,12 +1434,12 @@ ROCBLAS_EXPORT rocblas_status rocblas_izamin(rocblas_handle handl n rocblas_int number of elements in the x and y vectors. @param[inout] - x pointer storing vector x on the GPU. + x pointer storing vector x in device memory. @param[in] incx rocblas_int specifies the increment between elements of x. @param[inout] - y pointer storing vector y on the GPU. + y pointer storing vector y in device memory. @param[in] incy rocblas_int specifies the increment between elements of y. @@ -1504,6 +1504,208 @@ ROCBLAS_EXPORT rocblas_status rocblas_zdrot(rocblas_handle handle, const double* c, const double* s); +/*! \brief BLAS Level 1 API + + \details + rot_batched applies the Givens rotation matrix defined by c=cos(alpha) and s=sin(alpha) to batched vectors x and y. + Scalars c and s may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[in] + n rocblas_int + number of elements in the x and y vectors. + @param[inout] + x array of pointers storing vector x in device memory. + @param[in] + incx rocblas_int + specifies the increment between elements of x. + @param[inout] + y array of pointers storing vector y in device memory. + @param[in] + incy rocblas_int + specifies the increment between elements of y. + @param[in] + c scalar cosine component of the rotation matrix, may be stored in host or device memory. + @param[in] + s scalar sine component of the rotation matrix, may be stored in host or device memory. + @param[in] + batch_count rocblas_int + the number of x and y arrays, i.e. the number of batches. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srot_batched(rocblas_handle handle, + rocblas_int n, + float* const x[], + rocblas_int incx, + float* const y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drot_batched(rocblas_handle handle, + rocblas_int n, + double* const x[], + rocblas_int incx, + double* const y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_crot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* const x[], + rocblas_int incx, + rocblas_float_complex* const y[], + rocblas_int incy, + const float* c, + const rocblas_float_complex* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_csrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* const x[], + rocblas_int incx, + rocblas_float_complex* const y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* const x[], + rocblas_int incx, + rocblas_double_complex* const y[], + rocblas_int incy, + const double* c, + const rocblas_double_complex* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zdrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* const x[], + rocblas_int incx, + rocblas_double_complex* const y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count); + +/*! \brief BLAS Level 1 API + + \details + rot_strided_batched applies the Givens rotation matrix defined by c=cos(alpha) and s=sin(alpha) to strided batched vectors x and y. + Scalars c and s may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[in] + n rocblas_int + number of elements in the x and y vectors. + @param[inout] + x pointer storing strided vectors x in device memory. + @param[in] + incx rocblas_int + specifies the increment between elements of x. + @param[in] + stride_x rocblas_stride + specifies the increment from the beginning of x_i to the beginning of x_(i+1) + @param[inout] + y pointer storing strided vectors y in device memory. + @param[in] + incy rocblas_int + specifies the increment between elements of y. + @param[in] + stride_y rocblas_stride + specifies the increment from the beginning of y_i to the beginning of y_(i+1) + @param[in] + c scalar cosine component of the rotation matrix, may be stored in host or device memory. + @param[in] + s scalar sine component of the rotation matrix, may be stored in host or device memory. + @param[in] + batch_count rocblas_int + the number of x and y arrays, i.e. the number of batches. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srot_strided_batched(rocblas_handle handle, + rocblas_int n, + float* x, + rocblas_int incx, + rocblas_stride stride_x, + float* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const float* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drot_strided_batched(rocblas_handle handle, + rocblas_int n, + double* x, + rocblas_int incx, + rocblas_stride stride_x, + double* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const double* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_crot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_float_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const rocblas_float_complex* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_csrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_float_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const float* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_double_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const rocblas_double_complex* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zdrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_double_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const double* s, + rocblas_int batch_count); + /*! \brief BLAS Level 1 API \details @@ -1544,6 +1746,140 @@ ROCBLAS_EXPORT rocblas_status rocblas_zrotg(rocblas_handle handle, double* c, rocblas_double_complex* s); +/*! \brief BLAS Level 1 API + + \details + rotg_batched creates the Givens rotation matrix for the batched vectors (a b). + a, b, c, and s may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + If the pointer mode is set to rocblas_pointer_mode_host, this function blocks the CPU until the GPU has finished and the results are available in host memory. + If the pointer mode is set to rocblas_pointer_mode_device, this function returns immediately and synchronization is required to read the results. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[inout] + a batched array of single input vector elements, overwritten with r. + @param[inout] + b batched array of single input vector elements, overwritten with z. + @param[inout] + c batched array of cosine elements of Givens rotations. + @param[inout] + s batched array of sine elements of Givens rotations. + @param[in] + batch_count rocblas_int + number of batches (length of arrays a, b, c, and s). + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotg_batched(rocblas_handle handle, + float* const a[], + float* const b[], + float* const c[], + float* const s[], + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotg_batched(rocblas_handle handle, + double* const a[], + double* const b[], + double* const c[], + double* const s[], + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_crotg_batched(rocblas_handle handle, + rocblas_float_complex* const a[], + rocblas_float_complex* const b[], + float* const c[], + rocblas_float_complex* const s[], + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zrotg_batched(rocblas_handle handle, + rocblas_double_complex* const a[], + rocblas_double_complex* const b[], + double* const c[], + rocblas_double_complex* const s[], + rocblas_int batch_count); + +/*! \brief BLAS Level 1 API + + \details + rotg_strided_batched creates the Givens rotation matrix for the strided batched vectors (a b). + a, b, c, and s may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + If the pointer mode is set to rocblas_pointer_mode_host, this function blocks the CPU until the GPU has finished and the results are available in host memory. + If the pointer mode is set to rocblas_pointer_mode_device, this function returns immediately and synchronization is required to read the results. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[inout] + a strided_batched pointer to single input vector elements, overwritten with r. + @param[in] + stride_a rocblas_stride + distance between elements of a in batch (distance between a_i and a_(i + 1)) + @param[inout] + b strided_batched pointer to single input vector elements, overwritten with z. + @param[in] + stride_b rocblas_stride + distance between elements of b in batch (distance between b_i and b_(i + 1)) + @param[inout] + c strided_batched pointer to cosine elements of Givens rotations. + @param[in] + stride_c rocblas_stride + distance between elements of c in batch (distance between c_i and c_(i + 1)) + @param[inout] + s strided_batched pointer to sine elements of Givens rotations. + @param[in] + stride_s rocblas_stride + distance between elements of s in batch (distance between s_i and s_(i + 1)) + @param[in] + batch_count rocblas_int + number of batches (length of arrays a, b, c, and s). + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotg_strided_batched(rocblas_handle handle, + float* a, + rocblas_stride stride_a, + float* b, + rocblas_stride stride_b, + float* c, + rocblas_stride stride_c, + float* s, + rocblas_stride stride_s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotg_strided_batched(rocblas_handle handle, + double* a, + rocblas_stride stride_a, + double* b, + rocblas_stride stride_b, + double* c, + rocblas_stride stride_c, + double* s, + rocblas_stride stride_s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_crotg_strided_batched(rocblas_handle handle, + rocblas_float_complex* a, + rocblas_stride stride_a, + rocblas_float_complex* b, + rocblas_stride stride_b, + float* c, + rocblas_stride stride_c, + rocblas_float_complex* s, + rocblas_stride stride_s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zrotg_strided_batched(rocblas_handle handle, + rocblas_double_complex* a, + rocblas_stride stride_a, + rocblas_double_complex* b, + rocblas_stride stride_b, + double* c, + rocblas_stride stride_c, + rocblas_double_complex* s, + rocblas_stride stride_s, + rocblas_int batch_count); + /*! \brief BLAS Level 1 API \details @@ -1597,6 +1933,137 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, rocblas_int incy, const double* param); +/*! \brief BLAS Level 1 API + + \details + rotm_batched applies the modified Givens rotation matrix defined by param to batched vectors x and y. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[in] + n rocblas_int + number of elements in the x and y vectors. + @param[inout] + x array of pointers storing vectors x on the GPU. + @param[in] + incx rocblas_int + specifies the increment between elements of x. + @param[inout] + y array of pointers storing vectors y on the GPU. + @param[in] + incy rocblas_int + specifies the increment between elements of y. + @param[in] + param array of vectors of 5 elements defining the rotation. + param[0] = flag + param[1] = H11 + param[2] = H21 + param[3] = H12 + param[4] = H22 + The flag parameter defines the form of H: + flag = -1 => H = ( H11 H12 H21 H22 ) + flag = 0 => H = ( 1.0 H12 H21 1.0 ) + flag = 1 => H = ( H11 1.0 -1.0 H22 ) + flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) + param may ONLY be stored on the device for the batched version of this function. + @param[in] + batch_count rocblas_int + the number of x and y arrays, i.e. the number of batches. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotm_batched(rocblas_handle handle, + rocblas_int n, + float* const x[], + rocblas_int incx, + float* const y[], + rocblas_int incy, + const float* const param[], + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, + rocblas_int n, + double* const x[], + rocblas_int incx, + double* const y[], + rocblas_int incy, + const double* const param[], + rocblas_int batch_count); + +/*! \brief BLAS Level 1 API + + \details + rotm_strided_batched applies the modified Givens rotation matrix defined by param to strided batched vectors x and y. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[in] + n rocblas_int + number of elements in the x and y vectors. + @param[inout] + x pointers storing strided batched vectors x on the GPU. + @param[in] + incx rocblas_int + specifies the increment between elements of x. + @param[in] + stride_x rocblas_stride + specifies the increment between the beginning of x_i and x_(i + 1) + @param[inout] + y pointers storing strided batched vectors y on the GPU. + @param[in] + incy rocblas_int + specifies the increment between elements of y. + @param[in] + stride_y rocblas_stride + specifies the increment between the beginning of y_i and y_(i + 1) + @param[in] + param strided_batched array of vectors of 5 elements defining the rotation. + param[0] = flag + param[1] = H11 + param[2] = H21 + param[3] = H12 + param[4] = H22 + The flag parameter defines the form of H: + flag = -1 => H = ( H11 H12 H21 H22 ) + flag = 0 => H = ( 1.0 H12 H21 1.0 ) + flag = 1 => H = ( H11 1.0 -1.0 H22 ) + flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) + param may ONLY be stored on the device for the strided_batched version of this function. + @param[in] + stride_param rocblas_stride + specifies the increment between the beginning of param_i and param_(i + 1) + @param[in] + batch_count rocblas_int + the number of x and y arrays, i.e. the number of batches. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotm_strided_batched(rocblas_handle handle, + rocblas_int n, + float* x, + rocblas_int incx, + rocblas_stride stride_x, + float* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* param, + rocblas_stride stride_param, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handle, + rocblas_int n, + double* x, + rocblas_int incx, + rocblas_stride stride_x, + double* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* param, + rocblas_stride stride_param, + rocblas_int batch_count); + /*! \brief BLAS Level 1 API \details @@ -1629,6 +2096,9 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, flag = 1 => H = ( H11 1.0 -1.0 H22 ) flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + @param[in] + stride_param rocblas_stride + specifies the increment between the beginning of param_i and param_(i + 1) ********************************************************************/ @@ -1638,6 +2108,136 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotmg( ROCBLAS_EXPORT rocblas_status rocblas_drotmg( rocblas_handle handle, double* d1, double* d2, double* x1, const double* y1, double* param); +/*! \brief BLAS Level 1 API + + \details + rotmg_batched creates the modified Givens rotation matrix for the batched vectors (d1 * x1, d2 * y1). + Parameters may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + If the pointer mode is set to rocblas_pointer_mode_host, this function blocks the CPU until the GPU has finished and the results are available in host memory. + If the pointer mode is set to rocblas_pointer_mode_device, this function returns immediately and synchronization is required to read the results. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[inout] + d1 batched array of input scalars that is overwritten. + @param[inout] + d2 batched array of input scalars that is overwritten. + @param[inout] + x1 batched array of input scalars that is overwritten. + @param[in] + y1 batched array of input scalars. + @param[out] + param batched array of vectors of 5 elements defining the rotation. + param[0] = flag + param[1] = H11 + param[2] = H21 + param[3] = H12 + param[4] = H22 + The flag parameter defines the form of H: + flag = -1 => H = ( H11 H12 H21 H22 ) + flag = 0 => H = ( 1.0 H12 H21 1.0 ) + flag = 1 => H = ( H11 1.0 -1.0 H22 ) + flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) + param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + @param[in] + batch_count rocblas_int + the number of instances in the batch. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotmg_batched(rocblas_handle handle, + float* const d1[], + float* const d2[], + float* const x1[], + const float* const y1[], + float* const param[], + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, + double* const d1[], + double* const d2[], + double* const x1[], + const double* const y1[], + double* const param[], + rocblas_int batch_count); + +/*! \brief BLAS Level 1 API + + \details + rotmg_strided_batched creates the modified Givens rotation matrix for the batched vectors (d1 * x1, d2 * y1). + Parameters may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + If the pointer mode is set to rocblas_pointer_mode_host, this function blocks the CPU until the GPU has finished and the results are available in host memory. + If the pointer mode is set to rocblas_pointer_mode_device, this function returns immediately and synchronization is required to read the results. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[inout] + d1 batched array of input scalars that is overwritten. + @param[in] + stride_d1 rocblas_stride + specifies the increment between the beginning of d1_i and d1_(i+1) + @param[inout] + d2 batched array of input scalars that is overwritten. + @param[in] + stride_d2 rocblas_stride + specifies the increment between the beginning of d2_i and d2_(i+1) + @param[inout] + x1 batched array of input scalars that is overwritten. + @param[in] + stride_x1 rocblas_stride + specifies the increment between the beginning of x1_i and x1_(i+1) + @param[in] + y1 batched array of input scalars. + @param[in] + stride_y1 rocblas_stride + specifies the increment between the beginning of y1_i and y1_(i+1) + @param[out] + param batched array of vectors of 5 elements defining the rotation. + param[0] = flag + param[1] = H11 + param[2] = H21 + param[3] = H12 + param[4] = H22 + The flag parameter defines the form of H: + flag = -1 => H = ( H11 H12 H21 H22 ) + flag = 0 => H = ( 1.0 H12 H21 1.0 ) + flag = 1 => H = ( H11 1.0 -1.0 H22 ) + flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) + param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + @param[in] + batch_count rocblas_int + the number of instances in the batch. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotmg_strided_batched(rocblas_handle handle, + float* d1, + rocblas_stride stride_d1, + float* d2, + rocblas_stride stride_d2, + float* x1, + rocblas_stride stride_x1, + const float* y1, + rocblas_stride stride_y1, + float* param, + rocblas_stride stride_param, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotmg_strided_batched(rocblas_handle handle, + double* d1, + rocblas_stride stride_d1, + double* d2, + rocblas_stride stride_d2, + double* x1, + rocblas_stride stride_x1, + const double* y1, + rocblas_stride stride_y1, + double* param, + rocblas_stride stride_param, + rocblas_int batch_count); + /* * =========================================================================== * level 2 BLAS diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index 3d7deeed5..c01702ee6 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -142,9 +142,17 @@ set( rocblas_blas1_source blas1/rocblas_scal_strided_batched.cpp blas1/rocblas_swap.cpp blas1/rocblas_rot.cpp + blas1/rocblas_rot_batched.cpp + blas1/rocblas_rot_strided_batched.cpp blas1/rocblas_rotg.cpp + blas1/rocblas_rotg_batched.cpp + blas1/rocblas_rotg_strided_batched.cpp blas1/rocblas_rotm.cpp + blas1/rocblas_rotm_batched.cpp + blas1/rocblas_rotm_strided_batched.cpp blas1/rocblas_rotmg.cpp + blas1/rocblas_rotmg_batched.cpp + blas1/rocblas_rotmg_strided_batched.cpp blas1/rocblas_swap_batched.cpp blas1/rocblas_swap_strided_batched.cpp ) diff --git a/library/src/blas1/rocblas_rot.cpp b/library/src/blas1/rocblas_rot.cpp index ba4b18fb6..d7869b840 100644 --- a/library/src/blas1/rocblas_rot.cpp +++ b/library/src/blas1/rocblas_rot.cpp @@ -1,6 +1,7 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_rot.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" @@ -10,58 +11,6 @@ namespace { constexpr int NB = 512; - template , int>::type = 0> - __global__ void rot_kernel(rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - U c_device_host, - V s_device_host) - { - auto c = load_scalar(c_device_host); - auto s = load_scalar(s_device_host); - ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - - if(tid < n) - { - auto ix = tid * incx; - auto iy = tid * incy; - auto temp = c * x[ix] + s * y[iy]; - y[iy] = c * y[iy] - s * x[ix]; - x[ix] = temp; - } - } - - template , int>::type = 0> - __global__ void rot_kernel(rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - U c_device_host, - V s_device_host) - { - auto c = load_scalar(c_device_host); - auto s = load_scalar(s_device_host); - ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - - if(tid < n) - { - auto ix = tid * incx; - auto iy = tid * incy; - auto temp = c * x[ix] + s * y[iy]; - y[iy] = c * y[iy] - conj(s) * x[ix]; - x[ix] = temp; - } - } - template constexpr char rocblas_rot_name[] = "unknown"; template <> @@ -78,14 +27,14 @@ namespace constexpr char rocblas_rot_name[] = "rocblas_zdrot"; template - rocblas_status rocblas_rot(rocblas_handle handle, - rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - const U* c, - const V* s) + rocblas_status rocblas_rot_impl(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + T* y, + rocblas_int incy, + const U* c, + const V* s) { if(!handle) return rocblas_status_invalid_handle; @@ -95,8 +44,12 @@ namespace log_trace(handle, rocblas_rot_name, n, x, incx, y, incy, c, s); if(layer_mode & rocblas_layer_mode_log_bench) log_bench(handle, - "./rocblas-bench -f rot -r", + "./rocblas-bench -f rot --a_type", rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--c_type", + rocblas_precision_string, "-n", n, "--incx", @@ -111,22 +64,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - // Quick return if possible - if(n <= 0 || incx <= 0 || incy <= 0) - return rocblas_status_success; - - dim3 blocks((n - 1) / NB + 1); - dim3 threads(NB); - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(rocblas_pointer_mode_device == handle->pointer_mode) - hipLaunchKernelGGL( - rot_kernel, blocks, threads, 0, rocblas_stream, n, x, incx, y, incy, c, s); - else // c and s are on host - hipLaunchKernelGGL( - rot_kernel, blocks, threads, 0, rocblas_stream, n, x, incx, y, incy, *c, *s); - - return rocblas_status_success; + return rocblas_rot_template(handle, n, x, 0, incx, 0, y, 0, incy, 0, c, 0, s, 0, 1); } } // namespace @@ -148,7 +86,7 @@ rocblas_status rocblas_srot(rocblas_handle handle, const float* c, const float* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_drot(rocblas_handle handle, @@ -160,7 +98,7 @@ rocblas_status rocblas_drot(rocblas_handle handle, const double* c, const double* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_crot(rocblas_handle handle, @@ -172,7 +110,7 @@ rocblas_status rocblas_crot(rocblas_handle handle, const float* c, const rocblas_float_complex* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_csrot(rocblas_handle handle, @@ -184,7 +122,7 @@ rocblas_status rocblas_csrot(rocblas_handle handle, const float* c, const float* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_zrot(rocblas_handle handle, @@ -196,7 +134,7 @@ rocblas_status rocblas_zrot(rocblas_handle handle, const double* c, const rocblas_double_complex* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_zdrot(rocblas_handle handle, @@ -208,7 +146,7 @@ rocblas_status rocblas_zdrot(rocblas_handle handle, const double* c, const double* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } } // extern "C" diff --git a/library/src/blas1/rocblas_rot.hpp b/library/src/blas1/rocblas_rot.hpp new file mode 100644 index 000000000..6d7b0028f --- /dev/null +++ b/library/src/blas1/rocblas_rot.hpp @@ -0,0 +1,143 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "rocblas.h" +#include "utility.h" + +template , int>::type = 0> +__global__ void rot_kernel(rocblas_int n, + T2 x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T2 y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U c_device_host, + rocblas_stride c_stride, + V s_device_host, + rocblas_stride s_stride) +{ + auto c = load_scalar(c_device_host, hipBlockIdx_y, c_stride); + auto s = load_scalar(s_device_host, hipBlockIdx_y, s_stride); + auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); + auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); + ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + + if(tid < n) + { + auto ix = tid * incx; + auto iy = tid * incy; + auto temp = c * x[ix] + s * y[iy]; + y[iy] = c * y[iy] - s * x[ix]; + x[ix] = temp; + } +} + +template , int>::type = 0> +__global__ void rot_kernel(rocblas_int n, + T2 x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T2 y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U c_device_host, + rocblas_stride c_stride, + V s_device_host, + rocblas_stride s_stride) +{ + auto c = load_scalar(c_device_host, hipBlockIdx_y, c_stride); + auto s = load_scalar(s_device_host, hipBlockIdx_y, s_stride); + auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); + auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); + ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + + if(tid < n) + { + auto ix = tid * incx; + auto iy = tid * incy; + auto temp = c * x[ix] + s * y[iy]; + y[iy] = c * y[iy] - conj(s) * x[ix]; + x[ix] = temp; + } +} + +template +rocblas_status rocblas_rot_template(rocblas_handle handle, + rocblas_int n, + T2 x, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T2 y, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U* c, + rocblas_stride c_stride, + V* s, + rocblas_stride s_stride, + rocblas_int batch_count) +{ + // Quick return if possible + if(n <= 0 || incx <= 0 || incy <= 0 || batch_count == 0) + return rocblas_status_success; + + dim3 blocks((n - 1) / NB + 1, batch_count); + dim3 threads(NB); + hipStream_t rocblas_stream = handle->rocblas_stream; + + if(rocblas_pointer_mode_device == handle->pointer_mode) + hipLaunchKernelGGL(rot_kernel, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + c, + c_stride, + s, + s_stride); + else // c and s are on host + hipLaunchKernelGGL(rot_kernel, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + *c, + c_stride, + *s, + s_stride); + + return rocblas_status_success; +} diff --git a/library/src/blas1/rocblas_rot_batched.cpp b/library/src/blas1/rocblas_rot_batched.cpp new file mode 100644 index 000000000..e0c9d2bbf --- /dev/null +++ b/library/src/blas1/rocblas_rot_batched.cpp @@ -0,0 +1,173 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rot.hpp" +#include "utility.h" + +namespace +{ + constexpr int NB = 512; + + template + constexpr char rocblas_rot_name[] = "unknown"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_srot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_drot"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_crot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_zrot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_csrot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_zdrot_batched"; + + template + rocblas_status rocblas_rot_batched_impl(rocblas_handle handle, + rocblas_int n, + T* const x[], + rocblas_int incx, + T* const y[], + rocblas_int incy, + const U* c, + const V* s, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rot_name, n, x, incx, y, incy, c, s, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rot_batched --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--c_type", + rocblas_precision_string, + "-n", + n, + "--incx", + incx, + "--incy", + incy, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, + rocblas_rot_name, + "N", + n, + "incx", + incx, + "incy", + incy, + "batch", + batch_count); + + if(!x || !y || !c || !s) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rot_template( + handle, n, x, 0, incx, 0, y, 0, incy, 0, c, 0, s, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +rocblas_status rocblas_srot_batched(rocblas_handle handle, + rocblas_int n, + float* const x[], + rocblas_int incx, + float* const y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_drot_batched(rocblas_handle handle, + rocblas_int n, + double* const x[], + rocblas_int incx, + double* const y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_crot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* const x[], + rocblas_int incx, + rocblas_float_complex* const y[], + rocblas_int incy, + const float* c, + const rocblas_float_complex* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_csrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* const x[], + rocblas_int incx, + rocblas_float_complex* const y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_zrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* const x[], + rocblas_int incx, + rocblas_double_complex* const y[], + rocblas_int incy, + const double* c, + const rocblas_double_complex* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_zdrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* const x[], + rocblas_int incx, + rocblas_double_complex* const y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rot_strided_batched.cpp b/library/src/blas1/rocblas_rot_strided_batched.cpp new file mode 100644 index 000000000..65d7b5df8 --- /dev/null +++ b/library/src/blas1/rocblas_rot_strided_batched.cpp @@ -0,0 +1,214 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rot.hpp" +#include "utility.h" + +namespace +{ + constexpr int NB = 512; + + template + constexpr char rocblas_rot_name[] = "unknown"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_srot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_drot"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_crot_strided_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_zrot_strided_batched"; + template <> + constexpr char rocblas_rot_name[] + = "rocblas_csrot_strided_batched"; + template <> + constexpr char rocblas_rot_name[] + = "rocblas_zdrot_strided_batched"; + + template + rocblas_status rocblas_rot_strided_batched_impl(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + rocblas_stride stride_x, + T* y, + rocblas_int incy, + rocblas_stride stride_y, + const U* c, + const V* s, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, + rocblas_rot_name, + n, + x, + incx, + stride_x, + y, + incy, + stride_y, + c, + s, + batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rot_strided_batched --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--c_type", + rocblas_precision_string, + "-n", + n, + "--incx", + incx, + "--stride_x", + stride_x, + "--incy", + incy, + "--stride_y", + stride_y, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, + rocblas_rot_name, + "N", + n, + "incx", + incx, + "stride_x", + stride_x, + "incy", + incy, + "stride_y", + stride_y, + "batch", + batch_count); + + if(!x || !y || !c || !s) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rot_template( + handle, n, x, 0, incx, stride_x, y, 0, incy, stride_y, c, 0, s, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +rocblas_status rocblas_srot_strided_batched(rocblas_handle handle, + rocblas_int n, + float* x, + rocblas_int incx, + rocblas_stride stride_x, + float* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const float* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_drot_strided_batched(rocblas_handle handle, + rocblas_int n, + double* x, + rocblas_int incx, + rocblas_stride stride_x, + double* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const double* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_crot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_float_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const rocblas_float_complex* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_csrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_float_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const float* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_zrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_double_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const rocblas_double_complex* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_zdrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_double_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const double* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotg.cpp b/library/src/blas1/rocblas_rotg.cpp index d293a10cf..9884c9d90 100644 --- a/library/src/blas1/rocblas_rotg.cpp +++ b/library/src/blas1/rocblas_rotg.cpp @@ -1,6 +1,7 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_rotg.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" @@ -8,64 +9,6 @@ namespace { - template , int>::type = 0> - __device__ __host__ void rotg_calc(T& a, T& b, U& c, T& s) - { - T scale = rocblas_abs(a) + rocblas_abs(b); - if(scale == 0.0) - { - c = 1.0; - s = 0.0; - a = 0.0; - b = 0.0; - } - else - { - T sa = a / scale; - T sb = b / scale; - T r = scale * sqrt(sa * sa + sb * sb); - T roe = rocblas_abs(a) > rocblas_abs(b) ? a : b; - r = copysign(r, roe); - c = a / r; - s = b / r; - T z = 1.0; - if(rocblas_abs(a) > rocblas_abs(b)) - z = s; - if(rocblas_abs(b) >= rocblas_abs(a) && c != 0.0) - z = 1.0 / c; - a = r; - b = z; - } - } - - template , int>::type = 0> - __device__ __host__ void rotg_calc(T& a, T& b, U& c, T& s) - { - if(!rocblas_abs(a)) - { - c = 0; - s = {1, 0}; - a = b; - } - else - { - auto scale = rocblas_abs(a) + rocblas_abs(b); - auto sa = rocblas_abs(a / scale); - auto sb = rocblas_abs(b / scale); - auto norm = scale * sqrt(sa * sa + sb * sb); - auto alpha = a / rocblas_abs(a); - c = rocblas_abs(a) / norm; - s = alpha * conj(b) / norm; - a = alpha * norm; - } - } - - template - __global__ void rotg_kernel(T* a, T* b, U* c, T* s) - { - rotg_calc(*a, *b, *c, *s); - } - template constexpr char rocblas_rotg_name[] = "unknown"; template <> @@ -78,7 +21,7 @@ namespace constexpr char rocblas_rotg_name[] = "rocblas_zrotg"; template - rocblas_status rocblas_rotg(rocblas_handle handle, T* a, T* b, U* c, T* s) + rocblas_status rocblas_rotg_impl(rocblas_handle handle, T* a, T* b, U* c, T* s) { if(!handle) return rocblas_status_invalid_handle; @@ -87,7 +30,11 @@ namespace if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, rocblas_rotg_name, a, b, c, s); if(layer_mode & rocblas_layer_mode_log_bench) - log_bench(handle, "./rocblas-bench -f rotg -r", rocblas_precision_string); + log_bench(handle, + "./rocblas-bench -f rotg --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string); if(layer_mode & rocblas_layer_mode_log_profile) log_profile(handle, rocblas_rotg_name); @@ -96,19 +43,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(rocblas_pointer_mode_device == handle->pointer_mode) - { - hipLaunchKernelGGL(rotg_kernel, 1, 1, 0, rocblas_stream, a, b, c, s); - } - else - { - RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); - rotg_calc(*a, *b, *c, *s); - } - - return rocblas_status_success; + return rocblas_rotg_template(handle, a, 0, 0, b, 0, 0, c, 0, 0, s, 0, 0, 1); } } // namespace @@ -123,12 +58,12 @@ extern "C" { rocblas_status rocblas_srotg(rocblas_handle handle, float* a, float* b, float* c, float* s) { - return rocblas_rotg(handle, a, b, c, s); + return rocblas_rotg_impl(handle, a, b, c, s); } rocblas_status rocblas_drotg(rocblas_handle handle, double* a, double* b, double* c, double* s) { - return rocblas_rotg(handle, a, b, c, s); + return rocblas_rotg_impl(handle, a, b, c, s); } rocblas_status rocblas_crotg(rocblas_handle handle, @@ -137,7 +72,7 @@ rocblas_status rocblas_crotg(rocblas_handle handle, float* c, rocblas_float_complex* s) { - return rocblas_rotg(handle, a, b, c, s); + return rocblas_rotg_impl(handle, a, b, c, s); } rocblas_status rocblas_zrotg(rocblas_handle handle, @@ -146,7 +81,7 @@ rocblas_status rocblas_zrotg(rocblas_handle handle, double* c, rocblas_double_complex* s) { - return rocblas_rotg(handle, a, b, c, s); + return rocblas_rotg_impl(handle, a, b, c, s); } } // extern "C" diff --git a/library/src/blas1/rocblas_rotg.hpp b/library/src/blas1/rocblas_rotg.hpp new file mode 100644 index 000000000..e46da7b23 --- /dev/null +++ b/library/src/blas1/rocblas_rotg.hpp @@ -0,0 +1,139 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "utility.h" + +template , int>::type = 0> +__device__ __host__ void rocblas_rotg_calc(T& a, T& b, U& c, T& s) +{ + T scale = rocblas_abs(a) + rocblas_abs(b); + if(scale == 0.0) + { + c = 1.0; + s = 0.0; + a = 0.0; + b = 0.0; + } + else + { + T sa = a / scale; + T sb = b / scale; + T r = scale * sqrt(sa * sa + sb * sb); + T roe = rocblas_abs(a) > rocblas_abs(b) ? a : b; + r = copysign(r, roe); + c = a / r; + s = b / r; + T z = 1.0; + if(rocblas_abs(a) > rocblas_abs(b)) + z = s; + if(rocblas_abs(b) >= rocblas_abs(a) && c != 0.0) + z = 1.0 / c; + a = r; + b = z; + } +} + +template , int>::type = 0> +__device__ __host__ void rocblas_rotg_calc(T& a, T& b, U& c, T& s) +{ + if(!rocblas_abs(a)) + { + c = 0; + s = {1, 0}; + a = b; + } + else + { + auto scale = rocblas_abs(a) + rocblas_abs(b); + auto sa = rocblas_abs(a / scale); + auto sb = rocblas_abs(b / scale); + auto norm = scale * sqrt(sa * sa + sb * sb); + auto alpha = a / rocblas_abs(a); + c = rocblas_abs(a) / norm; + s = alpha * conj(b) / norm; + a = alpha * norm; + } +} + +template +__global__ void rocblas_rotg_kernel(T a_in, + rocblas_int offset_a, + rocblas_stride stride_a, + T b_in, + rocblas_int offset_b, + rocblas_stride stride_b, + U c_in, + rocblas_int offset_c, + rocblas_stride stride_c, + T s_in, + rocblas_int offset_s, + rocblas_stride stride_s) +{ + auto a = load_ptr_batch(a_in, hipBlockIdx_x, offset_a, stride_a); + auto b = load_ptr_batch(b_in, hipBlockIdx_x, offset_b, stride_b); + auto c = load_ptr_batch(c_in, hipBlockIdx_x, offset_c, stride_c); + auto s = load_ptr_batch(s_in, hipBlockIdx_x, offset_s, stride_s); + rocblas_rotg_calc(*a, *b, *c, *s); +} + +template +rocblas_status rocblas_rotg_template(rocblas_handle handle, + T a_in, + rocblas_int offset_a, + rocblas_stride stride_a, + T b_in, + rocblas_int offset_b, + rocblas_stride stride_b, + U c_in, + rocblas_int offset_c, + rocblas_stride stride_c, + T s_in, + rocblas_int offset_s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + if(!batch_count) + return rocblas_status_success; + + hipStream_t rocblas_stream = handle->rocblas_stream; + + if(rocblas_pointer_mode_device == handle->pointer_mode) + { + hipLaunchKernelGGL(rocblas_rotg_kernel, + batch_count, + 1, + 0, + rocblas_stream, + a_in, + offset_a, + stride_a, + b_in, + offset_b, + stride_b, + c_in, + offset_c, + stride_c, + s_in, + offset_s, + stride_s); + } + else + { + RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); + // TODO: make this faster for a large number of batches. + for(int i = 0; i < batch_count; i++) + { + auto a = load_ptr_batch(a_in, i, offset_a, stride_a); + auto b = load_ptr_batch(b_in, i, offset_b, stride_b); + auto c = load_ptr_batch(c_in, i, offset_c, stride_c); + auto s = load_ptr_batch(s_in, i, offset_s, stride_s); + + rocblas_rotg_calc(*a, *b, *c, *s); + } + } + + return rocblas_status_success; +} \ No newline at end of file diff --git a/library/src/blas1/rocblas_rotg_batched.cpp b/library/src/blas1/rocblas_rotg_batched.cpp new file mode 100644 index 000000000..fb4260cde --- /dev/null +++ b/library/src/blas1/rocblas_rotg_batched.cpp @@ -0,0 +1,108 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotg.hpp" +#include "utility.h" + +namespace +{ + template + constexpr char rocblas_rotg_name[] = "unknown"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_srotg_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_drotg_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_crotg_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_zrotg_batched"; + + template + rocblas_status rocblas_rotg_batched_impl(rocblas_handle handle, + T* const a[], + T* const b[], + U* const c[], + T* const s[], + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rotg_name, a, b, c, s, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotg_batched --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, rocblas_rotg_name, "batch", batch_count); + + if(!a || !b || !c || !s) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotg_template(handle, a, 0, 0, b, 0, 0, c, 0, 0, s, 0, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +rocblas_status rocblas_srotg_batched(rocblas_handle handle, + float* const a[], + float* const b[], + float* const c[], + float* const s[], + rocblas_int batch_count) +{ + return rocblas_rotg_batched_impl(handle, a, b, c, s, batch_count); +} + +rocblas_status rocblas_drotg_batched(rocblas_handle handle, + double* const a[], + double* const b[], + double* const c[], + double* const s[], + rocblas_int batch_count) +{ + return rocblas_rotg_batched_impl(handle, a, b, c, s, batch_count); +} + +rocblas_status rocblas_crotg_batched(rocblas_handle handle, + rocblas_float_complex* const a[], + rocblas_float_complex* const b[], + float* const c[], + rocblas_float_complex* const s[], + rocblas_int batch_count) +{ + return rocblas_rotg_batched_impl(handle, a, b, c, s, batch_count); +} + +rocblas_status rocblas_zrotg_batched(rocblas_handle handle, + rocblas_double_complex* const a[], + rocblas_double_complex* const b[], + double* const c[], + rocblas_double_complex* const s[], + rocblas_int batch_count) +{ + return rocblas_rotg_batched_impl(handle, a, b, c, s, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotg_strided_batched.cpp b/library/src/blas1/rocblas_rotg_strided_batched.cpp new file mode 100644 index 000000000..a50f47452 --- /dev/null +++ b/library/src/blas1/rocblas_rotg_strided_batched.cpp @@ -0,0 +1,133 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotg.hpp" +#include "utility.h" + +namespace +{ + template + constexpr char rocblas_rotg_name[] = "unknown"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_srotg_strided_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_drotg_strided_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_crotg_strided_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_zrotg_strided_batched"; + + template + rocblas_status rocblas_rotg_strided_batched_impl(rocblas_handle handle, + T* a, + rocblas_stride stride_a, + T* b, + rocblas_stride stride_b, + U* c, + rocblas_stride stride_c, + T* s, + rocblas_stride stride_s, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rotg_name, a, b, c, s, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotg_batched --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, rocblas_rotg_name, "batch", batch_count); + + if(!a || !b || !c || !s) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotg_template( + handle, a, 0, stride_a, b, 0, stride_b, c, 0, stride_c, s, 0, stride_s, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +rocblas_status rocblas_srotg_strided_batched(rocblas_handle handle, + float* a, + rocblas_stride stride_a, + float* b, + rocblas_stride stride_b, + float* c, + rocblas_stride stride_c, + float* s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + return rocblas_rotg_strided_batched_impl( + handle, a, stride_a, b, stride_b, c, stride_c, s, stride_s, batch_count); +} + +rocblas_status rocblas_drotg_strided_batched(rocblas_handle handle, + double* a, + rocblas_stride stride_a, + double* b, + rocblas_stride stride_b, + double* c, + rocblas_stride stride_c, + double* s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + return rocblas_rotg_strided_batched_impl( + handle, a, stride_a, b, stride_b, c, stride_c, s, stride_s, batch_count); +} + +rocblas_status rocblas_crotg_strided_batched(rocblas_handle handle, + rocblas_float_complex* a, + rocblas_stride stride_a, + rocblas_float_complex* b, + rocblas_stride stride_b, + float* c, + rocblas_stride stride_c, + rocblas_float_complex* s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + return rocblas_rotg_strided_batched_impl( + handle, a, stride_a, b, stride_b, c, stride_c, s, stride_s, batch_count); +} + +rocblas_status rocblas_zrotg_strided_batched(rocblas_handle handle, + rocblas_double_complex* a, + rocblas_stride stride_a, + rocblas_double_complex* b, + rocblas_stride stride_b, + double* c, + rocblas_stride stride_c, + rocblas_double_complex* s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + return rocblas_rotg_strided_batched_impl( + handle, a, stride_a, b, stride_b, c, stride_c, s, stride_s, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotm.cpp b/library/src/blas1/rocblas_rotm.cpp index 5ce51b266..96339e12a 100644 --- a/library/src/blas1/rocblas_rotm.cpp +++ b/library/src/blas1/rocblas_rotm.cpp @@ -1,6 +1,7 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_rotm.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" @@ -10,49 +11,6 @@ namespace { constexpr int NB = 512; - template - __global__ void rotm_kernel(rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - U flag_device_host, - U h11_device_host, - U h21_device_host, - U h12_device_host, - U h22_device_host) - { - auto flag = load_scalar(flag_device_host); - auto h11 = load_scalar(h11_device_host); - auto h21 = load_scalar(h21_device_host); - auto h12 = load_scalar(h12_device_host); - auto h22 = load_scalar(h22_device_host); - ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - - if(tid < n && flag != -2) - { - auto ix = tid * incx; - auto iy = tid * incy; - auto w = x[ix]; - auto z = y[iy]; - if(flag < 0) - { - x[ix] = w * h11 + z * h12; - y[iy] = w * h21 + z * h22; - } - else if(flag == 0) - { - x[ix] = w + z * h12; - y[iy] = w * h21 + z; - } - else - { - x[ix] = w * h11 + z; - y[iy] = -w + z * h22; - } - } - } - template constexpr char rocblas_rotm_name[] = "unknown"; template <> @@ -61,13 +19,13 @@ namespace constexpr char rocblas_rotm_name[] = "rocblas_drotm"; template - rocblas_status rocblas_rotm(rocblas_handle handle, - rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - const T* param) + rocblas_status rocblas_rotm_impl(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + T* y, + rocblas_int incy, + const T* param) { if(!handle) return rocblas_status_invalid_handle; @@ -93,50 +51,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - // Quick return if possible - if(n <= 0 || incx <= 0 || incy <= 0) - return rocblas_status_success; - if(rocblas_pointer_mode_host == handle->pointer_mode && param[0] == -2) - return rocblas_status_success; - - dim3 blocks((n - 1) / NB + 1); - dim3 threads(NB); - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(rocblas_pointer_mode_device == handle->pointer_mode) - hipLaunchKernelGGL(rotm_kernel, - blocks, - threads, - 0, - rocblas_stream, - n, - x, - incx, - y, - incy, - param, - param + 1, - param + 2, - param + 3, - param + 4); - else // c and s are on host - hipLaunchKernelGGL(rotm_kernel, - blocks, - threads, - 0, - rocblas_stream, - n, - x, - incx, - y, - incy, - param[0], - param[1], - param[2], - param[3], - param[4]); - - return rocblas_status_success; + return rocblas_rotm_template( + handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 0, 1); } } // namespace @@ -157,7 +73,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotm(rocblas_handle handle, rocblas_int incy, const float* param) { - return rocblas_rotm(handle, n, x, incx, y, incy, param); + return rocblas_rotm_impl(handle, n, x, incx, y, incy, param); } ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, @@ -168,7 +84,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, rocblas_int incy, const double* param) { - return rocblas_rotm(handle, n, x, incx, y, incy, param); + return rocblas_rotm_impl(handle, n, x, incx, y, incy, param); } } // extern "C" diff --git a/library/src/blas1/rocblas_rotm.hpp b/library/src/blas1/rocblas_rotm.hpp new file mode 100644 index 000000000..2d64e6b6d --- /dev/null +++ b/library/src/blas1/rocblas_rotm.hpp @@ -0,0 +1,210 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "utility.h" + +template +__device__ void rotm_kernel_calc(rocblas_int n, + T x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U flag, + U h11, + U h21, + U h12, + U h22) +{ + auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); + auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); + ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + + if(tid < n && flag != -2) + { + auto ix = tid * incx; + auto iy = tid * incy; + auto w = x[ix]; + auto z = y[iy]; + if(flag < 0) + { + x[ix] = w * h11 + z * h12; + y[iy] = w * h21 + z * h22; + } + else if(flag == 0) + { + x[ix] = w + z * h12; + y[iy] = w * h21 + z; + } + else + { + x[ix] = w * h11 + z; + y[iy] = -w + z * h22; + } + } +} + +template +__global__ void rotm_kernel_batched(rocblas_int n, + T x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U param, + rocblas_int offset_param, + rocblas_stride stride_param) +{ + auto p = load_ptr_batch(param, hipBlockIdx_y, offset_param, stride_param); + auto flag = p[0]; + auto h11 = p[1]; + auto h21 = p[2]; + auto h12 = p[3]; + auto h22 = p[4]; + rotm_kernel_calc(n, + x_in, + offset_x, + incx, + stride_x, + y_in, + offset_y, + incy, + stride_y, + flag, + h11, + h21, + h12, + h22); +} + +template +__global__ void rotm_kernel_regular(rocblas_int n, + T* x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T* y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + T flag, + T h11, + T h21, + T h12, + T h22) +{ + rotm_kernel_calc(n, + x_in, + offset_x, + incx, + stride_x, + y_in, + offset_y, + incy, + stride_y, + flag, + h11, + h21, + h12, + h22); +} + +// Workaround to avoid constexpr if - Helper function to quick return when param[0] == -2 +template +bool quick_return_param(rocblas_handle handle, const T* param, rocblas_stride stride_param) +{ + if(rocblas_pointer_mode_host == handle->pointer_mode) + if(param[0] == -2 && stride_param == 0) + return true; + return false; +} + +template +bool quick_return_param(rocblas_handle handle, const T* const param[], rocblas_stride stride_param) +{ + return false; +} + +template +rocblas_status rocblas_rotm_template(rocblas_handle handle, + rocblas_int n, + T x, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T y, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U param, + rocblas_int offset_param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + // Quick return if possible + if(n <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + return rocblas_status_success; + + if(quick_return_param(handle, param, stride_param)) + return rocblas_status_success; + + dim3 blocks((n - 1) / NB + 1, batch_count); + dim3 threads(NB); + hipStream_t rocblas_stream = handle->rocblas_stream; + + if(rocblas_pointer_mode_device == handle->pointer_mode) + hipLaunchKernelGGL(rotm_kernel_batched, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + param, + offset_param, + stride_param); + else if(!BATCHED_OR_STRIDED) + hipLaunchKernelGGL(rotm_kernel_regular, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + param[0], + param[1], + param[2], + param[3], + param[4]); + else // host mode not implemented for (strided_)batched functions + { + // TODO: if desired we can use a host for loop to iterate through + // batches in this scenario. Currently simply not implemented. + return rocblas_status_not_implemented; + } + + return rocblas_status_success; +} \ No newline at end of file diff --git a/library/src/blas1/rocblas_rotm_batched.cpp b/library/src/blas1/rocblas_rotm_batched.cpp new file mode 100644 index 000000000..64ed7bff8 --- /dev/null +++ b/library/src/blas1/rocblas_rotm_batched.cpp @@ -0,0 +1,106 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotm.hpp" +#include "utility.h" + +namespace +{ + constexpr int NB = 512; + + template + constexpr char rocblas_rotm_name[] = "unknown"; + template <> + constexpr char rocblas_rotm_name[] = "rocblas_srotm_batched"; + template <> + constexpr char rocblas_rotm_name[] = "rocblas_drotm_batched"; + + template + rocblas_status rocblas_rotm_batched_impl(rocblas_handle handle, + rocblas_int n, + T* const x[], + rocblas_int incx, + T* const y[], + rocblas_int incy, + const T* const param[], + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rotm_name, n, x, incx, y, incy, param, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotm_batched -r", + rocblas_precision_string, + "-n", + n, + "--incx", + incx, + "--incy", + incy, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, + rocblas_rotm_name, + "N", + n, + "incx", + incx, + "incy", + incy, + "batch", + batch_count); + + if(!x || !y || !param) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotm_template( + handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +ROCBLAS_EXPORT rocblas_status rocblas_srotm_batched(rocblas_handle handle, + rocblas_int n, + float* const x[], + rocblas_int incx, + float* const y[], + rocblas_int incy, + const float* const param[], + rocblas_int batch_count) +{ + return rocblas_rotm_batched_impl(handle, n, x, incx, y, incy, param, batch_count); +} + +ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, + rocblas_int n, + double* const x[], + rocblas_int incx, + double* const y[], + rocblas_int incy, + const double* const param[], + rocblas_int batch_count) +{ + return rocblas_rotm_batched_impl(handle, n, x, incx, y, incy, param, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotm_strided_batched.cpp b/library/src/blas1/rocblas_rotm_strided_batched.cpp new file mode 100644 index 000000000..0dbeeb97d --- /dev/null +++ b/library/src/blas1/rocblas_rotm_strided_batched.cpp @@ -0,0 +1,147 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotm.hpp" +#include "utility.h" + +namespace +{ + constexpr int NB = 512; + + template + constexpr char rocblas_rotm_name[] = "unknown"; + template <> + constexpr char rocblas_rotm_name[] = "rocblas_srotm_strided_batched"; + template <> + constexpr char rocblas_rotm_name[] = "rocblas_drotm_strided_batched"; + + template + rocblas_status rocblas_rotm_strided_batched_impl(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + rocblas_stride stride_x, + T* y, + rocblas_int incy, + rocblas_stride stride_y, + const T* param, + rocblas_stride stride_param, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, + rocblas_rotm_name, + n, + x, + incx, + stride_x, + y, + incy, + stride_y, + param, + batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotm_strided_batched -r", + rocblas_precision_string, + "-n", + n, + "--incx", + incx, + "--stride_x", + stride_x, + "--incy", + incy, + "--stride_y", + stride_y, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, + rocblas_rotm_name, + "N", + n, + "incx", + incx, + "stride_x", + stride_x, + "incy", + incy, + "stride_y", + stride_y, + "batch", + batch_count); + + if(!x || !y || !param) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotm_template(handle, + n, + x, + 0, + incx, + stride_x, + y, + 0, + incy, + stride_y, + param, + 0, + stride_param, + batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +ROCBLAS_EXPORT rocblas_status rocblas_srotm_strided_batched(rocblas_handle handle, + rocblas_int n, + float* x, + rocblas_int incx, + rocblas_stride stride_x, + float* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + return rocblas_rotm_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, param, stride_param, batch_count); +} + +ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handle, + rocblas_int n, + double* x, + rocblas_int incx, + rocblas_stride stride_x, + double* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + return rocblas_rotm_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, param, stride_param, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotmg.cpp b/library/src/blas1/rocblas_rotmg.cpp index f01aca7fa..e7a8e3e64 100644 --- a/library/src/blas1/rocblas_rotmg.cpp +++ b/library/src/blas1/rocblas_rotmg.cpp @@ -1,6 +1,7 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_rotmg.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" @@ -8,153 +9,6 @@ namespace { - template - __device__ __host__ void rotmg_calc(T& d1, T& d2, T& x1, const T& y1, T* param) - { - const T gam = 4096; - const T gamsq = gam * gam; - const T rgamsq = 1 / gamsq; - - T flag = -1; - T h11 = 0, h21 = 0, h12 = 0, h22 = 0; - - if(d1 < 0) - { - d1 = d2 = x1 = 0; - } - else - { - T p2 = d2 * y1; - if(p2 == 0) - { - flag = -2; - param[0] = flag; - return; - } - T p1 = d1 * x1; - T q2 = p2 * y1; - T q1 = p1 * x1; - if(rocblas_abs(q1) > rocblas_abs(q2)) - { - h21 = -y1 / x1; - h12 = p2 / p1; - T u = 1 - h12 * h21; - if(u > 0) - { - flag = 0; - d1 /= u; - d2 /= u; - x1 *= u; - } - } - else - { - if(q2 < 0) - { - d1 = d2 = x1 = 0; - } - else - { - flag = 1; - h11 = p1 / p2; - h22 = x1 / y1; - T u = 1 + h11 * h22; - T temp = d2 / u; - d2 = d1 / u; - d1 = temp; - x1 = y1 * u; - } - } - - if(d1 != 0) - { - while((d1 <= rgamsq) || (d1 >= gamsq)) - { - if(flag == 0) - { - h11 = h22 = 1; - flag = -1; - } - else - { - h21 = -1; - h12 = 1; - flag = -1; - } - if(d1 <= rgamsq) - { - d1 *= gamsq; - x1 /= gam; - h11 /= gam; - h12 /= gam; - } - else - { - d1 /= gamsq; - x1 *= gam; - h11 *= gam; - h12 *= gam; - } - } - } - - if(d2 != 0) - { - while((rocblas_abs(d2) <= rgamsq) || (rocblas_abs(d2) >= gamsq)) - { - if(flag == 0) - { - h11 = h22 = 1; - flag = -1; - } - else - { - h21 = -1; - h12 = 1; - flag = -1; - } - if(rocblas_abs(d2) <= rgamsq) - { - d2 *= gamsq; - h21 /= gam; - h22 /= gam; - } - else - { - d2 /= gamsq; - h21 *= gam; - h22 *= gam; - } - } - } - } - - if(flag < 0) - { - param[1] = h11; - param[2] = h21; - param[3] = h12; - param[4] = h22; - } - else if(flag == 0) - { - param[2] = h21; - param[3] = h12; - } - else - { - param[1] = h11; - param[4] = h22; - } - param[0] = flag; - } - - template - __global__ void rotmg_kernel(T* d1, T* d2, T* x1, const T* y1, T* param) - { - rotmg_calc(*d1, *d2, *x1, *y1, param); - } - template constexpr char rocblas_rotmg_name[] = "unknown"; template <> @@ -163,7 +17,8 @@ namespace constexpr char rocblas_rotmg_name[] = "rocblas_drotmg"; template - rocblas_status rocblas_rotmg(rocblas_handle handle, T* d1, T* d2, T* x1, const T* y1, T* param) + rocblas_status + rocblas_rotmg_impl(rocblas_handle handle, T* d1, T* d2, T* x1, const T* y1, T* param) { if(!handle) return rocblas_status_invalid_handle; @@ -172,7 +27,7 @@ namespace if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, rocblas_rotmg_name, d1, d2, x1, y1, param); if(layer_mode & rocblas_layer_mode_log_bench) - log_trace(handle, "./rocblas-bench -f rotmg -r", rocblas_precision_string); + log_bench(handle, "./rocblas-bench -f rotmg -r", rocblas_precision_string); if(layer_mode & rocblas_layer_mode_log_profile) log_profile(handle, rocblas_rotmg_name); @@ -181,19 +36,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(rocblas_pointer_mode_device == handle->pointer_mode) - { - hipLaunchKernelGGL(rotmg_kernel, 1, 1, 0, rocblas_stream, d1, d2, x1, y1, param); - } - else - { - RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); - rotmg_calc(*d1, *d2, *x1, *y1, param); - } - - return rocblas_status_success; + return rocblas_rotmg_template( + handle, d1, 0, 0, d2, 0, 0, x1, 0, 0, y1, 0, 0, param, 0, 0, 1); } } // namespace @@ -209,13 +53,13 @@ extern "C" { ROCBLAS_EXPORT rocblas_status rocblas_srotmg( rocblas_handle handle, float* d1, float* d2, float* x1, const float* y1, float* param) { - return rocblas_rotmg(handle, d1, d2, x1, y1, param); + return rocblas_rotmg_impl(handle, d1, d2, x1, y1, param); } ROCBLAS_EXPORT rocblas_status rocblas_drotmg( rocblas_handle handle, double* d1, double* d2, double* x1, const double* y1, double* param) { - return rocblas_rotmg(handle, d1, d2, x1, y1, param); + return rocblas_rotmg_impl(handle, d1, d2, x1, y1, param); } } // extern "C" diff --git a/library/src/blas1/rocblas_rotmg.hpp b/library/src/blas1/rocblas_rotmg.hpp new file mode 100644 index 000000000..49b1e2e31 --- /dev/null +++ b/library/src/blas1/rocblas_rotmg.hpp @@ -0,0 +1,240 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "utility.h" + +template +__device__ __host__ void rocblas_rotmg_calc(T& d1, T& d2, T& x1, const T& y1, T* param) +{ + const T gam = 4096; + const T gamsq = gam * gam; + const T rgamsq = 1 / gamsq; + + T flag = -1; + T h11 = 0, h21 = 0, h12 = 0, h22 = 0; + + if(d1 < 0) + { + d1 = d2 = x1 = 0; + } + else + { + T p2 = d2 * y1; + if(p2 == 0) + { + flag = -2; + param[0] = flag; + return; + } + T p1 = d1 * x1; + T q2 = p2 * y1; + T q1 = p1 * x1; + if(rocblas_abs(q1) > rocblas_abs(q2)) + { + h21 = -y1 / x1; + h12 = p2 / p1; + T u = 1 - h12 * h21; + if(u > 0) + { + flag = 0; + d1 /= u; + d2 /= u; + x1 *= u; + } + } + else + { + if(q2 < 0) + { + d1 = d2 = x1 = 0; + } + else + { + flag = 1; + h11 = p1 / p2; + h22 = x1 / y1; + T u = 1 + h11 * h22; + T temp = d2 / u; + d2 = d1 / u; + d1 = temp; + x1 = y1 * u; + } + } + + if(d1 != 0) + { + while((d1 <= rgamsq) || (d1 >= gamsq)) + { + if(flag == 0) + { + h11 = h22 = 1; + flag = -1; + } + else + { + h21 = -1; + h12 = 1; + flag = -1; + } + if(d1 <= rgamsq) + { + d1 *= gamsq; + x1 /= gam; + h11 /= gam; + h12 /= gam; + } + else + { + d1 /= gamsq; + x1 *= gam; + h11 *= gam; + h12 *= gam; + } + } + } + + if(d2 != 0) + { + while((rocblas_abs(d2) <= rgamsq) || (rocblas_abs(d2) >= gamsq)) + { + if(flag == 0) + { + h11 = h22 = 1; + flag = -1; + } + else + { + h21 = -1; + h12 = 1; + flag = -1; + } + if(rocblas_abs(d2) <= rgamsq) + { + d2 *= gamsq; + h21 /= gam; + h22 /= gam; + } + else + { + d2 /= gamsq; + h21 *= gam; + h22 *= gam; + } + } + } + } + + if(flag < 0) + { + param[1] = h11; + param[2] = h21; + param[3] = h12; + param[4] = h22; + } + else if(flag == 0) + { + param[2] = h21; + param[3] = h12; + } + else + { + param[1] = h11; + param[4] = h22; + } + param[0] = flag; +} + +template +__global__ void rocblas_rotmg_kernel(T d1_in, + rocblas_int offset_d1, + rocblas_stride stride_d1, + T d2_in, + rocblas_int offset_d2, + rocblas_stride stride_d2, + T x1_in, + rocblas_int offset_x1, + rocblas_stride stride_x1, + U y1_in, + rocblas_int offset_y1, + rocblas_stride stride_y1, + T param, + rocblas_int offset_param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + auto d1 = load_ptr_batch(d1_in, hipBlockIdx_x, offset_d1, stride_d1); + auto d2 = load_ptr_batch(d2_in, hipBlockIdx_x, offset_d2, stride_d2); + auto x1 = load_ptr_batch(x1_in, hipBlockIdx_x, offset_x1, stride_x1); + auto y1 = load_ptr_batch(y1_in, hipBlockIdx_x, offset_y1, stride_y1); + auto p = load_ptr_batch(param, hipBlockIdx_x, offset_param, stride_param); + rocblas_rotmg_calc(*d1, *d2, *x1, *y1, p); +} + +template +rocblas_status rocblas_rotmg_template(rocblas_handle handle, + T d1_in, + rocblas_int offset_d1, + rocblas_stride stride_d1, + T d2_in, + rocblas_int offset_d2, + rocblas_stride stride_d2, + T x1_in, + rocblas_int offset_x1, + rocblas_stride stride_x1, + U y1_in, + rocblas_int offset_y1, + rocblas_stride stride_y1, + T param, + rocblas_int offset_param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + if(!batch_count) + return rocblas_status_success; + + hipStream_t rocblas_stream = handle->rocblas_stream; + if(rocblas_pointer_mode_device == handle->pointer_mode) + { + hipLaunchKernelGGL(rocblas_rotmg_kernel, + batch_count, + 1, + 0, + rocblas_stream, + d1_in, + offset_d1, + stride_d1, + d2_in, + offset_d2, + stride_d2, + x1_in, + offset_x1, + stride_x1, + y1_in, + offset_y1, + stride_y1, + param, + offset_param, + stride_param, + batch_count); + } + else + { + RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); + // TODO: make this faster for a large number of batches. + for(int i = 0; i < batch_count; i++) + { + auto d1 = load_ptr_batch(d1_in, i, offset_d1, stride_d1); + auto d2 = load_ptr_batch(d2_in, i, offset_d2, stride_d2); + auto x1 = load_ptr_batch(x1_in, i, offset_x1, stride_x1); + auto y1 = load_ptr_batch(y1_in, i, offset_y1, stride_y1); + auto p = load_ptr_batch(param, i, offset_param, stride_param); + + rocblas_rotmg_calc(*d1, *d2, *x1, *y1, p); + } + } + + return rocblas_status_success; +} \ No newline at end of file diff --git a/library/src/blas1/rocblas_rotmg_batched.cpp b/library/src/blas1/rocblas_rotmg_batched.cpp new file mode 100644 index 000000000..dc0856322 --- /dev/null +++ b/library/src/blas1/rocblas_rotmg_batched.cpp @@ -0,0 +1,86 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotmg.hpp" +#include "utility.h" + +namespace +{ + template + constexpr char rocblas_rotmg_name[] = "unknown"; + template <> + constexpr char rocblas_rotmg_name[] = "rocblas_srotmg_batched"; + template <> + constexpr char rocblas_rotmg_name[] = "rocblas_drotmg_batched"; + + template + rocblas_status rocblas_rotmg_batched_impl(rocblas_handle handle, + T* const d1[], + T* const d2[], + T* const x1[], + const T* const y1[], + T* const param[], + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rotmg_name, d1, d2, x1, y1, param, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotmg_batched -r", + rocblas_precision_string, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, rocblas_rotmg_name, "batch", batch_count); + + if(!d1 || !d2 || !x1 || !y1 || !param) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotmg_template( + handle, d1, 0, 0, d2, 0, 0, x1, 0, 0, y1, 0, 0, param, 0, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +ROCBLAS_EXPORT rocblas_status rocblas_srotmg_batched(rocblas_handle handle, + float* const d1[], + float* const d2[], + float* const x1[], + const float* const y1[], + float* const param[], + rocblas_int batch_count) +{ + return rocblas_rotmg_batched_impl(handle, d1, d2, x1, y1, param, batch_count); +} + +ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, + double* const d1[], + double* const d2[], + double* const x1[], + const double* const y1[], + double* const param[], + rocblas_int batch_count) +{ + return rocblas_rotmg_batched_impl(handle, d1, d2, x1, y1, param, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotmg_strided_batched.cpp b/library/src/blas1/rocblas_rotmg_strided_batched.cpp new file mode 100644 index 000000000..0f63df43b --- /dev/null +++ b/library/src/blas1/rocblas_rotmg_strided_batched.cpp @@ -0,0 +1,157 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotmg.hpp" +#include "utility.h" + +namespace +{ + template + constexpr char rocblas_rotmg_name[] = "unknown"; + template <> + constexpr char rocblas_rotmg_name[] = "rocblas_srotmg_strided_batched"; + template <> + constexpr char rocblas_rotmg_name[] = "rocblas_drotmg_strided_batched"; + + template + rocblas_status rocblas_rotmg_strided_batched_impl(rocblas_handle handle, + T* d1, + rocblas_stride stride_d1, + T* d2, + rocblas_stride stride_d2, + T* x1, + rocblas_stride stride_x1, + const T* y1, + rocblas_stride stride_y1, + T* param, + rocblas_stride stride_param, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, + rocblas_rotmg_name, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + param, + batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotmg_strided_batched -r", + rocblas_precision_string, + "--batch", + batch_count, + "--stride_a", + stride_d1, + "--stride_b", + stride_d2, + "--stride_x", + stride_x1, + "--stride_y", + stride_y1); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, rocblas_rotmg_name, "batch", batch_count); + + if(!d1 || !d2 || !x1 || !y1 || !param) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotmg_template(handle, + d1, + 0, + stride_d1, + d2, + 0, + stride_d2, + x1, + 0, + stride_x1, + y1, + 0, + stride_y1, + param, + 0, + stride_param, + batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +ROCBLAS_EXPORT rocblas_status rocblas_srotmg_strided_batched(rocblas_handle handle, + float* d1, + rocblas_stride stride_d1, + float* d2, + rocblas_stride stride_d2, + float* x1, + rocblas_stride stride_x1, + const float* y1, + rocblas_stride stride_y1, + float* param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + return rocblas_rotmg_strided_batched_impl(handle, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + param, + stride_param, + batch_count); +} + +ROCBLAS_EXPORT rocblas_status rocblas_drotmg_strided_batched(rocblas_handle handle, + double* d1, + rocblas_stride stride_d1, + double* d2, + rocblas_stride stride_d2, + double* x1, + rocblas_stride stride_x1, + const double* y1, + rocblas_stride stride_y1, + double* param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + return rocblas_rotmg_strided_batched_impl(handle, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + param, + stride_param, + batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_scal.hpp b/library/src/blas1/rocblas_scal.hpp index 69b20ce33..4e5a94881 100644 --- a/library/src/blas1/rocblas_scal.hpp +++ b/library/src/blas1/rocblas_scal.hpp @@ -39,7 +39,7 @@ rocblas_status rocblas_scal_template(rocblas_handle handle, // outside of rocblas) if(handle->is_device_memory_size_query()) { - if(stride_alpha && rocblas_pointer_mode_device == handle->pointer_mode && n > 0 && incx > 0 + if(stride_alpha && rocblas_pointer_mode_host == handle->pointer_mode && n > 0 && incx > 0 && batch_count > 0) return handle->set_optimal_device_memory_size(sizeof(V) * batch_count * stride_alpha); else