diff --git a/clients/common/rocblas_gentest.py b/clients/common/rocblas_gentest.py index 383eef4e5..c05604e8e 100755 --- a/clients/common/rocblas_gentest.py +++ b/clients/common/rocblas_gentest.py @@ -207,13 +207,21 @@ def setdefaults(test): test.setdefault('stride_y', int(test['N'] * abs(test['incy']) * test['stride_scale'])) - if test['function'] in ('ger_strided_batched'): - if all([x in test for x in ('M', 'incx', 'stride_scale')]): - test.setdefault('stride_x', int(test['M'] * abs(test['incx']) * - test['stride_scale'])) - if all([x in test for x in ('N', 'incy', 'stride_scale')]): - test.setdefault('stride_y', int(test['N'] * abs(test['incy']) * - test['stride_scale'])) + if test['function'] in ('gemv_strided_batched', 'ger_strided_batched'): + if test['function'] in ('ger_strided_batched') or test['transA'] in ('T', 'C'): + if all([x in test for x in ('M', 'incx', 'stride_scale')]): + test.setdefault('stride_x', int(test['M'] * abs(test['incx']) * + test['stride_scale'])) + if all([x in test for x in ('N', 'incy', 'stride_scale')]): + test.setdefault('stride_y', int(test['N'] * abs(test['incy']) * + test['stride_scale'])) + else: + if all([x in test for x in ('N', 'incx', 'stride_scale')]): + test.setdefault('stride_x', int(test['N'] * abs(test['incx']) * + test['stride_scale'])) + if all([x in test for x in ('M', 'incy', 'stride_scale')]): + test.setdefault('stride_y', int(test['M'] * abs(test['incy']) * + test['stride_scale'])) test.setdefault('stride_x', 0) test.setdefault('stride_y', 0) diff --git a/clients/gtest/CMakeLists.txt b/clients/gtest/CMakeLists.txt index 9281e5a99..ead4e869e 100644 --- a/clients/gtest/CMakeLists.txt +++ b/clients/gtest/CMakeLists.txt @@ -179,7 +179,7 @@ endif( ) set( ROCBLAS_TEST_DATA "${PROJECT_BINARY_DIR}/staging/rocblas_gtest.data") add_custom_command( OUTPUT "${ROCBLAS_TEST_DATA}" COMMAND ../common/rocblas_gentest.py -I ../include rocblas_gtest.yaml -o "${ROCBLAS_TEST_DATA}" - DEPENDS ../common/rocblas_gentest.py rocblas_gtest.yaml ../include/rocblas_common.yaml known_bugs.yaml blas1_gtest.yaml gemm_gtest.yaml gemm_batched_gtest.yaml gemm_strided_batched_gtest.yaml gemv_gtest.yaml gemv_batched_gtest.yaml gemv_strided_batched_gtest.yaml symv_gtest.yaml syr_gtest.yaml ger_gtest.yaml trsm_gtest.yaml trtri_gtest.yaml geam_gtest.yaml set_get_vector_gtest.yaml set_get_matrix_gtest.yaml trsv_gtest.yaml logging_mode_gtest.yaml set_get_pointer_mode_gtest.yaml + DEPENDS ../common/rocblas_gentest.py rocblas_gtest.yaml ../include/rocblas_common.yaml known_bugs.yaml blas1_gtest.yaml gemm_gtest.yaml gemm_batched_gtest.yaml gemm_strided_batched_gtest.yaml gemv_gtest.yaml symv_gtest.yaml syr_gtest.yaml ger_gtest.yaml trsm_gtest.yaml trtri_gtest.yaml geam_gtest.yaml set_get_vector_gtest.yaml set_get_matrix_gtest.yaml trsv_gtest.yaml logging_mode_gtest.yaml set_get_pointer_mode_gtest.yaml WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" ) add_custom_target( rocblas-test-data DEPENDS "${ROCBLAS_TEST_DATA}" ) diff --git a/clients/gtest/gemv_batched_gtest.yaml b/clients/gtest/gemv_batched_gtest.yaml deleted file mode 100644 index e0eeedd3d..000000000 --- a/clients/gtest/gemv_batched_gtest.yaml +++ /dev/null @@ -1,91 +0,0 @@ ---- -include: rocblas_common.yaml -include: known_bugs.yaml - -Definitions: - - &small_matrix_size_range - - { M: -1, N: 1, lda: 1 } - - { M: 1, N: -1, lda: 1 } - - { M: 1, N: 1, lda: 0 } - - { M: 10, N: 10, lda: 9 } - - { M: 0, N: 1, lda: 1 } - - { M: 1, N: 0, lda: 1 } - - { M: -1, N: -1, lda: -1 } - - { M: 10, N: 10, lda: 2 } - - { M: 100, N: 200, lda: 200 } - - - &medium_matrix_size_range - - { M: 300, N: 400, lda: 400 } - - { M: 600, N: 500, lda: 601 } - - - &large_matrix_size_range - - { M: 1000, N: 1000, lda: 1000 } - - { M: 2000, N: 2000, lda: 2000 } - - { M: 4011, N: 4011, lda: 4011 } - - { M: 8000, N: 8000, lda: 8000 } - - - &incx_incy_range - - { incx: 2, incy: 1 } - - { incx: -1, incy: 2 } - - { incx: 1, incy: 1 } - - { incx: -1, incy: 3 } - - { incx: 3, incy: -1 } - - { incx: 0, incy: 1 } - - { incx: 1, incy: 0 } - - { incx: 0, incy: -1 } - - { incx: 10, incy: 100 } - - - &alpha_beta_range - - { alpha: 2.0, beta: 0.0 } - - { alpha: -1.0, beta: -1.0 } - - { alpha: 2.0, beta: 1.0 } - - { alpha: 0.0, beta: 1.0 } - -Tests: -- name: gemv_batched_bad_arg - category: pre_checkin - function: gemv_batched_bad_arg - precision: *single_double_precisions - transA: N - -- name: gemv_batched_NaN - category: pre_checkin - function: gemv_batched - precision: *single_double_precisions - transA: [ N, T, C ] - matrix_size: *medium_matrix_size_range - incx_incy: *incx_incy_range - alpha: [ -1.0, 0, 1.0, 2.0 ] - beta: .NaN # converted to 0.0 in test code - batch_count: [ -1, 0, 1, 3 ] - -- name: gemv_batched_small - category: quick - function: gemv_batched - precision: *single_double_precisions - transA: [ N, T, C ] - matrix_size: *small_matrix_size_range - incx_incy: *incx_incy_range - alpha_beta: *alpha_beta_range - batch_count: [ -1, 0, 1, 3 ] - -- name: gemv_batched_medium - category: pre_checkin - function: gemv_batched - precision: *single_double_precisions_complex_real - transA: [ N, T, C ] - matrix_size: *medium_matrix_size_range - incx_incy: *incx_incy_range - alpha_beta: *alpha_beta_range - batch_count: [ 3 ] - -- name: gemv_batched_large - category: nightly - function: gemv_batched - precision: *single_double_precisions - transA: [ N, T, C ] - matrix_size: *large_matrix_size_range - incx_incy: *incx_incy_range - alpha_beta: *alpha_beta_range - batch_count: [ 3 ] -... diff --git a/clients/gtest/gemv_gtest.yaml b/clients/gtest/gemv_gtest.yaml index 9dda01aee..db8b6ea6f 100644 --- a/clients/gtest/gemv_gtest.yaml +++ b/clients/gtest/gemv_gtest.yaml @@ -4,36 +4,36 @@ include: known_bugs.yaml Definitions: - &small_matrix_size_range - - { M: -1, N: 1, lda: 1 } - - { M: 1, N: -1, lda: 1 } - - { M: 1, N: 1, lda: 0 } - - { M: 10, N: 10, lda: 9 } - - { M: 0, N: 1, lda: 1 } - - { M: 1, N: 0, lda: 1 } - - { M: -1, N: -1, lda: -1 } - - { M: 10, N: 10, lda: 2 } - - { M: 100, N: 200, lda: 200 } + - { M: -1, N: 1, lda: 1, stride_a: 1 } + - { M: 1, N: -1, lda: 1, stride_a: 1 } + - { M: 1, N: 1, lda: 0, stride_a: 1 } + - { M: 10, N: 10, lda: 9, stride_a: 1 } + - { M: 0, N: 1, lda: 1, stride_a: 1 } + - { M: 1, N: 0, lda: 1, stride_a: 1 } + - { M: -1, N: -1, lda: -1, stride_a: 1 } + - { M: 10, N: 10, lda: 2, stride_a: 1 } + - { M: 100, N: 200, lda: 200, stride_a: 40000 } - &medium_matrix_size_range - - { M: 300, N: 400, lda: 400 } - - { M: 600, N: 500, lda: 601 } + - { M: 300, N: 400, lda: 400, stride_a: 160000 } + - { M: 600, N: 500, lda: 601, stride_a: 301000 } - &large_matrix_size_range - - { M: 1000, N: 1000, lda: 1000 } - - { M: 2000, N: 2000, lda: 2000 } - - { M: 4011, N: 4011, lda: 4011 } - - { M: 8000, N: 8000, lda: 8000 } + - { M: 1000, N: 1000, lda: 1000, stride_a: 1000000 } + - { M: 2000, N: 2000, lda: 2000, stride_a: 4000000 } + - { M: 4011, N: 4011, lda: 4011, stride_a: 16088200 } + - { M: 8000, N: 8000, lda: 8000, stride_a: 64000000 } - &incx_incy_range - - { incx: 2, incy: 1 } - - { incx: -1, incy: 2 } - - { incx: 1, incy: 1 } - - { incx: -1, incy: 3 } - - { incx: 3, incy: -1 } - - { incx: 0, incy: 1 } - - { incx: 1, incy: 0 } - - { incx: 0, incy: -1 } - - { incx: 10, incy: 100 } + - { incx: 2, incy: 1, stride_scale: 1 } + - { incx: -1, incy: 2, stride_scale: 1 } + - { incx: 1, incy: 1, stride_scale: 1 } + - { incx: -1, incy: 3, stride_scale: 1.5 } + - { incx: 3, incy: -1, stride_scale: 1 } + - { incx: 0, incy: 1, stride_scale: 1 } + - { incx: 1, incy: 0, stride_scale: 1 } + - { incx: 0, incy: -1, stride_scale: 2 } + - { incx: 10, incy: 100, stride_scale: 1 } - &alpha_beta_range - { alpha: 2.0, beta: 0.0, alphai: 1.5, betai: 0.5 } @@ -84,4 +84,98 @@ Tests: matrix_size: *large_matrix_size_range incx_incy: *incx_incy_range alpha_beta: *alpha_beta_range + +- name: gemv_batched_bad_arg + category: pre_checkin + function: gemv_batched_bad_arg + precision: *single_double_precisions + transA: N + +- name: gemv_batched_NaN + category: pre_checkin + function: gemv_batched + precision: *single_double_precisions + transA: [ N, T, C ] + matrix_size: *medium_matrix_size_range + incx_incy: *incx_incy_range + alpha: [ -1.0, 0, 1.0, 2.0 ] + beta: .NaN # converted to 0.0 in test code + batch_count: [ -1, 0, 1, 3 ] + +- name: gemv_batched_small + category: quick + function: gemv_batched + precision: *single_double_precisions + transA: [ N, T, C ] + matrix_size: *small_matrix_size_range + incx_incy: *incx_incy_range + alpha_beta: *alpha_beta_range + batch_count: [ -1, 0, 1, 3 ] + +- name: gemv_batched_medium + category: pre_checkin + function: gemv_batched + precision: *single_double_precisions_complex_real + transA: [ N, T, C ] + matrix_size: *medium_matrix_size_range + incx_incy: *incx_incy_range + alpha_beta: *alpha_beta_range + batch_count: [ 3 ] + +- name: gemv_batched_large + category: nightly + function: gemv_batched + precision: *single_double_precisions + transA: [ N, T, C ] + matrix_size: *large_matrix_size_range + incx_incy: *incx_incy_range + alpha_beta: *alpha_beta_range + batch_count: [ 3 ] + +- name: gemv_strided_batched_bad_arg + category: pre_checkin + function: gemv_strided_batched_bad_arg + precision: *single_double_precisions + transA: N + +- name: gemv_strided_batched_NaN + category: pre_checkin + function: gemv_strided_batched + precision: *single_double_precisions + transA: [ N, T, C ] + matrix_size: *medium_matrix_size_range + incx_incy: *incx_incy_range + alpha: [ -1.0, 0, 1.0, 2.0 ] + beta: .NaN # converted to 0.0 in test code + batch_count: [ -1, 0, 1, 3 ] + +- name: gemv_strided_batched_small + category: quick + function: gemv_strided_batched + precision: *single_double_precisions + transA: [ N, T, C ] + matrix_size: *small_matrix_size_range + incx_incy: *incx_incy_range + alpha_beta: *alpha_beta_range + batch_count: [ -1, 0, 1, 3 ] + +- name: gemv_strided_batched_medium + category: pre_checkin + function: gemv_strided_batched + precision: *single_double_precisions_complex_real + transA: [ N, T, C ] + matrix_size: *medium_matrix_size_range + incx_incy: *incx_incy_range + alpha_beta: *alpha_beta_range + batch_count: [ 3 ] + +- name: gemv_strided_batched_large + category: nightly + function: gemv_strided_batched + precision: *single_double_precisions + transA: [ N, T, C ] + matrix_size: *large_matrix_size_range + incx_incy: *incx_incy_range + alpha_beta: *alpha_beta_range + batch_count: [ 3 ] ... diff --git a/clients/gtest/gemv_strided_batched_gtest.yaml b/clients/gtest/gemv_strided_batched_gtest.yaml deleted file mode 100644 index 7a4ddbd2e..000000000 --- a/clients/gtest/gemv_strided_batched_gtest.yaml +++ /dev/null @@ -1,91 +0,0 @@ ---- -include: rocblas_common.yaml -include: known_bugs.yaml - -Definitions: - - &small_matrix_size_range - - { M: -1, N: 1, lda: 1, stride_a: 1 } - - { M: 1, N: -1, lda: 1, stride_a: 1 } - - { M: 1, N: 1, lda: 0, stride_a: 1 } - - { M: 10, N: 10, lda: 9, stride_a: 1 } - - { M: 0, N: 1, lda: 1, stride_a: 1 } - - { M: 1, N: 0, lda: 1, stride_a: 1 } - - { M: -1, N: -1, lda: -1, stride_a: 1 } - - { M: 10, N: 10, lda: 2, stride_a: 1 } - - { M: 100, N: 200, lda: 200, stride_a: 40000 } - - - &medium_matrix_size_range - - { M: 300, N: 400, lda: 400, stride_a: 160000 } - - { M: 600, N: 500, lda: 601, stride_a: 301000 } - - - &large_matrix_size_range - - { M: 1000, N: 1000, lda: 1000, stride_a: 1000000 } - - { M: 2000, N: 2000, lda: 2000, stride_a: 4000000 } - - { M: 4011, N: 4011, lda: 4011, stride_a: 16088200 } - - { M: 8000, N: 8000, lda: 8000, stride_a: 64000000 } - - - &incx_incy_range - - { incx: 2, incy: 1, stride_x: 8000, stride_y: 8000 } - - { incx: -1, incy: 2, stride_x: 8000, stride_y: 8000 } - - { incx: 1, incy: 1, stride_x: 8000, stride_y: 8000 } - - { incx: -1, incy: 3, stride_x: 4000, stride_y: 4000 } - - { incx: 3, incy: -1, stride_x: 2000, stride_y: 2000 } - - { incx: 0, incy: 1, stride_x: 1000, stride_y: 1000 } - - { incx: 1, incy: 0, stride_x: 1000, stride_y: 1000 } - - { incx: 0, incy: -1, stride_x: 1, stride_y: 1 } - - { incx: 10, incy: 100, stride_x: 8000, stride_y: 8000 } - - - &alpha_beta_range - - { alpha: 2.0, beta: 0.0 } - - { alpha: -1.0, beta: -1.0 } - - { alpha: 2.0, beta: 1.0 } - - { alpha: 0.0, beta: 1.0 } - -Tests: -- name: gemv_strided_batched_bad_arg - category: pre_checkin - function: gemv_strided_batched_bad_arg - precision: *single_double_precisions - transA: N - -- name: gemv_strided_batched_NaN - category: pre_checkin - function: gemv_strided_batched - precision: *single_double_precisions - transA: [ N, T, C ] - matrix_size: *medium_matrix_size_range - incx_incy: *incx_incy_range - alpha: [ -1.0, 0, 1.0, 2.0 ] - beta: .NaN # converted to 0.0 in test code - batch_count: [ -1, 0, 1, 3 ] - -- name: gemv_strided_batched_small - category: quick - function: gemv_strided_batched - precision: *single_double_precisions - transA: [ N, T, C ] - matrix_size: *small_matrix_size_range - incx_incy: *incx_incy_range - alpha_beta: *alpha_beta_range - batch_count: [ -1, 0, 1, 3 ] - -- name: gemv_strided_batched_medium - category: pre_checkin - function: gemv_strided_batched - precision: *single_double_precisions_complex_real - transA: [ N, T, C ] - matrix_size: *medium_matrix_size_range - incx_incy: *incx_incy_range - alpha_beta: *alpha_beta_range - batch_count: [ 3 ] - -- name: gemv_strided_batched_large - category: nightly - function: gemv_strided_batched - precision: *single_double_precisions - transA: [ N, T, C ] - matrix_size: *large_matrix_size_range - incx_incy: *incx_incy_range - alpha_beta: *alpha_beta_range - batch_count: [ 3 ] -... diff --git a/clients/gtest/ger_gtest.cpp b/clients/gtest/ger_gtest.cpp index 6154bc04d..1b9f3294a 100644 --- a/clients/gtest/ger_gtest.cpp +++ b/clients/gtest/ger_gtest.cpp @@ -55,24 +55,32 @@ namespace { RocBLAS_TestName name; - name << rocblas_datatype2string(arg.a_type) << '_' << arg.M << '_' << arg.N << '_' - << arg.alpha << '_' << arg.incx; + name << rocblas_datatype2string(arg.a_type); - if(GER_TYPE == GER_STRIDED_BATCHED) - name << '_' << arg.stride_x; + if(strstr(arg.function, "_bad_arg") != nullptr) + { + name << "_bad_arg"; + } + else + { + name << '_' << arg.M << '_' << arg.N << '_' << arg.alpha << '_' << arg.incx; + + if(GER_TYPE == GER_STRIDED_BATCHED) + name << '_' << arg.stride_x; - name << '_' << arg.incy; + name << '_' << arg.incy; - if(GER_TYPE == GER_STRIDED_BATCHED) - name << '_' << arg.stride_y; + if(GER_TYPE == GER_STRIDED_BATCHED) + name << '_' << arg.stride_y; - name << '_' << arg.lda; + name << '_' << arg.lda; - if(GER_TYPE == GER_STRIDED_BATCHED) - name << '_' << arg.stride_a; + if(GER_TYPE == GER_STRIDED_BATCHED) + name << '_' << arg.stride_a; - if(GER_TYPE == GER_STRIDED_BATCHED || GER_TYPE == GER_BATCHED) - name << '_' << arg.batch_count; + if(GER_TYPE == GER_STRIDED_BATCHED || GER_TYPE == GER_BATCHED) + name << '_' << arg.batch_count; + } return std::move(name); } diff --git a/clients/gtest/ger_gtest.yaml b/clients/gtest/ger_gtest.yaml index a1ab5d422..7215ad2be 100644 --- a/clients/gtest/ger_gtest.yaml +++ b/clients/gtest/ger_gtest.yaml @@ -11,7 +11,7 @@ Definitions: - { M: 0, N: 1, lda: 1, stride_a: 1 } - { M: 1, N: 0, lda: 1, stride_a: 1 } - { M: 1, N: 1, lda: 0, stride_a: 1 } - - { M: 11, N: 12, lda: 13, stride_a: 1 } + - { M: 11, N: 12, lda: 13, stride_a: 156 } - { M: 16, N: 16, lda: 16, stride_a: 256 } - { M: 33, N: 32, lda: 33, stride_a: 1056 } - { M: 65, N: 65, lda: 66, stride_a: 4300 } @@ -46,7 +46,10 @@ Definitions: Tests: - name: ger_bad_arg category: pre_checkin - function: ger_bad_arg + function: + - ger_bad_arg + - ger_batched_bad_arg + - ger_strided_batched_bad_arg precision: *single_double_precisions - name: ger_small @@ -73,12 +76,6 @@ Tests: incx_incy: *incx_incy_range alpha: [ -0.5, 2.0, 0.0, 0.6 ] -- name: ger_batched_bad_arg - category: pre_checkin - function: ger_batched_bad_arg - precision: *single_double_precisions - batch_count: [ -5, 0, 1, 5, 10 ] - - name: ger_batched_small category: quick function: ger_batched @@ -106,13 +103,6 @@ Tests: alpha: [ -0.5, 2.0, 0.0 ] batch_count: [ 1, 3 ] -- name: ger_strided_batched_bad_arg - category: pre_checkin - function: ger_strided_batched_bad_arg - precision: *single_double_precisions - stride_scale: [ -1, 0, 0.5, 1, 2 ] - batch_count: [ -5, 0, 1, 5, 10 ] - - name: ger_strided_batched_small category: quick function: ger_strided_batched @@ -120,7 +110,7 @@ Tests: matrix_size: *small_matrix_size_range incx_incy: *incx_incy_range alpha: [ -0.5, 2.0, 0.0 ] - stride_scale: [ 0.5, 1, 2 ] + stride_scale: [ 1, 2 ] batch_count: [ -5, 0, 1, 5, 10 ] - name: ger_strided_batched_medium @@ -130,7 +120,7 @@ Tests: matrix_size: *medium_matrix_size_range incx_incy: *incx_incy_range alpha: [ -0.5, 2.0, 0.0 ] - stride_scale: [ 0.5, 1, 2 ] + stride_scale: [ 1, 2 ] batch_count: [ 1, 5, 10 ] - name: ger_strided_batched_large @@ -140,5 +130,6 @@ Tests: matrix_size: *large_matrix_size_range incx_incy: *nightly_incx_incy_range alpha: [ -0.5, 2.0, 0.0 ] + stride_scale: [ 1 ] batch_count: [ 1, 3 ] ... diff --git a/clients/gtest/rocblas_gtest.yaml b/clients/gtest/rocblas_gtest.yaml index 67770c925..f1db4f620 100644 --- a/clients/gtest/rocblas_gtest.yaml +++ b/clients/gtest/rocblas_gtest.yaml @@ -1,7 +1,5 @@ include: blas1_gtest.yaml include: gemv_gtest.yaml -include: gemv_batched_gtest.yaml -include: gemv_strided_batched_gtest.yaml include: gemm_gtest.yaml include: gemm_batched_gtest.yaml include: gemm_strided_batched_gtest.yaml diff --git a/clients/include/testing_gemv_strided_batched.hpp b/clients/include/testing_gemv_strided_batched.hpp index 381888908..4cbe4d156 100644 --- a/clients/include/testing_gemv_strided_batched.hpp +++ b/clients/include/testing_gemv_strided_batched.hpp @@ -197,8 +197,7 @@ void testing_gemv_strided_batched(const Arguments& arg) size_y = dim_y * abs_incy; // argument sanity check before allocating invalid memory - if(M < 0 || N < 0 || lda < M || lda < 1 || !incx || !incy || stride_a < size_A - || stride_x < size_x || stride_y < size_y || batch_count < 0) + if(M < 0 || N < 0 || lda < M || lda < 1 || !incx || !incy || batch_count < 0) { static const size_t safe_size = 100; // arbitrarily set to 100 device_vector dA1(safe_size); diff --git a/clients/include/testing_ger.hpp b/clients/include/testing_ger.hpp index 60df009fd..155478cd2 100644 --- a/clients/include/testing_ger.hpp +++ b/clients/include/testing_ger.hpp @@ -60,7 +60,7 @@ void testing_ger(const Arguments& arg) rocblas_int incx = arg.incx; rocblas_int incy = arg.incy; rocblas_int lda = arg.lda; - T h_alpha = (T)arg.alpha; + T h_alpha = arg.get_alpha(); rocblas_local_handle handle; diff --git a/clients/include/testing_ger_batched.hpp b/clients/include/testing_ger_batched.hpp index 0a9698224..746a73faf 100644 --- a/clients/include/testing_ger_batched.hpp +++ b/clients/include/testing_ger_batched.hpp @@ -59,7 +59,7 @@ void testing_ger_batched(const Arguments& arg) rocblas_int incx = arg.incx; rocblas_int incy = arg.incy; rocblas_int lda = arg.lda; - T h_alpha = (T)arg.alpha; + T h_alpha = arg.get_alpha(); rocblas_int batch_count = arg.batch_count; rocblas_local_handle handle; diff --git a/clients/include/testing_ger_strided_batched.hpp b/clients/include/testing_ger_strided_batched.hpp index 235dda315..cffcec420 100644 --- a/clients/include/testing_ger_strided_batched.hpp +++ b/clients/include/testing_ger_strided_batched.hpp @@ -116,7 +116,7 @@ void testing_ger_strided_batched(const Arguments& arg) rocblas_int incx = arg.incx; rocblas_int incy = arg.incy; rocblas_int lda = arg.lda; - T h_alpha = (T)arg.alpha; + T h_alpha = arg.get_alpha(); rocblas_int stride_x = arg.stride_x; rocblas_int stride_y = arg.stride_y; rocblas_int stride_a = arg.stride_a; @@ -131,8 +131,7 @@ void testing_ger_strided_batched(const Arguments& arg) size_t size_y = N * abs_incy; // argument check before allocating invalid memory - if(M < 0 || N < 0 || lda < M || lda < 1 || !incx || !incy || stride_a < size_A - || stride_x < size_x || stride_y < size_y || batch_count < 0) + if(M < 0 || N < 0 || lda < M || lda < 1 || !incx || !incy || batch_count < 0) { static const size_t safe_size = 100; // arbitrarily set to 100 device_vector dA_1(safe_size); diff --git a/clients/include/testing_scal_batched.hpp b/clients/include/testing_scal_batched.hpp index 8704bde04..5a9ab5ad6 100644 --- a/clients/include/testing_scal_batched.hpp +++ b/clients/include/testing_scal_batched.hpp @@ -176,7 +176,7 @@ void testing_scal_batched(const Arguments& arg) gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; rocblas_gflops = batch_count * scal_gflop_count(N) / gpu_time_used * 1e6 * 1; - rocblas_bandwidth = (2.0 * N) * sizeof(T) / gpu_time_used / 1e3; + rocblas_bandwidth = batch_count * (2.0 * N) * sizeof(T) / gpu_time_used / 1e3; std::cout << "N,alpha,incx,rocblas-Gflops,rocblas-GB/s,rocblas-us"; diff --git a/clients/include/testing_scal_strided_batched.hpp b/clients/include/testing_scal_strided_batched.hpp index 2cc1d2fe6..5c070c204 100644 --- a/clients/include/testing_scal_strided_batched.hpp +++ b/clients/include/testing_scal_strided_batched.hpp @@ -162,7 +162,7 @@ void testing_scal_strided_batched(const Arguments& arg) gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; rocblas_gflops = batch_count * scal_gflop_count(N) / gpu_time_used * 1e6 * 1; - rocblas_bandwidth = (2.0 * N) * sizeof(T) / gpu_time_used / 1e3; + rocblas_bandwidth = batch_count * (2.0 * N) * sizeof(T) / gpu_time_used / 1e3; std::cout << "N,alpha,incx,rocblas-Gflops,rocblas-GB/s,rocblas-us"; diff --git a/library/include/rocblas-functions.h b/library/include/rocblas-functions.h index 4ce094e8b..6f893f473 100644 --- a/library/include/rocblas-functions.h +++ b/library/include/rocblas-functions.h @@ -190,12 +190,14 @@ ROCBLAS_EXPORT rocblas_status rocblas_zdscal_batched(rocblas_handle @param[inout] x pointer storing vector x on the GPU. @param[in] - incx specifies the increment for the elements of x. + incx rocblas_int + specifies the increment for the elements of x. @param[in] - stride_x stride form the start of one vector (x_i) and the next one (x_i+1). - There are no restrictions placed on stride_x, however the user should - take care to ensure that stride_x is of appropriate size, for a typical - case this means stride_x > n * incx. + stride_x rocblas_stride + stride from the start of one vector (x_i) and the next one (x_i+1). + There are no restrictions placed on stride_x, however the user should + take care to ensure that stride_x is of appropriate size, for a typical + case this means stride_x >= n * incx. @param[in] batch_count specifies the number of batches in x. ********************************************************************/ @@ -384,15 +386,21 @@ ROCBLAS_EXPORT rocblas_status rocblas_zcopy_batched(rocblas_handle specifies the increments for the elements of vectors x_i. @param[in] stridex rocblas_stride - stride form the start of one vector (x_i) and the next one (x_i+1) - @param[in] + stride from the start of one vector (x_i) and the next one (x_i+1). + There are no restrictions placed on stride_x, however the user should + take care to ensure that stride_x is of appropriate size, for a typical + case this means stride_x >= n * incx. + @param[out] y pointer to the first vector (y_0) in the batch stored on the GPU. @param[in] incy rocblas_int specifies the increment for the elements of vectors y_i. @param[in] stridey rocblas_stride - stride from the start of one vector (y_i) and the next one (y_i+1) + stride from the start of one vector (y_i) and the next one (y_i+1). + There are no restrictions placed on stride_y, however the user should + take care to ensure that stride_y is of appropriate size, for a typical + case this means stride_y >= n * incy. stridey should be non zero. @param[in] incy rocblas_int specifies the increment for the elements of y. @@ -907,18 +915,24 @@ ROCBLAS_EXPORT rocblas_status rocblas_zswap_batched(rocblas_handle hand specifies the increment for the elements of x. @param[in] stridex rocblas_stride - specifies the pointer increment between batches for x. + stride from the start of one vector (x_i) and the next one (x_i+1). + There are no restrictions placed on stride_x, however the user should + take care to ensure that stride_x is of appropriate size, for a typical + case this means stride_x >= n * incx. @param[inout] y a pointer to the first vector y_i on the GPU. @param[in] incy rocblas_int specifies the increment for the elements of y. @param[in] - stridey rocblas_stride - specifies the pointer increment between batches for y. - @param[in] - batch_count rocblas_int - number of instances in the batch + stridey rocblas_stride + stride from the start of one vector (y_i) and the next one (y_i+1). + There are no restrictions placed on stride_x, however the user should + take care to ensure that stride_y is of appropriate size, for a typical + case this means stride_y >= n * incy. stridey should be non zero. + @param[in] + batch_count rocblas_int + number of instances in the batch ********************************************************************/ @@ -1144,7 +1158,10 @@ ROCBLAS_EXPORT rocblas_status rocblas_dzasum_batched(rocblas_handle specifies the increment for the elements of each x_i. incx must be > 0. @param[in] stridex rocblas_stride - specifies the pointer increment between batches for x. stridex must be be non zero. + stride from the start of one vector (x_i) and the next one (x_i+1). + There are no restrictions placed on stride_x, however the user should + take care to ensure that stride_x is of appropriate size, for a typical + case this means stride_x >= n * incx. @param[out] results pointer to array for storing contiguous batch_count results. either on the host CPU or device GPU. @@ -1303,7 +1320,10 @@ ROCBLAS_EXPORT rocblas_status rocblas_dznrm2_batched(rocblas_handle specifies the increment for the elements of each x_i. incx must be > 0. @param[in] stridex rocblas_stride - specifies the pointer increment between batches for x. stridex must be non zero. + stride from the start of one vector (x_i) and the next one (x_i+1). + There are no restrictions placed on stride_x, however the user should + take care to ensure that stride_x is of appropriate size, for a typical + case this means stride_x >= n * incx. @param[in] batch_count rocblas_int number of instances in the batch @@ -1886,7 +1906,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_zgemv_batched(rocblas_handle lda rocblas_int specifies the leading dimension of matrices A_i. @param[in] - strideA rocblas_int + strideA rocblas_stride stride from the start of one matrix (A_i) and the next one (A_i+1) @param[in] x pointer to the first vector (x_0) in the batch stored on the GPU. @@ -1894,8 +1914,11 @@ ROCBLAS_EXPORT rocblas_status rocblas_zgemv_batched(rocblas_handle incx rocblas_int specifies the increment for the elements of vectors x_i. @param[in] - stridex rocblas_int - stride form the start of one vector (x_i) and the next one (x_i+1) + stridex rocblas_stride + stride from the start of one vector (x_i) and the next one (x_i+1). + There are no restrictions placed on stride_x, however the user should + take care to ensure that stride_x is of appropriate size. When trans equals rocblas_operation_none + this typically means stride_x >= n * incx, otherwise stride_x >= m * incx. @param[in] beta specifies the scalar beta. @param[inout] @@ -1904,8 +1927,11 @@ ROCBLAS_EXPORT rocblas_status rocblas_zgemv_batched(rocblas_handle incy rocblas_int specifies the increment for the elements of vectors y_i. @param[in] - stridey rocblas_int - stride from the start of one vector (y_i) and the next one (y_i+1) + stridey rocblas_stride + stride from the start of one vector (y_i) and the next one (y_i+1). + There are no restrictions placed on stride_y, however the user should + take care to ensure that stride_y is of appropriate size. When trans equals rocblas_operation_none + this typically means stride_y >= m * incy, otherwise stride_y >= n * incy. stridey should be non zero. @param[in] batch_count rocblas_int number of instances in the batch @@ -2028,7 +2054,8 @@ ROCBLAS_EXPORT rocblas_status rocblas_zgemv_strided_batched(rocblas_handle x pointer storing vector x on the GPU. @param[in] - incx specifies the increment for the elements of x. + incx rocblas_int + specifies the increment for the elements of x. ********************************************************************/ ROCBLAS_EXPORT rocblas_status rocblas_strsv(rocblas_handle handle, @@ -2080,7 +2107,8 @@ ROCBLAS_EXPORT rocblas_status rocblas_dtrsv(rocblas_handle handle, @param[in] x pointer storing vector x on the GPU. @param[in] - incx specifies the increment for the elements of x. + incx rocblas_int + specifies the increment for the elements of x. @param[in] beta specifies the scalar beta. @param[out] @@ -2287,23 +2315,29 @@ ROCBLAS_EXPORT rocblas_status rocblas_dger_batched(rocblas_handle handle, incx rocblas_int specifies the increments for the elements of vectors x_i. @param[in] - stridex rocblas_int - stride form the start of one vector (x_i) and the next one (x_i+1) - @param[in] + stridex rocblas_stride + stride from the start of one vector (x_i) and the next one (x_i+1). + There are no restrictions placed on stride_x, however the user should + take care to ensure that stride_x is of appropriate size, for a typical + case this means stride_x >= m * incx. + @param[inout] y pointer to the first vector (y_0) in the batch stored on the GPU. @param[in] incy rocblas_int specifies the increment for the elements of vectors y_i. @param[in] - stridey rocblas_int - stride from the start of one vector (y_i) and the next one (y_i+1) + stridey rocblas_stride + stride from the start of one vector (y_i) and the next one (y_i+1). + There are no restrictions placed on stride_y, however the user should + take care to ensure that stride_y is of appropriate size, for a typical + case this means stride_y >= n * incy. @param[inout] A pointer to the first matrix (A_0) in the batch stored on the GPU. @param[in] lda rocblas_int specifies the leading dimension of A. @param[in] - strideA rocblas_int + strideA rocblas_stride stride from the start of one matrix (A_i) and the next one (A_i+1) @param[in] batch_count rocblas_int @@ -2702,11 +2736,11 @@ ROCBLAS_EXPORT rocblas_status rocblas_dtrtri_batched(rocblas_handle handle, ldinvA rocblas_int specifies the leading dimension of invA. @param[in] - stride_invA rocblas_stride - "batch stride invA": stride from the start of one "invA" matrix to the next + stride_invA rocblas_stride + "batch stride invA": stride from the start of one "invA" matrix to the next @param[in] - batch_count rocblas_int - numbers of matrices in the batch + batch_count rocblas_int + numbers of matrices in the batch ********************************************************************/ ROCBLAS_EXPORT rocblas_status rocblas_strtri_strided_batched(rocblas_handle handle, @@ -3178,7 +3212,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_zgemm_batched(rocblas_handle lda rocblas_int specifies the leading dimension of "A". @param[in] - stride_a rocblas_stride + stride_a rocblas_stride stride from the start of one "A" matrix to the next @param[in] B pointer storing strided batched matrix B on the GPU. @@ -3186,7 +3220,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_zgemm_batched(rocblas_handle ldb rocblas_int specifies the leading dimension of "B". @param[in] - stride_b rocblas_stride + stride_b rocblas_stride stride from the start of one "B" matrix to the next @param[in] beta specifies the scalar beta. @@ -3196,7 +3230,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_zgemm_batched(rocblas_handle ldc rocblas_int specifies the leading dimension of "C". @param[in] - stride_c rocblas_stride + stride_c rocblas_stride stride from the start of one "C" matrix to the next @param[in] batch_count @@ -3980,7 +4014,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_gemm_batched_ex(rocblas_handle handle, lda rocblas_int. specifies the leading dimension of A. @param[in] - stride_a rocblas_long. + stride_a rocblas_stride. specifies stride from start of one "A" matrix to the next. @param[in] b void *. @@ -3992,7 +4026,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_gemm_batched_ex(rocblas_handle handle, ldb rocblas_int. specifies the leading dimension of B. @param[in] - stride_b rocblas_long. + stride_b rocblas_stride. specifies stride from start of one "B" matrix to the next. @param[in] beta const void *. @@ -4007,7 +4041,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_gemm_batched_ex(rocblas_handle handle, ldc rocblas_int. specifies the leading dimension of C. @param[in] - stride_c rocblas_long. + stride_c rocblas_stride. specifies stride from start of one "C" matrix to the next. @param[out] d void *. @@ -4019,7 +4053,7 @@ ROCBLAS_EXPORT rocblas_status rocblas_gemm_batched_ex(rocblas_handle handle, ldd rocblas_int. specifies the leading dimension of D. @param[in] - stride_d rocblas_long. + stride_d rocblas_stride. specifies stride from start of one "D" matrix to the next. @param[in] batch_count diff --git a/library/src/blas2/gemv_device.hpp b/library/src/blas2/gemv_device.hpp index d2b973e6f..efa160fbf 100644 --- a/library/src/blas2/gemv_device.hpp +++ b/library/src/blas2/gemv_device.hpp @@ -12,16 +12,16 @@ template {}, int>::type = 0> -__device__ void gemvn_kernel(rocblas_int m, - rocblas_int n, - U alpha, - const T* A, - rocblas_int lda, - const T* x, - rocblas_int incx, - U beta, - T* y, - rocblas_int incy) +__device__ void gemvn_kernel_calc(rocblas_int m, + rocblas_int n, + U alpha, + const T* A, + rocblas_int lda, + const T* x, + rocblas_int incx, + U beta, + T* y, + rocblas_int incy) { rocblas_int thread_id = hipThreadIdx_x + hipThreadIdx_y * hipBlockDim_x; @@ -172,16 +172,16 @@ __device__ void gemvn_kernel(rocblas_int m, // Overload for double precision complex numbers. We run out of registers // if we use the above algorithm. template -__device__ void gemvn_kernel(rocblas_int m, - rocblas_int n, - U alpha, - const rocblas_double_complex* A, - rocblas_int lda, - const rocblas_double_complex* x, - rocblas_int incx, - U beta, - rocblas_double_complex* y, - rocblas_int incy) +__device__ void gemvn_kernel_calc(rocblas_int m, + rocblas_int n, + U alpha, + const rocblas_double_complex* A, + rocblas_int lda, + const rocblas_double_complex* x, + rocblas_int incx, + U beta, + rocblas_double_complex* y, + rocblas_int incy) { rocblas_int thread_id = hipThreadIdx_x + hipThreadIdx_y * hipBlockDim_x; @@ -250,16 +250,16 @@ __device__ void gemvn_kernel(rocblas_int m, } template -__device__ void gemvc_kernel(rocblas_int m, - rocblas_int n, - U alpha, - const T* A, - rocblas_int lda, - const T* x, - rocblas_int incx, - U beta, - T* y, - rocblas_int incy) +__device__ void gemvc_kernel_calc(rocblas_int m, + rocblas_int n, + U alpha, + const T* A, + rocblas_int lda, + const T* x, + rocblas_int incx, + U beta, + T* y, + rocblas_int incy) { rocblas_int tx = hipThreadIdx_x; @@ -317,16 +317,16 @@ __device__ void gemvc_kernel(rocblas_int m, } template -__device__ void gemvt_kernel(rocblas_int m, - rocblas_int n, - U alpha, - const T* A, - rocblas_int lda, - const T* x, - rocblas_int incx, - U beta, - T* y, - rocblas_int incy) +__device__ void gemvt_kernel_calc(rocblas_int m, + rocblas_int n, + U alpha, + const T* A, + rocblas_int lda, + const T* x, + rocblas_int incx, + U beta, + T* y, + rocblas_int incy) { rocblas_int tx = hipThreadIdx_x; @@ -383,201 +383,98 @@ __device__ void gemvt_kernel(rocblas_int m, } } -template -__global__ void gemvn_kernel_strided(rocblas_int m, - rocblas_int n, - U alpha_device_host, - const T* Aa, - rocblas_int lda, - rocblas_int strideA, - const T* xa, - rocblas_int incx, - rocblas_int stridex, - U beta_device_host, - T* ya, - rocblas_int incy, - rocblas_int stridey) +template +__global__ void gemvn_kernel(rocblas_int m, + rocblas_int n, + U alpha_device_host, + rocblas_stride stride_alpha, + const V* Aa, + ptrdiff_t shifta, + rocblas_int lda, + rocblas_stride strideA, + const V* xa, + ptrdiff_t shiftx, + rocblas_int incx, + rocblas_stride stridex, + U beta_device_host, + rocblas_stride stride_beta, + W* ya, + ptrdiff_t shifty, + rocblas_int incy, + rocblas_stride stridey) { rocblas_int num_threads = hipBlockDim_x * hipBlockDim_y * hipBlockDim_z; if(DIM_X * DIM_Y != num_threads) return; // need to launch exactly the same number of threads as template parameters indicate - const T* A; - const T* x; - T* y; - A = Aa + hipBlockIdx_y * strideA; - x = xa + hipBlockIdx_y * stridex; - y = ya + hipBlockIdx_y * stridey; + const T* A = load_ptr_batch(Aa, hipBlockIdx_y, shifta, strideA); + const T* x = load_ptr_batch(xa, hipBlockIdx_y, shiftx, stridex); + T* y = load_ptr_batch(ya, hipBlockIdx_y, shifty, stridey); - if(incx < 0) - x -= ptrdiff_t(incx) * (n - 1); - if(incy < 0) - y -= ptrdiff_t(incy) * (m - 1); + auto alpha = load_scalar(alpha_device_host, hipBlockIdx_y, stride_alpha); + auto beta = load_scalar(beta_device_host, hipBlockIdx_y, stride_beta); - auto alpha = load_scalar(alpha_device_host); - auto beta = load_scalar(beta_device_host); - - gemvn_kernel(m, n, alpha, A, lda, x, incx, beta, y, incy); + gemvn_kernel_calc(m, n, alpha, A, lda, x, incx, beta, y, incy); } -template -__global__ void gemvc_kernel_strided(rocblas_int m, - rocblas_int n, - U alpha_device_host, - const T* Aa, - rocblas_int lda, - rocblas_int strideA, - const T* xa, - rocblas_int incx, - rocblas_int stridex, - U beta_device_host, - T* ya, - rocblas_int incy, - rocblas_int stridey) +template +__global__ void gemvc_kernel(rocblas_int m, + rocblas_int n, + U alpha_device_host, + rocblas_stride stride_alpha, + const V* Aa, + ptrdiff_t shifta, + rocblas_int lda, + rocblas_stride strideA, + const V* xa, + ptrdiff_t shiftx, + rocblas_int incx, + rocblas_stride stridex, + U beta_device_host, + rocblas_stride stride_beta, + W* ya, + ptrdiff_t shifty, + rocblas_int incy, + rocblas_stride stridey) { - const T* A; - const T* x; - T* y; - A = Aa + hipBlockIdx_y * strideA; - x = xa + hipBlockIdx_y * stridex; - y = ya + hipBlockIdx_y * stridey; - - if(incx < 0) - x -= ptrdiff_t(incx) * (m - 1); - if(incy < 0) - y -= ptrdiff_t(incy) * (n - 1); - - auto alpha = load_scalar(alpha_device_host); - auto beta = load_scalar(beta_device_host); - - gemvc_kernel(m, n, alpha, A, lda, x, incx, beta, y, incy); -} + const T* A = load_ptr_batch(Aa, hipBlockIdx_y, shifta, strideA); + const T* x = load_ptr_batch(xa, hipBlockIdx_y, shiftx, stridex); + T* y = load_ptr_batch(ya, hipBlockIdx_y, shifty, stridey); -template -__global__ void gemvt_kernel_strided(rocblas_int m, - rocblas_int n, - U alpha_device_host, - const T* Aa, - rocblas_int lda, - rocblas_int strideA, - const T* xa, - rocblas_int incx, - rocblas_int stridex, - U beta_device_host, - T* ya, - rocblas_int incy, - rocblas_int stridey) -{ - const T* A; - const T* x; - T* y; - A = Aa + hipBlockIdx_y * strideA; - x = xa + hipBlockIdx_y * stridex; - y = ya + hipBlockIdx_y * stridey; - - if(incx < 0) - x -= ssize_t(incx) * (m - 1); - if(incy < 0) - y -= ssize_t(incy) * (n - 1); - - auto alpha = load_scalar(alpha_device_host); - auto beta = load_scalar(beta_device_host); - - gemvt_kernel(m, n, alpha, A, lda, x, incx, beta, y, incy); -} + auto alpha = load_scalar(alpha_device_host, hipBlockIdx_y, stride_alpha); + auto beta = load_scalar(beta_device_host, hipBlockIdx_y, stride_beta); -template -__global__ void gemvn_kernel_batched(rocblas_int m, - rocblas_int n, - U alpha_device_host, - const T* const Aa[], - rocblas_int lda, - const T* const xa[], - rocblas_int incx, - U beta_device_host, - T* const ya[], - rocblas_int incy) -{ - rocblas_int num_threads = hipBlockDim_x * hipBlockDim_y * hipBlockDim_z; - if(DIM_X * DIM_Y != num_threads) - return; // need to launch exactly the same number of threads as template parameters indicate - - const T* A; - const T* x; - T* y; - A = Aa[hipBlockIdx_y]; - x = xa[hipBlockIdx_y]; - y = ya[hipBlockIdx_y]; - - if(incx < 0) - x -= ptrdiff_t(incx) * (n - 1); - if(incy < 0) - y -= ptrdiff_t(incy) * (m - 1); - - auto alpha = load_scalar(alpha_device_host); - auto beta = load_scalar(beta_device_host); - - gemvn_kernel(m, n, alpha, A, lda, x, incx, beta, y, incy); + gemvc_kernel_calc(m, n, alpha, A, lda, x, incx, beta, y, incy); } -template -__global__ void gemvc_kernel_batched(rocblas_int m, - rocblas_int n, - U alpha_device_host, - const T* const Aa[], - rocblas_int lda, - const T* const xa[], - rocblas_int incx, - U beta_device_host, - T* const ya[], - rocblas_int incy) +template +__global__ void gemvt_kernel(rocblas_int m, + rocblas_int n, + U alpha_device_host, + rocblas_stride stride_alpha, + const V* Aa, + ptrdiff_t shifta, + rocblas_int lda, + rocblas_stride strideA, + const V* xa, + ptrdiff_t shiftx, + rocblas_int incx, + rocblas_stride stridex, + U beta_device_host, + rocblas_stride stride_beta, + W* ya, + ptrdiff_t shifty, + rocblas_int incy, + rocblas_stride stridey) { - const T* A; - const T* x; - T* y; - A = Aa[hipBlockIdx_y]; - x = xa[hipBlockIdx_y]; - y = ya[hipBlockIdx_y]; - - if(incx < 0) - x -= ptrdiff_t(incx) * (m - 1); - if(incy < 0) - y -= ptrdiff_t(incy) * (n - 1); - - auto alpha = load_scalar(alpha_device_host); - auto beta = load_scalar(beta_device_host); - - gemvc_kernel(m, n, alpha, A, lda, x, incx, beta, y, incy); -} + const T* A = load_ptr_batch(Aa, hipBlockIdx_y, shifta, strideA); + const T* x = load_ptr_batch(xa, hipBlockIdx_y, shiftx, stridex); + T* y = load_ptr_batch(ya, hipBlockIdx_y, shifty, stridey); -template -__global__ void gemvt_kernel_batched(rocblas_int m, - rocblas_int n, - U alpha_device_host, - const T* const Aa[], - rocblas_int lda, - const T* const xa[], - rocblas_int incx, - U beta_device_host, - T* const ya[], - rocblas_int incy) -{ - const T* A; - const T* x; - T* y; - A = Aa[hipBlockIdx_y]; - x = xa[hipBlockIdx_y]; - y = ya[hipBlockIdx_y]; - - if(incx < 0) - x -= ssize_t(incx) * (m - 1); - if(incy < 0) - y -= ssize_t(incy) * (n - 1); - - auto alpha = load_scalar(alpha_device_host); - auto beta = load_scalar(beta_device_host); - - gemvt_kernel(m, n, alpha, A, lda, x, incx, beta, y, incy); + auto alpha = load_scalar(alpha_device_host, hipBlockIdx_y, stride_alpha); + auto beta = load_scalar(beta_device_host, hipBlockIdx_y, stride_beta); + + gemvt_kernel_calc(m, n, alpha, A, lda, x, incx, beta, y, incy); } #endif diff --git a/library/src/blas2/rocblas_gemv.cpp b/library/src/blas2/rocblas_gemv.cpp index 25aeaf9e4..f214074ef 100644 --- a/library/src/blas2/rocblas_gemv.cpp +++ b/library/src/blas2/rocblas_gemv.cpp @@ -132,7 +132,8 @@ namespace if(m < 0 || n < 0 || lda < m || lda < 1 || !incx || !incy) return rocblas_status_invalid_size; - return rocblas_gemv_template(handle, transA, m, n, alpha, A, lda, x, incx, beta, y, incy); + return rocblas_gemv_template( + handle, transA, m, n, alpha, A, 0, lda, 0, x, 0, incx, 0, beta, y, 0, incy, 0, 1); } } // namespace diff --git a/library/src/blas2/rocblas_gemv.hpp b/library/src/blas2/rocblas_gemv.hpp index a3b6561eb..9b89e7b1a 100644 --- a/library/src/blas2/rocblas_gemv.hpp +++ b/library/src/blas2/rocblas_gemv.hpp @@ -7,390 +7,40 @@ #include "handle.h" #include "rocblas.h" -template +template rocblas_status rocblas_gemv_template(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, - const T* alpha, - const T* A, + const U* alpha, + const V* A, + rocblas_int offseta, rocblas_int lda, - const T* x, + rocblas_stride strideA, + const V* x, + rocblas_int offsetx, rocblas_int incx, - const T* beta, - T* y, - rocblas_int incy) + rocblas_stride stridex, + const U* beta, + W* y, + rocblas_int offsety, + rocblas_int incy, + rocblas_stride stridey, + rocblas_int batch_count) { //quick return - if(!m || !n) - return rocblas_status_success; - - hipStream_t rocblas_stream = handle->rocblas_stream; - - if(transA == rocblas_operation_none) - { - // GEMVN_DIM_Y must be at least 4, 8 * 8 is very slow only 40Gflop/s - static constexpr int GEMVN_DIM_X = 64; - static constexpr int GEMVN_DIM_Y = 16; - rocblas_int blocks = (m - 1) / (GEMVN_DIM_X * 4) + 1; - if(std::is_same{}) - blocks = (m - 1) / (GEMVN_DIM_X) + 1; - dim3 gemvn_grid(blocks, 1); - dim3 gemvn_threads(GEMVN_DIM_X, GEMVN_DIM_Y); - - if(handle->pointer_mode == rocblas_pointer_mode_device) - { - hipLaunchKernelGGL((gemvn_kernel_strided), - gemvn_grid, - gemvn_threads, - 0, - rocblas_stream, - m, - n, - alpha, - A, - lda, - 0, // strideA = 0 - x, - incx, - 0, // stridex = 0 - beta, - y, - incy, - 0); // stridey = 0 - } - else - { - if(!*alpha && *beta == 1) - return rocblas_status_success; - - hipLaunchKernelGGL((gemvn_kernel_strided), - gemvn_grid, - gemvn_threads, - 0, - rocblas_stream, - m, - n, - *alpha, - A, - lda, - 0, // strideA = 0 - x, - incx, - 0, // stridex = 0 - *beta, - y, - incy, - 0); // stridey = 0 - } - } - else if(transA == rocblas_operation_transpose) - { - // transpose - // number of columns on the y-dim of the grid - static constexpr int NB = 256; - dim3 gemvt_grid(n, 1); - dim3 gemvt_threads(NB); - - if(handle->pointer_mode == rocblas_pointer_mode_device) - { - hipLaunchKernelGGL(gemvt_kernel_strided, - gemvt_grid, - gemvt_threads, - 0, - rocblas_stream, - m, - n, - alpha, - A, - lda, - 0, // strideA = 0 - x, - incx, - 0, // stridex = 0 - beta, - y, - incy, - 0); // stridey = 0 - } - else - { - if(!*alpha && *beta == 1) - return rocblas_status_success; - - hipLaunchKernelGGL(gemvt_kernel_strided, - gemvt_grid, - gemvt_threads, - 0, - rocblas_stream, - m, - n, - *alpha, - A, - lda, - 0, // strideA = 0 - x, - incx, - 0, // stridex = 0 - *beta, - y, - incy, - 0); // stridey = 0 - } - } - else // conjugate transpose - { - // conjugate transpose - // number of columns on the y-dim of the grid - static constexpr int NB = 256; - dim3 gemvc_grid(n, 1); - dim3 gemvc_threads(NB); - - if(handle->pointer_mode == rocblas_pointer_mode_device) - { - hipLaunchKernelGGL(gemvc_kernel_strided, - gemvc_grid, - gemvc_threads, - 0, - rocblas_stream, - m, - n, - alpha, - A, - lda, - 0, // strideA = 0 - x, - incx, - 0, // stridex = 0 - beta, - y, - incy, - 0); // stridey = 0 - } - else - { - if(!*alpha && *beta == 1) - return rocblas_status_success; - - hipLaunchKernelGGL(gemvc_kernel_strided, - gemvc_grid, - gemvc_threads, - 0, - rocblas_stream, - m, - n, - *alpha, - A, - lda, - 0, // strideA = 0 - x, - incx, - 0, // stridex = 0 - *beta, - y, - incy, - 0); // stridey = 0 - } - } - return rocblas_status_success; -} - -template -rocblas_status rocblas_gemv_batched_template(rocblas_handle handle, - rocblas_operation transA, - rocblas_int m, - rocblas_int n, - const T* alpha, - const T* const A[], - rocblas_int lda, - const T* const x[], - rocblas_int incx, - const T* beta, - T* const y[], - rocblas_int incy, - rocblas_int batch_count) -{ - // Quick return if possible. Not Argument error if(!m || !n || !batch_count) return rocblas_status_success; hipStream_t rocblas_stream = handle->rocblas_stream; - if(transA == rocblas_operation_none) - { - // GEMVN_DIM_Y must be at least 4, 8 * 8 is very slow only 40Gflop/s - static constexpr int GEMVN_DIM_X = 64; - static constexpr int GEMVN_DIM_Y = 16; - rocblas_int blocks = (m - 1) / (GEMVN_DIM_X * 4) + 1; - if(std::is_same{}) - blocks = (m - 1) / (GEMVN_DIM_X) + 1; - - dim3 gemvn_grid(blocks, batch_count); - dim3 gemvn_threads(GEMVN_DIM_X, GEMVN_DIM_Y); - - if(handle->pointer_mode == rocblas_pointer_mode_device) - { - hipLaunchKernelGGL((gemvn_kernel_batched), - gemvn_grid, - gemvn_threads, - 0, - rocblas_stream, - m, - n, - alpha, - A, - lda, - x, - incx, - beta, - y, - incy); - } - else - { - if(!*alpha && *beta == 1) - return rocblas_status_success; - - hipLaunchKernelGGL((gemvn_kernel_batched), - gemvn_grid, - gemvn_threads, - 0, - rocblas_stream, - m, - n, - *alpha, - A, - lda, - x, - incx, - *beta, - y, - incy); - } - } - else if(transA == rocblas_operation_transpose) - { - // transpose - // number of columns on the y-dim of the grid - static constexpr int NB = 256; - dim3 gemvt_grid(n, batch_count); - dim3 gemvt_threads(NB); - - if(handle->pointer_mode == rocblas_pointer_mode_device) - { - hipLaunchKernelGGL(gemvt_kernel_batched, - gemvt_grid, - gemvt_threads, - 0, - rocblas_stream, - m, - n, - alpha, - A, - lda, - x, - incx, - beta, - y, - incy); - } - else - { - if(!*alpha && *beta == 1) - return rocblas_status_success; - - hipLaunchKernelGGL(gemvt_kernel_batched, - gemvt_grid, - gemvt_threads, - 0, - rocblas_stream, - m, - n, - *alpha, - A, - lda, - x, - incx, - *beta, - y, - incy); - } - } - else // conjugate transpose - { - // conjugate transpose - // number of columns on the y-dim of the grid - static constexpr int NB = 256; - dim3 gemvc_grid(n, batch_count); - dim3 gemvc_threads(NB); - - if(handle->pointer_mode == rocblas_pointer_mode_device) - { - hipLaunchKernelGGL(gemvc_kernel_batched, - gemvc_grid, - gemvc_threads, - 0, - rocblas_stream, - m, - n, - alpha, - A, - lda, - x, - incx, - beta, - y, - incy); - } - else - { - if(!*alpha && *beta == 1) - return rocblas_status_success; - - hipLaunchKernelGGL(gemvc_kernel_batched, - gemvc_grid, - gemvc_threads, - 0, - rocblas_stream, - m, - n, - *alpha, - A, - lda, - x, - incx, - *beta, - y, - incy); - } - } - - return rocblas_status_success; -} - -template -rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, - rocblas_operation transA, - rocblas_int m, - rocblas_int n, - const T* alpha, - const T* A, - rocblas_int lda, - rocblas_int strideA, - const T* x, - rocblas_int incx, - rocblas_int stridex, - const T* beta, - T* y, - rocblas_int incy, - rocblas_int stridey, - rocblas_int batch_count) -{ - // Quick return if possible. Not Argument error - if(!m || !n || !batch_count) - return rocblas_status_success; - - hipStream_t rocblas_stream = handle->rocblas_stream; + // in case of negative inc shift pointer to end of data for negative indexing tid*inc + auto shiftx = incx < 0 + ? offsetx - ptrdiff_t(incx) * ((transA == rocblas_operation_none ? n : m) - 1) + : offsetx; + auto shifty = incy < 0 + ? offsety - ptrdiff_t(incy) * ((transA == rocblas_operation_none ? m : n) - 1) + : offsety; if(transA == rocblas_operation_none) { @@ -400,13 +50,12 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, rocblas_int blocks = (m - 1) / (GEMVN_DIM_X * 4) + 1; if(std::is_same{}) blocks = (m - 1) / (GEMVN_DIM_X) + 1; - dim3 gemvn_grid(blocks, batch_count); dim3 gemvn_threads(GEMVN_DIM_X, GEMVN_DIM_Y); if(handle->pointer_mode == rocblas_pointer_mode_device) { - hipLaunchKernelGGL((gemvn_kernel_strided), + hipLaunchKernelGGL((gemvn_kernel), gemvn_grid, gemvn_threads, 0, @@ -414,14 +63,19 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, m, n, alpha, + 0, A, + offseta, lda, strideA, x, + shiftx, incx, stridex, beta, + 0, y, + shifty, incy, stridey); } @@ -430,7 +84,7 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, if(!*alpha && *beta == 1) return rocblas_status_success; - hipLaunchKernelGGL((gemvn_kernel_strided), + hipLaunchKernelGGL((gemvn_kernel), gemvn_grid, gemvn_threads, 0, @@ -438,14 +92,19 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, m, n, *alpha, + 0, A, + offseta, lda, strideA, x, + shiftx, incx, stridex, *beta, + 0, y, + shifty, incy, stridey); } @@ -460,7 +119,7 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, if(handle->pointer_mode == rocblas_pointer_mode_device) { - hipLaunchKernelGGL(gemvt_kernel_strided, + hipLaunchKernelGGL((gemvt_kernel), gemvt_grid, gemvt_threads, 0, @@ -468,14 +127,19 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, m, n, alpha, + 0, A, + offseta, lda, strideA, x, + shiftx, incx, stridex, beta, + 0, y, + shifty, incy, stridey); } @@ -484,7 +148,7 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, if(!*alpha && *beta == 1) return rocblas_status_success; - hipLaunchKernelGGL(gemvt_kernel_strided, + hipLaunchKernelGGL((gemvt_kernel), gemvt_grid, gemvt_threads, 0, @@ -492,14 +156,19 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, m, n, *alpha, + 0, A, + offseta, lda, strideA, x, + shiftx, incx, stridex, *beta, + 0, y, + shifty, incy, stridey); } @@ -514,7 +183,7 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, if(handle->pointer_mode == rocblas_pointer_mode_device) { - hipLaunchKernelGGL(gemvc_kernel_strided, + hipLaunchKernelGGL((gemvc_kernel), gemvc_grid, gemvc_threads, 0, @@ -522,14 +191,19 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, m, n, alpha, + 0, A, + offseta, lda, strideA, x, + shiftx, incx, stridex, beta, + 0, y, + shifty, incy, stridey); } @@ -538,7 +212,7 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, if(!*alpha && *beta == 1) return rocblas_status_success; - hipLaunchKernelGGL(gemvc_kernel_strided, + hipLaunchKernelGGL((gemvc_kernel), gemvc_grid, gemvc_threads, 0, @@ -546,14 +220,19 @@ rocblas_status rocblas_gemv_strided_batched_template(rocblas_handle handle, m, n, *alpha, + 0, A, + offseta, lda, strideA, x, + shiftx, incx, stridex, *beta, + 0, y, + shifty, incy, stridey); } diff --git a/library/src/blas2/rocblas_gemv_batched.cpp b/library/src/blas2/rocblas_gemv_batched.cpp index c0af67b88..0ba168d74 100644 --- a/library/src/blas2/rocblas_gemv_batched.cpp +++ b/library/src/blas2/rocblas_gemv_batched.cpp @@ -138,8 +138,25 @@ namespace if(batch_count < 0) return rocblas_status_invalid_size; - return rocblas_gemv_batched_template( - handle, transA, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count); + return rocblas_gemv_template(handle, + transA, + m, + n, + alpha, + A, + 0, + lda, + 0, + x, + 0, + incx, + 0, + beta, + y, + 0, + incy, + 0, + batch_count); } } // namespace diff --git a/library/src/blas2/rocblas_gemv_strided_batched.cpp b/library/src/blas2/rocblas_gemv_strided_batched.cpp index 8e2b6fc30..f7d46f558 100644 --- a/library/src/blas2/rocblas_gemv_strided_batched.cpp +++ b/library/src/blas2/rocblas_gemv_strided_batched.cpp @@ -156,50 +156,28 @@ namespace return rocblas_status_invalid_pointer; if(m < 0 || n < 0 || lda < m || lda < 1 || !incx || !incy) return rocblas_status_invalid_size; - if(strideA < lda * n) - return rocblas_status_invalid_size; if(batch_count < 0) return rocblas_status_invalid_size; - size_t size_x, dim_x, abs_incx; - size_t size_y, dim_y, abs_incy; - - if(transA == rocblas_operation_none) - { - dim_x = n; - dim_y = m; - } - else - { - dim_x = m; - dim_y = n; - } - - abs_incx = incx >= 0 ? incx : -incx; - abs_incy = incy >= 0 ? incy : -incy; - - size_x = dim_x * abs_incx; - size_y = dim_y * abs_incy; - - if(stridex < size_x || stridey < size_y) - return rocblas_status_invalid_size; - - return rocblas_gemv_strided_batched_template(handle, - transA, - m, - n, - alpha, - A, - lda, - strideA, - x, - incx, - stridex, - beta, - y, - incy, - stridey, - batch_count); + return rocblas_gemv_template(handle, + transA, + m, + n, + alpha, + A, + 0, + lda, + strideA, + x, + 0, + incx, + stridex, + beta, + y, + 0, + incy, + stridey, + batch_count); } } //namespace diff --git a/library/src/blas2/rocblas_ger.cpp b/library/src/blas2/rocblas_ger.cpp index 218d21419..e37a6e2e0 100644 --- a/library/src/blas2/rocblas_ger.cpp +++ b/library/src/blas2/rocblas_ger.cpp @@ -1,10 +1,10 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ +#include "rocblas_ger.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" -#include "rocblas_ger_strided_batched.hpp" #include "utility.h" namespace @@ -84,12 +84,8 @@ namespace if(m < 0 || n < 0 || !incx || !incy || lda < m || lda < 1) return rocblas_status_invalid_size; - // Quick return if possible. Not Argument error - if(!m || !n) - return rocblas_status_success; - - rocblas_ger_strided_batched_template( - handle, m, n, alpha, x, 0, incx, incx * m, y, 0, incy, incy * n, A, 0, lda, lda * n, 1); + rocblas_ger_template( + handle, m, n, alpha, 0, x, 0, incx, 0, y, 0, incy, 0, A, 0, lda, 0, 1); return rocblas_status_success; } diff --git a/library/src/blas2/rocblas_ger.hpp b/library/src/blas2/rocblas_ger.hpp new file mode 100644 index 000000000..fc0997161 --- /dev/null +++ b/library/src/blas2/rocblas_ger.hpp @@ -0,0 +1,125 @@ +/* ************************************************************************ + * Copyright 2016-2019 Advanced Micro Devices, Inc. + * ************************************************************************ */ +#include "handle.h" +#include "logging.h" +#include "rocblas.h" +#include "utility.h" + +template +__global__ void ger_kernel(rocblas_int m, + rocblas_int n, + W alpha_device_host, + rocblas_stride stride_alpha, + const U __restrict__ xa, + ptrdiff_t shiftx, + rocblas_int incx, + rocblas_int stridex, + const U __restrict__ ya, + ptrdiff_t shifty, + rocblas_int incy, + rocblas_int stridey, + V Aa, + ptrdiff_t shifta, + rocblas_int lda, + rocblas_int strideA) +{ + + ptrdiff_t tx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + ptrdiff_t ty = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y; + + if(tx < m && ty < n) + { + auto alpha = load_scalar(alpha_device_host, hipBlockIdx_z, stride_alpha); + T* A = load_ptr_batch(Aa, hipBlockIdx_z, shifta, strideA); + const T* __restrict__ x = load_ptr_batch(xa, hipBlockIdx_z, shiftx, stridex); + const T* __restrict__ y = load_ptr_batch(ya, hipBlockIdx_z, shifty, stridey); + + A[tx + lda * ty] += alpha * x[tx * incx] * y[ty * incy]; + } +} + +template +rocblas_status rocblas_ger_template(rocblas_handle handle, + rocblas_int m, + rocblas_int n, + const W* alpha, + rocblas_stride stride_alpha, + const U* x, + rocblas_int offsetx, + rocblas_int incx, + rocblas_int stridex, + const U* y, + rocblas_int offsety, + rocblas_int incy, + rocblas_int stridey, + V* A, + rocblas_int offsetA, + rocblas_int lda, + rocblas_int strideA, + rocblas_int batch_count) +{ + // Quick return if possible. Not Argument error + if(!m || !n || !batch_count) + return rocblas_status_success; + + hipStream_t rocblas_stream = handle->rocblas_stream; + + // in case of negative inc shift pointer to end of data for negative indexing tid*inc + auto shiftx = incx < 0 ? offsetx - ptrdiff_t(incx) * (m - 1) : offsetx; + auto shifty = incy < 0 ? offsety - ptrdiff_t(incy) * (n - 1) : offsety; + + static constexpr int GEMV_DIM_X = 128; + static constexpr int GEMV_DIM_Y = 8; + rocblas_int blocksX = (m - 1) / GEMV_DIM_X + 1; + rocblas_int blocksY = (n - 1) / GEMV_DIM_Y + 1; + + dim3 grid(blocksX, blocksY, batch_count); + dim3 threads(GEMV_DIM_X, GEMV_DIM_Y); + + if(handle->pointer_mode == rocblas_pointer_mode_device) + hipLaunchKernelGGL(ger_kernel, + grid, + threads, + 0, + rocblas_stream, + m, + n, + alpha, + stride_alpha, + x, + shiftx, + incx, + stridex, + y, + shifty, + incy, + stridey, + A, + offsetA, + lda, + strideA); + else + hipLaunchKernelGGL(ger_kernel, + grid, + threads, + 0, + rocblas_stream, + m, + n, + *alpha, + stride_alpha, + x, + shiftx, + incx, + stridex, + y, + shifty, + incy, + stridey, + A, + offsetA, + lda, + strideA); + return rocblas_status_success; +} diff --git a/library/src/blas2/rocblas_ger_batched.cpp b/library/src/blas2/rocblas_ger_batched.cpp index 8715638c6..05c96ba65 100644 --- a/library/src/blas2/rocblas_ger_batched.cpp +++ b/library/src/blas2/rocblas_ger_batched.cpp @@ -1,10 +1,10 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ -#include "rocblas_ger_batched.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" +#include "rocblas_ger.hpp" #include "utility.h" namespace @@ -111,12 +111,8 @@ namespace if(m < 0 || n < 0 || !incx || !incy || lda < m || lda < 1 || batch_count < 0) return rocblas_status_invalid_size; - // Quick return if possible. Not Argument error - if(!m || !n || !batch_count) - return rocblas_status_success; - - rocblas_ger_batched_template( - handle, m, n, alpha, x, 0, incx, y, 0, incy, A, 0, lda, batch_count); + rocblas_ger_template( + handle, m, n, alpha, 0, x, 0, incx, 0, y, 0, incy, 0, A, 0, lda, 0, batch_count); return rocblas_status_success; } diff --git a/library/src/blas2/rocblas_ger_batched.hpp b/library/src/blas2/rocblas_ger_batched.hpp deleted file mode 100644 index 172238d7f..000000000 --- a/library/src/blas2/rocblas_ger_batched.hpp +++ /dev/null @@ -1,108 +0,0 @@ -/* ************************************************************************ - * Copyright 2016-2019 Advanced Micro Devices, Inc. - * ************************************************************************ */ -#include "handle.h" -#include "logging.h" -#include "rocblas.h" -#include "utility.h" - -template -__global__ void ger_batched_kernel(rocblas_int m, - rocblas_int n, - U alpha_device_host, - const T* const __restrict__ xa[], - rocblas_int shiftx, - rocblas_int incx, - const T* const __restrict__ ya[], - rocblas_int shifty, - rocblas_int incy, - T* const Aa[], - rocblas_int shiftA, - rocblas_int lda) -{ - ptrdiff_t tx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - ptrdiff_t ty = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y; - - if(tx < m && ty < n) - { - auto alpha = load_scalar(alpha_device_host); - T* A; - const T* __restrict__ x; - const T* __restrict__ y; - A = Aa[hipBlockIdx_z] + shiftA; - x = xa[hipBlockIdx_z] + shiftx; - y = ya[hipBlockIdx_z] + shifty; - - if(incx < 0) - x -= ssize_t(incx) * (m - 1); - if(incy < 0) - y -= ssize_t(incy) * (n - 1); - - A[tx + lda * ty] += alpha * x[tx * incx] * y[ty * incy]; - } -} - -template -rocblas_status rocblas_ger_batched_template(rocblas_handle handle, - rocblas_int m, - rocblas_int n, - const T* alpha, - const T* const x[], - rocblas_int shiftx, - rocblas_int incx, - const T* const y[], - rocblas_int shifty, - rocblas_int incy, - T* const A[], - rocblas_int shiftA, - rocblas_int lda, - rocblas_int batch_count) -{ - hipStream_t rocblas_stream = handle->rocblas_stream; - - static constexpr int GEMV_DIM_X = 128; - static constexpr int GEMV_DIM_Y = 8; - rocblas_int blocksX = (m - 1) / GEMV_DIM_X + 1; - rocblas_int blocksY = (n - 1) / GEMV_DIM_Y + 1; - - dim3 ger_batched_grid(blocksX, blocksY, batch_count); - dim3 ger_batched_threads(GEMV_DIM_X, GEMV_DIM_Y); - - if(handle->pointer_mode == rocblas_pointer_mode_device) - hipLaunchKernelGGL(ger_batched_kernel, - ger_batched_grid, - ger_batched_threads, - 0, - rocblas_stream, - m, - n, - alpha, - x, - shiftx, - incx, - y, - shifty, - incy, - A, - shiftA, - lda); - else - hipLaunchKernelGGL(ger_batched_kernel, - ger_batched_grid, - ger_batched_threads, - 0, - rocblas_stream, - m, - n, - *alpha, - x, - shiftx, - incx, - y, - shifty, - incy, - A, - shiftA, - lda); - return rocblas_status_success; -} diff --git a/library/src/blas2/rocblas_ger_strided_batched.cpp b/library/src/blas2/rocblas_ger_strided_batched.cpp index 9343b24f2..0aa59c71d 100644 --- a/library/src/blas2/rocblas_ger_strided_batched.cpp +++ b/library/src/blas2/rocblas_ger_strided_batched.cpp @@ -1,10 +1,10 @@ /* ************************************************************************ * Copyright 2016-2019 Advanced Micro Devices, Inc. * ************************************************************************ */ -#include "rocblas_ger_strided_batched.hpp" #include "handle.h" #include "logging.h" #include "rocblas.h" +#include "rocblas_ger.hpp" #include "utility.h" namespace @@ -129,31 +129,27 @@ namespace if(!x || !y || !A) return rocblas_status_invalid_pointer; - if(m < 0 || n < 0 || !incx || !incy || lda < m || lda < 1 || stridex < m * std::abs(incx) - || stridey < n * abs(incy) || strideA < lda * n || batch_count < 0) + if(m < 0 || n < 0 || !incx || !incy || lda < m || lda < 1 || batch_count < 0) return rocblas_status_invalid_size; - // Quick return if possible. Not Argument error - if(!m || !n || !batch_count) - return rocblas_status_success; - - rocblas_ger_strided_batched_template(handle, - m, - n, - alpha, - x, - 0, - incx, - stridex, - y, - 0, - incy, - stridey, - A, - 0, - lda, - strideA, - batch_count); + rocblas_ger_template(handle, + m, + n, + alpha, + 0, + x, + 0, + incx, + stridex, + y, + 0, + incy, + stridey, + A, + 0, + lda, + strideA, + batch_count); return rocblas_status_success; } diff --git a/library/src/blas2/rocblas_ger_strided_batched.hpp b/library/src/blas2/rocblas_ger_strided_batched.hpp deleted file mode 100644 index 17defad70..000000000 --- a/library/src/blas2/rocblas_ger_strided_batched.hpp +++ /dev/null @@ -1,121 +0,0 @@ -/* ************************************************************************ - * Copyright 2016-2019 Advanced Micro Devices, Inc. - * ************************************************************************ */ -#include "handle.h" -#include "logging.h" -#include "rocblas.h" -#include "utility.h" - -template -__global__ void ger_strided_batched_kernel(rocblas_int m, - rocblas_int n, - U alpha_device_host, - const T* const __restrict__ xa, - rocblas_int shiftx, - rocblas_int incx, - rocblas_int stridex, - const T* const __restrict__ ya, - rocblas_int shifty, - rocblas_int incy, - rocblas_int stridey, - T* const Aa, - rocblas_int shiftA, - rocblas_int lda, - rocblas_int strideA) -{ - - ptrdiff_t tx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - ptrdiff_t ty = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y; - - if(tx < m && ty < n) - { - auto alpha = load_scalar(alpha_device_host); - T* A; - const T* __restrict__ x; - const T* __restrict__ y; - A = Aa + hipBlockIdx_z * strideA + shiftA; - x = xa + hipBlockIdx_z * stridex + shiftx; - y = ya + hipBlockIdx_z * stridey + shifty; - - A[tx + lda * ty] += alpha * x[tx * incx] * y[ty * incy]; - } -} - -template -rocblas_status rocblas_ger_strided_batched_template(rocblas_handle handle, - rocblas_int m, - rocblas_int n, - const T* alpha, - const T* x, - rocblas_int shiftx, - rocblas_int incx, - rocblas_int stridex, - const T* y, - rocblas_int shifty, - rocblas_int incy, - rocblas_int stridey, - T* A, - rocblas_int shiftA, - rocblas_int lda, - rocblas_int strideA, - rocblas_int batch_count) -{ - hipStream_t rocblas_stream = handle->rocblas_stream; - - static constexpr int GEMV_DIM_X = 128; - static constexpr int GEMV_DIM_Y = 8; - rocblas_int blocksX = (m - 1) / GEMV_DIM_X + 1; - rocblas_int blocksY = (n - 1) / GEMV_DIM_Y + 1; - - dim3 ger_strided_batched_grid(blocksX, blocksY, batch_count); - dim3 ger_strided_batched_threads(GEMV_DIM_X, GEMV_DIM_Y); - - if(incx < 0) - x -= ptrdiff_t(incx) * (m - 1); - if(incy < 0) - y -= ptrdiff_t(incy) * (n - 1); - - if(handle->pointer_mode == rocblas_pointer_mode_device) - hipLaunchKernelGGL(ger_strided_batched_kernel, - ger_strided_batched_grid, - ger_strided_batched_threads, - 0, - rocblas_stream, - m, - n, - alpha, - x, - shiftx, - incx, - stridex, - y, - shifty, - incy, - stridey, - A, - shiftA, - lda, - strideA); - else - hipLaunchKernelGGL(ger_strided_batched_kernel, - ger_strided_batched_grid, - ger_strided_batched_threads, - 0, - rocblas_stream, - m, - n, - *alpha, - x, - shiftx, - incx, - stridex, - y, - shifty, - incy, - stridey, - A, - shiftA, - lda, - strideA); - return rocblas_status_success; -} diff --git a/library/src/blas2/rocblas_trsv.cpp b/library/src/blas2/rocblas_trsv.cpp index 1474a1984..19e4ecf0d 100644 --- a/library/src/blas2/rocblas_trsv.cpp +++ b/library/src/blas2/rocblas_trsv.cpp @@ -119,54 +119,92 @@ namespace { // left, lower no-transpose jb = min(BLOCK, m); - rocblas_gemv_template( - handle, transA, jb, jb, &one, invA, BLOCK, B, incx, &zero, X, 1); + rocblas_gemv_template(handle, + transA, + jb, + jb, + &one, + invA, + 0, + BLOCK, + 0, + B, + 0, + incx, + 0, + &zero, + X, + 0, + 1, + 0, + 1); if(BLOCK < m) { - rocblas_gemv_template(handle, - transA, - m - BLOCK, - BLOCK, - &negative_one, - A + BLOCK, - lda, - X, - 1, - &one, - B + BLOCK * incx, - incx); + rocblas_gemv_template(handle, + transA, + m - BLOCK, + BLOCK, + &negative_one, + A + BLOCK, + 0, + lda, + 0, + X, + 0, + 1, + 0, + &one, + B + BLOCK * incx, + 0, + incx, + 0, + 1); // remaining blocks for(i = BLOCK; i < m; i += BLOCK) { jb = min(m - i, BLOCK); - rocblas_gemv_template(handle, - transA, - jb, - jb, - &one, - invA + i * BLOCK, - BLOCK, - B + i * incx, - incx, - &zero, - X + i, - 1); + rocblas_gemv_template(handle, + transA, + jb, + jb, + &one, + invA + i * BLOCK, + 0, + BLOCK, + 0, + B + i * incx, + 0, + incx, + 0, + &zero, + X + i, + 0, + 1, + 0, + 1); if(i + BLOCK < m) - rocblas_gemv_template(handle, - transA, - m - i - BLOCK, - BLOCK, - &negative_one, - A + i + BLOCK + i * lda, - lda, - X + i, - 1, - &one, - B + (i + BLOCK) * incx, - incx); + rocblas_gemv_template(handle, + transA, + m - i - BLOCK, + BLOCK, + &negative_one, + A + i + BLOCK + i * lda, + 0, + lda, + 0, + X + i, + 0, + 1, + 0, + &one, + B + (i + BLOCK) * incx, + 0, + incx, + 0, + 1); } } } @@ -177,64 +215,92 @@ namespace i = m - jb; // if m=n=35=lda=ldb, BLOCK =32, then jb = 3, i = 32; {3, 35, 3, 32, 35, 35} - rocblas_gemv_template(handle, - transA, - jb, - jb, - &one, - invA + i * BLOCK, - BLOCK, - B + i * incx, - incx, - &zero, - X + i, - 1); + rocblas_gemv_template(handle, + transA, + jb, + jb, + &one, + invA + i * BLOCK, + 0, + BLOCK, + 0, + B + i * incx, + 0, + incx, + 0, + &zero, + X + i, + 0, + 1, + 0, + 1); if(i >= BLOCK) { - rocblas_gemv_template(handle, - transA, - i, - jb, - &negative_one, - A + i * lda, - lda, - X + i, - 1, - &one, - B, - incx); + rocblas_gemv_template(handle, + transA, + i, + jb, + &negative_one, + A + i * lda, + 0, + lda, + 0, + X + i, + 0, + 1, + 0, + &one, + B, + 0, + incx, + 0, + 1); // remaining blocks for(i = m - jb - BLOCK; i >= 0; i -= BLOCK) { //{32, 35, 32, 32, 35, 35} - rocblas_gemv_template(handle, - transA, - BLOCK, - BLOCK, - &one, - invA + i * BLOCK, - BLOCK, - B + i * incx, - incx, - &zero, - X + i, - 1); + rocblas_gemv_template(handle, + transA, + BLOCK, + BLOCK, + &one, + invA + i * BLOCK, + 0, + BLOCK, + 0, + B + i * incx, + 0, + incx, + 0, + &zero, + X + i, + 0, + 1, + 0, + 1); if(i >= BLOCK) - rocblas_gemv_template(handle, - transA, - i, - BLOCK, - &negative_one, - A + i * lda, - lda, - X + i, - 1, - &one, - B, - incx); + rocblas_gemv_template(handle, + transA, + i, + BLOCK, + &negative_one, + A + i * lda, + 0, + lda, + 0, + X + i, + 0, + 1, + 0, + &one, + B, + 0, + incx, + 0, + 1); } } } @@ -247,63 +313,91 @@ namespace jb = m % BLOCK == 0 ? BLOCK : m % BLOCK; i = m - jb; - rocblas_gemv_template(handle, - transA, - jb, - jb, - &one, - invA + i * BLOCK, - BLOCK, - B + i * incx, - incx, - &zero, - X + i, - 1); + rocblas_gemv_template(handle, + transA, + jb, + jb, + &one, + invA + i * BLOCK, + 0, + BLOCK, + 0, + B + i * incx, + 0, + incx, + 0, + &zero, + X + i, + 0, + 1, + 0, + 1); if(i - BLOCK >= 0) { - rocblas_gemv_template(handle, - transA, - jb, - i, - &negative_one, - A + i, - lda, - X + i, - 1, - &one, - B, - incx); + rocblas_gemv_template(handle, + transA, + jb, + i, + &negative_one, + A + i, + 0, + lda, + 0, + X + i, + 0, + 1, + 0, + &one, + B, + 0, + incx, + 0, + 1); // remaining blocks for(i = m - jb - BLOCK; i >= 0; i -= BLOCK) { - rocblas_gemv_template(handle, - transA, - BLOCK, - BLOCK, - &one, - invA + i * BLOCK, - BLOCK, - B + i * incx, - incx, - &zero, - X + i, - 1); + rocblas_gemv_template(handle, + transA, + BLOCK, + BLOCK, + &one, + invA + i * BLOCK, + 0, + BLOCK, + 0, + B + i * incx, + 0, + incx, + 0, + &zero, + X + i, + 0, + 1, + 0, + 1); if(i >= BLOCK) - rocblas_gemv_template(handle, - transA, - BLOCK, - i, - &negative_one, - A + i, - lda, - X + i, - 1, - &one, - B, - incx); + rocblas_gemv_template(handle, + transA, + BLOCK, + i, + &negative_one, + A + i, + 0, + lda, + 0, + X + i, + 0, + 1, + 0, + &one, + B, + 0, + incx, + 0, + 1); } } } @@ -311,54 +405,92 @@ namespace { // left, upper transpose jb = min(BLOCK, m); - rocblas_gemv_template( - handle, transA, jb, jb, &one, invA, BLOCK, B, incx, &zero, X, 1); + rocblas_gemv_template(handle, + transA, + jb, + jb, + &one, + invA, + 0, + BLOCK, + 0, + B, + 0, + incx, + 0, + &zero, + X, + 0, + 1, + 0, + 1); if(BLOCK < m) { - rocblas_gemv_template(handle, - transA, - BLOCK, - m - BLOCK, - &negative_one, - A + BLOCK * lda, - lda, - X, - 1, - &one, - B + BLOCK * incx, - incx); + rocblas_gemv_template(handle, + transA, + BLOCK, + m - BLOCK, + &negative_one, + A + BLOCK * lda, + 0, + lda, + 0, + X, + 0, + 1, + 0, + &one, + B + BLOCK * incx, + 0, + incx, + 0, + 1); // remaining blocks for(i = BLOCK; i < m; i += BLOCK) { jb = min(m - i, BLOCK); - rocblas_gemv_template(handle, - transA, - jb, - jb, - &one, - invA + i * BLOCK, - BLOCK, - B + i * incx, - incx, - &zero, - X + i, - 1); + rocblas_gemv_template(handle, + transA, + jb, + jb, + &one, + invA + i * BLOCK, + 0, + BLOCK, + 0, + B + i * incx, + 0, + incx, + 0, + &zero, + X + i, + 0, + 1, + 0, + 1); if(i + BLOCK < m) - rocblas_gemv_template(handle, - transA, - BLOCK, - m - i - BLOCK, - &negative_one, - A + i + (i + BLOCK) * lda, - lda, - X + i, - 1, - &one, - B + (i + BLOCK) * incx, - incx); + rocblas_gemv_template(handle, + transA, + BLOCK, + m - i - BLOCK, + &negative_one, + A + i + (i + BLOCK) * lda, + 0, + lda, + 0, + X + i, + 0, + 1, + 0, + &one, + B + (i + BLOCK) * incx, + 0, + incx, + 0, + 1); } } } @@ -409,32 +541,46 @@ namespace A_current = parity ? A + BLOCK * ((lda + 1) * q - lda) : A + M * lda; } - rocblas_gemv_template(handle, - transA, - M, - N, - &negative_one, - A_current, - lda, - B_current, - incx, - &one, - x_temp, - 1); + rocblas_gemv_template(handle, + transA, + M, + N, + &negative_one, + A_current, + 0, + lda, + 0, + B_current, + 0, + incx, + 0, + &one, + x_temp, + 0, + 1, + 0, + 1); } - rocblas_gemv_template(handle, - transA, - BLOCK, - BLOCK, - &one, - invA + j * BLOCK * BLOCK, - BLOCK, - x_temp, - 1, - &zero, - B + j * BLOCK * incx, - incx); + rocblas_gemv_template(handle, + transA, + BLOCK, + BLOCK, + &one, + invA + j * BLOCK * BLOCK, + 0, + BLOCK, + 0, + x_temp, + 0, + 1, + 0, + &zero, + B + j * BLOCK * incx, + 0, + incx, + 0, + 1); } return rocblas_status_success;