From 00519f334b0b806302615b8d900427bbe5676f6f Mon Sep 17 00:00:00 2001 From: First Last Date: Mon, 30 Sep 2019 15:10:50 -0600 Subject: [PATCH 01/11] Added rot_batched and rot_strided_batched --- clients/benchmarks/client.cpp | 48 +++- clients/common/rocblas_gentest.py | 2 +- clients/gtest/blas1_gtest.cpp | 33 ++- clients/gtest/blas1_gtest.yaml | 10 + clients/include/rocblas.hpp | 77 +++++ clients/include/testing_rot_batched.hpp | 265 ++++++++++++++++++ .../include/testing_rot_strided_batched.hpp | 234 ++++++++++++++++ clients/include/testing_rotg_batched.hpp | 0 .../include/testing_rotg_strided_batched.hpp | 0 clients/include/testing_rotm_batched.hpp | 0 .../include/testing_rotm_strided_batched.hpp | 0 library/include/rocblas-functions.h | 202 +++++++++++++ library/src/CMakeLists.txt | 2 + library/src/blas1/rocblas_rot.cpp | 104 ++----- library/src/blas1/rocblas_rot.hpp | 143 ++++++++++ library/src/blas1/rocblas_rot_batched.cpp | 173 ++++++++++++ .../src/blas1/rocblas_rot_strided_batched.cpp | 214 ++++++++++++++ library/src/include/utility.h | 8 + 18 files changed, 1423 insertions(+), 92 deletions(-) create mode 100644 clients/include/testing_rot_batched.hpp create mode 100644 clients/include/testing_rot_strided_batched.hpp create mode 100644 clients/include/testing_rotg_batched.hpp create mode 100644 clients/include/testing_rotg_strided_batched.hpp create mode 100644 clients/include/testing_rotm_batched.hpp create mode 100644 clients/include/testing_rotm_strided_batched.hpp create mode 100644 library/src/blas1/rocblas_rot.hpp create mode 100644 library/src/blas1/rocblas_rot_batched.cpp create mode 100644 library/src/blas1/rocblas_rot_strided_batched.cpp diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 06fcf8644..681859afd 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -21,6 +21,8 @@ #include "testing_nrm2_batched.hpp" #include "testing_nrm2_strided_batched.hpp" #include "testing_rot.hpp" +#include "testing_rot_batched.hpp" +#include "testing_rot_strided_batched.hpp" #include "testing_rotg.hpp" #include "testing_rotm.hpp" #include "testing_rotmg.hpp" @@ -174,8 +176,6 @@ struct perf_blas< testing_set_get_vector(arg); else if(!strcmp(arg.function, "set_get_matrix")) testing_set_get_matrix(arg); - else if(!strcmp(arg.function, "rot")) - testing_rot(arg); else if(!strcmp(arg.function, "rotg")) testing_rotg(arg); else if(!strcmp(arg.function, "rotm")) @@ -324,6 +324,47 @@ struct perf_blas +struct perf_blas_rot : rocblas_test_invalid +{ +}; + +template +struct perf_blas_rot< + Ti, + To, + Tc, + typename std::enable_if<( + (std::is_same{} && std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{} + && std::is_same{}) + || (std::is_same{} && std::is_same{} + && std::is_same{}) + || (std::is_same{} && std::is_same{} + && std::is_same{}) + || (std::is_same{} && std::is_same{} + && std::is_same{}))>::type> +{ + explicit operator bool() + { + return true; + } + + void operator()(const Arguments& arg) + { + if(!strcmp(arg.function, "rot")) + testing_rot(arg); + else if(!strcmp(arg.function, "rot_batched")) + testing_rot_batched(arg); + else if(!strcmp(arg.function, "rot_strided_batched")) + testing_rot_strided_batched(arg); + else + throw std::invalid_argument("Invalid combination --function "s + arg.function + + " --a_type "s + rocblas_datatype2string(arg.a_type)); + } +}; + template struct perf_blas_scal : rocblas_test_invalid { @@ -521,6 +562,9 @@ int run_bench_test(Arguments& arg) if(!strcmp(function, "scal") || !strcmp(function, "scal_batched") || !strcmp(function, "scal_strided_batched")) rocblas_blas1_dispatch(arg); + else if(!strcmp(function, "rot") || !strcmp(function, "rot_batched") + || !strcmp(function, "rot_strided_batched")) + rocblas_blas1_dispatch(arg); else rocblas_simple_dispatch(arg); } diff --git a/clients/common/rocblas_gentest.py b/clients/common/rocblas_gentest.py index cffa95f1f..140cc7e85 100755 --- a/clients/common/rocblas_gentest.py +++ b/clients/common/rocblas_gentest.py @@ -198,7 +198,7 @@ def setdefaults(test): if test['function'] in ('asum_strided_batched', 'nrm2_strided_batched', 'scal_strided_batched', 'swap_strided_batched', - 'copy_strided_batched'): + 'copy_strided_batched', 'rot_strided_batched'): if all([x in test for x in ('N', 'incx', 'stride_scale')]): test.setdefault('stride_x', int(test['N'] * abs(test['incx']) * test['stride_scale'])) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index dd95a7a8a..fc8994e96 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -16,8 +16,14 @@ #include "testing_nrm2_batched.hpp" #include "testing_nrm2_strided_batched.hpp" #include "testing_rot.hpp" +#include "testing_rot_batched.hpp" +#include "testing_rot_strided_batched.hpp" #include "testing_rotg.hpp" +#include "testing_rotg_batched.hpp" +#include "testing_rotg_strided_batched.hpp" #include "testing_rotm.hpp" +#include "testing_rotm_batched.hpp" +#include "testing_rotm_strided_batched.hpp" #include "testing_rotmg.hpp" #include "testing_scal.hpp" #include "testing_scal_batched.hpp" @@ -53,8 +59,14 @@ namespace swap_batched, swap_strided_batched, rot, + rot_batched, + rot_strided_batched, rotg, + rotg_batched, + rotg_strided_batched, rotm, + rotm_batched, + rotm_strided_batched, rotmg, }; @@ -89,17 +101,21 @@ namespace || BLAS1 == blas1::scal_strided_batched); bool is_batched = (BLAS1 == blas1::nrm2_batched || BLAS1 == blas1::asum_batched || BLAS1 == blas1::scal_batched || BLAS1 == blas1::swap_batched - || BLAS1 == blas1::copy_batched); + || BLAS1 == blas1::copy_batched || BLAS1 == blas1::rot_batched); bool is_strided = (BLAS1 == blas1::nrm2_strided_batched || BLAS1 == blas1::asum_strided_batched || BLAS1 == blas1::scal_strided_batched || BLAS1 == blas1::swap_strided_batched - || BLAS1 == blas1::copy_strided_batched); + || BLAS1 == blas1::copy_strided_batched + || BLAS1 == blas1::rot_strided_batched); - if((is_scal || BLAS1 == blas1::rot || BLAS1 == blas1::rotg) + if((is_scal || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotg) && arg.a_type != arg.b_type) name << '_' << rocblas_datatype2string(arg.b_type); - if(BLAS1 == blas1::rot && arg.compute_type != arg.a_type) + if((BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched) + && arg.compute_type != arg.a_type) name << '_' << rocblas_datatype2string(arg.compute_type); name << '_' << arg.N; @@ -118,12 +134,14 @@ namespace || BLAS1 == blas1::copy_strided_batched || BLAS1 == blas1::copy_batched || BLAS1 == blas1::dot || BLAS1 == blas1::swap || BLAS1 == blas1::swap_batched || BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::rot + || BLAS1 == blas1::rot_batched || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotm) { name << '_' << arg.incy; } - if(BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::copy_strided_batched) + if(BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::copy_strided_batched + || BLAS1 == blas1::rot_strided_batched) { name << '_' << arg.stride_y; } @@ -204,7 +222,8 @@ namespace || std::is_same{} || std::is_same{})) - || (BLAS1 == blas1::rot + || ((BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched) && ((std::is_same{} && std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{} && std::is_same{}) @@ -300,6 +319,8 @@ BLAS1_TESTING(swap, ARG1) BLAS1_TESTING(swap_batched, ARG1) BLAS1_TESTING(swap_strided_batched, ARG1) BLAS1_TESTING(rot, ARG3) +BLAS1_TESTING(rot_batched, ARG3) +BLAS1_TESTING(rot_strided_batched, ARG3) BLAS1_TESTING(rotg, ARG2) BLAS1_TESTING(rotm, ARG1) BLAS1_TESTING(rotmg, ARG1) diff --git a/clients/gtest/blas1_gtest.yaml b/clients/gtest/blas1_gtest.yaml index 3378b020e..cf0676fe6 100644 --- a/clients/gtest/blas1_gtest.yaml +++ b/clients/gtest/blas1_gtest.yaml @@ -64,6 +64,7 @@ Tests: - asum_batched: *single_double_precisions_complex_real - nrm2_batched: *single_double_precisions_complex_real - copy_batched: *single_double_precisions_complex_real + - rot_batched: *rot_precisions - name: blas1_strided_batched category: quick @@ -79,6 +80,7 @@ Tests: - asum_strided_batched: *single_double_precisions_complex_real - nrm2_strided_batched: *single_double_precisions_complex_real - copy_strided_batched: *single_double_precisions_complex_real + - rot_strided_batched: *rot_precisions - name: blas1 @@ -114,6 +116,7 @@ Tests: - asum_batched: *double_precision_complex_real - nrm2_batched: *double_precision_complex_real - copy_batched: *single_double_precisions_complex_real + - rot_batched: *rot_precisions - name: blas1_strided_batched category: pre_checkin @@ -129,6 +132,7 @@ Tests: - asum_strided_batched: *double_precision_complex_real - nrm2_strided_batched: *double_precision_complex_real - copy_strided_batched: *single_double_precisions_complex_real + - rot_strided_batched: *rot_precisions - name: blas1 category: nightly @@ -161,6 +165,7 @@ Tests: - scal_batched: *single_double_precisions_complex_real - scal_batched: *single_double_complex_real_in_complex_out - copy_batched: *single_double_precisions_complex_real + - rot_batched: *rot_precisions - name: blas1_batched category: nightly @@ -173,6 +178,7 @@ Tests: - scal_batched: *single_double_precisions_complex_real - scal_batched: *single_double_complex_real_in_complex_out - copy_batched: *single_double_precisions_complex_real + - rot_batched: *rot_precisions - name: blas1_strided_batched category: nightly @@ -186,6 +192,7 @@ Tests: - scal_strided_batched: *single_double_precisions_complex_real - scal_strided_batched: *single_double_complex_real_in_complex_out - copy_strided_batched: *single_double_precisions_complex_real + - rot_strided_batched: *rot_precisions - name: blas1_strided_batched category: nightly @@ -199,6 +206,7 @@ Tests: - scal_strided_batched: *single_double_precisions_complex_real - scal_strided_batched: *single_double_complex_real_in_complex_out - copy_strided_batched: *single_double_precisions_complex_real + - rot_strided_batched: *rot_precisions - name: blas1 @@ -314,6 +322,7 @@ Tests: - scal_batched_bad_arg: *single_double_precisions_complex_real - scal_batched_bad_arg: *single_double_complex_real_in_complex_out - copy_batched_bad_arg: *single_double_precisions_complex_real + - rot_batched_bad_arg: *rot_precisions - name: blas1_strided_batched_bad_arg category: pre_checkin @@ -323,5 +332,6 @@ Tests: - scal_strided_batched_bad_arg: *single_double_precisions_complex_real - scal_strided_batched_bad_arg: *single_double_complex_real_in_complex_out - copy_strided_batched_bad_arg: *single_double_precisions_complex_real + - rot_strided_batched_bad_arg: *rot_precisions ... diff --git a/clients/include/rocblas.hpp b/clients/include/rocblas.hpp index 6894e5e3e..2d94cdcd4 100644 --- a/clients/include/rocblas.hpp +++ b/clients/include/rocblas.hpp @@ -510,6 +510,83 @@ static constexpr auto template <> static constexpr auto rocblas_rot = rocblas_zdrot; +// rot_batched +template +rocblas_status (*rocblas_rot_batched)(rocblas_handle handle, + rocblas_int n, + T* x[], + rocblas_int incx, + T* y[], + rocblas_int incy, + const U* c, + const V* s, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rot_batched = rocblas_srot_batched; + +template <> +static constexpr auto rocblas_rot_batched = rocblas_drot_batched; + +template <> +static constexpr auto + rocblas_rot_batched = rocblas_crot_batched; + +template <> +static constexpr auto + rocblas_rot_batched = rocblas_csrot_batched; + +template <> +static constexpr auto rocblas_rot_batched = rocblas_zrot_batched; + +template <> +static constexpr auto + rocblas_rot_batched = rocblas_zdrot_batched; + +// rot_strided_batched +template +rocblas_status (*rocblas_rot_strided_batched)(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + rocblas_stride stride_x, + T* y, + rocblas_int incy, + rocblas_stride stride_y, + const U* c, + const V* s, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rot_strided_batched = rocblas_srot_strided_batched; + +template <> +static constexpr auto rocblas_rot_strided_batched = rocblas_drot_strided_batched; + +template <> +static constexpr auto + rocblas_rot_strided_batched = rocblas_crot_strided_batched; + +template <> +static constexpr auto rocblas_rot_strided_batched = rocblas_csrot_strided_batched; + +template <> +static constexpr auto + rocblas_rot_strided_batched = rocblas_zrot_strided_batched; + +template <> +static constexpr auto rocblas_rot_strided_batched = rocblas_zdrot_strided_batched; + // rotg template rocblas_status (*rocblas_rotg)(rocblas_handle handle, T* a, T* b, U* c, T* s); diff --git a/clients/include/testing_rot_batched.hpp b/clients/include/testing_rot_batched.hpp new file mode 100644 index 000000000..37dcc7587 --- /dev/null +++ b/clients/include/testing_rot_batched.hpp @@ -0,0 +1,265 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rot_batched_bad_arg(const Arguments& arg) +{ + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_int incy = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; + + rocblas_local_handle handle; + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(nullptr, N, dx, incx, dy, incy, dc, ds, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, nullptr, incx, dy, incy, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, dx, incx, nullptr, incy, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, nullptr, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rot_batched(const Arguments& arg) +{ + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host_x = 0.0, norm_error_host_y = 0.0, norm_error_device_x = 0.0, + norm_error_device_y = 0.0; + + // check to prevent undefined memory allocation error + if(N <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + { + device_vector dx(1); + device_vector dy(1); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count)), + rocblas_status_invalid_size); + if(!batch_count) + CHECK_ROCBLAS_ERROR( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count))); + return; + } + + size_t size_x = N * size_t(incx); + size_t size_y = N * size_t(incy); + + device_vector dx(batch_count); + device_vector dy(batch_count); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Initial Data on CPU + host_vector hx[batch_count]; //(size_x); + host_vector hy[batch_count]; //(size_y); + host_vector hc(1); + host_vector hs(1); + + device_batch_vector bx(batch_count, size_x); + device_batch_vector by(batch_count, size_y); + + for(int i = 0; i < batch_count; i++) + { + hx[i] = host_vector(size_x); + hy[i] = host_vector(size_y); + } + + int last = batch_count - 1; + if((!bx[last] && size_x) || (!by[last] && size_y)) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + rocblas_seedrand(); + for(int b = 0; b < batch_count; b++) + { + rocblas_init(hx[b], 1, N, incx); + rocblas_init(hy[b], 1, N, incy); + } + rocblas_init(hc, 1, 1, 1); + rocblas_init(hs, 1, 1, 1); + + // CPU BLAS reference data + host_vector cx[batch_count]; + host_vector cy[batch_count]; + for(int b = 0; b < batch_count; b++) + { + cx[b] = hx[b]; + cy[b] = hy[b]; + } + // cblas_rotg(cx, cy, hc, hs); + // cx[0] = hx[0]; + // cy[0] = hy[0]; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rot(N, cx[b], incx, cy[b], incy, hc, hs); + } + cpu_time_used = get_time_us() - cpu_time_used; + + if(arg.unit_check || arg.norm_check) + { + // Test rocblas_pointer_mode_host + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, hc, hs, batch_count))); + + host_vector rx[batch_count]; + host_vector ry[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rx[b] = host_vector(size_x); + ry[b] = host_vector(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + unit_check_general(1, N, batch_count, incx, cx, rx); + unit_check_general(1, N, batch_count, incy, cy, ry); + } + if(arg.norm_check) + { + norm_error_host_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + norm_error_host_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + } + } + + // Test rocblas_pointer_mode_device + // { + // CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + // for(int b = 0; b < batch_count; b++) + // { + // CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + // } + // CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + // CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U), hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(V), hipMemcpyHostToDevice)); + + // CHECK_ROCBLAS_ERROR((rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count))); + + // host_vector rx[batch_count]; + // host_vector ry[batch_count]; + // for(int b = 0; b < batch_count; b++) + // { + // rx[b] = host_vector(size_x); + // ry[b] = host_vector(size_y); + // CHECK_HIP_ERROR(hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + // CHECK_HIP_ERROR(hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + // } + + // if(arg.unit_check) + // { + // unit_check_general(1, N, batch_count, incx, cx, rx); + // unit_check_general(1, N, batch_count, incy, cy, ry); + // } + // if(arg.norm_check) + // { + // norm_error_device_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + // norm_error_device_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + // } + // } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rot_batched(handle, N, dx, incx, dy, incy, hc, hs, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rot_batched(handle, N, dx, incx, dy, incy, hc, hs, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "N,incx,incy,rocblas(us),cpu(us)"; + if(arg.norm_check) + std::cout + << ",norm_error_host_x,norm_error_host_y,norm_error_device_x,norm_error_device_y"; + std::cout << std::endl; + std::cout << N << "," << incx << "," << incy << "," << gpu_time_used << "," + << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host_x << ',' << norm_error_host_y << "," + << norm_error_device_x << "," << norm_error_device_y; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rot_strided_batched.hpp b/clients/include/testing_rot_strided_batched.hpp new file mode 100644 index 000000000..babcf2392 --- /dev/null +++ b/clients/include/testing_rot_strided_batched.hpp @@ -0,0 +1,234 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rot_strided_batched_bad_arg(const Arguments& arg) +{ + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_stride stride_x = 1; + rocblas_int incy = 1; + rocblas_stride stride_y = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; + + rocblas_local_handle handle; + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + nullptr, N, dx, incx, stride_x, dy, incy, stride_y, dc, ds, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + handle, N, nullptr, incx, stride_x, dy, incy, stride_y, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, nullptr, incy, stride_y, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, nullptr, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, dc, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rot_strided_batched(const Arguments& arg) +{ + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int stride_x = arg.stride_x; + rocblas_int stride_y = arg.stride_y; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host_x = 0.0, norm_error_host_y = 0.0, norm_error_device_x = 0.0, + norm_error_device_y = 0.0; + + // check to prevent undefined memory allocation error + if(N <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + { + static const size_t safe_size = 100; // arbitrarily set to 100 + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rot_strided_batched)(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dc, + ds, + batch_count), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, dc, ds, batch_count))); + return; + } + + size_t size_x = N * size_t(incx) + size_t(stride_x) * size_t(batch_count - 1); + size_t size_y = N * size_t(incy) + size_t(stride_y) * size_t(batch_count - 1); + + device_vector dx(size_x); + device_vector dy(size_y); + device_vector dc(1); + device_vector ds(1); + if(!dx || !dy || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Initial Data on CPU + host_vector hx(size_x); + host_vector hy(size_y); + host_vector hc(1); + host_vector hs(1); + rocblas_seedrand(); + rocblas_init(hx, 1, N, incx, stride_x, batch_count); + rocblas_init(hy, 1, N, incy, stride_y, batch_count); + rocblas_init(hc, 1, 1, 1); + rocblas_init(hs, 1, 1, 1); + + // CPU BLAS reference data + host_vector cx = hx; + host_vector cy = hy; + // cblas_rotg(cx, cy, hc, hs); + // cx[0] = hx[0]; + // cy[0] = hy[0]; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rot(N, cx + b * stride_x, incx, cy + b * stride_y, incy, hc, hs); + } + cpu_time_used = get_time_us() - cpu_time_used; + + if(arg.unit_check || arg.norm_check) + { + // Test rocblas_pointer_mode_host + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR((rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hc, hs, batch_count))); + host_vector rx(size_x); + host_vector ry(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + if(arg.unit_check) + { + unit_check_general(1, N, batch_count, incx, stride_x, cx, rx); + unit_check_general(1, N, batch_count, incy, stride_y, cy, ry); + } + if(arg.norm_check) + { + norm_error_host_x + = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + norm_error_host_y + = norm_check_general('F', 1, N, incy, stride_x, batch_count, cy, ry); + } + } + + // Test rocblas_pointer_mode_device + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(V), hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR((rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, dc, ds, batch_count))); + host_vector rx(size_x); + host_vector ry(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + if(arg.unit_check) + { + unit_check_general(1, N, batch_count, incx, stride_x, cx, rx); + unit_check_general(1, N, batch_count, incy, stride_y, cy, ry); + } + if(arg.norm_check) + { + norm_error_device_x + = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + norm_error_device_y + = norm_check_general('F', 1, N, incy, stride_y, batch_count, cy, ry); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hc, hs, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rot_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hc, hs, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "N,incx,incy,rocblas(us),cpu(us)"; + if(arg.norm_check) + std::cout + << ",norm_error_host_x,norm_error_host_y,norm_error_device_x,norm_error_device_y"; + std::cout << std::endl; + std::cout << N << "," << incx << "," << incy << "," << gpu_time_used << "," + << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host_x << ',' << norm_error_host_y << "," + << norm_error_device_x << "," << norm_error_device_y; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rotg_batched.hpp b/clients/include/testing_rotg_batched.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/clients/include/testing_rotg_strided_batched.hpp b/clients/include/testing_rotg_strided_batched.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/clients/include/testing_rotm_batched.hpp b/clients/include/testing_rotm_batched.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/clients/include/testing_rotm_strided_batched.hpp b/clients/include/testing_rotm_strided_batched.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 0960f6c98..80a8b88e1 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -1255,6 +1255,208 @@ ROCBLAS_EXPORT rocblas_status rocblas_zdrot(rocblas_handle handle, const double* c, const double* s); +/*! \brief BLAS Level 1 API + + \details + rot_batched applies the Givens rotation matrix defined by c=cos(alpha) and s=sin(alpha) to batched vectors x and y. + Scalars c and s may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[in] + n rocblas_int + number of elements in the x and y vectors. + @param[inout] + x array of pointers storing vector x on the GPU. + @param[in] + incx rocblas_int + specifies the increment between elements of x. + @param[inout] + y array of pointers storing vector y on the GPU. + @param[in] + incy rocblas_int + specifies the increment between elements of y. + @param[in] + c scalar cosine component of the rotation matrix, may be stored in host or device memory. + @param[in] + s scalar sine component of the rotation matrix, may be stored in host or device memory. + @param[in] + batch_count rocblas_int + the number of x and y arrays, i.e. the number of batches. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srot_batched(rocblas_handle handle, + rocblas_int n, + float* x[], + rocblas_int incx, + float* y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drot_batched(rocblas_handle handle, + rocblas_int n, + double* x[], + rocblas_int incx, + double* y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_crot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x[], + rocblas_int incx, + rocblas_float_complex* y[], + rocblas_int incy, + const float* c, + const rocblas_float_complex* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_csrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x[], + rocblas_int incx, + rocblas_float_complex* y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x[], + rocblas_int incx, + rocblas_double_complex* y[], + rocblas_int incy, + const double* c, + const rocblas_double_complex* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zdrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x[], + rocblas_int incx, + rocblas_double_complex* y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count); + +/*! \brief BLAS Level 1 API + + \details + rot_strided_batched applies the Givens rotation matrix defined by c=cos(alpha) and s=sin(alpha) to strided batched vectors x and y. + Scalars c and s may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[in] + n rocblas_int + number of elements in the x and y vectors. + @param[inout] + x pointer storing strided vectors x on the GPU. + @param[in] + incx rocblas_int + specifies the increment between elements of x. + @param[in] + stride_x rocblas_stride + specifies the increment from the beginning of x_i to the beginning of x_(i+1) + @param[inout] + y pointer storing strided vectors y on the GPU. + @param[in] + incy rocblas_int + specifies the increment between elements of y. + @param[in] + stride_y rocblas_stride + specifies the increment from the beginning of y_i to the beginning of y_(i+1) + @param[in] + c scalar cosine component of the rotation matrix, may be stored in host or device memory. + @param[in] + s scalar sine component of the rotation matrix, may be stored in host or device memory. + @param[in] + batch_count rocblas_int + the number of x and y arrays, i.e. the number of batches. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srot_strided_batched(rocblas_handle handle, + rocblas_int n, + float* x, + rocblas_int incx, + rocblas_stride stride_x, + float* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const float* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drot_strided_batched(rocblas_handle handle, + rocblas_int n, + double* x, + rocblas_int incx, + rocblas_stride stride_x, + double* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const double* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_crot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_float_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const rocblas_float_complex* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_csrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_float_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const float* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_double_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const rocblas_double_complex* s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zdrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_double_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const double* s, + rocblas_int batch_count); + /*! \brief BLAS Level 1 API \details diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index 98d6585c3..b3496b2de 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -139,6 +139,8 @@ set( rocblas_blas1_source blas1/rocblas_scal_strided_batched.cpp blas1/rocblas_swap.cpp blas1/rocblas_rot.cpp + blas1/rocblas_rot_batched.cpp + blas1/rocblas_rot_strided_batched.cpp blas1/rocblas_rotg.cpp blas1/rocblas_rotm.cpp blas1/rocblas_rotmg.cpp diff --git a/library/src/blas1/rocblas_rot.cpp b/library/src/blas1/rocblas_rot.cpp index ba4b18fb6..d7869b840 100644 --- a/library/src/blas1/rocblas_rot.cpp +++ b/library/src/blas1/rocblas_rot.cpp @@ -1,6 +1,7 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_rot.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" @@ -10,58 +11,6 @@ namespace { constexpr int NB = 512; - template , int>::type = 0> - __global__ void rot_kernel(rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - U c_device_host, - V s_device_host) - { - auto c = load_scalar(c_device_host); - auto s = load_scalar(s_device_host); - ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - - if(tid < n) - { - auto ix = tid * incx; - auto iy = tid * incy; - auto temp = c * x[ix] + s * y[iy]; - y[iy] = c * y[iy] - s * x[ix]; - x[ix] = temp; - } - } - - template , int>::type = 0> - __global__ void rot_kernel(rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - U c_device_host, - V s_device_host) - { - auto c = load_scalar(c_device_host); - auto s = load_scalar(s_device_host); - ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - - if(tid < n) - { - auto ix = tid * incx; - auto iy = tid * incy; - auto temp = c * x[ix] + s * y[iy]; - y[iy] = c * y[iy] - conj(s) * x[ix]; - x[ix] = temp; - } - } - template constexpr char rocblas_rot_name[] = "unknown"; template <> @@ -78,14 +27,14 @@ namespace constexpr char rocblas_rot_name[] = "rocblas_zdrot"; template - rocblas_status rocblas_rot(rocblas_handle handle, - rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - const U* c, - const V* s) + rocblas_status rocblas_rot_impl(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + T* y, + rocblas_int incy, + const U* c, + const V* s) { if(!handle) return rocblas_status_invalid_handle; @@ -95,8 +44,12 @@ namespace log_trace(handle, rocblas_rot_name, n, x, incx, y, incy, c, s); if(layer_mode & rocblas_layer_mode_log_bench) log_bench(handle, - "./rocblas-bench -f rot -r", + "./rocblas-bench -f rot --a_type", rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--c_type", + rocblas_precision_string, "-n", n, "--incx", @@ -111,22 +64,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - // Quick return if possible - if(n <= 0 || incx <= 0 || incy <= 0) - return rocblas_status_success; - - dim3 blocks((n - 1) / NB + 1); - dim3 threads(NB); - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(rocblas_pointer_mode_device == handle->pointer_mode) - hipLaunchKernelGGL( - rot_kernel, blocks, threads, 0, rocblas_stream, n, x, incx, y, incy, c, s); - else // c and s are on host - hipLaunchKernelGGL( - rot_kernel, blocks, threads, 0, rocblas_stream, n, x, incx, y, incy, *c, *s); - - return rocblas_status_success; + return rocblas_rot_template(handle, n, x, 0, incx, 0, y, 0, incy, 0, c, 0, s, 0, 1); } } // namespace @@ -148,7 +86,7 @@ rocblas_status rocblas_srot(rocblas_handle handle, const float* c, const float* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_drot(rocblas_handle handle, @@ -160,7 +98,7 @@ rocblas_status rocblas_drot(rocblas_handle handle, const double* c, const double* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_crot(rocblas_handle handle, @@ -172,7 +110,7 @@ rocblas_status rocblas_crot(rocblas_handle handle, const float* c, const rocblas_float_complex* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_csrot(rocblas_handle handle, @@ -184,7 +122,7 @@ rocblas_status rocblas_csrot(rocblas_handle handle, const float* c, const float* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_zrot(rocblas_handle handle, @@ -196,7 +134,7 @@ rocblas_status rocblas_zrot(rocblas_handle handle, const double* c, const rocblas_double_complex* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } rocblas_status rocblas_zdrot(rocblas_handle handle, @@ -208,7 +146,7 @@ rocblas_status rocblas_zdrot(rocblas_handle handle, const double* c, const double* s) { - return rocblas_rot(handle, n, x, incx, y, incy, c, s); + return rocblas_rot_impl(handle, n, x, incx, y, incy, c, s); } } // extern "C" diff --git a/library/src/blas1/rocblas_rot.hpp b/library/src/blas1/rocblas_rot.hpp new file mode 100644 index 000000000..9dde321bc --- /dev/null +++ b/library/src/blas1/rocblas_rot.hpp @@ -0,0 +1,143 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "rocblas.h" +#include "utility.h" + +template , int>::type = 0> +__global__ void rot_kernel(rocblas_int n, + T2 x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T2 y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U c_device_host, + rocblas_stride c_stride, + V s_device_host, + rocblas_stride s_stride) +{ + auto c = load_scalar(c_device_host, hipBlockIdx_y, c_stride); + auto s = load_scalar(s_device_host, hipBlockIdx_y, s_stride); + auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); + auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); + ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + + if(tid < n) + { + auto ix = tid * incx; + auto iy = tid * incy; + auto temp = c * x[ix] + s * y[iy]; + y[iy] = c * y[iy] - s * x[ix]; + x[ix] = temp; + } +} + +template , int>::type = 0> +__global__ void rot_kernel(rocblas_int n, + T2 x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T2 y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U c_device_host, + rocblas_stride c_stride, + V s_device_host, + rocblas_stride s_stride) +{ + auto c = load_scalar(c_device_host, hipBlockIdx_y, c_stride); + auto s = load_scalar(s_device_host, hipBlockIdx_y, s_stride); + auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); + auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); + ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + + if(tid < n) + { + auto ix = tid * incx; + auto iy = tid * incy; + auto temp = c * x[ix] + s * y[iy]; + y[iy] = c * y[iy] - conj(s) * x[ix]; + x[ix] = temp; + } +} + +template +rocblas_status rocblas_rot_template(rocblas_handle handle, + rocblas_int n, + T2 x, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T2 y, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U* c, + rocblas_stride c_stride, + V* s, + rocblas_stride s_stride, + rocblas_int batch_count) +{ + // Quick return if possible + if(n <= 0 || incx <= 0 || incy <= 0) + return rocblas_status_success; + + dim3 blocks((n - 1) / NB + 1, batch_count); + dim3 threads(NB); + hipStream_t rocblas_stream = handle->rocblas_stream; + + if(rocblas_pointer_mode_device == handle->pointer_mode) + hipLaunchKernelGGL(rot_kernel, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + c, + c_stride, + s, + s_stride); + else // c and s are on host + hipLaunchKernelGGL(rot_kernel, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + *c, + c_stride, + *s, + s_stride); + + return rocblas_status_success; +} \ No newline at end of file diff --git a/library/src/blas1/rocblas_rot_batched.cpp b/library/src/blas1/rocblas_rot_batched.cpp new file mode 100644 index 000000000..34f9483cd --- /dev/null +++ b/library/src/blas1/rocblas_rot_batched.cpp @@ -0,0 +1,173 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rot.hpp" +#include "utility.h" + +namespace +{ + constexpr int NB = 512; + + template + constexpr char rocblas_rot_name[] = "unknown"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_srot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_drot"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_crot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_zrot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_csrot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_zdrot_batched"; + + template + rocblas_status rocblas_rot_batched_impl(rocblas_handle handle, + rocblas_int n, + T* x[], + rocblas_int incx, + T* y[], + rocblas_int incy, + const U* c, + const V* s, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rot_name, n, x, incx, y, incy, c, s, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rot_batched --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--c_type", + rocblas_precision_string, + "-n", + n, + "--incx", + incx, + "--incy", + incy, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, + rocblas_rot_name, + "N", + n, + "incx", + incx, + "incy", + incy, + "batch", + batch_count); + + if(!x || !y || !c || !s) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rot_template( + handle, n, x, 0, incx, 0, y, 0, incy, 0, c, 0, s, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +rocblas_status rocblas_srot_batched(rocblas_handle handle, + rocblas_int n, + float* x[], + rocblas_int incx, + float* y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_drot_batched(rocblas_handle handle, + rocblas_int n, + double* x[], + rocblas_int incx, + double* y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_crot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x[], + rocblas_int incx, + rocblas_float_complex* y[], + rocblas_int incy, + const float* c, + const rocblas_float_complex* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_csrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x[], + rocblas_int incx, + rocblas_float_complex* y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_zrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x[], + rocblas_int incx, + rocblas_double_complex* y[], + rocblas_int incy, + const double* c, + const rocblas_double_complex* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +rocblas_status rocblas_zdrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x[], + rocblas_int incx, + rocblas_double_complex* y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count) +{ + return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rot_strided_batched.cpp b/library/src/blas1/rocblas_rot_strided_batched.cpp new file mode 100644 index 000000000..65d7b5df8 --- /dev/null +++ b/library/src/blas1/rocblas_rot_strided_batched.cpp @@ -0,0 +1,214 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rot.hpp" +#include "utility.h" + +namespace +{ + constexpr int NB = 512; + + template + constexpr char rocblas_rot_name[] = "unknown"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_srot_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_drot"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_crot_strided_batched"; + template <> + constexpr char rocblas_rot_name[] = "rocblas_zrot_strided_batched"; + template <> + constexpr char rocblas_rot_name[] + = "rocblas_csrot_strided_batched"; + template <> + constexpr char rocblas_rot_name[] + = "rocblas_zdrot_strided_batched"; + + template + rocblas_status rocblas_rot_strided_batched_impl(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + rocblas_stride stride_x, + T* y, + rocblas_int incy, + rocblas_stride stride_y, + const U* c, + const V* s, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, + rocblas_rot_name, + n, + x, + incx, + stride_x, + y, + incy, + stride_y, + c, + s, + batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rot_strided_batched --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--c_type", + rocblas_precision_string, + "-n", + n, + "--incx", + incx, + "--stride_x", + stride_x, + "--incy", + incy, + "--stride_y", + stride_y, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, + rocblas_rot_name, + "N", + n, + "incx", + incx, + "stride_x", + stride_x, + "incy", + incy, + "stride_y", + stride_y, + "batch", + batch_count); + + if(!x || !y || !c || !s) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rot_template( + handle, n, x, 0, incx, stride_x, y, 0, incy, stride_y, c, 0, s, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +rocblas_status rocblas_srot_strided_batched(rocblas_handle handle, + rocblas_int n, + float* x, + rocblas_int incx, + rocblas_stride stride_x, + float* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const float* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_drot_strided_batched(rocblas_handle handle, + rocblas_int n, + double* x, + rocblas_int incx, + rocblas_stride stride_x, + double* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const double* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_crot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_float_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const rocblas_float_complex* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_csrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_float_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* c, + const float* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_zrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_double_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const rocblas_double_complex* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +rocblas_status rocblas_zdrot_strided_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* x, + rocblas_int incx, + rocblas_stride stride_x, + rocblas_double_complex* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* c, + const double* s, + rocblas_int batch_count) +{ + return rocblas_rot_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, c, s, batch_count); +} + +} // extern "C" diff --git a/library/src/include/utility.h b/library/src/include/utility.h index 0289da55e..d112dc55c 100644 --- a/library/src/include/utility.h +++ b/library/src/include/utility.h @@ -103,6 +103,14 @@ __forceinline__ __device__ __host__ T* return p[block] + offset; } +// For device array of device pointers +template +__forceinline__ __device__ __host__ T* + load_ptr_batch(T** p, rocblas_int block, rocblas_int offset, rocblas_stride stride) +{ + return p[block] + offset; +} + #endif // GOOGLE_TEST inline bool isAligned(const void* pointer, size_t byte_count) From 5f9aa7099633403dd48b3adc5a02508fc661a8ef Mon Sep 17 00:00:00 2001 From: First Last Date: Tue, 1 Oct 2019 10:50:15 -0600 Subject: [PATCH 02/11] Rotm batched and strided_batched functions --- clients/benchmarks/client.cpp | 6 + clients/common/rocblas_gentest.py | 3 +- clients/gtest/blas1_gtest.cpp | 17 +- clients/gtest/blas1_gtest.yaml | 10 + clients/include/rocblas.hpp | 38 ++- clients/include/testing_rot_batched.hpp | 65 +++-- clients/include/testing_rotm_batched.hpp | 274 ++++++++++++++++++ .../include/testing_rotm_strided_batched.hpp | 224 ++++++++++++++ library/include/rocblas-functions.h | 178 ++++++++++-- library/src/CMakeLists.txt | 2 + library/src/blas1/rocblas_rot.hpp | 2 +- library/src/blas1/rocblas_rot_batched.cpp | 56 ++-- library/src/blas1/rocblas_rotm.cpp | 107 +------ library/src/blas1/rocblas_rotm.hpp | 128 ++++++++ library/src/blas1/rocblas_rotm_batched.cpp | 106 +++++++ .../blas1/rocblas_rotm_strided_batched.cpp | 132 +++++++++ library/src/include/utility.h | 8 - 17 files changed, 1157 insertions(+), 199 deletions(-) create mode 100644 library/src/blas1/rocblas_rotm.hpp create mode 100644 library/src/blas1/rocblas_rotm_batched.cpp create mode 100644 library/src/blas1/rocblas_rotm_strided_batched.cpp diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 681859afd..a8845dd16 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -25,6 +25,8 @@ #include "testing_rot_strided_batched.hpp" #include "testing_rotg.hpp" #include "testing_rotm.hpp" +#include "testing_rotm_batched.hpp" +#include "testing_rotm_strided_batched.hpp" #include "testing_rotmg.hpp" #include "testing_scal.hpp" #include "testing_scal_batched.hpp" @@ -180,6 +182,10 @@ struct perf_blas< testing_rotg(arg); else if(!strcmp(arg.function, "rotm")) testing_rotm(arg); + else if(!strcmp(arg.function, "rotm_batched")) + testing_rotm_batched(arg); + else if(!strcmp(arg.function, "rotm_strided_batched")) + testing_rotm_strided_batched(arg); else if(!strcmp(arg.function, "rotmg")) testing_rotmg(arg); else if(!strcmp(arg.function, "gemv")) diff --git a/clients/common/rocblas_gentest.py b/clients/common/rocblas_gentest.py index 140cc7e85..cc42ac57a 100755 --- a/clients/common/rocblas_gentest.py +++ b/clients/common/rocblas_gentest.py @@ -198,7 +198,8 @@ def setdefaults(test): if test['function'] in ('asum_strided_batched', 'nrm2_strided_batched', 'scal_strided_batched', 'swap_strided_batched', - 'copy_strided_batched', 'rot_strided_batched'): + 'copy_strided_batched', 'rot_strided_batched', + 'rotm_strided_batched'): if all([x in test for x in ('N', 'incx', 'stride_scale')]): test.setdefault('stride_x', int(test['N'] * abs(test['incx']) * test['stride_scale'])) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index fc8994e96..50e22e484 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -101,13 +101,15 @@ namespace || BLAS1 == blas1::scal_strided_batched); bool is_batched = (BLAS1 == blas1::nrm2_batched || BLAS1 == blas1::asum_batched || BLAS1 == blas1::scal_batched || BLAS1 == blas1::swap_batched - || BLAS1 == blas1::copy_batched || BLAS1 == blas1::rot_batched); + || BLAS1 == blas1::copy_batched || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rotm_batched); bool is_strided = (BLAS1 == blas1::nrm2_strided_batched || BLAS1 == blas1::asum_strided_batched || BLAS1 == blas1::scal_strided_batched || BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::copy_strided_batched - || BLAS1 == blas1::rot_strided_batched); + || BLAS1 == blas1::rot_strided_batched + || BLAS1 == blas1::rotm_strided_batched); if((is_scal || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotg) @@ -135,13 +137,14 @@ namespace || BLAS1 == blas1::dot || BLAS1 == blas1::swap || BLAS1 == blas1::swap_batched || BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched || BLAS1 == blas1::rot_strided_batched - || BLAS1 == blas1::rotm) + || BLAS1 == blas1::rotm || BLAS1 == blas1::rotm_batched + || BLAS1 == blas1::rotm_strided_batched) { name << '_' << arg.incy; } if(BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::copy_strided_batched - || BLAS1 == blas1::rot_strided_batched) + || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotm_strided_batched) { name << '_' << arg.stride_y; } @@ -242,7 +245,9 @@ namespace || (std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}))) - || (BLAS1 == blas1::rotm && std::is_same{} && std::is_same{} + || ((BLAS1 == blas1::rotm || BLAS1 == blas1::rotm_batched + || BLAS1 == blas1::rotm_strided_batched) + && std::is_same{} && std::is_same{} && (std::is_same{} || std::is_same{})) || (BLAS1 == blas1::rotmg && std::is_same{} && std::is_same{} @@ -323,6 +328,8 @@ BLAS1_TESTING(rot_batched, ARG3) BLAS1_TESTING(rot_strided_batched, ARG3) BLAS1_TESTING(rotg, ARG2) BLAS1_TESTING(rotm, ARG1) +BLAS1_TESTING(rotm_batched, ARG1) +BLAS1_TESTING(rotm_strided_batched, ARG1) BLAS1_TESTING(rotmg, ARG1) // clang-format on diff --git a/clients/gtest/blas1_gtest.yaml b/clients/gtest/blas1_gtest.yaml index cf0676fe6..9fd899f15 100644 --- a/clients/gtest/blas1_gtest.yaml +++ b/clients/gtest/blas1_gtest.yaml @@ -65,6 +65,7 @@ Tests: - nrm2_batched: *single_double_precisions_complex_real - copy_batched: *single_double_precisions_complex_real - rot_batched: *rot_precisions + - rotm_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: quick @@ -81,6 +82,7 @@ Tests: - nrm2_strided_batched: *single_double_precisions_complex_real - copy_strided_batched: *single_double_precisions_complex_real - rot_strided_batched: *rot_precisions + - rotm_strided_batched: *single_double_precisions_complex_real - name: blas1 @@ -117,6 +119,7 @@ Tests: - nrm2_batched: *double_precision_complex_real - copy_batched: *single_double_precisions_complex_real - rot_batched: *rot_precisions + - rotm_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: pre_checkin @@ -133,6 +136,7 @@ Tests: - nrm2_strided_batched: *double_precision_complex_real - copy_strided_batched: *single_double_precisions_complex_real - rot_strided_batched: *rot_precisions + - rotm_strided_batched: *single_double_precisions_complex_real - name: blas1 category: nightly @@ -166,6 +170,7 @@ Tests: - scal_batched: *single_double_complex_real_in_complex_out - copy_batched: *single_double_precisions_complex_real - rot_batched: *rot_precisions + - rotm_batched: *single_double_precisions_complex_real - name: blas1_batched category: nightly @@ -179,6 +184,7 @@ Tests: - scal_batched: *single_double_complex_real_in_complex_out - copy_batched: *single_double_precisions_complex_real - rot_batched: *rot_precisions + - rotm_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: nightly @@ -193,6 +199,7 @@ Tests: - scal_strided_batched: *single_double_complex_real_in_complex_out - copy_strided_batched: *single_double_precisions_complex_real - rot_strided_batched: *rot_precisions + - rotm_strided_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: nightly @@ -207,6 +214,7 @@ Tests: - scal_strided_batched: *single_double_complex_real_in_complex_out - copy_strided_batched: *single_double_precisions_complex_real - rot_strided_batched: *rot_precisions + - rotm_strided_batched: *single_double_precisions_complex_real - name: blas1 @@ -323,6 +331,7 @@ Tests: - scal_batched_bad_arg: *single_double_complex_real_in_complex_out - copy_batched_bad_arg: *single_double_precisions_complex_real - rot_batched_bad_arg: *rot_precisions + - rotm_batched_bad_arg: *single_double_precisions_complex_real - name: blas1_strided_batched_bad_arg category: pre_checkin @@ -333,5 +342,6 @@ Tests: - scal_strided_batched_bad_arg: *single_double_complex_real_in_complex_out - copy_strided_batched_bad_arg: *single_double_precisions_complex_real - rot_strided_batched_bad_arg: *rot_precisions + - rotm_strided_batched_bad_arg: *single_double_precisions_complex_real ... diff --git a/clients/include/rocblas.hpp b/clients/include/rocblas.hpp index 2d94cdcd4..fa569bf0e 100644 --- a/clients/include/rocblas.hpp +++ b/clients/include/rocblas.hpp @@ -514,9 +514,9 @@ static constexpr auto rocblas_rot = rocb template rocblas_status (*rocblas_rot_batched)(rocblas_handle handle, rocblas_int n, - T* x[], + T* const x[], rocblas_int incx, - T* y[], + T* const y[], rocblas_int incy, const U* c, const V* s, @@ -619,6 +619,40 @@ static constexpr auto rocblas_rotm = rocblas_srotm; template <> static constexpr auto rocblas_rotm = rocblas_drotm; +// rotm_batched +template +rocblas_status (*rocblas_rotm_batched)(rocblas_handle handle, + rocblas_int n, + T* const x[], + rocblas_int incx, + T* const y[], + rocblas_int incy, + const T* param, + rocblas_int batch_count); +template <> +static constexpr auto rocblas_rotm_batched = rocblas_srotm_batched; + +template <> +static constexpr auto rocblas_rotm_batched = rocblas_drotm_batched; + +// rotm_strided_batched +template +rocblas_status (*rocblas_rotm_strided_batched)(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + rocblas_stride stride_x, + T* y, + rocblas_int incy, + rocblas_stride stride_y, + const T* param, + rocblas_int batch_count); +template <> +static constexpr auto rocblas_rotm_strided_batched = rocblas_srotm_strided_batched; + +template <> +static constexpr auto rocblas_rotm_strided_batched = rocblas_drotm_strided_batched; + //rotmg template rocblas_status (*rocblas_rotmg)(rocblas_handle handle, T* d1, T* d2, T* x1, const T* y1, T* param); diff --git a/clients/include/testing_rot_batched.hpp b/clients/include/testing_rot_batched.hpp index 37dcc7587..5866eae21 100644 --- a/clients/include/testing_rot_batched.hpp +++ b/clients/include/testing_rot_batched.hpp @@ -188,42 +188,43 @@ void testing_rot_batched(const Arguments& arg) } // Test rocblas_pointer_mode_device - // { - // CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); - // for(int b = 0; b < batch_count; b++) - // { - // CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); - // CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); - // } - // CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); - // CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); - // CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U), hipMemcpyHostToDevice)); - // CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(V), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(V), hipMemcpyHostToDevice)); - // CHECK_ROCBLAS_ERROR((rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count))); + CHECK_ROCBLAS_ERROR( + (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count))); - // host_vector rx[batch_count]; - // host_vector ry[batch_count]; - // for(int b = 0; b < batch_count; b++) - // { - // rx[b] = host_vector(size_x); - // ry[b] = host_vector(size_y); - // CHECK_HIP_ERROR(hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); - // CHECK_HIP_ERROR(hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); - // } + host_vector rx[batch_count]; + host_vector ry[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rx[b] = host_vector(size_x); + ry[b] = host_vector(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + } - // if(arg.unit_check) - // { - // unit_check_general(1, N, batch_count, incx, cx, rx); - // unit_check_general(1, N, batch_count, incy, cy, ry); - // } - // if(arg.norm_check) - // { - // norm_error_device_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); - // norm_error_device_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); - // } - // } + if(arg.unit_check) + { + unit_check_general(1, N, batch_count, incx, cx, rx); + unit_check_general(1, N, batch_count, incy, cy, ry); + } + if(arg.norm_check) + { + norm_error_device_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + norm_error_device_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + } + } } if(arg.timing) diff --git a/clients/include/testing_rotm_batched.hpp b/clients/include/testing_rotm_batched.hpp index e69de29bb..8ab71d6d6 100644 --- a/clients/include/testing_rotm_batched.hpp +++ b/clients/include/testing_rotm_batched.hpp @@ -0,0 +1,274 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotm_batched_bad_arg(const Arguments& arg) +{ + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_int incy = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; + + rocblas_local_handle handle; + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dparam(5); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(nullptr, N, dx, incx, dy, incy, dparam, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(handle, N, nullptr, incx, dy, incy, dparam, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(handle, N, dx, incx, nullptr, incy, dparam, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotm_batched(const Arguments& arg) +{ + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host_x = 0.0, norm_error_host_y = 0.0, norm_error_device_x = 0.0, + norm_error_device_y = 0.0; + + // check to prevent undefined memory allocation error + if(N <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + { + static const size_t safe_size = 100; // arbitrarily set to 100 + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dparam(5); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count)), + rocblas_status_invalid_size); + if(!batch_count) + CHECK_ROCBLAS_ERROR( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count))); + return; + } + + size_t size_x = N * size_t(incx); + size_t size_y = N * size_t(incy); + + device_vector dx(batch_count); + device_vector dy(batch_count); + device_vector dparam(5); + + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Initial Data on CPU + host_vector hx[batch_count]; + host_vector hy[batch_count]; + host_vector hdata(4); + host_vector hparam(5); + + device_batch_vector bx(batch_count, size_x); + device_batch_vector by(batch_count, size_y); + + for(int b = 0; b < batch_count; b++) + { + hx[b] = host_vector(size_x); + hy[b] = host_vector(size_y); + } + + int last = batch_count - 1; + if((!bx[last] && size_x) || (!by[last] && size_y)) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + rocblas_seedrand(); + for(int b = 0; b < batch_count; b++) + { + rocblas_init(hx[b], 1, N, incx); + rocblas_init(hy[b], 1, N, incy); + } + rocblas_init(hdata, 1, 4, 1); + + // CPU BLAS reference data + cblas_rotmg(&hdata[0], &hdata[1], &hdata[2], &hdata[3], hparam); + constexpr int FLAG_COUNT = 4; + const T FLAGS[FLAG_COUNT] = {-1, 0, 1, -2}; + + for(int i = 0; i < FLAG_COUNT; i++) + { + hparam[0] = FLAGS[i]; + host_vector cx[batch_count]; + host_vector cy[batch_count]; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cx[b] = hx[b]; + cy[b] = hy[b]; + + cblas_rotm(N, cx[b], incx, cy[b], incy, hparam); + } + cpu_time_used = get_time_us() - cpu_time_used; + + if(arg.unit_check || arg.norm_check) + { + // Test rocblas_pointer_mode_host + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR( + hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, hparam, batch_count))); + + host_vector rx[batch_count]; + host_vector ry[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rx[b] = host_vector(size_x); + ry[b] = host_vector(size_y); + CHECK_HIP_ERROR( + hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + T rel_error = std::numeric_limits::epsilon() * 1000; + near_check_general(1, N, batch_count, incx, cx, rx, rel_error); + near_check_general(1, N, batch_count, incy, cy, ry, rel_error); + } + if(arg.norm_check) + { + norm_error_host_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + norm_error_host_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + } + } + + // Test rocblas_pointer_mode_device + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR( + hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dparam, hparam, sizeof(T) * 5, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR( + (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count))); + + host_vector rx[batch_count]; + host_vector ry[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rx[b] = host_vector(size_x); + ry[b] = host_vector(size_y); + CHECK_HIP_ERROR( + hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + T rel_error = std::numeric_limits::epsilon() * 1000; + near_check_general(1, N, batch_count, incx, cx, rx, rel_error); + near_check_general(1, N, batch_count, incy, cy, ry, rel_error); + } + if(arg.norm_check) + { + norm_error_device_x + = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + norm_error_device_y + = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotm_batched(handle, N, dx, incx, dy, incy, hparam, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotm_batched(handle, N, dx, incx, dy, incy, hparam, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "N,incx,incy,rocblas(us),cpu(us)"; + if(arg.norm_check) + std::cout << ",norm_error_host_x,norm_error_host_y,norm_error_device_x,norm_error_" + "device_y"; + std::cout << std::endl; + std::cout << N << "," << incx << "," << incy << "," << gpu_time_used << "," + << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host_x << ',' << norm_error_host_y << "," + << norm_error_device_x << "," << norm_error_device_y; + std::cout << std::endl; + } + } +} diff --git a/clients/include/testing_rotm_strided_batched.hpp b/clients/include/testing_rotm_strided_batched.hpp index e69de29bb..d80632bfd 100644 --- a/clients/include/testing_rotm_strided_batched.hpp +++ b/clients/include/testing_rotm_strided_batched.hpp @@ -0,0 +1,224 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotm_strided_batched_bad_arg(const Arguments& arg) +{ + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_stride stride_x = 1; + rocblas_int incy = 1; + rocblas_stride stride_y = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; + + rocblas_local_handle handle; + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dparam(5); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_strided_batched( + nullptr, N, dx, incx, stride_x, dy, incy, stride_y, dparam, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_strided_batched( + handle, N, nullptr, incx, stride_x, dy, incy, stride_y, dparam, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_strided_batched( + handle, N, dx, incx, stride_x, nullptr, incy, stride_y, dparam, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotm_strided_batched(const Arguments& arg) +{ + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int stride_x = arg.stride_x; + rocblas_int stride_y = arg.stride_y; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host_x = 0.0, norm_error_host_y = 0.0, norm_error_device_x = 0.0, + norm_error_device_y = 0.0; + + // check to prevent undefined memory allocation error + if(N <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) + { + static const size_t safe_size = 100; // arbitrarily set to 100 + device_vector dx(safe_size); + device_vector dy(safe_size); + device_vector dparam(5); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS( + (rocblas_rotm_strided_batched< + T>)(handle, N, dx, incx, stride_x, dy, incy, stride_y, dparam, batch_count), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, dparam, batch_count))); + return; + } + + size_t size_x = N * size_t(incx) + size_t(stride_x) * size_t(batch_count - 1); + size_t size_y = N * size_t(incy) + size_t(stride_y) * size_t(batch_count - 1); + + device_vector dx(size_x); + device_vector dy(size_y); + device_vector dparam(5); + if(!dx || !dy || !dparam) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Initial Data on CPU + host_vector hx(size_x); + host_vector hy(size_y); + host_vector hdata(4); + host_vector hparam(5); + rocblas_seedrand(); + rocblas_init(hx, 1, N, incx, stride_x, batch_count); + rocblas_init(hy, 1, N, incy, stride_y, batch_count); + rocblas_init(hdata, 1, 4, 1); + + // CPU BLAS reference data + cblas_rotmg(&hdata[0], &hdata[1], &hdata[2], &hdata[3], hparam); + constexpr int FLAG_COUNT = 4; + const T FLAGS[FLAG_COUNT] = {-1, 0, 1, -2}; + + for(int i = 0; i < FLAG_COUNT; i++) + { + hparam[0] = FLAGS[i]; + host_vector cx = hx; + host_vector cy = hy; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotm(N, cx + b * stride_x, incx, cy + b * stride_y, incy, hparam); + } + cpu_time_used = get_time_us() - cpu_time_used; + + if(arg.unit_check || arg.norm_check) + { + // Test rocblas_pointer_mode_host + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hparam, batch_count))); + host_vector rx(size_x); + host_vector ry(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + if(arg.unit_check) + { + T rel_error = std::numeric_limits::epsilon() * 1000; + near_check_general(1, N, batch_count, incx, stride_x, cx, rx, rel_error); + near_check_general(1, N, batch_count, incy, stride_y, cy, ry, rel_error); + } + if(arg.norm_check) + { + norm_error_host_x + = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + norm_error_host_y + = norm_check_general('F', 1, N, incy, stride_x, batch_count, cy, ry); + } + } + + // Test rocblas_pointer_mode_device + { + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dparam, hparam, sizeof(T) * 5, hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, dparam, batch_count))); + host_vector rx(size_x); + host_vector ry(size_y); + CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + if(arg.unit_check) + { + T rel_error = std::numeric_limits::epsilon() * 1000; + near_check_general(1, N, batch_count, incx, stride_x, cx, rx, rel_error); + near_check_general(1, N, batch_count, incy, stride_y, cy, ry, rel_error); + } + if(arg.norm_check) + { + norm_error_device_x + = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + norm_error_device_y + = norm_check_general('F', 1, N, incy, stride_y, batch_count, cy, ry); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotm_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hparam, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotm_strided_batched( + handle, N, dx, incx, stride_x, dy, incy, stride_y, hparam, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "N,incx,incy,rocblas(us),cpu(us)"; + if(arg.norm_check) + std::cout << ",norm_error_host_x,norm_error_host_y,norm_error_device_x,norm_error_" + "device_y"; + std::cout << std::endl; + std::cout << N << "," << incx << "," << incy << "," << gpu_time_used << "," + << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host_x << ',' << norm_error_host_y << "," + << norm_error_device_x << "," << norm_error_device_y; + std::cout << std::endl; + } + } +} diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 80a8b88e1..3e28299fb 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -1289,9 +1289,9 @@ ROCBLAS_EXPORT rocblas_status rocblas_zdrot(rocblas_handle handle, ROCBLAS_EXPORT rocblas_status rocblas_srot_batched(rocblas_handle handle, rocblas_int n, - float* x[], + float* const x[], rocblas_int incx, - float* y[], + float* const y[], rocblas_int incy, const float* c, const float* s, @@ -1299,9 +1299,9 @@ ROCBLAS_EXPORT rocblas_status rocblas_srot_batched(rocblas_handle handle, ROCBLAS_EXPORT rocblas_status rocblas_drot_batched(rocblas_handle handle, rocblas_int n, - double* x[], + double* const x[], rocblas_int incx, - double* y[], + double* const y[], rocblas_int incy, const double* c, const double* s, @@ -1309,43 +1309,43 @@ ROCBLAS_EXPORT rocblas_status rocblas_drot_batched(rocblas_handle handle, ROCBLAS_EXPORT rocblas_status rocblas_crot_batched(rocblas_handle handle, rocblas_int n, - rocblas_float_complex* x[], + rocblas_float_complex* const x[], rocblas_int incx, - rocblas_float_complex* y[], + rocblas_float_complex* const y[], rocblas_int incy, const float* c, const rocblas_float_complex* s, rocblas_int batch_count); -ROCBLAS_EXPORT rocblas_status rocblas_csrot_batched(rocblas_handle handle, - rocblas_int n, - rocblas_float_complex* x[], - rocblas_int incx, - rocblas_float_complex* y[], - rocblas_int incy, - const float* c, - const float* s, - rocblas_int batch_count); +ROCBLAS_EXPORT rocblas_status rocblas_csrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* const x[], + rocblas_int incx, + rocblas_float_complex* const y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count); ROCBLAS_EXPORT rocblas_status rocblas_zrot_batched(rocblas_handle handle, rocblas_int n, - rocblas_double_complex* x[], + rocblas_double_complex* const x[], rocblas_int incx, - rocblas_double_complex* y[], + rocblas_double_complex* const y[], rocblas_int incy, const double* c, const rocblas_double_complex* s, rocblas_int batch_count); -ROCBLAS_EXPORT rocblas_status rocblas_zdrot_batched(rocblas_handle handle, - rocblas_int n, - rocblas_double_complex* x[], - rocblas_int incx, - rocblas_double_complex* y[], - rocblas_int incy, - const double* c, - const double* s, - rocblas_int batch_count); +ROCBLAS_EXPORT rocblas_status rocblas_zdrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* const x[], + rocblas_int incx, + rocblas_double_complex* const y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count); /*! \brief BLAS Level 1 API @@ -1550,6 +1550,132 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, rocblas_int incy, const double* param); +/*! \brief BLAS Level 1 API + + \details + rotm_batched applies the modified Givens rotation matrix defined by param to batched vectors x and y. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[in] + n rocblas_int + number of elements in the x and y vectors. + @param[inout] + x array of pointers storing vectors x on the GPU. + @param[in] + incx rocblas_int + specifies the increment between elements of x. + @param[inout] + y array of pointers storing vectors y on the GPU. + @param[in] + incy rocblas_int + specifies the increment between elements of y. + @param[in] + param vector of 5 elements defining the rotation. + param[0] = flag + param[1] = H11 + param[2] = H21 + param[3] = H12 + param[4] = H22 + The flag parameter defines the form of H: + flag = -1 => H = ( H11 H12 H21 H22 ) + flag = 0 => H = ( 1.0 H12 H21 1.0 ) + flag = 1 => H = ( H11 1.0 -1.0 H22 ) + flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) + param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + @param[in] + batch_count rocblas_int + the number of x and y arrays, i.e. the number of batches. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotm_batched(rocblas_handle handle, + rocblas_int n, + float* const x[], + rocblas_int incx, + float* const y[], + rocblas_int incy, + const float* param, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, + rocblas_int n, + double* const x[], + rocblas_int incx, + double* const y[], + rocblas_int incy, + const double* param, + rocblas_int batch_count); + +/*! \brief BLAS Level 1 API + + \details + rotm_strided_batched applies the modified Givens rotation matrix defined by param to batched vectors x and y. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[in] + n rocblas_int + number of elements in the x and y vectors. + @param[inout] + x pointers storing batched vectors x on the GPU. + @param[in] + incx rocblas_int + specifies the increment between elements of x. + @param[in] + stride_x rocblas_stride + specifies the increment between the beginning of x_i and x_(i + 1) + @param[inout] + y pointers storing batched vectors y on the GPU. + @param[in] + incy rocblas_int + specifies the increment between elements of y. + @param[in] + stride_y rocblas_stride + specifies the increment between the beginning of y_i and y_(i + 1) + @param[in] + param vector of 5 elements defining the rotation. + param[0] = flag + param[1] = H11 + param[2] = H21 + param[3] = H12 + param[4] = H22 + The flag parameter defines the form of H: + flag = -1 => H = ( H11 H12 H21 H22 ) + flag = 0 => H = ( 1.0 H12 H21 1.0 ) + flag = 1 => H = ( H11 1.0 -1.0 H22 ) + flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) + param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + @param[in] + batch_count rocblas_int + the number of x and y arrays, i.e. the number of batches. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotm_strided_batched(rocblas_handle handle, + rocblas_int n, + float* x, + rocblas_int incx, + rocblas_stride stride_x, + float* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* param, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handle, + rocblas_int n, + double* x, + rocblas_int incx, + rocblas_stride stride_x, + double* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* param, + rocblas_int batch_count); + /*! \brief BLAS Level 1 API \details diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index b3496b2de..ce47d1ddc 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -143,6 +143,8 @@ set( rocblas_blas1_source blas1/rocblas_rot_strided_batched.cpp blas1/rocblas_rotg.cpp blas1/rocblas_rotm.cpp + blas1/rocblas_rotm_batched.cpp + blas1/rocblas_rotm_strided_batched.cpp blas1/rocblas_rotmg.cpp blas1/rocblas_swap_batched.cpp blas1/rocblas_swap_strided_batched.cpp diff --git a/library/src/blas1/rocblas_rot.hpp b/library/src/blas1/rocblas_rot.hpp index 9dde321bc..6bce83b39 100644 --- a/library/src/blas1/rocblas_rot.hpp +++ b/library/src/blas1/rocblas_rot.hpp @@ -93,7 +93,7 @@ rocblas_status rocblas_rot_template(rocblas_handle handle, rocblas_int batch_count) { // Quick return if possible - if(n <= 0 || incx <= 0 || incy <= 0) + if(n <= 0 || incx <= 0 || incy <= 0 || batch_count == 0) return rocblas_status_success; dim3 blocks((n - 1) / NB + 1, batch_count); diff --git a/library/src/blas1/rocblas_rot_batched.cpp b/library/src/blas1/rocblas_rot_batched.cpp index 34f9483cd..e0c9d2bbf 100644 --- a/library/src/blas1/rocblas_rot_batched.cpp +++ b/library/src/blas1/rocblas_rot_batched.cpp @@ -29,9 +29,9 @@ namespace template rocblas_status rocblas_rot_batched_impl(rocblas_handle handle, rocblas_int n, - T* x[], + T* const x[], rocblas_int incx, - T* y[], + T* const y[], rocblas_int incy, const U* c, const V* s, @@ -94,9 +94,9 @@ extern "C" { rocblas_status rocblas_srot_batched(rocblas_handle handle, rocblas_int n, - float* x[], + float* const x[], rocblas_int incx, - float* y[], + float* const y[], rocblas_int incy, const float* c, const float* s, @@ -107,9 +107,9 @@ rocblas_status rocblas_srot_batched(rocblas_handle handle, rocblas_status rocblas_drot_batched(rocblas_handle handle, rocblas_int n, - double* x[], + double* const x[], rocblas_int incx, - double* y[], + double* const y[], rocblas_int incy, const double* c, const double* s, @@ -120,9 +120,9 @@ rocblas_status rocblas_drot_batched(rocblas_handle handle, rocblas_status rocblas_crot_batched(rocblas_handle handle, rocblas_int n, - rocblas_float_complex* x[], + rocblas_float_complex* const x[], rocblas_int incx, - rocblas_float_complex* y[], + rocblas_float_complex* const y[], rocblas_int incy, const float* c, const rocblas_float_complex* s, @@ -131,24 +131,24 @@ rocblas_status rocblas_crot_batched(rocblas_handle handle, return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); } -rocblas_status rocblas_csrot_batched(rocblas_handle handle, - rocblas_int n, - rocblas_float_complex* x[], - rocblas_int incx, - rocblas_float_complex* y[], - rocblas_int incy, - const float* c, - const float* s, - rocblas_int batch_count) +rocblas_status rocblas_csrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_float_complex* const x[], + rocblas_int incx, + rocblas_float_complex* const y[], + rocblas_int incy, + const float* c, + const float* s, + rocblas_int batch_count) { return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); } rocblas_status rocblas_zrot_batched(rocblas_handle handle, rocblas_int n, - rocblas_double_complex* x[], + rocblas_double_complex* const x[], rocblas_int incx, - rocblas_double_complex* y[], + rocblas_double_complex* const y[], rocblas_int incy, const double* c, const rocblas_double_complex* s, @@ -157,15 +157,15 @@ rocblas_status rocblas_zrot_batched(rocblas_handle handle, return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); } -rocblas_status rocblas_zdrot_batched(rocblas_handle handle, - rocblas_int n, - rocblas_double_complex* x[], - rocblas_int incx, - rocblas_double_complex* y[], - rocblas_int incy, - const double* c, - const double* s, - rocblas_int batch_count) +rocblas_status rocblas_zdrot_batched(rocblas_handle handle, + rocblas_int n, + rocblas_double_complex* const x[], + rocblas_int incx, + rocblas_double_complex* const y[], + rocblas_int incy, + const double* c, + const double* s, + rocblas_int batch_count) { return rocblas_rot_batched_impl(handle, n, x, incx, y, incy, c, s, batch_count); } diff --git a/library/src/blas1/rocblas_rotm.cpp b/library/src/blas1/rocblas_rotm.cpp index 5ce51b266..5ffd97723 100644 --- a/library/src/blas1/rocblas_rotm.cpp +++ b/library/src/blas1/rocblas_rotm.cpp @@ -1,6 +1,7 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_rotm.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" @@ -10,49 +11,6 @@ namespace { constexpr int NB = 512; - template - __global__ void rotm_kernel(rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - U flag_device_host, - U h11_device_host, - U h21_device_host, - U h12_device_host, - U h22_device_host) - { - auto flag = load_scalar(flag_device_host); - auto h11 = load_scalar(h11_device_host); - auto h21 = load_scalar(h21_device_host); - auto h12 = load_scalar(h12_device_host); - auto h22 = load_scalar(h22_device_host); - ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - - if(tid < n && flag != -2) - { - auto ix = tid * incx; - auto iy = tid * incy; - auto w = x[ix]; - auto z = y[iy]; - if(flag < 0) - { - x[ix] = w * h11 + z * h12; - y[iy] = w * h21 + z * h22; - } - else if(flag == 0) - { - x[ix] = w + z * h12; - y[iy] = w * h21 + z; - } - else - { - x[ix] = w * h11 + z; - y[iy] = -w + z * h22; - } - } - } - template constexpr char rocblas_rotm_name[] = "unknown"; template <> @@ -61,13 +19,13 @@ namespace constexpr char rocblas_rotm_name[] = "rocblas_drotm"; template - rocblas_status rocblas_rotm(rocblas_handle handle, - rocblas_int n, - T* x, - rocblas_int incx, - T* y, - rocblas_int incy, - const T* param) + rocblas_status rocblas_rotm_impl(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + T* y, + rocblas_int incy, + const T* param) { if(!handle) return rocblas_status_invalid_handle; @@ -93,50 +51,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - // Quick return if possible - if(n <= 0 || incx <= 0 || incy <= 0) - return rocblas_status_success; - if(rocblas_pointer_mode_host == handle->pointer_mode && param[0] == -2) - return rocblas_status_success; - - dim3 blocks((n - 1) / NB + 1); - dim3 threads(NB); - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(rocblas_pointer_mode_device == handle->pointer_mode) - hipLaunchKernelGGL(rotm_kernel, - blocks, - threads, - 0, - rocblas_stream, - n, - x, - incx, - y, - incy, - param, - param + 1, - param + 2, - param + 3, - param + 4); - else // c and s are on host - hipLaunchKernelGGL(rotm_kernel, - blocks, - threads, - 0, - rocblas_stream, - n, - x, - incx, - y, - incy, - param[0], - param[1], - param[2], - param[3], - param[4]); - - return rocblas_status_success; + return rocblas_rotm_template(handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 1); } } // namespace @@ -157,7 +72,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotm(rocblas_handle handle, rocblas_int incy, const float* param) { - return rocblas_rotm(handle, n, x, incx, y, incy, param); + return rocblas_rotm_impl(handle, n, x, incx, y, incy, param); } ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, @@ -168,7 +83,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, rocblas_int incy, const double* param) { - return rocblas_rotm(handle, n, x, incx, y, incy, param); + return rocblas_rotm_impl(handle, n, x, incx, y, incy, param); } } // extern "C" diff --git a/library/src/blas1/rocblas_rotm.hpp b/library/src/blas1/rocblas_rotm.hpp new file mode 100644 index 000000000..d4f7429f9 --- /dev/null +++ b/library/src/blas1/rocblas_rotm.hpp @@ -0,0 +1,128 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "utility.h" + +template +__global__ void rotm_kernel(rocblas_int n, + T x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U flag_device_host, + U h11_device_host, + U h21_device_host, + U h12_device_host, + U h22_device_host, + rocblas_stride stride_param) +{ + auto flag = load_scalar(flag_device_host); //, hipBlockIdx_y, stride_param); + auto h11 = load_scalar(h11_device_host); //, hipBlockIdx_y, stride_param); + auto h21 = load_scalar(h21_device_host); //, hipBlockIdx_y, stride_param); + auto h12 = load_scalar(h12_device_host); //, hipBlockIdx_y, stride_param); + auto h22 = load_scalar(h22_device_host); //, hipBlockIdx_y, stride_param); + auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); + auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); + ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + + if(tid < n && flag != -2) + { + auto ix = tid * incx; + auto iy = tid * incy; + auto w = x[ix]; + auto z = y[iy]; + if(flag < 0) + { + x[ix] = w * h11 + z * h12; + y[iy] = w * h21 + z * h22; + } + else if(flag == 0) + { + x[ix] = w + z * h12; + y[iy] = w * h21 + z; + } + else + { + x[ix] = w * h11 + z; + y[iy] = -w + z * h22; + } + } +} + +template +rocblas_status rocblas_rotm_template(rocblas_handle handle, + rocblas_int n, + U x, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + U y, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + const T* param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + // Quick return if possible + if(n <= 0 || incx <= 0 || incy <= 0 || batch_count == 0) + return rocblas_status_success; + if(rocblas_pointer_mode_host == handle->pointer_mode && param[0] == -2) + return rocblas_status_success; + + dim3 blocks((n - 1) / NB + 1, batch_count); + dim3 threads(NB); + hipStream_t rocblas_stream = handle->rocblas_stream; + + if(rocblas_pointer_mode_device == handle->pointer_mode) + hipLaunchKernelGGL(rotm_kernel, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + param, + param + 1, + param + 2, + param + 3, + param + 4, + stride_param); + else // c and s are on host + hipLaunchKernelGGL(rotm_kernel, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + param[0], + param[1], + param[2], + param[3], + param[4], + stride_param); + + return rocblas_status_success; +} \ No newline at end of file diff --git a/library/src/blas1/rocblas_rotm_batched.cpp b/library/src/blas1/rocblas_rotm_batched.cpp new file mode 100644 index 000000000..b7e7a650b --- /dev/null +++ b/library/src/blas1/rocblas_rotm_batched.cpp @@ -0,0 +1,106 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotm.hpp" +#include "utility.h" + +namespace +{ + constexpr int NB = 512; + + template + constexpr char rocblas_rotm_name[] = "unknown"; + template <> + constexpr char rocblas_rotm_name[] = "rocblas_srotm_batched"; + template <> + constexpr char rocblas_rotm_name[] = "rocblas_drotm_batched"; + + template + rocblas_status rocblas_rotm_batched_impl(rocblas_handle handle, + rocblas_int n, + T* const x[], + rocblas_int incx, + T* const y[], + rocblas_int incy, + const T* param, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rotm_name, n, x, incx, y, incy, param, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotm_batched -r", + rocblas_precision_string, + "-n", + n, + "--incx", + incx, + "--incy", + incy, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, + rocblas_rotm_name, + "N", + n, + "incx", + incx, + "incy", + incy, + "batch", + batch_count); + + if(!x || !y || !param) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotm_template( + handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +ROCBLAS_EXPORT rocblas_status rocblas_srotm_batched(rocblas_handle handle, + rocblas_int n, + float* const x[], + rocblas_int incx, + float* const y[], + rocblas_int incy, + const float* param, + rocblas_int batch_count) +{ + return rocblas_rotm_batched_impl(handle, n, x, incx, y, incy, param, batch_count); +} + +ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, + rocblas_int n, + double* const x[], + rocblas_int incx, + double* const y[], + rocblas_int incy, + const double* param, + rocblas_int batch_count) +{ + return rocblas_rotm_batched_impl(handle, n, x, incx, y, incy, param, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotm_strided_batched.cpp b/library/src/blas1/rocblas_rotm_strided_batched.cpp new file mode 100644 index 000000000..ef4fb5475 --- /dev/null +++ b/library/src/blas1/rocblas_rotm_strided_batched.cpp @@ -0,0 +1,132 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotm.hpp" +#include "utility.h" + +namespace +{ + constexpr int NB = 512; + + template + constexpr char rocblas_rotm_name[] = "unknown"; + template <> + constexpr char rocblas_rotm_name[] = "rocblas_srotm_strided_batched"; + template <> + constexpr char rocblas_rotm_name[] = "rocblas_drotm_strided_batched"; + + template + rocblas_status rocblas_rotm_strided_batched_impl(rocblas_handle handle, + rocblas_int n, + T* x, + rocblas_int incx, + rocblas_stride stride_x, + T* y, + rocblas_int incy, + rocblas_stride stride_y, + const T* param, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, + rocblas_rotm_name, + n, + x, + incx, + stride_x, + y, + incy, + stride_y, + param, + batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotm_strided_batched -r", + rocblas_precision_string, + "-n", + n, + "--incx", + incx, + "--stride_x", + stride_x, + "--incy", + incy, + "--stride_y", + stride_y, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, + rocblas_rotm_name, + "N", + n, + "incx", + incx, + "stride_x", + stride_x, + "incy", + incy, + "stride_y", + stride_y, + "batch", + batch_count); + + if(!x || !y || !param) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotm_template( + handle, n, x, 0, incx, stride_x, y, 0, incy, stride_y, param, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +ROCBLAS_EXPORT rocblas_status rocblas_srotm_strided_batched(rocblas_handle handle, + rocblas_int n, + float* x, + rocblas_int incx, + rocblas_stride stride_x, + float* y, + rocblas_int incy, + rocblas_stride stride_y, + const float* param, + rocblas_int batch_count) +{ + return rocblas_rotm_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, param, batch_count); +} + +ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handle, + rocblas_int n, + double* x, + rocblas_int incx, + rocblas_stride stride_x, + double* y, + rocblas_int incy, + rocblas_stride stride_y, + const double* param, + rocblas_int batch_count) +{ + return rocblas_rotm_strided_batched_impl( + handle, n, x, incx, stride_x, y, incy, stride_y, param, batch_count); +} + +} // extern "C" diff --git a/library/src/include/utility.h b/library/src/include/utility.h index d112dc55c..0289da55e 100644 --- a/library/src/include/utility.h +++ b/library/src/include/utility.h @@ -103,14 +103,6 @@ __forceinline__ __device__ __host__ T* return p[block] + offset; } -// For device array of device pointers -template -__forceinline__ __device__ __host__ T* - load_ptr_batch(T** p, rocblas_int block, rocblas_int offset, rocblas_stride stride) -{ - return p[block] + offset; -} - #endif // GOOGLE_TEST inline bool isAligned(const void* pointer, size_t byte_count) From 997dbbcc3b2150413f1d30ff06ed58a76add9eb3 Mon Sep 17 00:00:00 2001 From: First Last Date: Wed, 2 Oct 2019 12:31:09 -0600 Subject: [PATCH 03/11] Added rotg batched and strided_batched --- clients/benchmarks/client.cpp | 47 ++- clients/common/rocblas_gentest.py | 7 + clients/gtest/blas1_gtest.cpp | 31 +- clients/gtest/blas1_gtest.yaml | 17 +- clients/include/rocblas.hpp | 48 ++++ clients/include/rocblas_common.yaml | 8 +- clients/include/testing_rot_batched.hpp | 2 +- clients/include/testing_rotg.hpp | 8 +- clients/include/testing_rotg_batched.hpp | 272 ++++++++++++++++++ .../include/testing_rotg_strided_batched.hpp | 255 ++++++++++++++++ clients/include/testing_rotm_batched.hpp | 2 +- clients/include/type_dispatch.hpp | 6 +- library/include/rocblas-functions.h | 134 +++++++++ library/src/CMakeLists.txt | 2 + library/src/blas1/rocblas_rotg.cpp | 89 +----- library/src/blas1/rocblas_rotg.hpp | 135 +++++++++ library/src/blas1/rocblas_rotg_batched.cpp | 108 +++++++ .../blas1/rocblas_rotg_strided_batched.cpp | 133 +++++++++ library/src/blas1/rocblas_rotm.hpp | 10 +- 19 files changed, 1210 insertions(+), 104 deletions(-) create mode 100644 library/src/blas1/rocblas_rotg.hpp create mode 100644 library/src/blas1/rocblas_rotg_batched.cpp create mode 100644 library/src/blas1/rocblas_rotg_strided_batched.cpp diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index a8845dd16..de52fe89a 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -24,6 +24,8 @@ #include "testing_rot_batched.hpp" #include "testing_rot_strided_batched.hpp" #include "testing_rotg.hpp" +#include "testing_rotg_batched.hpp" +#include "testing_rotg_strided_batched.hpp" #include "testing_rotm.hpp" #include "testing_rotm_batched.hpp" #include "testing_rotm_strided_batched.hpp" @@ -178,8 +180,12 @@ struct perf_blas< testing_set_get_vector(arg); else if(!strcmp(arg.function, "set_get_matrix")) testing_set_get_matrix(arg); - else if(!strcmp(arg.function, "rotg")) - testing_rotg(arg); + // else if(!strcmp(arg.function, "rotg")) + // testing_rotg(arg); + // else if(!strcmp(arg.function, "rotg_batched")) + // testing_rotg_batched(arg); + // else if(!strcmp(arg.function, "rotg_strided_batched")) + // testing_rotg_strided_batched(arg); else if(!strcmp(arg.function, "rotm")) testing_rotm(arg); else if(!strcmp(arg.function, "rotm_batched")) @@ -406,6 +412,40 @@ struct perf_blas_scal< } }; +template +struct perf_blas_rotg : rocblas_test_invalid +{ +}; + +template +struct perf_blas_rotg< + Ta, + Tb, + typename std::enable_if< + (std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{}) + || (std::is_same{} && std::is_same{})>::type> +{ + explicit operator bool() + { + return true; + } + void operator()(const Arguments& arg) + { + if(!strcmp(arg.function, "rotg")) + testing_rotg(arg); + else if(!strcmp(arg.function, "rotg_batched")) + testing_rotg_batched(arg); + else if(!strcmp(arg.function, "rotg_strided_batched")) + testing_rotg_strided_batched(arg); + else + throw std::invalid_argument("Invalid combination --function "s + arg.function + + " --a_type " + rocblas_datatype2string(arg.a_type) + + " --b_type " + rocblas_datatype2string(arg.b_type)); + } +}; + int run_bench_test(Arguments& arg) { // disable unit_check in client benchmark, it is only used in gtest unit test @@ -568,6 +608,9 @@ int run_bench_test(Arguments& arg) if(!strcmp(function, "scal") || !strcmp(function, "scal_batched") || !strcmp(function, "scal_strided_batched")) rocblas_blas1_dispatch(arg); + else if(!strcmp(function, "rotg") || !strcmp(function, "rotg_batched") + || !strcmp(function, "rotg_strided_batched")) + rocblas_blas1_dispatch(arg); else if(!strcmp(function, "rot") || !strcmp(function, "rot_batched") || !strcmp(function, "rot_strided_batched")) rocblas_blas1_dispatch(arg); diff --git a/clients/common/rocblas_gentest.py b/clients/common/rocblas_gentest.py index cc42ac57a..4a6c55327 100755 --- a/clients/common/rocblas_gentest.py +++ b/clients/common/rocblas_gentest.py @@ -215,6 +215,13 @@ def setdefaults(test): test.setdefault('stride_y', int(test['N'] * abs(test['incy']) * test['stride_scale'])) + if test['function'] in ('rotg_strided_batched'): + if 'stride_scale' in test: + test.setdefault('stride_a', int(test['stride_scale'])) + test.setdefault('stride_b', int(test['stride_scale'])) + test.setdefault('stride_c', int(test['stride_scale'])) + test.setdefault('stride_d', int(test['stride_scale'])) + test.setdefault('stride_x', 0) test.setdefault('stride_y', 0) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index 50e22e484..111b4b5a1 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -99,20 +99,23 @@ namespace { bool is_scal = (BLAS1 == blas1::scal || BLAS1 == blas1::scal_batched || BLAS1 == blas1::scal_strided_batched); + bool is_rotg = (BLAS1 == blas1::rotg || BLAS1 == blas1::rotg_batched + || BLAS1 == blas1::rotg_strided_batched); bool is_batched = (BLAS1 == blas1::nrm2_batched || BLAS1 == blas1::asum_batched || BLAS1 == blas1::scal_batched || BLAS1 == blas1::swap_batched || BLAS1 == blas1::copy_batched || BLAS1 == blas1::rot_batched - || BLAS1 == blas1::rotm_batched); + || BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotg_batched); bool is_strided = (BLAS1 == blas1::nrm2_strided_batched || BLAS1 == blas1::asum_strided_batched || BLAS1 == blas1::scal_strided_batched || BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::copy_strided_batched || BLAS1 == blas1::rot_strided_batched - || BLAS1 == blas1::rotm_strided_batched); + || BLAS1 == blas1::rotm_strided_batched + || BLAS1 == blas1::rotg_strided_batched); - if((is_scal || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched - || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotg) + if((is_scal || is_rotg || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched) && arg.a_type != arg.b_type) name << '_' << rocblas_datatype2string(arg.b_type); if((BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched @@ -120,14 +123,16 @@ namespace && arg.compute_type != arg.a_type) name << '_' << rocblas_datatype2string(arg.compute_type); - name << '_' << arg.N; + if(!is_rotg) + name << '_' << arg.N; if(BLAS1 == blas1::axpy || is_scal) name << '_' << arg.alpha << "_" << arg.alphai; - name << '_' << arg.incx; + if(!is_rotg) + name << '_' << arg.incx; - if(is_strided) + if(is_strided && !is_rotg) { name << '_' << arg.stride_x; } @@ -149,6 +154,12 @@ namespace name << '_' << arg.stride_y; } + if(is_rotg) + { + name << '_' << arg.stride_a << '_' << arg.stride_b << '_' << arg.stride_c << '_' + << arg.stride_d; + } + if(is_batched || is_strided) { name << "_" << arg.batch_count; @@ -239,7 +250,9 @@ namespace || (std::is_same{} && std::is_same{} && std::is_same{}))) - || (BLAS1 == blas1::rotg && std::is_same{} + || ((BLAS1 == blas1::rotg || BLAS1 == blas1::rotg_batched + || BLAS1 == blas1::rotg_strided_batched) + && std::is_same{} && ((std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}) @@ -327,6 +340,8 @@ BLAS1_TESTING(rot, ARG3) BLAS1_TESTING(rot_batched, ARG3) BLAS1_TESTING(rot_strided_batched, ARG3) BLAS1_TESTING(rotg, ARG2) +BLAS1_TESTING(rotg_batched, ARG2) +BLAS1_TESTING(rotg_strided_batched, ARG2) BLAS1_TESTING(rotm, ARG1) BLAS1_TESTING(rotm_batched, ARG1) BLAS1_TESTING(rotm_strided_batched, ARG1) diff --git a/clients/gtest/blas1_gtest.yaml b/clients/gtest/blas1_gtest.yaml index 9fd899f15..19a9d647a 100644 --- a/clients/gtest/blas1_gtest.yaml +++ b/clients/gtest/blas1_gtest.yaml @@ -51,6 +51,19 @@ Tests: - rotg: *rotg_precisions - rotmg: *single_double_precisions_complex_real + - name: blas1_batched + category: quick + batch_count: [-1, 0, 5] + function: + - rotg_batched: *rotg_precisions + + - name: blas1_strided_batched + category: quick + batch_count: [-1, 0, 5] + stride_scale: [ 1, 7 ] + function: + - rotg_strided_batched: *rotg_precisions + - name: blas1_batched category: quick N: [ -1, 0, 5, 33792 ] @@ -331,17 +344,19 @@ Tests: - scal_batched_bad_arg: *single_double_complex_real_in_complex_out - copy_batched_bad_arg: *single_double_precisions_complex_real - rot_batched_bad_arg: *rot_precisions + - rotg_batched_bad_arg: *rot_precisions - rotm_batched_bad_arg: *single_double_precisions_complex_real - name: blas1_strided_batched_bad_arg category: pre_checkin batch_count: [1, 10] - strideScale: [ 1 ] + stride_scale: [ 1 ] function: - scal_strided_batched_bad_arg: *single_double_precisions_complex_real - scal_strided_batched_bad_arg: *single_double_complex_real_in_complex_out - copy_strided_batched_bad_arg: *single_double_precisions_complex_real - rot_strided_batched_bad_arg: *rot_precisions + - rotg_strided_batched_bad_arg: *rot_precisions - rotm_strided_batched_bad_arg: *single_double_precisions_complex_real ... diff --git a/clients/include/rocblas.hpp b/clients/include/rocblas.hpp index fa569bf0e..14757770b 100644 --- a/clients/include/rocblas.hpp +++ b/clients/include/rocblas.hpp @@ -603,6 +603,54 @@ static constexpr auto rocblas_rotg = rocblas_crotg template <> static constexpr auto rocblas_rotg = rocblas_zrotg; +// rotg_batched +template +rocblas_status (*rocblas_rotg_batched)(rocblas_handle handle, + T* const a[], + T* const b[], + U* const c[], + T* const s[], + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rotg_batched = rocblas_srotg_batched; + +template <> +static constexpr auto rocblas_rotg_batched = rocblas_drotg_batched; + +template <> +static constexpr auto rocblas_rotg_batched = rocblas_crotg_batched; + +template <> +static constexpr auto rocblas_rotg_batched = rocblas_zrotg_batched; + +//rotg_strided_batched +template +rocblas_status (*rocblas_rotg_strided_batched)(rocblas_handle handle, + T* a, + rocblas_stride stride_a, + T* b, + rocblas_stride stride_b, + U* c, + rocblas_stride stride_c, + T* s, + rocblas_stride stride_s, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rotg_strided_batched = rocblas_srotg_strided_batched; + +template <> +static constexpr auto rocblas_rotg_strided_batched = rocblas_drotg_strided_batched; + +template <> +static constexpr auto + rocblas_rotg_strided_batched = rocblas_crotg_strided_batched; + +template <> +static constexpr auto + rocblas_rotg_strided_batched = rocblas_zrotg_strided_batched; + //rotm template rocblas_status (*rocblas_rotm)(rocblas_handle handle, diff --git a/clients/include/rocblas_common.yaml b/clients/include/rocblas_common.yaml index 6f201fc2a..527fb0ba1 100644 --- a/clients/include/rocblas_common.yaml +++ b/clients/include/rocblas_common.yaml @@ -317,10 +317,10 @@ Defaults: uplo: '*' diag: '*' batch_count: -1 - stride_a: 0 - stride_b: 0 - stride_c: 0 - stride_d: 0 + # stride_a: 0 + # stride_b: 0 + # stride_c: 0 + # stride_d: 0 stride_scale: 1 norm_check: 0 unit_check: 1 diff --git a/clients/include/testing_rot_batched.hpp b/clients/include/testing_rot_batched.hpp index 5866eae21..7d4b871c5 100644 --- a/clients/include/testing_rot_batched.hpp +++ b/clients/include/testing_rot_batched.hpp @@ -81,7 +81,7 @@ void testing_rot_batched(const Arguments& arg) EXPECT_ROCBLAS_STATUS( (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count)), rocblas_status_invalid_size); - if(!batch_count) + else CHECK_ROCBLAS_ERROR( (rocblas_rot_batched(handle, N, dx, incx, dy, incy, dc, ds, batch_count))); return; diff --git a/clients/include/testing_rotg.hpp b/clients/include/testing_rotg.hpp index b64ce34bd..7c15b9604 100644 --- a/clients/include/testing_rotg.hpp +++ b/clients/include/testing_rotg.hpp @@ -120,10 +120,10 @@ void testing_rotg(const Arguments& arg) if(arg.unit_check) { - // unit_check_general(1, 1, 1, ca, ha); - // unit_check_general(1, 1, 1, cb, hb); - // unit_check_general(1, 1, 1, cc, hc); - // unit_check_general(1, 1, 1, cs, hs); + unit_check_general(1, 1, 1, ca, ha); + unit_check_general(1, 1, 1, cb, hb); + unit_check_general(1, 1, 1, cc, hc); + unit_check_general(1, 1, 1, cs, hs); } if(arg.norm_check) diff --git a/clients/include/testing_rotg_batched.hpp b/clients/include/testing_rotg_batched.hpp index e69de29bb..e42d23998 100644 --- a/clients/include/testing_rotg_batched.hpp +++ b/clients/include/testing_rotg_batched.hpp @@ -0,0 +1,272 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotg_batched_bad_arg(const Arguments& arg) +{ + rocblas_int batch_count = 5; + static const size_t safe_size = 1; + + rocblas_local_handle handle; + device_vector da(batch_count); + device_vector db(batch_count); + device_vector dc(batch_count); + device_vector ds(batch_count); + + if(!da || !db || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(nullptr, da, db, dc, ds, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, nullptr, db, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, da, nullptr, dc, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, da, db, nullptr, ds, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, da, db, dc, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotg_batched(const Arguments& arg) +{ + const int TEST_COUNT = 1; //00; + rocblas_int batch_count = arg.batch_count; + rocblas_local_handle handle; + + double gpu_time_used, cpu_time_used; + double norm_error_host = 0.0, norm_error_device = 0.0; + + // check to prevent undefined memory allocation error + if(batch_count <= 0) + { + size_t safe_size = 1; + device_vector da(safe_size); + device_vector db(safe_size); + device_vector dc(safe_size); + device_vector ds(safe_size); + + if(!da || !db || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rotg_batched(handle, da, db, dc, ds, batch_count)), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rotg_batched(handle, da, db, dc, ds, batch_count))); + return; + } + + // Initial Data on CPU + host_vector ha[batch_count]; + host_vector hb[batch_count]; + host_vector hc[batch_count]; + host_vector hs[batch_count]; + + device_batch_vector ba(batch_count, 1); + device_batch_vector bb(batch_count, 1); + + for(int b = 0; b < batch_count; b++) + { + ha[b] = host_vector(1); + hb[b] = host_vector(1); + hc[b] = host_vector(1); + hs[b] = host_vector(1); + } + + for(int i = 0; i < TEST_COUNT; i++) + { + host_vector ca[batch_count]; + host_vector cb[batch_count]; + host_vector cc[batch_count]; + host_vector cs[batch_count]; + + rocblas_seedrand(); + for(int b = 0; b < batch_count; b++) + { + rocblas_init(ha[b], 1, 1, 1); + rocblas_init(hb[b], 1, 1, 1); + rocblas_init(hc[b], 1, 1, 1); + rocblas_init(hs[b], 1, 1, 1); + ca[b] = ha[b]; + cb[b] = hb[b]; + cc[b] = hc[b]; + cs[b] = hs[b]; + } + + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotg(ca[b], cb[b], cc[b], cs[b]); + } + cpu_time_used = get_time_us() - cpu_time_used; + + // Test rocblas_pointer_mode_host + { + host_vector ra[batch_count]; + host_vector rb[batch_count]; + host_vector rc[batch_count]; + host_vector rs[batch_count]; + T* ra_in[batch_count]; + T* rb_in[batch_count]; + U* rc_in[batch_count]; + T* rs_in[batch_count]; + for(int b = 0; b < batch_count; b++) + { + ra_in[b] = ra[b] = ha[b]; + rb_in[b] = rb[b] = hb[b]; + rc_in[b] = rc[b] = hc[b]; + rs_in[b] = rs[b] = hs[b]; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + + CHECK_ROCBLAS_ERROR( + (rocblas_rotg_batched(handle, ra_in, rb_in, rc_in, rs_in, batch_count))); + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, ra, ca); + unit_check_general(1, 1, batch_count, 1, rb, cb); + unit_check_general(1, 1, batch_count, 1, rc, cc); + unit_check_general(1, 1, batch_count, 1, rs, cs); + } + + if(arg.norm_check) + { + norm_error_host = norm_check_general('F', 1, 1, batch_count, 1, ra, ca); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rb, cb); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rc, cc); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rs, cs); + } + } + + // Test rocblas_pointer_mode_device + { + device_vector da(batch_count); + device_vector db(batch_count); + device_vector dc(batch_count); + device_vector ds(batch_count); + device_batch_vector ba(batch_count, 1); + device_batch_vector bb(batch_count, 1); + device_batch_vector bc(batch_count, 1); + device_batch_vector bs(batch_count, 1); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(ba[b], ha[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bb[b], hb[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bc[b], hc[b], sizeof(U), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bs[b], hs[b], sizeof(T), hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(da, ba, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, bb, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, bc, sizeof(U*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, bs, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_ROCBLAS_ERROR((rocblas_rotg_batched(handle, da, db, dc, ds, batch_count))); + + host_vector ra[batch_count]; + host_vector rb[batch_count]; + host_vector rc[batch_count]; + host_vector rs[batch_count]; + for(int b = 0; b < batch_count; b++) + { + ra[b] = host_vector(1); + rb[b] = host_vector(1); + rc[b] = host_vector(1); + rs[b] = host_vector(1); + CHECK_HIP_ERROR(hipMemcpy(ra[b], ba[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rb[b], bb[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rc[b], bc[b], sizeof(U), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rs[b], bs[b], sizeof(T), hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, ra, ca); + unit_check_general(1, 1, batch_count, 1, rb, cb); + unit_check_general(1, 1, batch_count, 1, rc, cc); + unit_check_general(1, 1, batch_count, 1, rs, cs); + } + + if(arg.norm_check) + { + norm_error_device = norm_check_general('F', 1, 1, batch_count, 1, ra, ca); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rb, cb); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rc, cc); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rs, cs); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + // Device mode will be much quicker + // (TODO: or is there another reason we are typically using host_mode for timing?) + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + device_vector da(batch_count); + device_vector db(batch_count); + device_vector dc(batch_count); + device_vector ds(batch_count); + device_batch_vector ba(batch_count, 1); + device_batch_vector bb(batch_count, 1); + device_batch_vector bc(batch_count, 1); + device_batch_vector bs(batch_count, 1); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(ba[b], ha[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bb[b], hb[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bc[b], hc[b], sizeof(U), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bs[b], hs[b], sizeof(T), hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(da, ba, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, bb, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, bc, sizeof(U*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, bs, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotg_batched(handle, da, db, dc, ds, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotg_batched(handle, da, db, dc, ds, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "rocblas-us,CPU-us"; + if(arg.norm_check) + std::cout << ",norm_error_host_ptr,norm_error_device"; + std::cout << std::endl; + + std::cout << gpu_time_used << "," << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host << ',' << norm_error_device; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rotg_strided_batched.hpp b/clients/include/testing_rotg_strided_batched.hpp index e69de29bb..e30a69a4d 100644 --- a/clients/include/testing_rotg_strided_batched.hpp +++ b/clients/include/testing_rotg_strided_batched.hpp @@ -0,0 +1,255 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotg_strided_batched_bad_arg(const Arguments& arg) +{ + static const size_t safe_size = 1; + rocblas_int batch_count = 5; + rocblas_stride stride_a = 10; + rocblas_stride stride_b = 10; + rocblas_stride stride_c = 10; + rocblas_stride stride_s = 10; + + rocblas_local_handle handle; + device_vector da(batch_count * stride_a); + device_vector db(batch_count * stride_b); + device_vector dc(batch_count * stride_c); + device_vector ds(batch_count * stride_s); + if(!da || !db || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + nullptr, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + handle, nullptr, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + handle, da, stride_a, nullptr, stride_b, dc, stride_c, ds, stride_s, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, nullptr, stride_c, ds, stride_s, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, nullptr, stride_s, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotg_strided_batched(const Arguments& arg) +{ + const int TEST_COUNT = 1; + rocblas_int stride_a = arg.stride_a; + rocblas_int stride_b = arg.stride_b; + rocblas_int stride_c = arg.stride_c; + rocblas_int stride_s = arg.stride_d; + rocblas_int batch_count = arg.batch_count; + + rocblas_local_handle handle; + double gpu_time_used, cpu_time_used; + double norm_error_host = 0.0, norm_error_device = 0.0; + + // check to prevent undefined memory allocation error + if(batch_count <= 0) + { + static const size_t safe_size = 1; // arbitrarily set to 100 + device_vector da(safe_size); + device_vector db(safe_size); + device_vector dc(safe_size); + device_vector ds(safe_size); + if(!da || !db || !dc || !ds) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rotg_strided_batched)(handle, + da, + stride_a, + db, + stride_b, + dc, + stride_c, + ds, + stride_s, + batch_count), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count))); + return; + } + + size_t size_a = size_t(stride_a) * size_t(batch_count); + size_t size_b = size_t(stride_b) * size_t(batch_count); + size_t size_c = size_t(stride_c) * size_t(batch_count); + size_t size_s = size_t(stride_s) * size_t(batch_count); + + host_vector ha(size_a); + host_vector hb(size_b); + host_vector hc(size_c); + host_vector hs(size_s); + + for(int i = 0; i < TEST_COUNT; i++) + { + // Initial data on CPU + rocblas_seedrand(); + rocblas_init(ha, 1, 1, 1, stride_a, batch_count); + rocblas_init(hb, 1, 1, 1, stride_b, batch_count); + rocblas_init(hc, 1, 1, 1, stride_c, batch_count); + rocblas_init(hs, 1, 1, 1, stride_s, batch_count); + + // CPU_BLAS + host_vector ca = ha; + host_vector cb = hb; + host_vector cc = hc; + host_vector cs = hs; + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotg( + ca + b * stride_a, cb + b * stride_b, cc + b * stride_c, cs + b * stride_s); + } + cpu_time_used = get_time_us() - cpu_time_used; + + // Test rocblas_pointer_mode_host + { + host_vector ra = ha; + host_vector rb = hb; + host_vector rc = hc; + host_vector rs = hs; + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_ROCBLAS_ERROR((rocblas_rotg_strided_batched( + handle, ra, stride_a, rb, stride_b, rc, stride_c, rs, stride_s, batch_count))); + + if(arg.unit_check) + { + double rel_error = std::numeric_limits::epsilon() * 100; + near_check_general(1, 1, batch_count, 1, stride_a, ca, ra, rel_error); + near_check_general(1, 1, batch_count, 1, stride_b, cb, rb, rel_error); + near_check_general(1, 1, batch_count, 1, stride_c, cc, rc, rel_error); + near_check_general(1, 1, batch_count, 1, stride_s, cs, rs, rel_error); + } + + if(arg.norm_check) + { + norm_error_host + = norm_check_general('F', 1, 1, 1, stride_a, batch_count, ca, ra); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_b, batch_count, cb, rb); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_c, batch_count, cc, rc); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_s, batch_count, cs, rs); + } + } + + // Test rocblas_pointer_mode_device + { + device_vector da(size_a); + device_vector db(size_b); + device_vector dc(size_c); + device_vector ds(size_s); + CHECK_HIP_ERROR(hipMemcpy(da, ha, sizeof(T) * size_a, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, hb, sizeof(T) * size_b, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U) * size_c, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(T) * size_s, hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_ROCBLAS_ERROR((rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count))); + host_vector ra(size_a); + host_vector rb(size_b); + host_vector rc(size_c); + host_vector rs(size_s); + CHECK_HIP_ERROR(hipMemcpy(ra, da, sizeof(T) * size_a, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rb, db, sizeof(T) * size_b, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rc, dc, sizeof(U) * size_c, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rs, ds, sizeof(T) * size_s, hipMemcpyDeviceToHost)); + + if(arg.unit_check) + { + double rel_error = std::numeric_limits::epsilon() * 100; + near_check_general(1, 1, batch_count, 1, stride_a, ca, ra, rel_error); + near_check_general(1, 1, batch_count, 1, stride_b, cb, rb, rel_error); + near_check_general(1, 1, batch_count, 1, stride_c, cc, rc, rel_error); + near_check_general(1, 1, batch_count, 1, stride_s, cs, rs, rel_error); + } + + if(arg.norm_check) + { + norm_error_host + = norm_check_general('F', 1, 1, 1, stride_a, batch_count, ca, ra); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_b, batch_count, cb, rb); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_c, batch_count, cc, rc); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_s, batch_count, cs, rs); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + // Device mode will be quicker + // (TODO: or is there another reason we are typically using host_mode for timing?) + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + device_vector da(size_a); + device_vector db(size_b); + device_vector dc(size_c); + device_vector ds(size_s); + CHECK_HIP_ERROR(hipMemcpy(da, ha, sizeof(T) * size_a, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, hb, sizeof(T) * size_b, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, hc, sizeof(U) * size_c, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(ds, hs, sizeof(T) * size_s, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotg_strided_batched( + handle, da, stride_a, db, stride_b, dc, stride_c, ds, stride_s, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "rocblas-us,CPU-us"; + if(arg.norm_check) + std::cout << ",norm_error_host_ptr,norm_error_device"; + std::cout << std::endl; + + std::cout << gpu_time_used << "," << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host << ',' << norm_error_device; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rotm_batched.hpp b/clients/include/testing_rotm_batched.hpp index 8ab71d6d6..c8e9e8427 100644 --- a/clients/include/testing_rotm_batched.hpp +++ b/clients/include/testing_rotm_batched.hpp @@ -77,7 +77,7 @@ void testing_rotm_batched(const Arguments& arg) EXPECT_ROCBLAS_STATUS( (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count)), rocblas_status_invalid_size); - if(!batch_count) + else CHECK_ROCBLAS_ERROR( (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count))); return; diff --git a/clients/include/type_dispatch.hpp b/clients/include/type_dispatch.hpp index 4717ddf64..d70b3af6d 100644 --- a/clients/include/type_dispatch.hpp +++ b/clients/include/type_dispatch.hpp @@ -49,7 +49,7 @@ auto rocblas_blas1_dispatch(const Arguments& arg) if(Tb == Ti) return rocblas_simple_dispatch(arg); else - { // for csscal and zdscal only + { // for csscal and zdscal and complex rotg only if(Ti == rocblas_datatype_f32_c && Tb == rocblas_datatype_f32_r) return TEST{}(arg); else if(Ti == rocblas_datatype_f64_c && Tb == rocblas_datatype_f64_r) @@ -60,6 +60,10 @@ auto rocblas_blas1_dispatch(const Arguments& arg) return TEST{}(arg); else if(Ti == rocblas_datatype_f64_c && Tb == rocblas_datatype_f64_r) return TEST{}(arg); + else if(Ti == rocblas_datatype_f32_r && Tb == rocblas_datatype_f32_r) + return TEST{}(arg); + else if(Ti == rocblas_datatype_f64_r && Tb == rocblas_datatype_f64_r) + return TEST{}(arg); // else if(Ti == rocblas_datatype_f16_c && To == rocblas_datatype_f16_r) // return TEST{}(arg); diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 3e28299fb..16d713495 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -1497,6 +1497,140 @@ ROCBLAS_EXPORT rocblas_status rocblas_zrotg(rocblas_handle handle, double* c, rocblas_double_complex* s); +/*! \brief BLAS Level 1 API + + \details + rotg_batched creates the Givens rotation matrix for the batched vectors (a b). + a, b, c, and s may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + If the pointer mode is set to rocblas_pointer_mode_host, this function blocks the CPU until the GPU has finished and the results are available in host memory. + If the pointer mode is set to rocblas_pointer_mode_device, this function returns immediately and synchronization is required to read the results. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[inout] + a batched array of single input vector elements, overwritten with r. + @param[inout] + b batched array of single input vector elements, overwritten with z. + @param[inout] + c batched array of cosine elements of Givens rotations. + @param[inout] + s batched array of sine elements of Givens rotations. + @param[in] + batch_count rocblas_int + number of batches (length of arrays a, b, c, and s). + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotg_batched(rocblas_handle handle, + float* const a[], + float* const b[], + float* const c[], + float* const s[], + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotg_batched(rocblas_handle handle, + double* const a[], + double* const b[], + double* const c[], + double* const s[], + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_crotg_batched(rocblas_handle handle, + rocblas_float_complex* const a[], + rocblas_float_complex* const b[], + float* const c[], + rocblas_float_complex* const s[], + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zrotg_batched(rocblas_handle handle, + rocblas_double_complex* const a[], + rocblas_double_complex* const b[], + double* const c[], + rocblas_double_complex* const s[], + rocblas_int batch_count); + +/*! \brief BLAS Level 1 API + + \details + rotg_strided_batched creates the Givens rotation matrix for the strided batched vectors (a b). + a, b, c, and s may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + If the pointer mode is set to rocblas_pointer_mode_host, this function blocks the CPU until the GPU has finished and the results are available in host memory. + If the pointer mode is set to rocblas_pointer_mode_device, this function returns immediately and synchronization is required to read the results. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[inout] + a strided_batched pointer to single input vector elements, overwritten with r. + @param[in] + stride_a rocblas_stride + distance between elements of a in batch (distance between a_i and a_(i + 1)) + @param[inout] + b strided_batched pointer to single input vector elements, overwritten with z. + @param[in] + stride_b rocblas_stride + distance between elements of b in batch (distance between b_i and b_(i + 1)) + @param[inout] + c strided_batched pointer to cosine elements of Givens rotations. + @param[in] + stride_c rocblas_stride + distance between elements of c in batch (distance between c_i and c_(i + 1)) + @param[inout] + s strided_batched pointer to sine elements of Givens rotations. + @param[in] + stride_s rocblas_stride + distance between elements of s in batch (distance between s_i and s_(i + 1)) + @param[in] + batch_count rocblas_int + number of batches (length of arrays a, b, c, and s). + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotg_strided_batched(rocblas_handle handle, + float* a, + rocblas_stride stride_a, + float* b, + rocblas_stride stride_b, + float* c, + rocblas_stride stride_c, + float* s, + rocblas_stride stride_s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotg_strided_batched(rocblas_handle handle, + double* a, + rocblas_stride stride_a, + double* b, + rocblas_stride stride_b, + double* c, + rocblas_stride stride_c, + double* s, + rocblas_stride stride_s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_crotg_strided_batched(rocblas_handle handle, + rocblas_float_complex* a, + rocblas_stride stride_a, + rocblas_float_complex* b, + rocblas_stride stride_b, + float* c, + rocblas_stride stride_c, + rocblas_float_complex* s, + rocblas_stride stride_s, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_zrotg_strided_batched(rocblas_handle handle, + rocblas_double_complex* a, + rocblas_stride stride_a, + rocblas_double_complex* b, + rocblas_stride stride_b, + double* c, + rocblas_stride stride_c, + rocblas_double_complex* s, + rocblas_stride stride_s, + rocblas_int batch_count); + /*! \brief BLAS Level 1 API \details diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index ce47d1ddc..c442b831b 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -142,6 +142,8 @@ set( rocblas_blas1_source blas1/rocblas_rot_batched.cpp blas1/rocblas_rot_strided_batched.cpp blas1/rocblas_rotg.cpp + blas1/rocblas_rotg_batched.cpp + blas1/rocblas_rotg_strided_batched.cpp blas1/rocblas_rotm.cpp blas1/rocblas_rotm_batched.cpp blas1/rocblas_rotm_strided_batched.cpp diff --git a/library/src/blas1/rocblas_rotg.cpp b/library/src/blas1/rocblas_rotg.cpp index d293a10cf..9884c9d90 100644 --- a/library/src/blas1/rocblas_rotg.cpp +++ b/library/src/blas1/rocblas_rotg.cpp @@ -1,6 +1,7 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_rotg.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" @@ -8,64 +9,6 @@ namespace { - template , int>::type = 0> - __device__ __host__ void rotg_calc(T& a, T& b, U& c, T& s) - { - T scale = rocblas_abs(a) + rocblas_abs(b); - if(scale == 0.0) - { - c = 1.0; - s = 0.0; - a = 0.0; - b = 0.0; - } - else - { - T sa = a / scale; - T sb = b / scale; - T r = scale * sqrt(sa * sa + sb * sb); - T roe = rocblas_abs(a) > rocblas_abs(b) ? a : b; - r = copysign(r, roe); - c = a / r; - s = b / r; - T z = 1.0; - if(rocblas_abs(a) > rocblas_abs(b)) - z = s; - if(rocblas_abs(b) >= rocblas_abs(a) && c != 0.0) - z = 1.0 / c; - a = r; - b = z; - } - } - - template , int>::type = 0> - __device__ __host__ void rotg_calc(T& a, T& b, U& c, T& s) - { - if(!rocblas_abs(a)) - { - c = 0; - s = {1, 0}; - a = b; - } - else - { - auto scale = rocblas_abs(a) + rocblas_abs(b); - auto sa = rocblas_abs(a / scale); - auto sb = rocblas_abs(b / scale); - auto norm = scale * sqrt(sa * sa + sb * sb); - auto alpha = a / rocblas_abs(a); - c = rocblas_abs(a) / norm; - s = alpha * conj(b) / norm; - a = alpha * norm; - } - } - - template - __global__ void rotg_kernel(T* a, T* b, U* c, T* s) - { - rotg_calc(*a, *b, *c, *s); - } - template constexpr char rocblas_rotg_name[] = "unknown"; template <> @@ -78,7 +21,7 @@ namespace constexpr char rocblas_rotg_name[] = "rocblas_zrotg"; template - rocblas_status rocblas_rotg(rocblas_handle handle, T* a, T* b, U* c, T* s) + rocblas_status rocblas_rotg_impl(rocblas_handle handle, T* a, T* b, U* c, T* s) { if(!handle) return rocblas_status_invalid_handle; @@ -87,7 +30,11 @@ namespace if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, rocblas_rotg_name, a, b, c, s); if(layer_mode & rocblas_layer_mode_log_bench) - log_bench(handle, "./rocblas-bench -f rotg -r", rocblas_precision_string); + log_bench(handle, + "./rocblas-bench -f rotg --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string); if(layer_mode & rocblas_layer_mode_log_profile) log_profile(handle, rocblas_rotg_name); @@ -96,19 +43,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(rocblas_pointer_mode_device == handle->pointer_mode) - { - hipLaunchKernelGGL(rotg_kernel, 1, 1, 0, rocblas_stream, a, b, c, s); - } - else - { - RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); - rotg_calc(*a, *b, *c, *s); - } - - return rocblas_status_success; + return rocblas_rotg_template(handle, a, 0, 0, b, 0, 0, c, 0, 0, s, 0, 0, 1); } } // namespace @@ -123,12 +58,12 @@ extern "C" { rocblas_status rocblas_srotg(rocblas_handle handle, float* a, float* b, float* c, float* s) { - return rocblas_rotg(handle, a, b, c, s); + return rocblas_rotg_impl(handle, a, b, c, s); } rocblas_status rocblas_drotg(rocblas_handle handle, double* a, double* b, double* c, double* s) { - return rocblas_rotg(handle, a, b, c, s); + return rocblas_rotg_impl(handle, a, b, c, s); } rocblas_status rocblas_crotg(rocblas_handle handle, @@ -137,7 +72,7 @@ rocblas_status rocblas_crotg(rocblas_handle handle, float* c, rocblas_float_complex* s) { - return rocblas_rotg(handle, a, b, c, s); + return rocblas_rotg_impl(handle, a, b, c, s); } rocblas_status rocblas_zrotg(rocblas_handle handle, @@ -146,7 +81,7 @@ rocblas_status rocblas_zrotg(rocblas_handle handle, double* c, rocblas_double_complex* s) { - return rocblas_rotg(handle, a, b, c, s); + return rocblas_rotg_impl(handle, a, b, c, s); } } // extern "C" diff --git a/library/src/blas1/rocblas_rotg.hpp b/library/src/blas1/rocblas_rotg.hpp new file mode 100644 index 000000000..9c13d2ff9 --- /dev/null +++ b/library/src/blas1/rocblas_rotg.hpp @@ -0,0 +1,135 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "utility.h" + +template , int>::type = 0> +__device__ __host__ void rocblas_rotg_calc(T& a, T& b, U& c, T& s) +{ + T scale = rocblas_abs(a) + rocblas_abs(b); + if(scale == 0.0) + { + c = 1.0; + s = 0.0; + a = 0.0; + b = 0.0; + } + else + { + T sa = a / scale; + T sb = b / scale; + T r = scale * sqrt(sa * sa + sb * sb); + T roe = rocblas_abs(a) > rocblas_abs(b) ? a : b; + r = copysign(r, roe); + c = a / r; + s = b / r; + T z = 1.0; + if(rocblas_abs(a) > rocblas_abs(b)) + z = s; + if(rocblas_abs(b) >= rocblas_abs(a) && c != 0.0) + z = 1.0 / c; + a = r; + b = z; + } +} + +template , int>::type = 0> +__device__ __host__ void rocblas_rotg_calc(T& a, T& b, U& c, T& s) +{ + if(!rocblas_abs(a)) + { + c = 0; + s = {1, 0}; + a = b; + } + else + { + auto scale = rocblas_abs(a) + rocblas_abs(b); + auto sa = rocblas_abs(a / scale); + auto sb = rocblas_abs(b / scale); + auto norm = scale * sqrt(sa * sa + sb * sb); + auto alpha = a / rocblas_abs(a); + c = rocblas_abs(a) / norm; + s = alpha * conj(b) / norm; + a = alpha * norm; + } +} + +template +__global__ void rocblas_rotg_kernel(T a_in, + rocblas_int offset_a, + rocblas_stride stride_a, + T b_in, + rocblas_int offset_b, + rocblas_stride stride_b, + U c_in, + rocblas_int offset_c, + rocblas_stride stride_c, + T s_in, + rocblas_int offset_s, + rocblas_stride stride_s) +{ + auto a = load_ptr_batch(a_in, hipBlockIdx_x, offset_a, stride_a); + auto b = load_ptr_batch(b_in, hipBlockIdx_x, offset_b, stride_b); + auto c = load_ptr_batch(c_in, hipBlockIdx_x, offset_c, stride_c); + auto s = load_ptr_batch(s_in, hipBlockIdx_x, offset_s, stride_s); + rocblas_rotg_calc(*a, *b, *c, *s); +} + +template +rocblas_status rocblas_rotg_template(rocblas_handle handle, + T a_in, + rocblas_int offset_a, + rocblas_stride stride_a, + T b_in, + rocblas_int offset_b, + rocblas_stride stride_b, + U c_in, + rocblas_int offset_c, + rocblas_stride stride_c, + T s_in, + rocblas_int offset_s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + hipStream_t rocblas_stream = handle->rocblas_stream; + + if(rocblas_pointer_mode_device == handle->pointer_mode) + { + hipLaunchKernelGGL(rocblas_rotg_kernel, + batch_count, + 1, + 0, + rocblas_stream, + a_in, + offset_a, + stride_a, + b_in, + offset_b, + stride_b, + c_in, + offset_c, + stride_c, + s_in, + offset_s, + stride_s); + } + else + { + RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); + for(int i = 0; i < batch_count; i++) + { + auto a = load_ptr_batch(a_in, i, offset_a, stride_a); + auto b = load_ptr_batch(b_in, i, offset_b, stride_b); + auto c = load_ptr_batch(c_in, i, offset_c, stride_c); + auto s = load_ptr_batch(s_in, i, offset_s, stride_s); + + rocblas_rotg_calc(*a, *b, *c, *s); + } + } + + return rocblas_status_success; +} \ No newline at end of file diff --git a/library/src/blas1/rocblas_rotg_batched.cpp b/library/src/blas1/rocblas_rotg_batched.cpp new file mode 100644 index 000000000..fb4260cde --- /dev/null +++ b/library/src/blas1/rocblas_rotg_batched.cpp @@ -0,0 +1,108 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotg.hpp" +#include "utility.h" + +namespace +{ + template + constexpr char rocblas_rotg_name[] = "unknown"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_srotg_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_drotg_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_crotg_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_zrotg_batched"; + + template + rocblas_status rocblas_rotg_batched_impl(rocblas_handle handle, + T* const a[], + T* const b[], + U* const c[], + T* const s[], + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rotg_name, a, b, c, s, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotg_batched --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, rocblas_rotg_name, "batch", batch_count); + + if(!a || !b || !c || !s) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotg_template(handle, a, 0, 0, b, 0, 0, c, 0, 0, s, 0, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +rocblas_status rocblas_srotg_batched(rocblas_handle handle, + float* const a[], + float* const b[], + float* const c[], + float* const s[], + rocblas_int batch_count) +{ + return rocblas_rotg_batched_impl(handle, a, b, c, s, batch_count); +} + +rocblas_status rocblas_drotg_batched(rocblas_handle handle, + double* const a[], + double* const b[], + double* const c[], + double* const s[], + rocblas_int batch_count) +{ + return rocblas_rotg_batched_impl(handle, a, b, c, s, batch_count); +} + +rocblas_status rocblas_crotg_batched(rocblas_handle handle, + rocblas_float_complex* const a[], + rocblas_float_complex* const b[], + float* const c[], + rocblas_float_complex* const s[], + rocblas_int batch_count) +{ + return rocblas_rotg_batched_impl(handle, a, b, c, s, batch_count); +} + +rocblas_status rocblas_zrotg_batched(rocblas_handle handle, + rocblas_double_complex* const a[], + rocblas_double_complex* const b[], + double* const c[], + rocblas_double_complex* const s[], + rocblas_int batch_count) +{ + return rocblas_rotg_batched_impl(handle, a, b, c, s, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotg_strided_batched.cpp b/library/src/blas1/rocblas_rotg_strided_batched.cpp new file mode 100644 index 000000000..a50f47452 --- /dev/null +++ b/library/src/blas1/rocblas_rotg_strided_batched.cpp @@ -0,0 +1,133 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotg.hpp" +#include "utility.h" + +namespace +{ + template + constexpr char rocblas_rotg_name[] = "unknown"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_srotg_strided_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_drotg_strided_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_crotg_strided_batched"; + template <> + constexpr char rocblas_rotg_name[] = "rocblas_zrotg_strided_batched"; + + template + rocblas_status rocblas_rotg_strided_batched_impl(rocblas_handle handle, + T* a, + rocblas_stride stride_a, + T* b, + rocblas_stride stride_b, + U* c, + rocblas_stride stride_c, + T* s, + rocblas_stride stride_s, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rotg_name, a, b, c, s, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotg_batched --a_type", + rocblas_precision_string, + "--b_type", + rocblas_precision_string, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, rocblas_rotg_name, "batch", batch_count); + + if(!a || !b || !c || !s) + return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotg_template( + handle, a, 0, stride_a, b, 0, stride_b, c, 0, stride_c, s, 0, stride_s, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +rocblas_status rocblas_srotg_strided_batched(rocblas_handle handle, + float* a, + rocblas_stride stride_a, + float* b, + rocblas_stride stride_b, + float* c, + rocblas_stride stride_c, + float* s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + return rocblas_rotg_strided_batched_impl( + handle, a, stride_a, b, stride_b, c, stride_c, s, stride_s, batch_count); +} + +rocblas_status rocblas_drotg_strided_batched(rocblas_handle handle, + double* a, + rocblas_stride stride_a, + double* b, + rocblas_stride stride_b, + double* c, + rocblas_stride stride_c, + double* s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + return rocblas_rotg_strided_batched_impl( + handle, a, stride_a, b, stride_b, c, stride_c, s, stride_s, batch_count); +} + +rocblas_status rocblas_crotg_strided_batched(rocblas_handle handle, + rocblas_float_complex* a, + rocblas_stride stride_a, + rocblas_float_complex* b, + rocblas_stride stride_b, + float* c, + rocblas_stride stride_c, + rocblas_float_complex* s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + return rocblas_rotg_strided_batched_impl( + handle, a, stride_a, b, stride_b, c, stride_c, s, stride_s, batch_count); +} + +rocblas_status rocblas_zrotg_strided_batched(rocblas_handle handle, + rocblas_double_complex* a, + rocblas_stride stride_a, + rocblas_double_complex* b, + rocblas_stride stride_b, + double* c, + rocblas_stride stride_c, + rocblas_double_complex* s, + rocblas_stride stride_s, + rocblas_int batch_count) +{ + return rocblas_rotg_strided_batched_impl( + handle, a, stride_a, b, stride_b, c, stride_c, s, stride_s, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotm.hpp b/library/src/blas1/rocblas_rotm.hpp index d4f7429f9..7e04851a2 100644 --- a/library/src/blas1/rocblas_rotm.hpp +++ b/library/src/blas1/rocblas_rotm.hpp @@ -23,11 +23,11 @@ __global__ void rotm_kernel(rocblas_int n, U h22_device_host, rocblas_stride stride_param) { - auto flag = load_scalar(flag_device_host); //, hipBlockIdx_y, stride_param); - auto h11 = load_scalar(h11_device_host); //, hipBlockIdx_y, stride_param); - auto h21 = load_scalar(h21_device_host); //, hipBlockIdx_y, stride_param); - auto h12 = load_scalar(h12_device_host); //, hipBlockIdx_y, stride_param); - auto h22 = load_scalar(h22_device_host); //, hipBlockIdx_y, stride_param); + auto flag = load_scalar(flag_device_host, hipBlockIdx_y, stride_param); + auto h11 = load_scalar(h11_device_host, hipBlockIdx_y, stride_param); + auto h21 = load_scalar(h21_device_host, hipBlockIdx_y, stride_param); + auto h12 = load_scalar(h12_device_host, hipBlockIdx_y, stride_param); + auto h22 = load_scalar(h22_device_host, hipBlockIdx_y, stride_param); auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; From 52299929ad9b4fbf5768878eeb14bf65d5da75af Mon Sep 17 00:00:00 2001 From: First Last Date: Wed, 2 Oct 2019 14:10:53 -0600 Subject: [PATCH 04/11] Changes to rotm so template can accept array for scalar (param). --- library/src/blas1/rocblas_rotm.cpp | 2 +- library/src/blas1/rocblas_rotm.hpp | 45 +++++++++++++++++-- library/src/blas1/rocblas_rotm_batched.cpp | 2 +- .../blas1/rocblas_rotm_strided_batched.cpp | 2 +- library/src/blas1/rocblas_scal.hpp | 2 +- 5 files changed, 46 insertions(+), 7 deletions(-) diff --git a/library/src/blas1/rocblas_rotm.cpp b/library/src/blas1/rocblas_rotm.cpp index 5ffd97723..de58007bf 100644 --- a/library/src/blas1/rocblas_rotm.cpp +++ b/library/src/blas1/rocblas_rotm.cpp @@ -51,7 +51,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - return rocblas_rotm_template(handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 1); + return rocblas_rotm_template(handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 1, (T*)nullptr); } } // namespace diff --git a/library/src/blas1/rocblas_rotm.hpp b/library/src/blas1/rocblas_rotm.hpp index 7e04851a2..b49ccffa7 100644 --- a/library/src/blas1/rocblas_rotm.hpp +++ b/library/src/blas1/rocblas_rotm.hpp @@ -69,10 +69,22 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, rocblas_stride stride_y, const T* param, rocblas_stride stride_param, - rocblas_int batch_count) + rocblas_int batch_count, + T* mem) { + // Memory queries must be in template as _impl doesn't have stride_param parameter (for calls from + // outside of rocblas) + if(handle->is_device_memory_size_query()) + { + if(stride_param && rocblas_pointer_mode_host == handle->pointer_mode && n > 0 && incx > 0 + && incy > 0 && batch_count > 0) + return handle->set_optimal_device_memory_size(sizeof(T) * batch_count * stride_param); + else + return rocblas_status_size_unchanged; + } + // Quick return if possible - if(n <= 0 || incx <= 0 || incy <= 0 || batch_count == 0) + if(n <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) return rocblas_status_success; if(rocblas_pointer_mode_host == handle->pointer_mode && param[0] == -2) return rocblas_status_success; @@ -102,7 +114,7 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, param + 3, param + 4, stride_param); - else // c and s are on host + else if(!stride_param) // single param on host hipLaunchKernelGGL(rotm_kernel, blocks, threads, @@ -123,6 +135,33 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, param[3], param[4], stride_param); + else // array of params on host, copy to device + { + // This should NOT happen from calls from the API currently. + RETURN_IF_HIP_ERROR( + hipMemcpy(mem, param, sizeof(T) * batch_count * stride_param, hipMemcpyHostToDevice)); + + hipLaunchKernelGGL(rotm_kernel, + blocks, + threads, + 0, + rocblas_stream, + n, + x, + offset_x, + incx, + stride_x, + y, + offset_y, + incy, + stride_y, + mem, + mem + 1, + mem + 2, + mem + 3, + mem + 4, + 0); + } return rocblas_status_success; } \ No newline at end of file diff --git a/library/src/blas1/rocblas_rotm_batched.cpp b/library/src/blas1/rocblas_rotm_batched.cpp index b7e7a650b..09e2e12b0 100644 --- a/library/src/blas1/rocblas_rotm_batched.cpp +++ b/library/src/blas1/rocblas_rotm_batched.cpp @@ -66,7 +66,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); return rocblas_rotm_template( - handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, batch_count); + handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, batch_count, (T*)nullptr); } } // namespace diff --git a/library/src/blas1/rocblas_rotm_strided_batched.cpp b/library/src/blas1/rocblas_rotm_strided_batched.cpp index ef4fb5475..680dc12ae 100644 --- a/library/src/blas1/rocblas_rotm_strided_batched.cpp +++ b/library/src/blas1/rocblas_rotm_strided_batched.cpp @@ -86,7 +86,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); return rocblas_rotm_template( - handle, n, x, 0, incx, stride_x, y, 0, incy, stride_y, param, 0, batch_count); + handle, n, x, 0, incx, stride_x, y, 0, incy, stride_y, param, 0, batch_count, (T*)nullptr); } } // namespace diff --git a/library/src/blas1/rocblas_scal.hpp b/library/src/blas1/rocblas_scal.hpp index 23ee018aa..2c8663c64 100644 --- a/library/src/blas1/rocblas_scal.hpp +++ b/library/src/blas1/rocblas_scal.hpp @@ -39,7 +39,7 @@ rocblas_status rocblas_scal_template(rocblas_handle handle, // outside of rocblas) if(handle->is_device_memory_size_query()) { - if(stride_alpha && rocblas_pointer_mode_device == handle->pointer_mode && n > 0 && incx > 0 + if(stride_alpha && rocblas_pointer_mode_host == handle->pointer_mode && n > 0 && incx > 0 && batch_count > 0) return handle->set_optimal_device_memory_size(sizeof(V) * batch_count * stride_alpha); else From 68c006abe3385755cb2f25197d0970431970f280 Mon Sep 17 00:00:00 2001 From: First Last Date: Wed, 2 Oct 2019 17:43:29 -0600 Subject: [PATCH 05/11] Added rotmg batched and strided_batched, other misc. changes --- clients/benchmarks/client.cpp | 6 + clients/common/rocblas_gentest.py | 7 + clients/gtest/blas1_gtest.cpp | 16 +- clients/gtest/blas1_gtest.yaml | 4 + clients/include/rocblas.hpp | 36 ++ clients/include/testing_rotmg_batched.hpp | 294 ++++++++++++++++ .../include/testing_rotmg_strided_batched.hpp | 315 ++++++++++++++++++ library/include/rocblas-functions.h | 116 +++++++ library/src/CMakeLists.txt | 2 + library/src/blas1/rocblas_rotg.hpp | 1 + library/src/blas1/rocblas_rotm.cpp | 3 +- .../blas1/rocblas_rotm_strided_batched.cpp | 16 +- library/src/blas1/rocblas_rotmg.cpp | 171 +--------- library/src/blas1/rocblas_rotmg.hpp | 237 +++++++++++++ library/src/blas1/rocblas_rotmg_batched.cpp | 86 +++++ .../blas1/rocblas_rotmg_strided_batched.cpp | 133 ++++++++ 16 files changed, 1273 insertions(+), 170 deletions(-) create mode 100644 clients/include/testing_rotmg_batched.hpp create mode 100644 clients/include/testing_rotmg_strided_batched.hpp create mode 100644 library/src/blas1/rocblas_rotmg.hpp create mode 100644 library/src/blas1/rocblas_rotmg_batched.cpp create mode 100644 library/src/blas1/rocblas_rotmg_strided_batched.cpp diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index de52fe89a..c1644bedb 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -30,6 +30,8 @@ #include "testing_rotm_batched.hpp" #include "testing_rotm_strided_batched.hpp" #include "testing_rotmg.hpp" +#include "testing_rotmg_batched.hpp" +#include "testing_rotmg_strided_batched.hpp" #include "testing_scal.hpp" #include "testing_scal_batched.hpp" #include "testing_scal_strided_batched.hpp" @@ -194,6 +196,10 @@ struct perf_blas< testing_rotm_strided_batched(arg); else if(!strcmp(arg.function, "rotmg")) testing_rotmg(arg); + else if(!strcmp(arg.function, "rotmg_batched")) + testing_rotmg_batched(arg); + else if(!strcmp(arg.function, "rotmg_strided_batched")) + testing_rotmg_strided_batched(arg); else if(!strcmp(arg.function, "gemv")) testing_gemv(arg); else if(!strcmp(arg.function, "gemv_batched")) diff --git a/clients/common/rocblas_gentest.py b/clients/common/rocblas_gentest.py index 4a6c55327..105f2b370 100755 --- a/clients/common/rocblas_gentest.py +++ b/clients/common/rocblas_gentest.py @@ -222,6 +222,13 @@ def setdefaults(test): test.setdefault('stride_c', int(test['stride_scale'])) test.setdefault('stride_d', int(test['stride_scale'])) + if test['function'] in ('rotmg_strided_batched'): + if 'stride_scale' in test: + test.setdefault('stride_a', int(test['stride_scale'])) + test.setdefault('stride_b', int(test['stride_scale'])) + test.setdefault('stride_x', int(test['stride_scale'])) + test.setdefault('stride_y', int(test['stride_scale'])) + test.setdefault('stride_x', 0) test.setdefault('stride_y', 0) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index 111b4b5a1..2c28972bc 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -25,6 +25,8 @@ #include "testing_rotm_batched.hpp" #include "testing_rotm_strided_batched.hpp" #include "testing_rotmg.hpp" +#include "testing_rotmg_batched.hpp" +#include "testing_rotmg_strided_batched.hpp" #include "testing_scal.hpp" #include "testing_scal_batched.hpp" #include "testing_scal_strided_batched.hpp" @@ -68,6 +70,8 @@ namespace rotm_batched, rotm_strided_batched, rotmg, + rotmg_batched, + rotmg_strided_batched, }; // ---------------------------------------------------------------------------- @@ -104,7 +108,8 @@ namespace bool is_batched = (BLAS1 == blas1::nrm2_batched || BLAS1 == blas1::asum_batched || BLAS1 == blas1::scal_batched || BLAS1 == blas1::swap_batched || BLAS1 == blas1::copy_batched || BLAS1 == blas1::rot_batched - || BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotg_batched); + || BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotg_batched + || BLAS1 == blas1::rotmg_batched); bool is_strided = (BLAS1 == blas1::nrm2_strided_batched || BLAS1 == blas1::asum_strided_batched || BLAS1 == blas1::scal_strided_batched @@ -112,7 +117,8 @@ namespace || BLAS1 == blas1::copy_strided_batched || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotm_strided_batched - || BLAS1 == blas1::rotg_strided_batched); + || BLAS1 == blas1::rotg_strided_batched + || BLAS1 == blas1::rotmg_strided_batched); if((is_scal || is_rotg || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched || BLAS1 == blas1::rot_strided_batched) @@ -263,7 +269,9 @@ namespace && std::is_same{} && std::is_same{} && (std::is_same{} || std::is_same{})) - || (BLAS1 == blas1::rotmg && std::is_same{} && std::is_same{} + || ((BLAS1 == blas1::rotmg || BLAS1 == blas1::rotmg_batched + || BLAS1 == blas1::rotmg_strided_batched) + && std::is_same{} && std::is_same{} && (std::is_same{} || std::is_same{}))>; // Creates tests for one of the BLAS 1 functions @@ -346,6 +354,8 @@ BLAS1_TESTING(rotm, ARG1) BLAS1_TESTING(rotm_batched, ARG1) BLAS1_TESTING(rotm_strided_batched, ARG1) BLAS1_TESTING(rotmg, ARG1) +BLAS1_TESTING(rotmg_batched, ARG1) +BLAS1_TESTING(rotmg_strided_batched, ARG1) // clang-format on diff --git a/clients/gtest/blas1_gtest.yaml b/clients/gtest/blas1_gtest.yaml index 19a9d647a..94332ce16 100644 --- a/clients/gtest/blas1_gtest.yaml +++ b/clients/gtest/blas1_gtest.yaml @@ -56,6 +56,7 @@ Tests: batch_count: [-1, 0, 5] function: - rotg_batched: *rotg_precisions + - rotmg_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: quick @@ -63,6 +64,7 @@ Tests: stride_scale: [ 1, 7 ] function: - rotg_strided_batched: *rotg_precisions + - rotmg_strided_batched: *single_double_precisions_complex_real - name: blas1_batched category: quick @@ -345,6 +347,7 @@ Tests: - copy_batched_bad_arg: *single_double_precisions_complex_real - rot_batched_bad_arg: *rot_precisions - rotg_batched_bad_arg: *rot_precisions + - rotmg_batched_bad_arg: *single_double_precisions_complex_real - rotm_batched_bad_arg: *single_double_precisions_complex_real - name: blas1_strided_batched_bad_arg @@ -357,6 +360,7 @@ Tests: - copy_strided_batched_bad_arg: *single_double_precisions_complex_real - rot_strided_batched_bad_arg: *rot_precisions - rotg_strided_batched_bad_arg: *rot_precisions + - rotmg_strided_batched_bad_arg: *rot_precisions - rotm_strided_batched_bad_arg: *single_double_precisions_complex_real ... diff --git a/clients/include/rocblas.hpp b/clients/include/rocblas.hpp index 14757770b..48caee8d8 100644 --- a/clients/include/rocblas.hpp +++ b/clients/include/rocblas.hpp @@ -711,6 +711,42 @@ static constexpr auto rocblas_rotmg = rocblas_srotmg; template <> static constexpr auto rocblas_rotmg = rocblas_drotmg; +//rotmg_batched +template +rocblas_status (*rocblas_rotmg_batched)(rocblas_handle handle, + T* const d1[], + T* const d2[], + T* const x1[], + const T* const y1[], + T* param, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rotmg_batched = rocblas_srotmg_batched; + +template <> +static constexpr auto rocblas_rotmg_batched = rocblas_drotmg_batched; + +//rotmg_strided_batched +template +rocblas_status (*rocblas_rotmg_strided_batched)(rocblas_handle handle, + T* d1, + rocblas_stride stride_d1, + T* d2, + rocblas_stride stride_d2, + T* x1, + rocblas_stride stride_x1, + const T* y1, + rocblas_stride stride_y1, + T* param, + rocblas_int batch_count); + +template <> +static constexpr auto rocblas_rotmg_strided_batched = rocblas_srotmg_strided_batched; + +template <> +static constexpr auto rocblas_rotmg_strided_batched = rocblas_drotmg_strided_batched; + /* * =========================================================================== * level 2 BLAS diff --git a/clients/include/testing_rotmg_batched.hpp b/clients/include/testing_rotmg_batched.hpp new file mode 100644 index 000000000..6fc1d253a --- /dev/null +++ b/clients/include/testing_rotmg_batched.hpp @@ -0,0 +1,294 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotmg_batched_bad_arg(const Arguments& arg) +{ + rocblas_int batch_count = 5; + static const size_t safe_size = 5; + + rocblas_local_handle handle; + device_vector d1(batch_count); + device_vector d2(batch_count); + device_vector x1(batch_count); + device_vector y1(batch_count); + device_vector param(safe_size); + + if(!d1 || !d2 || !x1 || !y1 || !param) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_batched(nullptr, d1, d2, x1, y1, param, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, nullptr, d2, x1, y1, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, d1, nullptr, x1, y1, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, d1, d2, nullptr, y1, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, d1, d2, x1, nullptr, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_batched(handle, d1, d2, x1, y1, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotmg_batched(const Arguments& arg) +{ + const int TEST_COUNT = 1; //00; + rocblas_int batch_count = arg.batch_count; + rocblas_local_handle handle; + + double gpu_time_used, cpu_time_used; + double norm_error_host = 0.0, norm_error_device = 0.0; + + // check to prevent undefined memory allocation error + if(batch_count <= 0) + { + size_t safe_size = 1; + device_vector d1(safe_size); + device_vector d2(safe_size); + device_vector x1(safe_size); + device_vector y1(safe_size); + device_vector params(5); + + if(!d1 || !d2 || !x1 || !y1 || !params) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_batched(handle, d1, d2, x1, y1, params, batch_count)), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR( + (rocblas_rotmg_batched(handle, d1, d2, x1, y1, params, batch_count))); + + return; + } + + // Initial Data on CPU + host_vector hd1[batch_count]; + host_vector hd2[batch_count]; + host_vector hx1[batch_count]; + host_vector hy1[batch_count]; + host_vector hparams(5); + + device_batch_vector bd1(batch_count, 1); + device_batch_vector bd2(batch_count, 1); + device_batch_vector bx1(batch_count, 1); + device_batch_vector by1(batch_count, 1); + + for(int b = 0; b < batch_count; b++) + { + hd1[b] = host_vector(1); + hd2[b] = host_vector(1); + hx1[b] = host_vector(1); + hy1[b] = host_vector(1); + } + + for(int i = 0; i < TEST_COUNT; i++) + { + host_vector cd1[batch_count]; + host_vector cd2[batch_count]; + host_vector cx1[batch_count]; + host_vector cy1[batch_count]; + + rocblas_seedrand(); + rocblas_init(hparams, 1, 5, 1); + host_vector cparams = hparams; + for(int b = 0; b < batch_count; b++) + { + rocblas_init(hd1[b], 1, 1, 1); + rocblas_init(hd2[b], 1, 1, 1); + rocblas_init(hx1[b], 1, 1, 1); + rocblas_init(hy1[b], 1, 1, 1); + cd1[b] = hd1[b]; + cd2[b] = hd2[b]; + cx1[b] = hx1[b]; + cy1[b] = hy1[b]; + } + + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotmg(cd1[b], cd2[b], cx1[b], cy1[b], cparams); + } + cpu_time_used = get_time_us() - cpu_time_used; + + // Test rocblas_pointer_mode_host + { + host_vector rd1[batch_count]; + host_vector rd2[batch_count]; + host_vector rx1[batch_count]; + host_vector ry1[batch_count]; + T* rd1_in[batch_count]; + T* rd2_in[batch_count]; + T* rx1_in[batch_count]; + T* ry1_in[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rd1_in[b] = rd1[b] = hd1[b]; + rd2_in[b] = rd2[b] = hd2[b]; + rx1_in[b] = rx1[b] = hx1[b]; + ry1_in[b] = ry1[b] = hy1[b]; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + + CHECK_ROCBLAS_ERROR((rocblas_rotmg_batched( + handle, rd1_in, rd2_in, rx1_in, ry1_in, hparams, batch_count))); + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, rd1, cd1); + unit_check_general(1, 1, batch_count, 1, rd2, cd2); + unit_check_general(1, 1, batch_count, 1, rx1, cx1); + unit_check_general(1, 1, batch_count, 1, ry1, cy1); + } + + if(arg.norm_check) + { + norm_error_host = norm_check_general('F', 1, 1, batch_count, 1, rd1, cd1); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rd2, cd2); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rx1, cx1); + norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, ry1, cy1); + } + } + + // Test rocblas_pointer_mode_device + { + device_vector dd1(batch_count); + device_vector dd2(batch_count); + device_vector dx1(batch_count); + device_vector dy1(batch_count); + device_batch_vector bd1(batch_count, 1); + device_batch_vector bd2(batch_count, 1); + device_batch_vector bx1(batch_count, 1); + device_batch_vector by1(batch_count, 1); + device_vector dparams(5); + + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bd1[b], hd1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bd2[b], hd2[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bx1[b], hx1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by1[b], hy1[b], sizeof(T), hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dd1, bd1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dd2, bd2, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dx1, bx1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy1, by1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * 5, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_ROCBLAS_ERROR( + (rocblas_rotmg_batched(handle, dd1, dd2, dx1, dy1, dparams, batch_count))); + + host_vector rd1[batch_count]; + host_vector rd2[batch_count]; + host_vector rx1[batch_count]; + host_vector ry1[batch_count]; + for(int b = 0; b < batch_count; b++) + { + rd1[b] = host_vector(1); + rd2[b] = host_vector(1); + rx1[b] = host_vector(1); + ry1[b] = host_vector(1); + CHECK_HIP_ERROR(hipMemcpy(rd1[b], bd1[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rd2[b], bd2[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rx1[b], bx1[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry1[b], by1[b], sizeof(T), hipMemcpyDeviceToHost)); + } + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, rd1, cd1); + unit_check_general(1, 1, batch_count, 1, rd2, cd2); + unit_check_general(1, 1, batch_count, 1, rx1, cx1); + unit_check_general(1, 1, batch_count, 1, ry1, cy1); + } + + if(arg.norm_check) + { + norm_error_device = norm_check_general('F', 1, 1, batch_count, 1, rd1, cx1); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rd2, cd2); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rx1, cx1); + norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, ry1, cy1); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + // Device mode will be much quicker + // (TODO: or is there another reason we are typically using host_mode for timing?) + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + device_vector dd1(batch_count); + device_vector dd2(batch_count); + device_vector dx1(batch_count); + device_vector dy1(batch_count); + device_batch_vector bd1(batch_count, 1); + device_batch_vector bd2(batch_count, 1); + device_batch_vector bx1(batch_count, 1); + device_batch_vector by1(batch_count, 1); + device_vector dparams(5); + for(int b = 0; b < batch_count; b++) + { + CHECK_HIP_ERROR(hipMemcpy(bd1[b], hd1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bd2[b], hd2[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bx1[b], hx1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(by1[b], hy1[b], sizeof(T), hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dd1, bd1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dd2, bd2, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dx1, bx1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy1, by1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * 5, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotmg_batched(handle, dd1, dd2, dx1, dy1, dparams, batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotmg_batched(handle, dd1, dd2, dx1, dy1, dparams, batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "rocblas-us,CPU-us"; + if(arg.norm_check) + std::cout << ",norm_error_host_ptr,norm_error_device"; + std::cout << std::endl; + + std::cout << gpu_time_used << "," << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host << ',' << norm_error_device; + std::cout << std::endl; + } +} diff --git a/clients/include/testing_rotmg_strided_batched.hpp b/clients/include/testing_rotmg_strided_batched.hpp new file mode 100644 index 000000000..aa6fe5a9f --- /dev/null +++ b/clients/include/testing_rotmg_strided_batched.hpp @@ -0,0 +1,315 @@ +/* ************************************************************************ + * Copyright 2018-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ + +#include "cblas_interface.hpp" +#include "norm.hpp" +#include "rocblas.hpp" +#include "rocblas_init.hpp" +#include "rocblas_math.hpp" +#include "rocblas_random.hpp" +#include "rocblas_test.hpp" +#include "rocblas_vector.hpp" +#include "unit.hpp" +#include "utility.hpp" + +template +void testing_rotmg_strided_batched_bad_arg(const Arguments& arg) +{ + rocblas_int batch_count = 5; + static const size_t safe_size = 5; + + rocblas_local_handle handle; + device_vector d1(safe_size); + device_vector d2(safe_size); + device_vector x1(safe_size); + device_vector y1(safe_size); + device_vector param(safe_size); + + if(!d1 || !d2 || !x1 || !y1 || !param) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCBLAS_STATUS( + (rocblas_rotmg_strided_batched(nullptr, d1, 0, d2, 0, x1, 0, y1, 0, param, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, nullptr, 0, d2, 0, x1, 0, y1, 0, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, d1, 0, nullptr, 0, x1, 0, y1, 0, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, d1, 0, d2, 0, nullptr, 0, y1, 0, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, d1, 0, d2, 0, x1, 0, nullptr, 0, param, batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, d1, 0, d2, 0, x1, 0, y1, 0, nullptr, batch_count)), + rocblas_status_invalid_pointer); +} + +template +void testing_rotmg_strided_batched(const Arguments& arg) +{ + const int TEST_COUNT = 1; //00; + rocblas_int batch_count = arg.batch_count; + rocblas_int stride_d1 = arg.stride_a; + rocblas_int stride_d2 = arg.stride_b; + rocblas_int stride_x1 = arg.stride_x; + rocblas_int stride_y1 = arg.stride_y; + rocblas_local_handle handle; + + double gpu_time_used, cpu_time_used; + double norm_error_host = 0.0, norm_error_device = 0.0; + + // check to prevent undefined memory allocation error + if(batch_count <= 0) + { + size_t safe_size = 1; + device_vector d1(safe_size); + device_vector d2(safe_size); + device_vector x1(safe_size); + device_vector y1(safe_size); + device_vector params(5); + + if(!d1 || !d2 || !x1 || !y1 || !params) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + if(batch_count < 0) + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched(handle, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + params, + batch_count)), + rocblas_status_invalid_size); + else + CHECK_ROCBLAS_ERROR((rocblas_rotmg_strided_batched(handle, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + params, + batch_count))); + + return; + } + + size_t size_d1 = batch_count * stride_d1; + size_t size_d2 = batch_count * stride_d2; + size_t size_x1 = batch_count * stride_x1; + size_t size_y1 = batch_count * stride_y1; + + // Initial Data on CPU + host_vector hd1(size_d1); + host_vector hd2(size_d2); + host_vector hx1(size_x1); + host_vector hy1(size_y1); + host_vector hparams(5); + + for(int i = 0; i < TEST_COUNT; i++) + { + rocblas_seedrand(); + rocblas_init(hparams, 1, 5, 1); + rocblas_init(hd1, 1, 1, 1, stride_d1, batch_count); + rocblas_init(hd2, 1, 1, 1, stride_d2, batch_count); + rocblas_init(hx1, 1, 1, 1, stride_x1, batch_count); + rocblas_init(hy1, 1, 1, 1, stride_y1, batch_count); + + host_vector cparams = hparams; + host_vector cd1 = hd1; + host_vector cd2 = hd2; + host_vector cx1 = hx1; + host_vector cy1 = hy1; + + cpu_time_used = get_time_us(); + for(int b = 0; b < batch_count; b++) + { + cblas_rotmg(cd1 + b * stride_d1, + cd2 + b * stride_d2, + cx1 + b * stride_x1, + cy1 + b * stride_y1, + cparams); + } + cpu_time_used = get_time_us() - cpu_time_used; + + // Test rocblas_pointer_mode_host + { + host_vector rd1 = hd1; + host_vector rd2 = hd2; + host_vector rx1 = hx1; + host_vector ry1 = hy1; + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + + CHECK_ROCBLAS_ERROR((rocblas_rotmg_strided_batched(handle, + rd1, + stride_d1, + rd2, + stride_d2, + rx1, + stride_x1, + ry1, + stride_y1, + hparams, + batch_count))); + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, stride_d1, rd1, cd1); + unit_check_general(1, 1, batch_count, 1, stride_d2, rd2, cd2); + unit_check_general(1, 1, batch_count, 1, stride_x1, rx1, cx1); + unit_check_general(1, 1, batch_count, 1, stride_y1, ry1, cy1); + } + + if(arg.norm_check) + { + norm_error_host + = norm_check_general('F', 1, 1, 1, stride_d1, batch_count, rd1, cd1); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_d2, batch_count, rd2, cd2); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_x1, batch_count, rx1, cx1); + norm_error_host + += norm_check_general('F', 1, 1, 1, stride_y1, batch_count, ry1, cy1); + } + } + + // Test rocblas_pointer_mode_device + { + device_vector dd1(size_d1); + device_vector dd2(size_d2); + device_vector dx1(size_x1); + device_vector dy1(size_y1); + device_vector dparams(5); + + CHECK_HIP_ERROR(hipMemcpy(dd1, hd1, sizeof(T) * size_d1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dd2, hd2, sizeof(T) * size_d2, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dx1, hx1, sizeof(T) * size_x1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy1, hy1, sizeof(T) * size_y1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * 5, hipMemcpyHostToDevice)); + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + CHECK_ROCBLAS_ERROR((rocblas_rotmg_strided_batched(handle, + dd1, + stride_d1, + dd2, + stride_d2, + dx1, + stride_x1, + dy1, + stride_y1, + dparams, + batch_count))); + + host_vector rd1(size_d1); + host_vector rd2(size_d2); + host_vector rx1(size_x1); + host_vector ry1(size_y1); + + CHECK_HIP_ERROR(hipMemcpy(rd1, dd1, sizeof(T) * size_d1, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rd2, dd2, sizeof(T) * size_d2, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(rx1, dx1, sizeof(T) * size_x1, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(ry1, dy1, sizeof(T) * size_y1, hipMemcpyDeviceToHost)); + + if(arg.unit_check) + { + unit_check_general(1, 1, batch_count, 1, stride_d1, rd1, cd1); + unit_check_general(1, 1, batch_count, 1, stride_d2, rd2, cd2); + unit_check_general(1, 1, batch_count, 1, stride_x1, rx1, cx1); + unit_check_general(1, 1, batch_count, 1, stride_y1, ry1, cy1); + } + + if(arg.norm_check) + { + norm_error_device + = norm_check_general('F', 1, 1, 1, stride_d1, batch_count, rd1, cd1); + norm_error_device + += norm_check_general('F', 1, 1, 1, stride_d2, batch_count, rd2, cd2); + norm_error_device + += norm_check_general('F', 1, 1, 1, stride_x1, batch_count, rx1, cx1); + norm_error_device + += norm_check_general('F', 1, 1, 1, stride_y1, batch_count, ry1, cy1); + } + } + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = 100; + // Device mode will be much quicker + // (TODO: or is there another reason we are typically using host_mode for timing?) + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + device_vector dd1(size_d1); + device_vector dd2(size_d2); + device_vector dx1(size_x1); + device_vector dy1(size_y1); + device_vector dparams(5); + + CHECK_HIP_ERROR(hipMemcpy(dd1, hd1, sizeof(T) * size_d1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dd2, hd2, sizeof(T) * size_d2, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dx1, hx1, sizeof(T) * size_x1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dy1, hy1, sizeof(T) * size_y1, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * 5, hipMemcpyHostToDevice)); + + for(int iter = 0; iter < number_cold_calls; iter++) + { + rocblas_rotmg_strided_batched(handle, + dd1, + stride_d1, + dd2, + stride_d2, + dx1, + stride_x1, + dy1, + stride_y1, + dparams, + batch_count); + } + gpu_time_used = get_time_us(); // in microseconds + for(int iter = 0; iter < number_hot_calls; iter++) + { + rocblas_rotmg_strided_batched(handle, + dd1, + stride_d1, + dd2, + stride_d2, + dx1, + stride_x1, + dy1, + stride_y1, + dparams, + batch_count); + } + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + std::cout << "rocblas-us,CPU-us"; + if(arg.norm_check) + std::cout << ",norm_error_host_ptr,norm_error_device"; + std::cout << std::endl; + + std::cout << gpu_time_used << "," << cpu_time_used; + if(arg.norm_check) + std::cout << ',' << norm_error_host << ',' << norm_error_device; + std::cout << std::endl; + } +} diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 16d713495..3519597ed 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -1851,6 +1851,122 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotmg( ROCBLAS_EXPORT rocblas_status rocblas_drotmg( rocblas_handle handle, double* d1, double* d2, double* x1, const double* y1, double* param); +/*! \brief BLAS Level 1 API + + \details + rotmg_batched creates the modified Givens rotation matrix for the batched vectors (d1 * x1, d2 * y1). + Parameters may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + If the pointer mode is set to rocblas_pointer_mode_host, this function blocks the CPU until the GPU has finished and the results are available in host memory. + If the pointer mode is set to rocblas_pointer_mode_device, this function returns immediately and synchronization is required to read the results. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[inout] + d1 batched array of input scalars that is overwritten. + @param[inout] + d2 batched array of input scalars that is overwritten. + @param[inout] + x1 batched array of input scalars that is overwritten. + @param[in] + y1 batched array of input scalars. + @param[out] + param vector of 5 elements defining the rotation. + param[0] = flag + param[1] = H11 + param[2] = H21 + param[3] = H12 + param[4] = H22 + The flag parameter defines the form of H: + flag = -1 => H = ( H11 H12 H21 H22 ) + flag = 0 => H = ( 1.0 H12 H21 1.0 ) + flag = 1 => H = ( H11 1.0 -1.0 H22 ) + flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) + param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + @param[in] + batch_count rocblas_int + the number of instances in the batch. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotmg_batched(rocblas_handle handle, + float* const d1[], + float* const d2[], + float* const x1[], + const float* const y1[], + float* param, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, + double* const d1[], + double* const d2[], + double* const x1[], + const double* const y1[], + double* param, + rocblas_int batch_count); + +/*! \brief BLAS Level 1 API + + \details + rotmg_strided_batched creates the modified Givens rotation matrix for the batched vectors (d1 * x1, d2 * y1). + Parameters may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + If the pointer mode is set to rocblas_pointer_mode_host, this function blocks the CPU until the GPU has finished and the results are available in host memory. + If the pointer mode is set to rocblas_pointer_mode_device, this function returns immediately and synchronization is required to read the results. + + @param[in] + handle rocblas_handle + handle to the rocblas library context queue. + @param[inout] + d1 batched array of input scalars that is overwritten. + @param[inout] + d2 batched array of input scalars that is overwritten. + @param[inout] + x1 batched array of input scalars that is overwritten. + @param[in] + y1 batched array of input scalars. + @param[out] + param vector of 5 elements defining the rotation. + param[0] = flag + param[1] = H11 + param[2] = H21 + param[3] = H12 + param[4] = H22 + The flag parameter defines the form of H: + flag = -1 => H = ( H11 H12 H21 H22 ) + flag = 0 => H = ( 1.0 H12 H21 1.0 ) + flag = 1 => H = ( H11 1.0 -1.0 H22 ) + flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) + param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + @param[in] + batch_count rocblas_int + the number of instances in the batch. + + ********************************************************************/ + +ROCBLAS_EXPORT rocblas_status rocblas_srotmg_strided_batched(rocblas_handle handle, + float* d1, + rocblas_stride stride_d1, + float* d2, + rocblas_stride stride_d2, + float* x1, + rocblas_stride stride_x1, + const float* y1, + rocblas_stride stride_y1, + float* param, + rocblas_int batch_count); + +ROCBLAS_EXPORT rocblas_status rocblas_drotmg_strided_batched(rocblas_handle handle, + double* d1, + rocblas_stride stride_d1, + double* d2, + rocblas_stride stride_d2, + double* x1, + rocblas_stride stride_x1, + const double* y1, + rocblas_stride stride_y1, + double* param, + rocblas_int batch_count); + /* * =========================================================================== * level 2 BLAS diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index c442b831b..62959d3b4 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -148,6 +148,8 @@ set( rocblas_blas1_source blas1/rocblas_rotm_batched.cpp blas1/rocblas_rotm_strided_batched.cpp blas1/rocblas_rotmg.cpp + blas1/rocblas_rotmg_batched.cpp + blas1/rocblas_rotmg_strided_batched.cpp blas1/rocblas_swap_batched.cpp blas1/rocblas_swap_strided_batched.cpp ) diff --git a/library/src/blas1/rocblas_rotg.hpp b/library/src/blas1/rocblas_rotg.hpp index 9c13d2ff9..fea9c9c26 100644 --- a/library/src/blas1/rocblas_rotg.hpp +++ b/library/src/blas1/rocblas_rotg.hpp @@ -120,6 +120,7 @@ rocblas_status rocblas_rotg_template(rocblas_handle handle, else { RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); + // TODO: make this faster for larger batches. for(int i = 0; i < batch_count; i++) { auto a = load_ptr_batch(a_in, i, offset_a, stride_a); diff --git a/library/src/blas1/rocblas_rotm.cpp b/library/src/blas1/rocblas_rotm.cpp index de58007bf..22fad8b15 100644 --- a/library/src/blas1/rocblas_rotm.cpp +++ b/library/src/blas1/rocblas_rotm.cpp @@ -51,7 +51,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - return rocblas_rotm_template(handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 1, (T*)nullptr); + return rocblas_rotm_template( + handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 1, (T*)nullptr); } } // namespace diff --git a/library/src/blas1/rocblas_rotm_strided_batched.cpp b/library/src/blas1/rocblas_rotm_strided_batched.cpp index 680dc12ae..41204a052 100644 --- a/library/src/blas1/rocblas_rotm_strided_batched.cpp +++ b/library/src/blas1/rocblas_rotm_strided_batched.cpp @@ -85,8 +85,20 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - return rocblas_rotm_template( - handle, n, x, 0, incx, stride_x, y, 0, incy, stride_y, param, 0, batch_count, (T*)nullptr); + return rocblas_rotm_template(handle, + n, + x, + 0, + incx, + stride_x, + y, + 0, + incy, + stride_y, + param, + 0, + batch_count, + (T*)nullptr); } } // namespace diff --git a/library/src/blas1/rocblas_rotmg.cpp b/library/src/blas1/rocblas_rotmg.cpp index f01aca7fa..dc3773f11 100644 --- a/library/src/blas1/rocblas_rotmg.cpp +++ b/library/src/blas1/rocblas_rotmg.cpp @@ -1,6 +1,7 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_rotmg.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" @@ -8,153 +9,6 @@ namespace { - template - __device__ __host__ void rotmg_calc(T& d1, T& d2, T& x1, const T& y1, T* param) - { - const T gam = 4096; - const T gamsq = gam * gam; - const T rgamsq = 1 / gamsq; - - T flag = -1; - T h11 = 0, h21 = 0, h12 = 0, h22 = 0; - - if(d1 < 0) - { - d1 = d2 = x1 = 0; - } - else - { - T p2 = d2 * y1; - if(p2 == 0) - { - flag = -2; - param[0] = flag; - return; - } - T p1 = d1 * x1; - T q2 = p2 * y1; - T q1 = p1 * x1; - if(rocblas_abs(q1) > rocblas_abs(q2)) - { - h21 = -y1 / x1; - h12 = p2 / p1; - T u = 1 - h12 * h21; - if(u > 0) - { - flag = 0; - d1 /= u; - d2 /= u; - x1 *= u; - } - } - else - { - if(q2 < 0) - { - d1 = d2 = x1 = 0; - } - else - { - flag = 1; - h11 = p1 / p2; - h22 = x1 / y1; - T u = 1 + h11 * h22; - T temp = d2 / u; - d2 = d1 / u; - d1 = temp; - x1 = y1 * u; - } - } - - if(d1 != 0) - { - while((d1 <= rgamsq) || (d1 >= gamsq)) - { - if(flag == 0) - { - h11 = h22 = 1; - flag = -1; - } - else - { - h21 = -1; - h12 = 1; - flag = -1; - } - if(d1 <= rgamsq) - { - d1 *= gamsq; - x1 /= gam; - h11 /= gam; - h12 /= gam; - } - else - { - d1 /= gamsq; - x1 *= gam; - h11 *= gam; - h12 *= gam; - } - } - } - - if(d2 != 0) - { - while((rocblas_abs(d2) <= rgamsq) || (rocblas_abs(d2) >= gamsq)) - { - if(flag == 0) - { - h11 = h22 = 1; - flag = -1; - } - else - { - h21 = -1; - h12 = 1; - flag = -1; - } - if(rocblas_abs(d2) <= rgamsq) - { - d2 *= gamsq; - h21 /= gam; - h22 /= gam; - } - else - { - d2 /= gamsq; - h21 *= gam; - h22 *= gam; - } - } - } - } - - if(flag < 0) - { - param[1] = h11; - param[2] = h21; - param[3] = h12; - param[4] = h22; - } - else if(flag == 0) - { - param[2] = h21; - param[3] = h12; - } - else - { - param[1] = h11; - param[4] = h22; - } - param[0] = flag; - } - - template - __global__ void rotmg_kernel(T* d1, T* d2, T* x1, const T* y1, T* param) - { - rotmg_calc(*d1, *d2, *x1, *y1, param); - } - template constexpr char rocblas_rotmg_name[] = "unknown"; template <> @@ -163,7 +17,8 @@ namespace constexpr char rocblas_rotmg_name[] = "rocblas_drotmg"; template - rocblas_status rocblas_rotmg(rocblas_handle handle, T* d1, T* d2, T* x1, const T* y1, T* param) + rocblas_status + rocblas_rotmg_impl(rocblas_handle handle, T* d1, T* d2, T* x1, const T* y1, T* param) { if(!handle) return rocblas_status_invalid_handle; @@ -172,7 +27,7 @@ namespace if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, rocblas_rotmg_name, d1, d2, x1, y1, param); if(layer_mode & rocblas_layer_mode_log_bench) - log_trace(handle, "./rocblas-bench -f rotmg -r", rocblas_precision_string); + log_bench(handle, "./rocblas-bench -f rotmg -r", rocblas_precision_string); if(layer_mode & rocblas_layer_mode_log_profile) log_profile(handle, rocblas_rotmg_name); @@ -181,19 +36,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(rocblas_pointer_mode_device == handle->pointer_mode) - { - hipLaunchKernelGGL(rotmg_kernel, 1, 1, 0, rocblas_stream, d1, d2, x1, y1, param); - } - else - { - RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); - rotmg_calc(*d1, *d2, *x1, *y1, param); - } - - return rocblas_status_success; + return rocblas_rotmg_template(handle, d1, 0, 0, d2, 0, 0, x1, 0, 0, y1, 0, 0, param, 0, 1); } } // namespace @@ -209,13 +52,13 @@ extern "C" { ROCBLAS_EXPORT rocblas_status rocblas_srotmg( rocblas_handle handle, float* d1, float* d2, float* x1, const float* y1, float* param) { - return rocblas_rotmg(handle, d1, d2, x1, y1, param); + return rocblas_rotmg_impl(handle, d1, d2, x1, y1, param); } ROCBLAS_EXPORT rocblas_status rocblas_drotmg( rocblas_handle handle, double* d1, double* d2, double* x1, const double* y1, double* param) { - return rocblas_rotmg(handle, d1, d2, x1, y1, param); + return rocblas_rotmg_impl(handle, d1, d2, x1, y1, param); } } // extern "C" diff --git a/library/src/blas1/rocblas_rotmg.hpp b/library/src/blas1/rocblas_rotmg.hpp new file mode 100644 index 000000000..d0319542e --- /dev/null +++ b/library/src/blas1/rocblas_rotmg.hpp @@ -0,0 +1,237 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "utility.h" + +template +__device__ __host__ void rocblas_rotmg_calc(T& d1, T& d2, T& x1, const T& y1, T* param) +{ + const T gam = 4096; + const T gamsq = gam * gam; + const T rgamsq = 1 / gamsq; + + T flag = -1; + T h11 = 0, h21 = 0, h12 = 0, h22 = 0; + + if(d1 < 0) + { + d1 = d2 = x1 = 0; + } + else + { + T p2 = d2 * y1; + if(p2 == 0) + { + flag = -2; + param[0] = flag; + return; + } + T p1 = d1 * x1; + T q2 = p2 * y1; + T q1 = p1 * x1; + if(rocblas_abs(q1) > rocblas_abs(q2)) + { + h21 = -y1 / x1; + h12 = p2 / p1; + T u = 1 - h12 * h21; + if(u > 0) + { + flag = 0; + d1 /= u; + d2 /= u; + x1 *= u; + } + } + else + { + if(q2 < 0) + { + d1 = d2 = x1 = 0; + } + else + { + flag = 1; + h11 = p1 / p2; + h22 = x1 / y1; + T u = 1 + h11 * h22; + T temp = d2 / u; + d2 = d1 / u; + d1 = temp; + x1 = y1 * u; + } + } + + if(d1 != 0) + { + while((d1 <= rgamsq) || (d1 >= gamsq)) + { + if(flag == 0) + { + h11 = h22 = 1; + flag = -1; + } + else + { + h21 = -1; + h12 = 1; + flag = -1; + } + if(d1 <= rgamsq) + { + d1 *= gamsq; + x1 /= gam; + h11 /= gam; + h12 /= gam; + } + else + { + d1 /= gamsq; + x1 *= gam; + h11 *= gam; + h12 *= gam; + } + } + } + + if(d2 != 0) + { + while((rocblas_abs(d2) <= rgamsq) || (rocblas_abs(d2) >= gamsq)) + { + if(flag == 0) + { + h11 = h22 = 1; + flag = -1; + } + else + { + h21 = -1; + h12 = 1; + flag = -1; + } + if(rocblas_abs(d2) <= rgamsq) + { + d2 *= gamsq; + h21 /= gam; + h22 /= gam; + } + else + { + d2 /= gamsq; + h21 *= gam; + h22 *= gam; + } + } + } + } + + if(flag < 0) + { + param[1] = h11; + param[2] = h21; + param[3] = h12; + param[4] = h22; + } + else if(flag == 0) + { + param[2] = h21; + param[3] = h12; + } + else + { + param[1] = h11; + param[4] = h22; + } + param[0] = flag; +} + +template +__global__ void rocblas_rotmg_kernel(T d1_in, + rocblas_int offset_d1, + rocblas_stride stride_d1, + T d2_in, + rocblas_int offset_d2, + rocblas_stride stride_d2, + T x1_in, + rocblas_int offset_x1, + rocblas_stride stride_x1, + U y1_in, + rocblas_int offset_y1, + rocblas_stride stride_y1, + V param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + auto d1 = load_ptr_batch(d1_in, hipBlockIdx_x, offset_d1, stride_d1); + auto d2 = load_ptr_batch(d2_in, hipBlockIdx_x, offset_d2, stride_d2); + auto x1 = load_ptr_batch(x1_in, hipBlockIdx_x, offset_x1, stride_x1); + auto y1 = load_ptr_batch(y1_in, hipBlockIdx_x, offset_y1, stride_y1); + auto p = param + hipBlockIdx_x * stride_param; + rocblas_rotmg_calc(*d1, *d2, *x1, *y1, p); +} + +template +rocblas_status rocblas_rotmg_template(rocblas_handle handle, + T d1_in, + rocblas_int offset_d1, + rocblas_stride stride_d1, + T d2_in, + rocblas_int offset_d2, + rocblas_stride stride_d2, + T x1_in, + rocblas_int offset_x1, + rocblas_stride stride_x1, + U y1_in, + rocblas_int offset_y1, + rocblas_stride stride_y1, + V param, + rocblas_stride stride_param, + rocblas_int batch_count) +{ + if(!batch_count) + return rocblas_status_success; + + hipStream_t rocblas_stream = handle->rocblas_stream; + if(rocblas_pointer_mode_device == handle->pointer_mode) + { + hipLaunchKernelGGL(rocblas_rotmg_kernel, + batch_count, + 1, + 0, + rocblas_stream, + d1_in, + offset_d1, + stride_d1, + d2_in, + offset_d2, + stride_d2, + x1_in, + offset_x1, + stride_x1, + y1_in, + offset_y1, + stride_y1, + param, + stride_param, + batch_count); + } + else + { + RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); + // TODO: make this faster for larger batches. + for(int i = 0; i < batch_count; i++) + { + auto d1 = load_ptr_batch(d1_in, i, offset_d1, stride_d1); + auto d2 = load_ptr_batch(d2_in, i, offset_d2, stride_d2); + auto x1 = load_ptr_batch(x1_in, i, offset_x1, stride_x1); + auto y1 = load_ptr_batch(y1_in, i, offset_y1, stride_y1); + auto p = param + i * stride_param; + + rocblas_rotmg_calc(*d1, *d2, *x1, *y1, p); + } + } + + return rocblas_status_success; +} \ No newline at end of file diff --git a/library/src/blas1/rocblas_rotmg_batched.cpp b/library/src/blas1/rocblas_rotmg_batched.cpp new file mode 100644 index 000000000..81f257f62 --- /dev/null +++ b/library/src/blas1/rocblas_rotmg_batched.cpp @@ -0,0 +1,86 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotmg.hpp" +#include "utility.h" + +namespace +{ + template + constexpr char rocblas_rotmg_name[] = "unknown"; + template <> + constexpr char rocblas_rotmg_name[] = "rocblas_srotmg_batched"; + template <> + constexpr char rocblas_rotmg_name[] = "rocblas_drotmg_batched"; + + template + rocblas_status rocblas_rotmg_batched_impl(rocblas_handle handle, + T* const d1[], + T* const d2[], + T* const x1[], + const T* const y1[], + T* param, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + if(batch_count < 0) + return rocblas_status_invalid_size; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, rocblas_rotmg_name, d1, d2, x1, y1, param, batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotmg_batched -r", + rocblas_precision_string, + "--batch", + batch_count); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, rocblas_rotmg_name, "batch", batch_count); + + if(!d1 || !d2 || !x1 || !y1 || !param) + return rocblas_status_invalid_pointer; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotmg_template( + handle, d1, 0, 0, d2, 0, 0, x1, 0, 0, y1, 0, 0, param, 0, batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +ROCBLAS_EXPORT rocblas_status rocblas_srotmg_batched(rocblas_handle handle, + float* const d1[], + float* const d2[], + float* const x1[], + const float* const y1[], + float* param, + rocblas_int batch_count) +{ + return rocblas_rotmg_batched_impl(handle, d1, d2, x1, y1, param, batch_count); +} + +ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, + double* const d1[], + double* const d2[], + double* const x1[], + const double* const y1[], + double* param, + rocblas_int batch_count) +{ + return rocblas_rotmg_batched_impl(handle, d1, d2, x1, y1, param, batch_count); +} + +} // extern "C" diff --git a/library/src/blas1/rocblas_rotmg_strided_batched.cpp b/library/src/blas1/rocblas_rotmg_strided_batched.cpp new file mode 100644 index 000000000..fb08b8e12 --- /dev/null +++ b/library/src/blas1/rocblas_rotmg_strided_batched.cpp @@ -0,0 +1,133 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "rocblas_rotmg.hpp" +#include "utility.h" + +namespace +{ + template + constexpr char rocblas_rotmg_name[] = "unknown"; + template <> + constexpr char rocblas_rotmg_name[] = "rocblas_srotmg_strided_batched"; + template <> + constexpr char rocblas_rotmg_name[] = "rocblas_drotmg_strided_batched"; + + template + rocblas_status rocblas_rotmg_strided_batched_impl(rocblas_handle handle, + T* d1, + rocblas_stride stride_d1, + T* d2, + rocblas_stride stride_d2, + T* x1, + rocblas_stride stride_x1, + const T* y1, + rocblas_stride stride_y1, + T* param, + rocblas_int batch_count) + { + if(!handle) + return rocblas_status_invalid_handle; + if(batch_count < 0) + return rocblas_status_invalid_size; + + auto layer_mode = handle->layer_mode; + if(layer_mode & rocblas_layer_mode_log_trace) + log_trace(handle, + rocblas_rotmg_name, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + param, + batch_count); + if(layer_mode & rocblas_layer_mode_log_bench) + log_bench(handle, + "./rocblas-bench -f rotmg_strided_batched -r", + rocblas_precision_string, + "--batch", + batch_count, + "--stride_a", + stride_d1, + "--stride_b", + stride_d2, + "--stride_x", + stride_x1, + "--stride_y", + stride_y1); + if(layer_mode & rocblas_layer_mode_log_profile) + log_profile(handle, rocblas_rotmg_name, "batch", batch_count); + + if(!d1 || !d2 || !x1 || !y1 || !param) + return rocblas_status_invalid_pointer; + + RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); + + return rocblas_rotmg_template(handle, + d1, + 0, + stride_d1, + d2, + 0, + stride_d2, + x1, + 0, + stride_x1, + y1, + 0, + stride_y1, + param, + 0, + batch_count); + } + +} // namespace + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" { + +ROCBLAS_EXPORT rocblas_status rocblas_srotmg_strided_batched(rocblas_handle handle, + float* d1, + rocblas_stride stride_d1, + float* d2, + rocblas_stride stride_d2, + float* x1, + rocblas_stride stride_x1, + const float* y1, + rocblas_stride stride_y1, + float* param, + rocblas_int batch_count) +{ + return rocblas_rotmg_strided_batched_impl( + handle, d1, stride_d1, d2, stride_d2, x1, stride_x1, y1, stride_y1, param, batch_count); +} + +ROCBLAS_EXPORT rocblas_status rocblas_drotmg_strided_batched(rocblas_handle handle, + double* d1, + rocblas_stride stride_d1, + double* d2, + rocblas_stride stride_d2, + double* x1, + rocblas_stride stride_x1, + const double* y1, + rocblas_stride stride_y1, + double* param, + rocblas_int batch_count) +{ + return rocblas_rotmg_strided_batched_impl( + handle, d1, stride_d1, d2, stride_d2, x1, stride_x1, y1, stride_y1, param, batch_count); +} + +} // extern "C" From 49e65f196f51b8e5b82cd846ee8792e021c23ab1 Mon Sep 17 00:00:00 2001 From: First Last Date: Thu, 3 Oct 2019 17:44:16 -0600 Subject: [PATCH 06/11] format --- clients/gtest/blas1_gtest.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index 463202621..1477e6ca8 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -130,7 +130,8 @@ namespace || BLAS1 == blas1::rotmg_strided_batched); if((is_scal || is_rotg || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched - || BLAS1 == blas1::rot_strided_batched) && arg.a_type != arg.b_type) + || BLAS1 == blas1::rot_strided_batched) + && arg.a_type != arg.b_type) name << '_' << rocblas_datatype2string(arg.b_type); if((BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched || BLAS1 == blas1::rot_strided_batched) @@ -157,8 +158,9 @@ namespace || BLAS1 == blas1::dotc_batched || BLAS1 == blas1::dot_strided_batched || BLAS1 == blas1::dotc_strided_batched || BLAS1 == blas1::swap || BLAS1 == blas1::swap_batched || BLAS1 == blas1::swap_strided_batched - || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched || BLAS1 == blas1::rot_strided_batched - || BLAS1 == blas1::rotm || BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotm_strided_batched) + || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotm + || BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotm_strided_batched) { name << '_' << arg.incy; } From 026e45b3714dadc7f0cf31e5438201c3b01f3a70 Mon Sep 17 00:00:00 2001 From: First Last Date: Fri, 4 Oct 2019 11:25:06 -0600 Subject: [PATCH 07/11] Misc. changes --- clients/benchmarks/client.cpp | 6 ------ clients/gtest/blas1_gtest.cpp | 12 +++++------- clients/gtest/blas1_gtest.yaml | 6 ++++-- clients/include/rocblas_common.yaml | 4 ---- clients/include/testing_rotg_batched.hpp | 2 +- .../include/testing_rotg_strided_batched.hpp | 2 +- clients/include/testing_rotmg_batched.hpp | 2 +- .../include/testing_rotmg_strided_batched.hpp | 2 +- library/include/rocblas-functions.h | 18 +++++++++++++++--- library/src/blas1/rocblas_rotg.hpp | 5 ++++- library/src/blas1/rocblas_rotm.hpp | 5 +++-- library/src/blas1/rocblas_rotmg.hpp | 2 +- library/src/blas1/rocblas_rotmg_batched.cpp | 4 ++-- .../blas1/rocblas_rotmg_strided_batched.cpp | 4 ++-- 14 files changed, 40 insertions(+), 34 deletions(-) diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index a460f293b..71a263b8d 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -182,12 +182,6 @@ struct perf_blas< testing_set_get_vector(arg); else if(!strcmp(arg.function, "set_get_matrix")) testing_set_get_matrix(arg); - // else if(!strcmp(arg.function, "rotg")) - // testing_rotg(arg); - // else if(!strcmp(arg.function, "rotg_batched")) - // testing_rotg_batched(arg); - // else if(!strcmp(arg.function, "rotg_strided_batched")) - // testing_rotg_strided_batched(arg); else if(!strcmp(arg.function, "rotm")) testing_rotm(arg); else if(!strcmp(arg.function, "rotm_batched")) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index 1477e6ca8..4f6a547b7 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -109,6 +109,8 @@ namespace { bool is_scal = (BLAS1 == blas1::scal || BLAS1 == blas1::scal_batched || BLAS1 == blas1::scal_strided_batched); + bool is_rot = (BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched + || BLAS1 == blas1::rot_strided_batched); bool is_rotg = (BLAS1 == blas1::rotg || BLAS1 == blas1::rotg_batched || BLAS1 == blas1::rotg_strided_batched); bool is_batched = (BLAS1 == blas1::nrm2_batched || BLAS1 == blas1::asum_batched @@ -129,13 +131,9 @@ namespace || BLAS1 == blas1::rotg_strided_batched || BLAS1 == blas1::rotmg_strided_batched); - if((is_scal || is_rotg || BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched - || BLAS1 == blas1::rot_strided_batched) - && arg.a_type != arg.b_type) + if((is_scal || is_rotg || is_rot) && arg.a_type != arg.b_type) name << '_' << rocblas_datatype2string(arg.b_type); - if((BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched - || BLAS1 == blas1::rot_strided_batched) - && arg.compute_type != arg.a_type) + if(is_rot && arg.compute_type != arg.a_type) name << '_' << rocblas_datatype2string(arg.compute_type); if(!is_rotg) @@ -172,7 +170,7 @@ namespace name << '_' << arg.stride_y; } - if(is_rotg) + if(BLAS1 == blas1::rotg_strided_batched) { name << '_' << arg.stride_a << '_' << arg.stride_b << '_' << arg.stride_c << '_' << arg.stride_d; diff --git a/clients/gtest/blas1_gtest.yaml b/clients/gtest/blas1_gtest.yaml index 2325652f4..65d92fa6d 100644 --- a/clients/gtest/blas1_gtest.yaml +++ b/clients/gtest/blas1_gtest.yaml @@ -2,8 +2,6 @@ include: rocblas_common.yaml include: known_bugs.yaml - - Definitions: - &N_range - [ -1, 0, 5, 10, 500, 1000, 1024, 1025, 7111, 10000, 33792 ] @@ -367,6 +365,8 @@ Tests: function: - swap_batched: *single_double_precisions_complex_real - copy_batched: *single_double_precisions_complex_real + - rot_batched: *rot_precisions + - rotm_batched: *single_double_precisions_complex_real - name: blas1_strided_batched category: pre_checkin @@ -377,6 +377,8 @@ Tests: function: - swap_strided_batched: *single_double_precisions_complex_real - copy_strided_batched: *single_double_precisions_complex_real + - rot_strided_batched: *rot_precisions + - rotm_strided_batched: *single_double_precisions_complex_real # nightly - name: blas1 diff --git a/clients/include/rocblas_common.yaml b/clients/include/rocblas_common.yaml index d3aa2849f..620099283 100644 --- a/clients/include/rocblas_common.yaml +++ b/clients/include/rocblas_common.yaml @@ -317,10 +317,6 @@ Defaults: uplo: '*' diag: '*' batch_count: -1 - # stride_a: 0 - # stride_b: 0 - # stride_c: 0 - # stride_d: 0 norm_check: 0 unit_check: 1 timing: 0 diff --git a/clients/include/testing_rotg_batched.hpp b/clients/include/testing_rotg_batched.hpp index e42d23998..8a30dbd32 100644 --- a/clients/include/testing_rotg_batched.hpp +++ b/clients/include/testing_rotg_batched.hpp @@ -46,7 +46,7 @@ void testing_rotg_batched_bad_arg(const Arguments& arg) template void testing_rotg_batched(const Arguments& arg) { - const int TEST_COUNT = 1; //00; + const int TEST_COUNT = 100; rocblas_int batch_count = arg.batch_count; rocblas_local_handle handle; diff --git a/clients/include/testing_rotg_strided_batched.hpp b/clients/include/testing_rotg_strided_batched.hpp index e30a69a4d..baeec8112 100644 --- a/clients/include/testing_rotg_strided_batched.hpp +++ b/clients/include/testing_rotg_strided_batched.hpp @@ -59,7 +59,7 @@ void testing_rotg_strided_batched_bad_arg(const Arguments& arg) template void testing_rotg_strided_batched(const Arguments& arg) { - const int TEST_COUNT = 1; + const int TEST_COUNT = 100; rocblas_int stride_a = arg.stride_a; rocblas_int stride_b = arg.stride_b; rocblas_int stride_c = arg.stride_c; diff --git a/clients/include/testing_rotmg_batched.hpp b/clients/include/testing_rotmg_batched.hpp index 6fc1d253a..3686ee4b7 100644 --- a/clients/include/testing_rotmg_batched.hpp +++ b/clients/include/testing_rotmg_batched.hpp @@ -53,7 +53,7 @@ void testing_rotmg_batched_bad_arg(const Arguments& arg) template void testing_rotmg_batched(const Arguments& arg) { - const int TEST_COUNT = 1; //00; + const int TEST_COUNT = 100; rocblas_int batch_count = arg.batch_count; rocblas_local_handle handle; diff --git a/clients/include/testing_rotmg_strided_batched.hpp b/clients/include/testing_rotmg_strided_batched.hpp index aa6fe5a9f..a01eddc40 100644 --- a/clients/include/testing_rotmg_strided_batched.hpp +++ b/clients/include/testing_rotmg_strided_batched.hpp @@ -55,7 +55,7 @@ void testing_rotmg_strided_batched_bad_arg(const Arguments& arg) template void testing_rotmg_strided_batched(const Arguments& arg) { - const int TEST_COUNT = 1; //00; + const int TEST_COUNT = 100; rocblas_int batch_count = arg.batch_count; rocblas_int stride_d1 = arg.stride_a; rocblas_int stride_d2 = arg.stride_b; diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index d53f1a2a0..6d513523b 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -1994,7 +1994,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, /*! \brief BLAS Level 1 API \details - rotm_strided_batched applies the modified Givens rotation matrix defined by param to batched vectors x and y. + rotm_strided_batched applies the modified Givens rotation matrix defined by param to strided batched vectors x and y. @param[in] handle rocblas_handle @@ -2003,7 +2003,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, n rocblas_int number of elements in the x and y vectors. @param[inout] - x pointers storing batched vectors x on the GPU. + x pointers storing strided batched vectors x on the GPU. @param[in] incx rocblas_int specifies the increment between elements of x. @@ -2011,7 +2011,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, stride_x rocblas_stride specifies the increment between the beginning of x_i and x_(i + 1) @param[inout] - y pointers storing batched vectors y on the GPU. + y pointers storing strided batched vectors y on the GPU. @param[in] incy rocblas_int specifies the increment between elements of y. @@ -2167,12 +2167,24 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, handle to the rocblas library context queue. @param[inout] d1 batched array of input scalars that is overwritten. + @param[in] + stride_d1 rocblas_stride + specifies the increment between the beginning of d1_i and d1_(i+1) @param[inout] d2 batched array of input scalars that is overwritten. + @param[in] + stride_d2 rocblas_stride + specifies the increment between the beginning of d2_i and d2_(i+1) @param[inout] x1 batched array of input scalars that is overwritten. @param[in] + stride_x1 rocblas_stride + specifies the increment between the beginning of x1_i and x1_(i+1) + @param[in] y1 batched array of input scalars. + @param[in] + stride_y1 rocblas_stride + specifies the increment between the beginning of y1_i and y1_(i+1) @param[out] param vector of 5 elements defining the rotation. param[0] = flag diff --git a/library/src/blas1/rocblas_rotg.hpp b/library/src/blas1/rocblas_rotg.hpp index fea9c9c26..e46da7b23 100644 --- a/library/src/blas1/rocblas_rotg.hpp +++ b/library/src/blas1/rocblas_rotg.hpp @@ -95,6 +95,9 @@ rocblas_status rocblas_rotg_template(rocblas_handle handle, rocblas_stride stride_s, rocblas_int batch_count) { + if(!batch_count) + return rocblas_status_success; + hipStream_t rocblas_stream = handle->rocblas_stream; if(rocblas_pointer_mode_device == handle->pointer_mode) @@ -120,7 +123,7 @@ rocblas_status rocblas_rotg_template(rocblas_handle handle, else { RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); - // TODO: make this faster for larger batches. + // TODO: make this faster for a large number of batches. for(int i = 0; i < batch_count; i++) { auto a = load_ptr_batch(a_in, i, offset_a, stride_a); diff --git a/library/src/blas1/rocblas_rotm.hpp b/library/src/blas1/rocblas_rotm.hpp index b49ccffa7..da4ab3855 100644 --- a/library/src/blas1/rocblas_rotm.hpp +++ b/library/src/blas1/rocblas_rotm.hpp @@ -76,6 +76,7 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, // outside of rocblas) if(handle->is_device_memory_size_query()) { + // TODO: Decide if we want to support this or not. if(stride_param && rocblas_pointer_mode_host == handle->pointer_mode && n > 0 && incx > 0 && incy > 0 && batch_count > 0) return handle->set_optimal_device_memory_size(sizeof(T) * batch_count * stride_param); @@ -134,7 +135,7 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, param[2], param[3], param[4], - stride_param); + 0); else // array of params on host, copy to device { // This should NOT happen from calls from the API currently. @@ -160,7 +161,7 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, mem + 2, mem + 3, mem + 4, - 0); + stride_param); } return rocblas_status_success; diff --git a/library/src/blas1/rocblas_rotmg.hpp b/library/src/blas1/rocblas_rotmg.hpp index d0319542e..8c77e25e1 100644 --- a/library/src/blas1/rocblas_rotmg.hpp +++ b/library/src/blas1/rocblas_rotmg.hpp @@ -220,7 +220,7 @@ rocblas_status rocblas_rotmg_template(rocblas_handle handle, else { RETURN_IF_HIP_ERROR(hipStreamSynchronize(rocblas_stream)); - // TODO: make this faster for larger batches. + // TODO: make this faster for a large number of batches. for(int i = 0; i < batch_count; i++) { auto d1 = load_ptr_batch(d1_in, i, offset_d1, stride_d1); diff --git a/library/src/blas1/rocblas_rotmg_batched.cpp b/library/src/blas1/rocblas_rotmg_batched.cpp index 81f257f62..f09bfcc2c 100644 --- a/library/src/blas1/rocblas_rotmg_batched.cpp +++ b/library/src/blas1/rocblas_rotmg_batched.cpp @@ -27,8 +27,6 @@ namespace { if(!handle) return rocblas_status_invalid_handle; - if(batch_count < 0) - return rocblas_status_invalid_size; auto layer_mode = handle->layer_mode; if(layer_mode & rocblas_layer_mode_log_trace) @@ -44,6 +42,8 @@ namespace if(!d1 || !d2 || !x1 || !y1 || !param) return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); diff --git a/library/src/blas1/rocblas_rotmg_strided_batched.cpp b/library/src/blas1/rocblas_rotmg_strided_batched.cpp index fb08b8e12..de914e465 100644 --- a/library/src/blas1/rocblas_rotmg_strided_batched.cpp +++ b/library/src/blas1/rocblas_rotmg_strided_batched.cpp @@ -31,8 +31,6 @@ namespace { if(!handle) return rocblas_status_invalid_handle; - if(batch_count < 0) - return rocblas_status_invalid_size; auto layer_mode = handle->layer_mode; if(layer_mode & rocblas_layer_mode_log_trace) @@ -67,6 +65,8 @@ namespace if(!d1 || !d2 || !x1 || !y1 || !param) return rocblas_status_invalid_pointer; + if(batch_count < 0) + return rocblas_status_invalid_size; RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); From 474e196964becd4fa298f32818428662bd1266e5 Mon Sep 17 00:00:00 2001 From: First Last Date: Fri, 4 Oct 2019 16:40:49 -0600 Subject: [PATCH 08/11] Allowing batched param in rotm and rotmg. --- clients/common/rocblas_gentest.py | 5 + clients/gtest/blas1_gtest.cpp | 11 +- clients/include/rocblas.hpp | 6 +- clients/include/testing_rotm_batched.hpp | 121 +++++----- .../include/testing_rotm_strided_batched.hpp | 220 ++++++++++++------ clients/include/testing_rotmg_batched.hpp | 75 ++++-- .../include/testing_rotmg_strided_batched.hpp | 84 ++++--- library/include/rocblas-functions.h | 55 +++-- library/src/blas1/rocblas_rotm.cpp | 4 +- library/src/blas1/rocblas_rotm.hpp | 193 +++++++++------ library/src/blas1/rocblas_rotm_batched.cpp | 38 +-- .../blas1/rocblas_rotm_strided_batched.cpp | 35 +-- library/src/blas1/rocblas_rotmg.cpp | 3 +- library/src/blas1/rocblas_rotmg.hpp | 15 +- library/src/blas1/rocblas_rotmg_batched.cpp | 8 +- .../blas1/rocblas_rotmg_strided_batched.cpp | 32 ++- 16 files changed, 566 insertions(+), 339 deletions(-) diff --git a/clients/common/rocblas_gentest.py b/clients/common/rocblas_gentest.py index 2dd0c273c..39990b626 100755 --- a/clients/common/rocblas_gentest.py +++ b/clients/common/rocblas_gentest.py @@ -207,6 +207,9 @@ def setdefaults(test): if all([x in test for x in ('N', 'incy', 'stride_scale')]): test.setdefault('stride_y', int(test['N'] * abs(test['incy']) * test['stride_scale'])) + # we are using stride_c for param in rotm + if all([x in test for x in ('stride_scale')]): + test.setdefault('stride_c', int(test['stride_scale']) * 5) if test['function'] in ('ger_strided_batched'): if all([x in test for x in ('M', 'incx', 'stride_scale')]): @@ -223,10 +226,12 @@ def setdefaults(test): test.setdefault('stride_c', int(test['stride_scale'])) test.setdefault('stride_d', int(test['stride_scale'])) + # we are using stride_a for d1, stride_b for d2, and stride_c for param in rotmg if test['function'] in ('rotmg_strided_batched'): if 'stride_scale' in test: test.setdefault('stride_a', int(test['stride_scale'])) test.setdefault('stride_b', int(test['stride_scale'])) + test.setdefault('stride_c', int(test['stride_scale']) * 5) test.setdefault('stride_x', int(test['stride_scale'])) test.setdefault('stride_y', int(test['stride_scale'])) diff --git a/clients/gtest/blas1_gtest.cpp b/clients/gtest/blas1_gtest.cpp index 4f6a547b7..f2a1e7ca1 100644 --- a/clients/gtest/blas1_gtest.cpp +++ b/clients/gtest/blas1_gtest.cpp @@ -113,6 +113,8 @@ namespace || BLAS1 == blas1::rot_strided_batched); bool is_rotg = (BLAS1 == blas1::rotg || BLAS1 == blas1::rotg_batched || BLAS1 == blas1::rotg_strided_batched); + bool is_rotmg = (BLAS1 == blas1::rotmg || BLAS1 == blas1::rotmg_batched + || BLAS1 == blas1::rotmg_strided_batched); bool is_batched = (BLAS1 == blas1::nrm2_batched || BLAS1 == blas1::asum_batched || BLAS1 == blas1::scal_batched || BLAS1 == blas1::swap_batched || BLAS1 == blas1::copy_batched || BLAS1 == blas1::dot_batched @@ -136,13 +138,13 @@ namespace if(is_rot && arg.compute_type != arg.a_type) name << '_' << rocblas_datatype2string(arg.compute_type); - if(!is_rotg) + if(!is_rotg && !is_rotmg) name << '_' << arg.N; if(BLAS1 == blas1::axpy || is_scal) name << '_' << arg.alpha << "_" << arg.alphai; - if(!is_rotg) + if(!is_rotg && !is_rotmg) name << '_' << arg.incx; if(is_strided && !is_rotg) @@ -176,6 +178,11 @@ namespace << arg.stride_d; } + if(BLAS1 == blas1::rotm_strided_batched || BLAS1 == blas1::rotmg_strided_batched) + { + name << '_' << arg.stride_c; + } + if(is_batched || is_strided) { name << "_" << arg.batch_count; diff --git a/clients/include/rocblas.hpp b/clients/include/rocblas.hpp index 285fb84e5..ae96ec3c5 100644 --- a/clients/include/rocblas.hpp +++ b/clients/include/rocblas.hpp @@ -775,7 +775,7 @@ rocblas_status (*rocblas_rotm_batched)(rocblas_handle handle, rocblas_int incx, T* const y[], rocblas_int incy, - const T* param, + const T* const param[], rocblas_int batch_count); template <> static constexpr auto rocblas_rotm_batched = rocblas_srotm_batched; @@ -794,6 +794,7 @@ rocblas_status (*rocblas_rotm_strided_batched)(rocblas_handle handle, rocblas_int incy, rocblas_stride stride_y, const T* param, + rocblas_stride stride_param, rocblas_int batch_count); template <> static constexpr auto rocblas_rotm_strided_batched = rocblas_srotm_strided_batched; @@ -818,7 +819,7 @@ rocblas_status (*rocblas_rotmg_batched)(rocblas_handle handle, T* const d2[], T* const x1[], const T* const y1[], - T* param, + T* const param[], rocblas_int batch_count); template <> @@ -839,6 +840,7 @@ rocblas_status (*rocblas_rotmg_strided_batched)(rocblas_handle handle, const T* y1, rocblas_stride stride_y1, T* param, + rocblas_stride stride_param, rocblas_int batch_count); template <> diff --git a/clients/include/testing_rotm_batched.hpp b/clients/include/testing_rotm_batched.hpp index c8e9e8427..9ffedcd98 100644 --- a/clients/include/testing_rotm_batched.hpp +++ b/clients/include/testing_rotm_batched.hpp @@ -25,7 +25,7 @@ void testing_rotm_batched_bad_arg(const Arguments& arg) rocblas_local_handle handle; device_vector dx(safe_size); device_vector dy(safe_size); - device_vector dparam(5); + device_vector dparam(safe_size); if(!dx || !dy || !dparam) { CHECK_HIP_ERROR(hipErrorOutOfMemory); @@ -65,7 +65,7 @@ void testing_rotm_batched(const Arguments& arg) static const size_t safe_size = 100; // arbitrarily set to 100 device_vector dx(safe_size); device_vector dy(safe_size); - device_vector dparam(5); + device_vector dparam(safe_size); if(!dx || !dy || !dparam) { CHECK_HIP_ERROR(hipErrorOutOfMemory); @@ -88,7 +88,7 @@ void testing_rotm_batched(const Arguments& arg) device_vector dx(batch_count); device_vector dy(batch_count); - device_vector dparam(5); + device_vector dparam(batch_count); if(!dx || !dy || !dparam) { @@ -99,20 +99,24 @@ void testing_rotm_batched(const Arguments& arg) // Initial Data on CPU host_vector hx[batch_count]; host_vector hy[batch_count]; - host_vector hdata(4); - host_vector hparam(5); + host_vector hdata[batch_count]; //(4); + host_vector hparam[batch_count]; //(5); device_batch_vector bx(batch_count, size_x); device_batch_vector by(batch_count, size_y); + device_batch_vector bdata(batch_count, 4); + device_batch_vector bparam(batch_count, 5); for(int b = 0; b < batch_count; b++) { - hx[b] = host_vector(size_x); - hy[b] = host_vector(size_y); + hx[b] = host_vector(size_x); + hy[b] = host_vector(size_y); + hdata[b] = host_vector(4); + hparam[b] = host_vector(5); } int last = batch_count - 1; - if((!bx[last] && size_x) || (!by[last] && size_y)) + if((!bx[last] && size_x) || (!by[last] && size_y) || !bdata[last] || !bparam[last]) { CHECK_HIP_ERROR(hipErrorOutOfMemory); return; @@ -123,17 +127,20 @@ void testing_rotm_batched(const Arguments& arg) { rocblas_init(hx[b], 1, N, incx); rocblas_init(hy[b], 1, N, incy); + rocblas_init(hdata[b], 1, 4, 1); + + // CPU BLAS reference data + cblas_rotmg(&hdata[b][0], &hdata[b][1], &hdata[b][2], &hdata[b][3], hparam[b]); } - rocblas_init(hdata, 1, 4, 1); - // CPU BLAS reference data - cblas_rotmg(&hdata[0], &hdata[1], &hdata[2], &hdata[3], hparam); constexpr int FLAG_COUNT = 4; const T FLAGS[FLAG_COUNT] = {-1, 0, 1, -2}; for(int i = 0; i < FLAG_COUNT; i++) { - hparam[0] = FLAGS[i]; + for(int b = 0; b < batch_count; b++) + hparam[b][0] = FLAGS[i]; + host_vector cx[batch_count]; host_vector cy[batch_count]; cpu_time_used = get_time_us(); @@ -142,52 +149,53 @@ void testing_rotm_batched(const Arguments& arg) cx[b] = hx[b]; cy[b] = hy[b]; - cblas_rotm(N, cx[b], incx, cy[b], incy, hparam); + cblas_rotm(N, cx[b], incx, cy[b], incy, hparam[b]); } cpu_time_used = get_time_us() - cpu_time_used; if(arg.unit_check || arg.norm_check) { // Test rocblas_pointer_mode_host - { - CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); - for(int b = 0; b < batch_count; b++) - { - CHECK_HIP_ERROR( - hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR( - hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); - } - CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + // TODO: THIS IS NO LONGER SUPPORTED + // { + // CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + // for(int b = 0; b < batch_count; b++) + // { + // CHECK_HIP_ERROR( + // hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR( + // hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + // } + // CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); - CHECK_ROCBLAS_ERROR( - (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, hparam, batch_count))); + // CHECK_ROCBLAS_ERROR( + // (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, hparam, batch_count))); - host_vector rx[batch_count]; - host_vector ry[batch_count]; - for(int b = 0; b < batch_count; b++) - { - rx[b] = host_vector(size_x); - ry[b] = host_vector(size_y); - CHECK_HIP_ERROR( - hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); - CHECK_HIP_ERROR( - hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); - } + // host_vector rx[batch_count]; + // host_vector ry[batch_count]; + // for(int b = 0; b < batch_count; b++) + // { + // rx[b] = host_vector(size_x); + // ry[b] = host_vector(size_y); + // CHECK_HIP_ERROR( + // hipMemcpy(rx[b], bx[b], sizeof(T) * size_x, hipMemcpyDeviceToHost)); + // CHECK_HIP_ERROR( + // hipMemcpy(ry[b], by[b], sizeof(T) * size_y, hipMemcpyDeviceToHost)); + // } - if(arg.unit_check) - { - T rel_error = std::numeric_limits::epsilon() * 1000; - near_check_general(1, N, batch_count, incx, cx, rx, rel_error); - near_check_general(1, N, batch_count, incy, cy, ry, rel_error); - } - if(arg.norm_check) - { - norm_error_host_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); - norm_error_host_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); - } - } + // if(arg.unit_check) + // { + // T rel_error = std::numeric_limits::epsilon() * 1000; + // near_check_general(1, N, batch_count, incx, cx, rx, rel_error); + // near_check_general(1, N, batch_count, incy, cy, ry, rel_error); + // } + // if(arg.norm_check) + // { + // norm_error_host_x = norm_check_general('F', 1, N, batch_count, incx, cx, rx); + // norm_error_host_y = norm_check_general('F', 1, N, batch_count, incy, cy, ry); + // } + // } // Test rocblas_pointer_mode_device { @@ -198,10 +206,13 @@ void testing_rotm_batched(const Arguments& arg) hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); CHECK_HIP_ERROR( hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(bparam[b], hparam[b], sizeof(T) * 5, hipMemcpyHostToDevice)); } CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dparam, hparam, sizeof(T) * 5, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparam, bparam, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_ROCBLAS_ERROR( (rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count))); @@ -238,23 +249,27 @@ void testing_rotm_batched(const Arguments& arg) { int number_cold_calls = 2; int number_hot_calls = 100; - CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); for(int b = 0; b < batch_count; b++) { CHECK_HIP_ERROR(hipMemcpy(bx[b], hx[b], sizeof(T) * size_x, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(by[b], hy[b], sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(bparam[b], hparam[b], sizeof(T) * 5, hipMemcpyHostToDevice)); } CHECK_HIP_ERROR(hipMemcpy(dx, bx, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dy, by, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparam, bparam, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); for(int iter = 0; iter < number_cold_calls; iter++) { - rocblas_rotm_batched(handle, N, dx, incx, dy, incy, hparam, batch_count); + rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count); } gpu_time_used = get_time_us(); // in microseconds for(int iter = 0; iter < number_hot_calls; iter++) { - rocblas_rotm_batched(handle, N, dx, incx, dy, incy, hparam, batch_count); + rocblas_rotm_batched(handle, N, dx, incx, dy, incy, dparam, batch_count); } gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; diff --git a/clients/include/testing_rotm_strided_batched.hpp b/clients/include/testing_rotm_strided_batched.hpp index d80632bfd..7f67e3bef 100644 --- a/clients/include/testing_rotm_strided_batched.hpp +++ b/clients/include/testing_rotm_strided_batched.hpp @@ -16,18 +16,19 @@ template void testing_rotm_strided_batched_bad_arg(const Arguments& arg) { - rocblas_int N = 100; - rocblas_int incx = 1; - rocblas_stride stride_x = 1; - rocblas_int incy = 1; - rocblas_stride stride_y = 1; - rocblas_int batch_count = 5; - static const size_t safe_size = 100; + rocblas_int N = 100; + rocblas_int incx = 1; + rocblas_stride stride_x = 1; + rocblas_int incy = 1; + rocblas_stride stride_y = 1; + rocblas_stride stride_param = 1; + rocblas_int batch_count = 5; + static const size_t safe_size = 100; rocblas_local_handle handle; device_vector dx(safe_size); device_vector dy(safe_size); - device_vector dparam(5); + device_vector dparam(safe_size); if(!dx || !dy || !dparam) { CHECK_HIP_ERROR(hipErrorOutOfMemory); @@ -36,31 +37,48 @@ void testing_rotm_strided_batched_bad_arg(const Arguments& arg) EXPECT_ROCBLAS_STATUS( (rocblas_rotm_strided_batched( - nullptr, N, dx, incx, stride_x, dy, incy, stride_y, dparam, batch_count)), + nullptr, N, dx, incx, stride_x, dy, incy, stride_y, dparam, stride_param, batch_count)), rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS((rocblas_rotm_strided_batched(handle, + N, + nullptr, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count)), + rocblas_status_invalid_pointer); + EXPECT_ROCBLAS_STATUS((rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + nullptr, + incy, + stride_y, + dparam, + stride_param, + batch_count)), + rocblas_status_invalid_pointer); EXPECT_ROCBLAS_STATUS( (rocblas_rotm_strided_batched( - handle, N, nullptr, incx, stride_x, dy, incy, stride_y, dparam, batch_count)), - rocblas_status_invalid_pointer); - EXPECT_ROCBLAS_STATUS( - (rocblas_rotm_strided_batched( - handle, N, dx, incx, stride_x, nullptr, incy, stride_y, dparam, batch_count)), - rocblas_status_invalid_pointer); - EXPECT_ROCBLAS_STATUS( - (rocblas_rotm_strided_batched( - handle, N, dx, incx, stride_x, dy, incy, stride_y, nullptr, batch_count)), + handle, N, dx, incx, stride_x, dy, incy, stride_y, nullptr, stride_param, batch_count)), rocblas_status_invalid_pointer); } template void testing_rotm_strided_batched(const Arguments& arg) { - rocblas_int N = arg.N; - rocblas_int incx = arg.incx; - rocblas_int stride_x = arg.stride_x; - rocblas_int stride_y = arg.stride_y; - rocblas_int incy = arg.incy; - rocblas_int batch_count = arg.batch_count; + rocblas_int N = arg.N; + rocblas_int incx = arg.incx; + rocblas_int stride_x = arg.stride_x; + rocblas_int stride_y = arg.stride_y; + rocblas_int stride_param = arg.stride_c; + rocblas_int incy = arg.incy; + rocblas_int batch_count = arg.batch_count; rocblas_local_handle handle; double gpu_time_used, cpu_time_used; @@ -73,7 +91,7 @@ void testing_rotm_strided_batched(const Arguments& arg) static const size_t safe_size = 100; // arbitrarily set to 100 device_vector dx(safe_size); device_vector dy(safe_size); - device_vector dparam(5); + device_vector dparam(safe_size); if(!dx || !dy || !dparam) { CHECK_HIP_ERROR(hipErrorOutOfMemory); @@ -82,22 +100,40 @@ void testing_rotm_strided_batched(const Arguments& arg) CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); if(batch_count < 0) - EXPECT_ROCBLAS_STATUS( - (rocblas_rotm_strided_batched< - T>)(handle, N, dx, incx, stride_x, dy, incy, stride_y, dparam, batch_count), - rocblas_status_invalid_size); + EXPECT_ROCBLAS_STATUS((rocblas_rotm_strided_batched)(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count), + rocblas_status_invalid_size); else - CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched( - handle, N, dx, incx, stride_x, dy, incy, stride_y, dparam, batch_count))); + CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count))); return; } - size_t size_x = N * size_t(incx) + size_t(stride_x) * size_t(batch_count - 1); - size_t size_y = N * size_t(incy) + size_t(stride_y) * size_t(batch_count - 1); + size_t size_x = N * size_t(incx) + size_t(stride_x) * size_t(batch_count - 1); + size_t size_y = N * size_t(incy) + size_t(stride_y) * size_t(batch_count - 1); + size_t size_param = 5 + size_t(stride_param) * size_t(batch_count - 1); device_vector dx(size_x); device_vector dy(size_y); - device_vector dparam(5); + device_vector dparam(size_param); if(!dx || !dy || !dparam) { CHECK_HIP_ERROR(hipErrorOutOfMemory); @@ -107,66 +143,86 @@ void testing_rotm_strided_batched(const Arguments& arg) // Initial Data on CPU host_vector hx(size_x); host_vector hy(size_y); - host_vector hdata(4); - host_vector hparam(5); + host_vector hdata(4 * batch_count); + host_vector hparam(size_param); rocblas_seedrand(); rocblas_init(hx, 1, N, incx, stride_x, batch_count); rocblas_init(hy, 1, N, incy, stride_y, batch_count); - rocblas_init(hdata, 1, 4, 1); + rocblas_init(hdata, 1, 4, 1, 4, batch_count); // CPU BLAS reference data - cblas_rotmg(&hdata[0], &hdata[1], &hdata[2], &hdata[3], hparam); + for(int b = 0; b < batch_count; b++) + cblas_rotmg(hdata + b * 4, + hdata + b * 4 + 1, + hdata + b * 4 + 2, + hdata + b * 4 + 3, + hparam + b * stride_param); + constexpr int FLAG_COUNT = 4; const T FLAGS[FLAG_COUNT] = {-1, 0, 1, -2}; for(int i = 0; i < FLAG_COUNT; i++) { - hparam[0] = FLAGS[i]; + for(int b = 0; b < batch_count; b++) + (hparam + b * stride_param)[0] = FLAGS[i]; + host_vector cx = hx; host_vector cy = hy; cpu_time_used = get_time_us(); for(int b = 0; b < batch_count; b++) { - cblas_rotm(N, cx + b * stride_x, incx, cy + b * stride_y, incy, hparam); + cblas_rotm( + N, cx + b * stride_x, incx, cy + b * stride_y, incy, hparam + b * stride_param); } cpu_time_used = get_time_us() - cpu_time_used; if(arg.unit_check || arg.norm_check) { // Test rocblas_pointer_mode_host - { - CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); - CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); - CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched( - handle, N, dx, incx, stride_x, dy, incy, stride_y, hparam, batch_count))); - host_vector rx(size_x); - host_vector ry(size_y); - CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); - CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); - if(arg.unit_check) - { - T rel_error = std::numeric_limits::epsilon() * 1000; - near_check_general(1, N, batch_count, incx, stride_x, cx, rx, rel_error); - near_check_general(1, N, batch_count, incy, stride_y, cy, ry, rel_error); - } - if(arg.norm_check) - { - norm_error_host_x - = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); - norm_error_host_y - = norm_check_general('F', 1, N, incy, stride_x, batch_count, cy, ry); - } - } + // TODO: THIS IS NO LONGER SUPPORTED + // { + // CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + // CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + // CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched( + // handle, N, dx, incx, stride_x, dy, incy, stride_y, hparam, batch_count))); + // host_vector rx(size_x); + // host_vector ry(size_y); + // CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); + // CHECK_HIP_ERROR(hipMemcpy(ry, dy, sizeof(T) * size_y, hipMemcpyDeviceToHost)); + // if(arg.unit_check) + // { + // T rel_error = std::numeric_limits::epsilon() * 1000; + // near_check_general(1, N, batch_count, incx, stride_x, cx, rx, rel_error); + // near_check_general(1, N, batch_count, incy, stride_y, cy, ry, rel_error); + // } + // if(arg.norm_check) + // { + // norm_error_host_x + // = norm_check_general('F', 1, N, incx, stride_x, batch_count, cx, rx); + // norm_error_host_y + // = norm_check_general('F', 1, N, incy, stride_x, batch_count, cy, ry); + // } + // } // Test rocblas_pointer_mode_device { CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dparam, hparam, sizeof(T) * 5, hipMemcpyHostToDevice)); - CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched( - handle, N, dx, incx, stride_x, dy, incy, stride_y, dparam, batch_count))); + CHECK_HIP_ERROR( + hipMemcpy(dparam, hparam, sizeof(T) * size_param, hipMemcpyHostToDevice)); + CHECK_ROCBLAS_ERROR((rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count))); host_vector rx(size_x); host_vector ry(size_y); CHECK_HIP_ERROR(hipMemcpy(rx, dx, sizeof(T) * size_x, hipMemcpyDeviceToHost)); @@ -191,20 +247,40 @@ void testing_rotm_strided_batched(const Arguments& arg) { int number_cold_calls = 2; int number_hot_calls = 100; - CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); CHECK_HIP_ERROR(hipMemcpy(dx, hx, sizeof(T) * size_x, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dy, hy, sizeof(T) * size_y, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparam, hparam, sizeof(T) * size_param, hipMemcpyHostToDevice)); for(int iter = 0; iter < number_cold_calls; iter++) { - rocblas_rotm_strided_batched( - handle, N, dx, incx, stride_x, dy, incy, stride_y, hparam, batch_count); + rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count); } gpu_time_used = get_time_us(); // in microseconds for(int iter = 0; iter < number_hot_calls; iter++) { - rocblas_rotm_strided_batched( - handle, N, dx, incx, stride_x, dy, incy, stride_y, hparam, batch_count); + rocblas_rotm_strided_batched(handle, + N, + dx, + incx, + stride_x, + dy, + incy, + stride_y, + dparam, + stride_param, + batch_count); } gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; diff --git a/clients/include/testing_rotmg_batched.hpp b/clients/include/testing_rotmg_batched.hpp index 3686ee4b7..728362813 100644 --- a/clients/include/testing_rotmg_batched.hpp +++ b/clients/include/testing_rotmg_batched.hpp @@ -24,7 +24,7 @@ void testing_rotmg_batched_bad_arg(const Arguments& arg) device_vector d2(batch_count); device_vector x1(batch_count); device_vector y1(batch_count); - device_vector param(safe_size); + device_vector param(batch_count); if(!d1 || !d2 || !x1 || !y1 || !param) { @@ -68,7 +68,7 @@ void testing_rotmg_batched(const Arguments& arg) device_vector d2(safe_size); device_vector x1(safe_size); device_vector y1(safe_size); - device_vector params(5); + device_vector params(safe_size); if(!d1 || !d2 || !x1 || !y1 || !params) { @@ -93,19 +93,21 @@ void testing_rotmg_batched(const Arguments& arg) host_vector hd2[batch_count]; host_vector hx1[batch_count]; host_vector hy1[batch_count]; - host_vector hparams(5); + host_vector hparams[batch_count]; device_batch_vector bd1(batch_count, 1); device_batch_vector bd2(batch_count, 1); device_batch_vector bx1(batch_count, 1); device_batch_vector by1(batch_count, 1); + device_batch_vector bparams(batch_count, 5); for(int b = 0; b < batch_count; b++) { - hd1[b] = host_vector(1); - hd2[b] = host_vector(1); - hx1[b] = host_vector(1); - hy1[b] = host_vector(1); + hd1[b] = host_vector(1); + hd2[b] = host_vector(1); + hx1[b] = host_vector(1); + hy1[b] = host_vector(1); + hparams[b] = host_vector(5); } for(int i = 0; i < TEST_COUNT; i++) @@ -114,26 +116,28 @@ void testing_rotmg_batched(const Arguments& arg) host_vector cd2[batch_count]; host_vector cx1[batch_count]; host_vector cy1[batch_count]; + host_vector cparams[batch_count]; rocblas_seedrand(); - rocblas_init(hparams, 1, 5, 1); - host_vector cparams = hparams; + for(int b = 0; b < batch_count; b++) { rocblas_init(hd1[b], 1, 1, 1); rocblas_init(hd2[b], 1, 1, 1); rocblas_init(hx1[b], 1, 1, 1); rocblas_init(hy1[b], 1, 1, 1); - cd1[b] = hd1[b]; - cd2[b] = hd2[b]; - cx1[b] = hx1[b]; - cy1[b] = hy1[b]; + rocblas_init(hparams[b], 1, 5, 1); + cd1[b] = hd1[b]; + cd2[b] = hd2[b]; + cx1[b] = hx1[b]; + cy1[b] = hy1[b]; + cparams[b] = hparams[b]; } cpu_time_used = get_time_us(); for(int b = 0; b < batch_count; b++) { - cblas_rotmg(cd1[b], cd2[b], cx1[b], cy1[b], cparams); + cblas_rotmg(cd1[b], cd2[b], cx1[b], cy1[b], cparams[b]); } cpu_time_used = get_time_us() - cpu_time_used; @@ -143,22 +147,25 @@ void testing_rotmg_batched(const Arguments& arg) host_vector rd2[batch_count]; host_vector rx1[batch_count]; host_vector ry1[batch_count]; + host_vector rparams[batch_count]; T* rd1_in[batch_count]; T* rd2_in[batch_count]; T* rx1_in[batch_count]; T* ry1_in[batch_count]; + T* rparams_in[batch_count]; for(int b = 0; b < batch_count; b++) { rd1_in[b] = rd1[b] = hd1[b]; rd2_in[b] = rd2[b] = hd2[b]; rx1_in[b] = rx1[b] = hx1[b]; ry1_in[b] = ry1[b] = hy1[b]; + rparams_in[b] = rparams[b] = hparams[b]; } CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); CHECK_ROCBLAS_ERROR((rocblas_rotmg_batched( - handle, rd1_in, rd2_in, rx1_in, ry1_in, hparams, batch_count))); + handle, rd1_in, rd2_in, rx1_in, ry1_in, rparams_in, batch_count))); if(arg.unit_check) { @@ -166,6 +173,7 @@ void testing_rotmg_batched(const Arguments& arg) unit_check_general(1, 1, batch_count, 1, rd2, cd2); unit_check_general(1, 1, batch_count, 1, rx1, cx1); unit_check_general(1, 1, batch_count, 1, ry1, cy1); + unit_check_general(1, 5, batch_count, 1, rparams, cparams); } if(arg.norm_check) @@ -174,6 +182,8 @@ void testing_rotmg_batched(const Arguments& arg) norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rd2, cd2); norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, rx1, cx1); norm_error_host += norm_check_general('F', 1, 1, batch_count, 1, ry1, cy1); + norm_error_host + += norm_check_general('F', 1, 5, batch_count, 1, rparams, cparams); } } @@ -183,11 +193,12 @@ void testing_rotmg_batched(const Arguments& arg) device_vector dd2(batch_count); device_vector dx1(batch_count); device_vector dy1(batch_count); + device_vector dparams(batch_count); device_batch_vector bd1(batch_count, 1); device_batch_vector bd2(batch_count, 1); device_batch_vector bx1(batch_count, 1); device_batch_vector by1(batch_count, 1); - device_vector dparams(5); + device_batch_vector bparams(batch_count, 5); for(int b = 0; b < batch_count; b++) { @@ -195,12 +206,15 @@ void testing_rotmg_batched(const Arguments& arg) CHECK_HIP_ERROR(hipMemcpy(bd2[b], hd2[b], sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(bx1[b], hx1[b], sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(by1[b], hy1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(bparams[b], hparams[b], sizeof(T) * 5, hipMemcpyHostToDevice)); } CHECK_HIP_ERROR(hipMemcpy(dd1, bd1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dd2, bd2, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dx1, bx1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dy1, by1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * 5, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparams, bparams, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); CHECK_ROCBLAS_ERROR( @@ -210,16 +224,20 @@ void testing_rotmg_batched(const Arguments& arg) host_vector rd2[batch_count]; host_vector rx1[batch_count]; host_vector ry1[batch_count]; + host_vector rparams[batch_count]; for(int b = 0; b < batch_count; b++) { - rd1[b] = host_vector(1); - rd2[b] = host_vector(1); - rx1[b] = host_vector(1); - ry1[b] = host_vector(1); + rd1[b] = host_vector(1); + rd2[b] = host_vector(1); + rx1[b] = host_vector(1); + ry1[b] = host_vector(1); + rparams[b] = host_vector(5); CHECK_HIP_ERROR(hipMemcpy(rd1[b], bd1[b], sizeof(T), hipMemcpyDeviceToHost)); CHECK_HIP_ERROR(hipMemcpy(rd2[b], bd2[b], sizeof(T), hipMemcpyDeviceToHost)); CHECK_HIP_ERROR(hipMemcpy(rx1[b], bx1[b], sizeof(T), hipMemcpyDeviceToHost)); CHECK_HIP_ERROR(hipMemcpy(ry1[b], by1[b], sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(rparams[b], bparams[b], sizeof(T) * 5, hipMemcpyDeviceToHost)); } if(arg.unit_check) @@ -228,6 +246,7 @@ void testing_rotmg_batched(const Arguments& arg) unit_check_general(1, 1, batch_count, 1, rd2, cd2); unit_check_general(1, 1, batch_count, 1, rx1, cx1); unit_check_general(1, 1, batch_count, 1, ry1, cy1); + unit_check_general(1, 5, batch_count, 1, rparams, cparams); } if(arg.norm_check) @@ -236,6 +255,8 @@ void testing_rotmg_batched(const Arguments& arg) norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rd2, cd2); norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, rx1, cx1); norm_error_device += norm_check_general('F', 1, 1, batch_count, 1, ry1, cy1); + norm_error_device + += norm_check_general('F', 1, 5, batch_count, 1, rparams, cparams); } } } @@ -244,31 +265,35 @@ void testing_rotmg_batched(const Arguments& arg) { int number_cold_calls = 2; int number_hot_calls = 100; - // Device mode will be much quicker - // (TODO: or is there another reason we are typically using host_mode for timing?) + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); device_vector dd1(batch_count); device_vector dd2(batch_count); device_vector dx1(batch_count); device_vector dy1(batch_count); + device_vector dparams(batch_count); device_batch_vector bd1(batch_count, 1); device_batch_vector bd2(batch_count, 1); device_batch_vector bx1(batch_count, 1); device_batch_vector by1(batch_count, 1); - device_vector dparams(5); + device_batch_vector bparams(batch_count, 5); + for(int b = 0; b < batch_count; b++) { CHECK_HIP_ERROR(hipMemcpy(bd1[b], hd1[b], sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(bd2[b], hd2[b], sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(bx1[b], hx1[b], sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(by1[b], hy1[b], sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(bparams[b], hparams[b], sizeof(T) * 5, hipMemcpyHostToDevice)); } CHECK_HIP_ERROR(hipMemcpy(dd1, bd1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dd2, bd2, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dx1, bx1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dy1, by1, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * 5, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparams, bparams, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); for(int iter = 0; iter < number_cold_calls; iter++) { diff --git a/clients/include/testing_rotmg_strided_batched.hpp b/clients/include/testing_rotmg_strided_batched.hpp index a01eddc40..127660796 100644 --- a/clients/include/testing_rotmg_strided_batched.hpp +++ b/clients/include/testing_rotmg_strided_batched.hpp @@ -32,35 +32,36 @@ void testing_rotmg_strided_batched_bad_arg(const Arguments& arg) return; } - EXPECT_ROCBLAS_STATUS( - (rocblas_rotmg_strided_batched(nullptr, d1, 0, d2, 0, x1, 0, y1, 0, param, batch_count)), - rocblas_status_invalid_handle); EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( - handle, nullptr, 0, d2, 0, x1, 0, y1, 0, param, batch_count)), + nullptr, d1, 0, d2, 0, x1, 0, y1, 0, param, 0, batch_count)), + rocblas_status_invalid_handle); + EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( + handle, nullptr, 0, d2, 0, x1, 0, y1, 0, param, 0, batch_count)), rocblas_status_invalid_pointer); EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( - handle, d1, 0, nullptr, 0, x1, 0, y1, 0, param, batch_count)), + handle, d1, 0, nullptr, 0, x1, 0, y1, 0, param, 0, batch_count)), rocblas_status_invalid_pointer); EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( - handle, d1, 0, d2, 0, nullptr, 0, y1, 0, param, batch_count)), + handle, d1, 0, d2, 0, nullptr, 0, y1, 0, param, 0, batch_count)), rocblas_status_invalid_pointer); EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( - handle, d1, 0, d2, 0, x1, 0, nullptr, 0, param, batch_count)), + handle, d1, 0, d2, 0, x1, 0, nullptr, 0, param, 0, batch_count)), rocblas_status_invalid_pointer); EXPECT_ROCBLAS_STATUS((rocblas_rotmg_strided_batched( - handle, d1, 0, d2, 0, x1, 0, y1, 0, nullptr, batch_count)), + handle, d1, 0, d2, 0, x1, 0, y1, 0, nullptr, 0, batch_count)), rocblas_status_invalid_pointer); } template void testing_rotmg_strided_batched(const Arguments& arg) { - const int TEST_COUNT = 100; - rocblas_int batch_count = arg.batch_count; - rocblas_int stride_d1 = arg.stride_a; - rocblas_int stride_d2 = arg.stride_b; - rocblas_int stride_x1 = arg.stride_x; - rocblas_int stride_y1 = arg.stride_y; + const int TEST_COUNT = 100; + rocblas_int batch_count = arg.batch_count; + rocblas_int stride_d1 = arg.stride_a; + rocblas_int stride_d2 = arg.stride_b; + rocblas_int stride_x1 = arg.stride_x; + rocblas_int stride_y1 = arg.stride_y; + rocblas_int stride_param = arg.stride_c; rocblas_local_handle handle; double gpu_time_used, cpu_time_used; @@ -74,7 +75,7 @@ void testing_rotmg_strided_batched(const Arguments& arg) device_vector d2(safe_size); device_vector x1(safe_size); device_vector y1(safe_size); - device_vector params(5); + device_vector params(safe_size); if(!d1 || !d2 || !x1 || !y1 || !params) { @@ -94,6 +95,7 @@ void testing_rotmg_strided_batched(const Arguments& arg) y1, stride_y1, params, + stride_param, batch_count)), rocblas_status_invalid_size); else @@ -107,27 +109,29 @@ void testing_rotmg_strided_batched(const Arguments& arg) y1, stride_y1, params, + stride_param, batch_count))); return; } - size_t size_d1 = batch_count * stride_d1; - size_t size_d2 = batch_count * stride_d2; - size_t size_x1 = batch_count * stride_x1; - size_t size_y1 = batch_count * stride_y1; + size_t size_d1 = batch_count * stride_d1; + size_t size_d2 = batch_count * stride_d2; + size_t size_x1 = batch_count * stride_x1; + size_t size_y1 = batch_count * stride_y1; + size_t size_param = batch_count * stride_param; // Initial Data on CPU host_vector hd1(size_d1); host_vector hd2(size_d2); host_vector hx1(size_x1); host_vector hy1(size_y1); - host_vector hparams(5); + host_vector hparams(size_param); for(int i = 0; i < TEST_COUNT; i++) { rocblas_seedrand(); - rocblas_init(hparams, 1, 5, 1); + rocblas_init(hparams, 1, 5, 1, stride_param, batch_count); rocblas_init(hd1, 1, 1, 1, stride_d1, batch_count); rocblas_init(hd2, 1, 1, 1, stride_d2, batch_count); rocblas_init(hx1, 1, 1, 1, stride_x1, batch_count); @@ -146,16 +150,17 @@ void testing_rotmg_strided_batched(const Arguments& arg) cd2 + b * stride_d2, cx1 + b * stride_x1, cy1 + b * stride_y1, - cparams); + cparams + b * stride_param); } cpu_time_used = get_time_us() - cpu_time_used; // Test rocblas_pointer_mode_host { - host_vector rd1 = hd1; - host_vector rd2 = hd2; - host_vector rx1 = hx1; - host_vector ry1 = hy1; + host_vector rd1 = hd1; + host_vector rd2 = hd2; + host_vector rx1 = hx1; + host_vector ry1 = hy1; + host_vector rparams = hparams; CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); @@ -168,7 +173,8 @@ void testing_rotmg_strided_batched(const Arguments& arg) stride_x1, ry1, stride_y1, - hparams, + rparams, + stride_param, batch_count))); if(arg.unit_check) @@ -177,6 +183,7 @@ void testing_rotmg_strided_batched(const Arguments& arg) unit_check_general(1, 1, batch_count, 1, stride_d2, rd2, cd2); unit_check_general(1, 1, batch_count, 1, stride_x1, rx1, cx1); unit_check_general(1, 1, batch_count, 1, stride_y1, ry1, cy1); + unit_check_general(1, 5, batch_count, 1, stride_param, rparams, cparams); } if(arg.norm_check) @@ -189,6 +196,8 @@ void testing_rotmg_strided_batched(const Arguments& arg) += norm_check_general('F', 1, 1, 1, stride_x1, batch_count, rx1, cx1); norm_error_host += norm_check_general('F', 1, 1, 1, stride_y1, batch_count, ry1, cy1); + norm_error_host += norm_check_general( + 'F', 1, 5, 1, stride_param, batch_count, rparams, cparams); } } @@ -198,13 +207,14 @@ void testing_rotmg_strided_batched(const Arguments& arg) device_vector dd2(size_d2); device_vector dx1(size_x1); device_vector dy1(size_y1); - device_vector dparams(5); + device_vector dparams(size_param); CHECK_HIP_ERROR(hipMemcpy(dd1, hd1, sizeof(T) * size_d1, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dd2, hd2, sizeof(T) * size_d2, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dx1, hx1, sizeof(T) * size_x1, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dy1, hy1, sizeof(T) * size_y1, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * 5, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dparams, hparams, sizeof(T) * size_param, hipMemcpyHostToDevice)); CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); CHECK_ROCBLAS_ERROR((rocblas_rotmg_strided_batched(handle, @@ -217,17 +227,21 @@ void testing_rotmg_strided_batched(const Arguments& arg) dy1, stride_y1, dparams, + stride_param, batch_count))); host_vector rd1(size_d1); host_vector rd2(size_d2); host_vector rx1(size_x1); host_vector ry1(size_y1); + host_vector rparams(size_param); CHECK_HIP_ERROR(hipMemcpy(rd1, dd1, sizeof(T) * size_d1, hipMemcpyDeviceToHost)); CHECK_HIP_ERROR(hipMemcpy(rd2, dd2, sizeof(T) * size_d2, hipMemcpyDeviceToHost)); CHECK_HIP_ERROR(hipMemcpy(rx1, dx1, sizeof(T) * size_x1, hipMemcpyDeviceToHost)); CHECK_HIP_ERROR(hipMemcpy(ry1, dy1, sizeof(T) * size_y1, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(rparams, dparams, sizeof(T) * size_param, hipMemcpyDeviceToHost)); if(arg.unit_check) { @@ -235,6 +249,7 @@ void testing_rotmg_strided_batched(const Arguments& arg) unit_check_general(1, 1, batch_count, 1, stride_d2, rd2, cd2); unit_check_general(1, 1, batch_count, 1, stride_x1, rx1, cx1); unit_check_general(1, 1, batch_count, 1, stride_y1, ry1, cy1); + unit_check_general(1, 5, batch_count, 1, stride_param, rparams, cparams); } if(arg.norm_check) @@ -247,6 +262,8 @@ void testing_rotmg_strided_batched(const Arguments& arg) += norm_check_general('F', 1, 1, 1, stride_x1, batch_count, rx1, cx1); norm_error_device += norm_check_general('F', 1, 1, 1, stride_y1, batch_count, ry1, cy1); + norm_error_host += norm_check_general( + 'F', 1, 5, 1, stride_param, batch_count, rparams, cparams); } } } @@ -255,21 +272,20 @@ void testing_rotmg_strided_batched(const Arguments& arg) { int number_cold_calls = 2; int number_hot_calls = 100; - // Device mode will be much quicker - // (TODO: or is there another reason we are typically using host_mode for timing?) + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); device_vector dd1(size_d1); device_vector dd2(size_d2); device_vector dx1(size_x1); device_vector dy1(size_y1); - device_vector dparams(5); + device_vector dparams(size_param); CHECK_HIP_ERROR(hipMemcpy(dd1, hd1, sizeof(T) * size_d1, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dd2, hd2, sizeof(T) * size_d2, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dx1, hx1, sizeof(T) * size_x1, hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dy1, hy1, sizeof(T) * size_y1, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * 5, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dparams, hparams, sizeof(T) * size_param, hipMemcpyHostToDevice)); for(int iter = 0; iter < number_cold_calls; iter++) { @@ -283,6 +299,7 @@ void testing_rotmg_strided_batched(const Arguments& arg) dy1, stride_y1, dparams, + stride_param, batch_count); } gpu_time_used = get_time_us(); // in microseconds @@ -298,6 +315,7 @@ void testing_rotmg_strided_batched(const Arguments& arg) dy1, stride_y1, dparams, + stride_param, batch_count); } gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 6d513523b..29fe7d21e 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -1955,7 +1955,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, incy rocblas_int specifies the increment between elements of y. @param[in] - param vector of 5 elements defining the rotation. + param array of vectors of 5 elements defining the rotation. param[0] = flag param[1] = H11 param[2] = H21 @@ -1966,30 +1966,30 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm(rocblas_handle handle, flag = 0 => H = ( 1.0 H12 H21 1.0 ) flag = 1 => H = ( H11 1.0 -1.0 H22 ) flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) - param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + param may ONLY be stored on the device for the batched version of this function. @param[in] batch_count rocblas_int the number of x and y arrays, i.e. the number of batches. ********************************************************************/ -ROCBLAS_EXPORT rocblas_status rocblas_srotm_batched(rocblas_handle handle, - rocblas_int n, - float* const x[], - rocblas_int incx, - float* const y[], - rocblas_int incy, - const float* param, - rocblas_int batch_count); +ROCBLAS_EXPORT rocblas_status rocblas_srotm_batched(rocblas_handle handle, + rocblas_int n, + float* const x[], + rocblas_int incx, + float* const y[], + rocblas_int incy, + const float* const param[], + rocblas_int batch_count); -ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, - rocblas_int n, - double* const x[], - rocblas_int incx, - double* const y[], - rocblas_int incy, - const double* param, - rocblas_int batch_count); +ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, + rocblas_int n, + double* const x[], + rocblas_int incx, + double* const y[], + rocblas_int incy, + const double* const param[], + rocblas_int batch_count); /*! \brief BLAS Level 1 API @@ -2019,7 +2019,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, stride_y rocblas_stride specifies the increment between the beginning of y_i and y_(i + 1) @param[in] - param vector of 5 elements defining the rotation. + param strided_batched array of vectors of 5 elements defining the rotation. param[0] = flag param[1] = H11 param[2] = H21 @@ -2030,7 +2030,10 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, flag = 0 => H = ( 1.0 H12 H21 1.0 ) flag = 1 => H = ( H11 1.0 -1.0 H22 ) flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) - param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + param may ONLY be stored on the device for the strided_batched version of this function. + @param[in] + stride_param rocblas_stride + specifies the increment between the beginning of param_i and param_(i + 1) @param[in] batch_count rocblas_int the number of x and y arrays, i.e. the number of batches. @@ -2046,6 +2049,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotm_strided_batched(rocblas_handle handl rocblas_int incy, rocblas_stride stride_y, const float* param, + rocblas_stride stride_param, rocblas_int batch_count); ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handle, @@ -2057,6 +2061,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handl rocblas_int incy, rocblas_stride stride_y, const double* param, + rocblas_stride stride_param, rocblas_int batch_count); /*! \brief BLAS Level 1 API @@ -2120,7 +2125,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotmg( @param[in] y1 batched array of input scalars. @param[out] - param vector of 5 elements defining the rotation. + param batched array of vectors of 5 elements defining the rotation. param[0] = flag param[1] = H11 param[2] = H21 @@ -2143,7 +2148,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotmg_batched(rocblas_handle handle, float* const d2[], float* const x1[], const float* const y1[], - float* param, + float* const param[], rocblas_int batch_count); ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, @@ -2151,7 +2156,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, double* const d2[], double* const x1[], const double* const y1[], - double* param, + double* const param[], rocblas_int batch_count); /*! \brief BLAS Level 1 API @@ -2186,7 +2191,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, stride_y1 rocblas_stride specifies the increment between the beginning of y1_i and y1_(i+1) @param[out] - param vector of 5 elements defining the rotation. + param batched array of vectors of 5 elements defining the rotation. param[0] = flag param[1] = H11 param[2] = H21 @@ -2214,6 +2219,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotmg_strided_batched(rocblas_handle hand const float* y1, rocblas_stride stride_y1, float* param, + rocblas_stride stride_param, rocblas_int batch_count); ROCBLAS_EXPORT rocblas_status rocblas_drotmg_strided_batched(rocblas_handle handle, @@ -2226,6 +2232,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotmg_strided_batched(rocblas_handle hand const double* y1, rocblas_stride stride_y1, double* param, + rocblas_stride stride_param, rocblas_int batch_count); /* diff --git a/library/src/blas1/rocblas_rotm.cpp b/library/src/blas1/rocblas_rotm.cpp index 22fad8b15..96339e12a 100644 --- a/library/src/blas1/rocblas_rotm.cpp +++ b/library/src/blas1/rocblas_rotm.cpp @@ -51,8 +51,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - return rocblas_rotm_template( - handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 1, (T*)nullptr); + return rocblas_rotm_template( + handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 0, 1); } } // namespace diff --git a/library/src/blas1/rocblas_rotm.hpp b/library/src/blas1/rocblas_rotm.hpp index da4ab3855..25a7d0e25 100644 --- a/library/src/blas1/rocblas_rotm.hpp +++ b/library/src/blas1/rocblas_rotm.hpp @@ -7,30 +7,24 @@ #include "utility.h" template -__global__ void rotm_kernel(rocblas_int n, - T x_in, - rocblas_int offset_x, - rocblas_int incx, - rocblas_stride stride_x, - T y_in, - rocblas_int offset_y, - rocblas_int incy, - rocblas_stride stride_y, - U flag_device_host, - U h11_device_host, - U h21_device_host, - U h12_device_host, - U h22_device_host, - rocblas_stride stride_param) +__device__ void rotm_kernel_calc(rocblas_int n, + T x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U flag, + U h11, + U h21, + U h12, + U h22) { - auto flag = load_scalar(flag_device_host, hipBlockIdx_y, stride_param); - auto h11 = load_scalar(h11_device_host, hipBlockIdx_y, stride_param); - auto h21 = load_scalar(h21_device_host, hipBlockIdx_y, stride_param); - auto h12 = load_scalar(h12_device_host, hipBlockIdx_y, stride_param); - auto h22 = load_scalar(h22_device_host, hipBlockIdx_y, stride_param); - auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); - auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); - ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + auto x = load_ptr_batch(x_in, hipBlockIdx_y, offset_x, stride_x); + auto y = load_ptr_batch(y_in, hipBlockIdx_y, offset_y, stride_y); + ptrdiff_t tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; if(tid < n && flag != -2) { @@ -56,38 +50,110 @@ __global__ void rotm_kernel(rocblas_int n, } } -template +template +__global__ void rotm_kernel_batched(rocblas_int n, + T x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + U param, + rocblas_int offset_param, + rocblas_stride stride_param) +{ + auto p = load_ptr_batch(param, hipBlockIdx_y, offset_param, stride_param); + auto flag = p[0]; + auto h11 = p[1]; + auto h21 = p[2]; + auto h12 = p[3]; + auto h22 = p[4]; + rotm_kernel_calc(n, + x_in, + offset_x, + incx, + stride_x, + y_in, + offset_y, + incy, + stride_y, + flag, + h11, + h21, + h12, + h22); +} + +template +__global__ void rotm_kernel_regular(rocblas_int n, + T* x_in, + rocblas_int offset_x, + rocblas_int incx, + rocblas_stride stride_x, + T* y_in, + rocblas_int offset_y, + rocblas_int incy, + rocblas_stride stride_y, + T flag, + T h11, + T h21, + T h12, + T h22) +{ + rotm_kernel_calc(n, + x_in, + offset_x, + incx, + stride_x, + y_in, + offset_y, + incy, + stride_y, + flag, + h11, + h21, + h12, + h22); +} + +// Workaround to avoid constexpr if - Helper function to quick return when param[0] == -2 +template +bool quick_return_param(const T* param, rocblas_stride stride_param) +{ + if(param[0] == -2 && stride_param == 0) + return true; + return false; +} + +template +bool quick_return_param(const T* const param[], rocblas_stride stride_param) +{ + return false; +} + +template rocblas_status rocblas_rotm_template(rocblas_handle handle, rocblas_int n, - U x, + T x, rocblas_int offset_x, rocblas_int incx, rocblas_stride stride_x, - U y, + T y, rocblas_int offset_y, rocblas_int incy, rocblas_stride stride_y, - const T* param, + U param, + rocblas_int offset_param, rocblas_stride stride_param, - rocblas_int batch_count, - T* mem) + rocblas_int batch_count) { - // Memory queries must be in template as _impl doesn't have stride_param parameter (for calls from - // outside of rocblas) - if(handle->is_device_memory_size_query()) - { - // TODO: Decide if we want to support this or not. - if(stride_param && rocblas_pointer_mode_host == handle->pointer_mode && n > 0 && incx > 0 - && incy > 0 && batch_count > 0) - return handle->set_optimal_device_memory_size(sizeof(T) * batch_count * stride_param); - else - return rocblas_status_size_unchanged; - } - // Quick return if possible if(n <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) return rocblas_status_success; - if(rocblas_pointer_mode_host == handle->pointer_mode && param[0] == -2) + + if(quick_return_param(param, stride_param)) return rocblas_status_success; dim3 blocks((n - 1) / NB + 1, batch_count); @@ -95,7 +161,7 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, hipStream_t rocblas_stream = handle->rocblas_stream; if(rocblas_pointer_mode_device == handle->pointer_mode) - hipLaunchKernelGGL(rotm_kernel, + hipLaunchKernelGGL(rotm_kernel_batched, blocks, threads, 0, @@ -110,13 +176,10 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, incy, stride_y, param, - param + 1, - param + 2, - param + 3, - param + 4, + offset_param, stride_param); - else if(!stride_param) // single param on host - hipLaunchKernelGGL(rotm_kernel, + else if(!BATCHED_OR_STRIDED) + hipLaunchKernelGGL(rotm_kernel_regular, blocks, threads, 0, @@ -134,34 +197,12 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, param[1], param[2], param[3], - param[4], - 0); - else // array of params on host, copy to device + param[4]); + else // host mode not implemented for (strided_)batched functions { - // This should NOT happen from calls from the API currently. - RETURN_IF_HIP_ERROR( - hipMemcpy(mem, param, sizeof(T) * batch_count * stride_param, hipMemcpyHostToDevice)); - - hipLaunchKernelGGL(rotm_kernel, - blocks, - threads, - 0, - rocblas_stream, - n, - x, - offset_x, - incx, - stride_x, - y, - offset_y, - incy, - stride_y, - mem, - mem + 1, - mem + 2, - mem + 3, - mem + 4, - stride_param); + // TODO: if desired we can use a host for loop to iterate through + // batches in this scenario. Currently simply not implemented. + return rocblas_status_not_implemented; } return rocblas_status_success; diff --git a/library/src/blas1/rocblas_rotm_batched.cpp b/library/src/blas1/rocblas_rotm_batched.cpp index 09e2e12b0..64ed7bff8 100644 --- a/library/src/blas1/rocblas_rotm_batched.cpp +++ b/library/src/blas1/rocblas_rotm_batched.cpp @@ -25,7 +25,7 @@ namespace rocblas_int incx, T* const y[], rocblas_int incy, - const T* param, + const T* const param[], rocblas_int batch_count) { if(!handle) @@ -65,8 +65,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - return rocblas_rotm_template( - handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, batch_count, (T*)nullptr); + return rocblas_rotm_template( + handle, n, x, 0, incx, 0, y, 0, incy, 0, param, 0, 0, batch_count); } } // namespace @@ -79,26 +79,26 @@ namespace extern "C" { -ROCBLAS_EXPORT rocblas_status rocblas_srotm_batched(rocblas_handle handle, - rocblas_int n, - float* const x[], - rocblas_int incx, - float* const y[], - rocblas_int incy, - const float* param, - rocblas_int batch_count) +ROCBLAS_EXPORT rocblas_status rocblas_srotm_batched(rocblas_handle handle, + rocblas_int n, + float* const x[], + rocblas_int incx, + float* const y[], + rocblas_int incy, + const float* const param[], + rocblas_int batch_count) { return rocblas_rotm_batched_impl(handle, n, x, incx, y, incy, param, batch_count); } -ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, - rocblas_int n, - double* const x[], - rocblas_int incx, - double* const y[], - rocblas_int incy, - const double* param, - rocblas_int batch_count) +ROCBLAS_EXPORT rocblas_status rocblas_drotm_batched(rocblas_handle handle, + rocblas_int n, + double* const x[], + rocblas_int incx, + double* const y[], + rocblas_int incy, + const double* const param[], + rocblas_int batch_count) { return rocblas_rotm_batched_impl(handle, n, x, incx, y, incy, param, batch_count); } diff --git a/library/src/blas1/rocblas_rotm_strided_batched.cpp b/library/src/blas1/rocblas_rotm_strided_batched.cpp index 41204a052..0dbeeb97d 100644 --- a/library/src/blas1/rocblas_rotm_strided_batched.cpp +++ b/library/src/blas1/rocblas_rotm_strided_batched.cpp @@ -28,6 +28,7 @@ namespace rocblas_int incy, rocblas_stride stride_y, const T* param, + rocblas_stride stride_param, rocblas_int batch_count) { if(!handle) @@ -85,20 +86,20 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - return rocblas_rotm_template(handle, - n, - x, - 0, - incx, - stride_x, - y, - 0, - incy, - stride_y, - param, - 0, - batch_count, - (T*)nullptr); + return rocblas_rotm_template(handle, + n, + x, + 0, + incx, + stride_x, + y, + 0, + incy, + stride_y, + param, + 0, + stride_param, + batch_count); } } // namespace @@ -120,10 +121,11 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotm_strided_batched(rocblas_handle handl rocblas_int incy, rocblas_stride stride_y, const float* param, + rocblas_stride stride_param, rocblas_int batch_count) { return rocblas_rotm_strided_batched_impl( - handle, n, x, incx, stride_x, y, incy, stride_y, param, batch_count); + handle, n, x, incx, stride_x, y, incy, stride_y, param, stride_param, batch_count); } ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handle, @@ -135,10 +137,11 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handl rocblas_int incy, rocblas_stride stride_y, const double* param, + rocblas_stride stride_param, rocblas_int batch_count) { return rocblas_rotm_strided_batched_impl( - handle, n, x, incx, stride_x, y, incy, stride_y, param, batch_count); + handle, n, x, incx, stride_x, y, incy, stride_y, param, stride_param, batch_count); } } // extern "C" diff --git a/library/src/blas1/rocblas_rotmg.cpp b/library/src/blas1/rocblas_rotmg.cpp index dc3773f11..e7a8e3e64 100644 --- a/library/src/blas1/rocblas_rotmg.cpp +++ b/library/src/blas1/rocblas_rotmg.cpp @@ -36,7 +36,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - return rocblas_rotmg_template(handle, d1, 0, 0, d2, 0, 0, x1, 0, 0, y1, 0, 0, param, 0, 1); + return rocblas_rotmg_template( + handle, d1, 0, 0, d2, 0, 0, x1, 0, 0, y1, 0, 0, param, 0, 0, 1); } } // namespace diff --git a/library/src/blas1/rocblas_rotmg.hpp b/library/src/blas1/rocblas_rotmg.hpp index 8c77e25e1..49b1e2e31 100644 --- a/library/src/blas1/rocblas_rotmg.hpp +++ b/library/src/blas1/rocblas_rotmg.hpp @@ -147,7 +147,7 @@ __device__ __host__ void rocblas_rotmg_calc(T& d1, T& d2, T& x1, const T& y1, T* param[0] = flag; } -template +template __global__ void rocblas_rotmg_kernel(T d1_in, rocblas_int offset_d1, rocblas_stride stride_d1, @@ -160,7 +160,8 @@ __global__ void rocblas_rotmg_kernel(T d1_in, U y1_in, rocblas_int offset_y1, rocblas_stride stride_y1, - V param, + T param, + rocblas_int offset_param, rocblas_stride stride_param, rocblas_int batch_count) { @@ -168,11 +169,11 @@ __global__ void rocblas_rotmg_kernel(T d1_in, auto d2 = load_ptr_batch(d2_in, hipBlockIdx_x, offset_d2, stride_d2); auto x1 = load_ptr_batch(x1_in, hipBlockIdx_x, offset_x1, stride_x1); auto y1 = load_ptr_batch(y1_in, hipBlockIdx_x, offset_y1, stride_y1); - auto p = param + hipBlockIdx_x * stride_param; + auto p = load_ptr_batch(param, hipBlockIdx_x, offset_param, stride_param); rocblas_rotmg_calc(*d1, *d2, *x1, *y1, p); } -template +template rocblas_status rocblas_rotmg_template(rocblas_handle handle, T d1_in, rocblas_int offset_d1, @@ -186,7 +187,8 @@ rocblas_status rocblas_rotmg_template(rocblas_handle handle, U y1_in, rocblas_int offset_y1, rocblas_stride stride_y1, - V param, + T param, + rocblas_int offset_param, rocblas_stride stride_param, rocblas_int batch_count) { @@ -214,6 +216,7 @@ rocblas_status rocblas_rotmg_template(rocblas_handle handle, offset_y1, stride_y1, param, + offset_param, stride_param, batch_count); } @@ -227,7 +230,7 @@ rocblas_status rocblas_rotmg_template(rocblas_handle handle, auto d2 = load_ptr_batch(d2_in, i, offset_d2, stride_d2); auto x1 = load_ptr_batch(x1_in, i, offset_x1, stride_x1); auto y1 = load_ptr_batch(y1_in, i, offset_y1, stride_y1); - auto p = param + i * stride_param; + auto p = load_ptr_batch(param, i, offset_param, stride_param); rocblas_rotmg_calc(*d1, *d2, *x1, *y1, p); } diff --git a/library/src/blas1/rocblas_rotmg_batched.cpp b/library/src/blas1/rocblas_rotmg_batched.cpp index f09bfcc2c..dc0856322 100644 --- a/library/src/blas1/rocblas_rotmg_batched.cpp +++ b/library/src/blas1/rocblas_rotmg_batched.cpp @@ -22,7 +22,7 @@ namespace T* const d2[], T* const x1[], const T* const y1[], - T* param, + T* const param[], rocblas_int batch_count) { if(!handle) @@ -48,7 +48,7 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); return rocblas_rotmg_template( - handle, d1, 0, 0, d2, 0, 0, x1, 0, 0, y1, 0, 0, param, 0, batch_count); + handle, d1, 0, 0, d2, 0, 0, x1, 0, 0, y1, 0, 0, param, 0, 0, batch_count); } } // namespace @@ -66,7 +66,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotmg_batched(rocblas_handle handle, float* const d2[], float* const x1[], const float* const y1[], - float* param, + float* const param[], rocblas_int batch_count) { return rocblas_rotmg_batched_impl(handle, d1, d2, x1, y1, param, batch_count); @@ -77,7 +77,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotmg_batched(rocblas_handle handle, double* const d2[], double* const x1[], const double* const y1[], - double* param, + double* const param[], rocblas_int batch_count) { return rocblas_rotmg_batched_impl(handle, d1, d2, x1, y1, param, batch_count); diff --git a/library/src/blas1/rocblas_rotmg_strided_batched.cpp b/library/src/blas1/rocblas_rotmg_strided_batched.cpp index de914e465..0f63df43b 100644 --- a/library/src/blas1/rocblas_rotmg_strided_batched.cpp +++ b/library/src/blas1/rocblas_rotmg_strided_batched.cpp @@ -27,6 +27,7 @@ namespace const T* y1, rocblas_stride stride_y1, T* param, + rocblas_stride stride_param, rocblas_int batch_count) { if(!handle) @@ -85,6 +86,7 @@ namespace stride_y1, param, 0, + stride_param, batch_count); } @@ -108,10 +110,21 @@ ROCBLAS_EXPORT rocblas_status rocblas_srotmg_strided_batched(rocblas_handle hand const float* y1, rocblas_stride stride_y1, float* param, + rocblas_stride stride_param, rocblas_int batch_count) { - return rocblas_rotmg_strided_batched_impl( - handle, d1, stride_d1, d2, stride_d2, x1, stride_x1, y1, stride_y1, param, batch_count); + return rocblas_rotmg_strided_batched_impl(handle, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + param, + stride_param, + batch_count); } ROCBLAS_EXPORT rocblas_status rocblas_drotmg_strided_batched(rocblas_handle handle, @@ -124,10 +137,21 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotmg_strided_batched(rocblas_handle hand const double* y1, rocblas_stride stride_y1, double* param, + rocblas_stride stride_param, rocblas_int batch_count) { - return rocblas_rotmg_strided_batched_impl( - handle, d1, stride_d1, d2, stride_d2, x1, stride_x1, y1, stride_y1, param, batch_count); + return rocblas_rotmg_strided_batched_impl(handle, + d1, + stride_d1, + d2, + stride_d2, + x1, + stride_x1, + y1, + stride_y1, + param, + stride_param, + batch_count); } } // extern "C" From ee9ed076193dba9e34180833065a631b85a8c24f Mon Sep 17 00:00:00 2001 From: First Last Date: Mon, 7 Oct 2019 10:30:57 -0600 Subject: [PATCH 09/11] Addressed comments and fix a rotm error --- library/include/rocblas-functions.h | 3 +++ library/src/blas1/rocblas_rot.hpp | 2 +- library/src/blas1/rocblas_rotm.hpp | 11 ++++++----- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 29fe7d21e..3aa3959c8 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -2096,6 +2096,9 @@ ROCBLAS_EXPORT rocblas_status rocblas_drotm_strided_batched(rocblas_handle handl flag = 1 => H = ( H11 1.0 -1.0 H22 ) flag = -2 => H = ( 1.0 0.0 0.0 1.0 ) param may be stored in either host or device memory, location is specified by calling rocblas_set_pointer_mode. + @param[in] + stride_param rocblas_stride + specifies the increment between the beginning of param_i and param_(i + 1) ********************************************************************/ diff --git a/library/src/blas1/rocblas_rot.hpp b/library/src/blas1/rocblas_rot.hpp index 6bce83b39..6d7b0028f 100644 --- a/library/src/blas1/rocblas_rot.hpp +++ b/library/src/blas1/rocblas_rot.hpp @@ -140,4 +140,4 @@ rocblas_status rocblas_rot_template(rocblas_handle handle, s_stride); return rocblas_status_success; -} \ No newline at end of file +} diff --git a/library/src/blas1/rocblas_rotm.hpp b/library/src/blas1/rocblas_rotm.hpp index 25a7d0e25..2d64e6b6d 100644 --- a/library/src/blas1/rocblas_rotm.hpp +++ b/library/src/blas1/rocblas_rotm.hpp @@ -120,15 +120,16 @@ __global__ void rotm_kernel_regular(rocblas_int n, // Workaround to avoid constexpr if - Helper function to quick return when param[0] == -2 template -bool quick_return_param(const T* param, rocblas_stride stride_param) +bool quick_return_param(rocblas_handle handle, const T* param, rocblas_stride stride_param) { - if(param[0] == -2 && stride_param == 0) - return true; + if(rocblas_pointer_mode_host == handle->pointer_mode) + if(param[0] == -2 && stride_param == 0) + return true; return false; } template -bool quick_return_param(const T* const param[], rocblas_stride stride_param) +bool quick_return_param(rocblas_handle handle, const T* const param[], rocblas_stride stride_param) { return false; } @@ -153,7 +154,7 @@ rocblas_status rocblas_rotm_template(rocblas_handle handle, if(n <= 0 || incx <= 0 || incy <= 0 || batch_count <= 0) return rocblas_status_success; - if(quick_return_param(param, stride_param)) + if(quick_return_param(handle, param, stride_param)) return rocblas_status_success; dim3 blocks((n - 1) / NB + 1, batch_count); From 8f7c408fe676e32c4d32322fba85393f25cb433e Mon Sep 17 00:00:00 2001 From: First Last Date: Mon, 7 Oct 2019 11:26:04 -0600 Subject: [PATCH 10/11] Addressing PR comments --- clients/common/rocblas_gentest.py | 6 +++++- clients/include/testing_rot.hpp | 10 ++++++++-- clients/include/testing_rot_batched.hpp | 10 ++++++++-- clients/include/testing_rot_strided_batched.hpp | 10 ++++++++-- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/clients/common/rocblas_gentest.py b/clients/common/rocblas_gentest.py index 39990b626..b159ce65a 100755 --- a/clients/common/rocblas_gentest.py +++ b/clients/common/rocblas_gentest.py @@ -219,6 +219,8 @@ def setdefaults(test): test.setdefault('stride_y', int(test['N'] * abs(test['incy']) * test['stride_scale'])) + # we are using stride_c for arg c and stride_d for arg s in rotg + # these are are single values for each batch if test['function'] in ('rotg_strided_batched'): if 'stride_scale' in test: test.setdefault('stride_a', int(test['stride_scale'])) @@ -226,7 +228,9 @@ def setdefaults(test): test.setdefault('stride_c', int(test['stride_scale'])) test.setdefault('stride_d', int(test['stride_scale'])) - # we are using stride_a for d1, stride_b for d2, and stride_c for param in rotmg + # we are using stride_a for d1, stride_b for d2, and stride_c for param in + # rotmg. These are are single values for each batch, except param which is + # a 5 element array if test['function'] in ('rotmg_strided_batched'): if 'stride_scale' in test: test.setdefault('stride_a', int(test['stride_scale'])) diff --git a/clients/include/testing_rot.hpp b/clients/include/testing_rot.hpp index f238dde65..d9db8e480 100644 --- a/clients/include/testing_rot.hpp +++ b/clients/include/testing_rot.hpp @@ -96,8 +96,14 @@ void testing_rot(const Arguments& arg) rocblas_seedrand(); rocblas_init(hx, 1, N, incx); rocblas_init(hy, 1, N, incy); - rocblas_init(hc, 1, 1, 1); - rocblas_init(hs, 1, 1, 1); + + // Random alpha (0 - 10) + host_vector alpha(1); + rocblas_init(alpha, 1, 1, 1); + + // cos and sin of alpha (in rads) + hc[0] = cos(alpha[0]); + hs[0] = sin(alpha[0]); // CPU BLAS reference data host_vector cx = hx; diff --git a/clients/include/testing_rot_batched.hpp b/clients/include/testing_rot_batched.hpp index 7d4b871c5..fef7ad3ac 100644 --- a/clients/include/testing_rot_batched.hpp +++ b/clients/include/testing_rot_batched.hpp @@ -128,8 +128,14 @@ void testing_rot_batched(const Arguments& arg) rocblas_init(hx[b], 1, N, incx); rocblas_init(hy[b], 1, N, incy); } - rocblas_init(hc, 1, 1, 1); - rocblas_init(hs, 1, 1, 1); + + // Random alpha (0 - 10) + host_vector alpha(1); + rocblas_init(alpha, 1, 1, 1); + + // cos and sin of alpha (in rads) + hc[0] = cos(alpha[0]); + hs[0] = sin(alpha[0]); // CPU BLAS reference data host_vector cx[batch_count]; diff --git a/clients/include/testing_rot_strided_batched.hpp b/clients/include/testing_rot_strided_batched.hpp index babcf2392..9d9ff2ae5 100644 --- a/clients/include/testing_rot_strided_batched.hpp +++ b/clients/include/testing_rot_strided_batched.hpp @@ -127,8 +127,14 @@ void testing_rot_strided_batched(const Arguments& arg) rocblas_seedrand(); rocblas_init(hx, 1, N, incx, stride_x, batch_count); rocblas_init(hy, 1, N, incy, stride_y, batch_count); - rocblas_init(hc, 1, 1, 1); - rocblas_init(hs, 1, 1, 1); + + // Random alpha (0 - 10) + host_vector alpha(1); + rocblas_init(alpha, 1, 1, 1); + + // cos and sin of alpha (in rads) + hc[0] = cos(alpha[0]); + hs[0] = sin(alpha[0]); // CPU BLAS reference data host_vector cx = hx; From 1580d890196b48a398899e0b8faa3b167f58dab8 Mon Sep 17 00:00:00 2001 From: First Last Date: Mon, 7 Oct 2019 17:44:28 -0600 Subject: [PATCH 11/11] Small documentation fix --- library/include/rocblas-functions.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 3aa3959c8..f6b208612 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -1434,12 +1434,12 @@ ROCBLAS_EXPORT rocblas_status rocblas_izamin(rocblas_handle handl n rocblas_int number of elements in the x and y vectors. @param[inout] - x pointer storing vector x on the GPU. + x pointer storing vector x in device memory. @param[in] incx rocblas_int specifies the increment between elements of x. @param[inout] - y pointer storing vector y on the GPU. + y pointer storing vector y in device memory. @param[in] incy rocblas_int specifies the increment between elements of y. @@ -1517,12 +1517,12 @@ ROCBLAS_EXPORT rocblas_status rocblas_zdrot(rocblas_handle handle, n rocblas_int number of elements in the x and y vectors. @param[inout] - x array of pointers storing vector x on the GPU. + x array of pointers storing vector x in device memory. @param[in] incx rocblas_int specifies the increment between elements of x. @param[inout] - y array of pointers storing vector y on the GPU. + y array of pointers storing vector y in device memory. @param[in] incy rocblas_int specifies the increment between elements of y. @@ -1609,7 +1609,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_zdrot_batched(rocblas_handle n rocblas_int number of elements in the x and y vectors. @param[inout] - x pointer storing strided vectors x on the GPU. + x pointer storing strided vectors x in device memory. @param[in] incx rocblas_int specifies the increment between elements of x. @@ -1617,7 +1617,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_zdrot_batched(rocblas_handle stride_x rocblas_stride specifies the increment from the beginning of x_i to the beginning of x_(i+1) @param[inout] - y pointer storing strided vectors y on the GPU. + y pointer storing strided vectors y in device memory. @param[in] incy rocblas_int specifies the increment between elements of y.