diff --git a/projects/rocsolver/clients/common/lapack/testing_syevd_heevd.hpp b/projects/rocsolver/clients/common/lapack/testing_syevd_heevd.hpp index 27496a2edc2..19d767dd47c 100644 --- a/projects/rocsolver/clients/common/lapack/testing_syevd_heevd.hpp +++ b/projects/rocsolver/clients/common/lapack/testing_syevd_heevd.hpp @@ -256,6 +256,82 @@ void syevd_heevd_eig7_initData(const rocblas_handle handle, CHECK_HIP_ERROR(dA.transfer_from(hA)); } } +template +void syevd_heevd_eig7_initData_gpu(const rocblas_handle handle, + const rocblas_evect evect, + const rocblas_int n, + device_strided_batch_vector& dA, + const rocblas_int lda, + const rocblas_int bc, + host_strided_batch_vector& hA, + std::vector& A, + bool test = true) +{ + if(bc > 1) + { + syevd_heevd_eig7_initData(handle, evect, n, dA, lda, bc, hA, A, test); + } + + else + { + using S = decltype(std::real(T{})); + + if(CPU) + { + // generates spectrum + S eps = std::numeric_limits::epsilon(); + for(auto i = 0; i < n - 1; ++i) + hA[0][i + i * lda] = (i + 1) * eps; + hA[0][n - 1 + (n - 1) * lda] = 1; + CHECK_HIP_ERROR(dA.transfer_from(hA)); + + // generates orthogonal matrix + rocblas_int n2 = n * n; + host_strided_batch_vector hQ(n2, 1, n2, 1); + device_strided_batch_vector dQ(n2, 1, n2, 1); + device_strided_batch_vector dipiv(n, 1, n, 1); + rocblas_init(hQ, true); + CHECK_HIP_ERROR(dQ.transfer_from(hQ)); + rocsolver_geqr2_geqrf(false, true, handle, n, n, dQ.data(), n, n2, dipiv.data(), n, 1); + + // generates matrix with given spectrum + rocsolver_ormxr_unmxr(true, handle, rocblas_side_left, rocblas_operation_transpose, n, + n, n, dQ.data(), n, dipiv.data(), dA.data(), lda); + rocsolver_ormxr_unmxr(true, handle, rocblas_side_right, rocblas_operation_none, n, n, n, + dQ.data(), n, dipiv.data(), dA.data(), lda); + CHECK_HIP_ERROR(hA.transfer_from(dA)); + + // make copy of original data to test vectors if required + if(test && evect == rocblas_evect_original) + { + for(rocblas_int i = 0; i < n; i++) + { + for(rocblas_int j = 0; j < n; j++) + A[i + j * lda] = hA[0][i + j * lda]; + } + } + } + + if(GPU) + { + // now copy to the GPU + CHECK_HIP_ERROR(dA.transfer_from(hA)); + } + } +} +template +void syevd_heevd_eig7_initData_gpu(const rocblas_handle handle, + const rocblas_evect evect, + const rocblas_int n, + device_batch_vector& dA, + const rocblas_int lda, + const rocblas_int bc, + host_batch_vector& hA, + std::vector& A, + bool test = true) +{ + syevd_heevd_eig7_initData(handle, evect, n, dA, lda, bc, hA, A, test); +} // Creates an `n` by `n` tridiagonal, Wilkinson matrix, which is formed as follows: // @@ -482,9 +558,13 @@ void syevd_heevd_initData(const rocblas_handle handle, std::vector& A, bool test = true) { +#define USE_GPU true if((std::getenv("TEST_EIG7") != nullptr) || (std::getenv("SYEVD_TEST_EIG7") != nullptr)) { - syevd_heevd_eig7_initData(handle, evect, n, dA, lda, bc, hA, A, test); + if(USE_GPU) + syevd_heevd_eig7_initData_gpu(handle, evect, n, dA, lda, bc, hA, A, test); + else + syevd_heevd_eig7_initData(handle, evect, n, dA, lda, bc, hA, A, test); } else if((std::getenv("TEST_WILKINSON") != nullptr) || (std::getenv("SYEVD_TEST_WILKINSON") != nullptr)) diff --git a/projects/rocsolver/library/src/auxiliary/rocauxiliary_stedc.cpp b/projects/rocsolver/library/src/auxiliary/rocauxiliary_stedc.cpp index 64413fcb75d..6074f4bb634 100644 --- a/projects/rocsolver/library/src/auxiliary/rocauxiliary_stedc.cpp +++ b/projects/rocsolver/library/src/auxiliary/rocauxiliary_stedc.cpp @@ -61,43 +61,41 @@ rocblas_status rocsolver_stedc_impl(rocblas_handle handle, rocblas_int batch_count = 1; // memory workspace sizes: - // size for lasrt stack/stedc workspace - size_t size_work_stack; // size for temporary computations - size_t size_tempvect, size_tempgemm; + size_t size_tempvect, size_workSvec, size_workStmp; // size for pointers to workspace (batched case) size_t size_workArr; - // size for vector with positions of split blocks - size_t size_splits_map; + // size for vector with positions of split blocks and different indices + size_t size_workInt; // size for temporary diagonal and z vectors. - size_t size_tmpz; - rocsolver_stedc_getMemorySize(evect, n, batch_count, &size_work_stack, - &size_tempvect, &size_tempgemm, &size_tmpz, - &size_splits_map, &size_workArr); + size_t size_workSz; + rocsolver_stedc_getMemorySize(evect, n, batch_count, &size_tempvect, + &size_workSvec, &size_workStmp, &size_workSz, + &size_workInt, &size_workArr); if(rocblas_is_device_memory_size_query(handle)) - return rocblas_set_optimal_device_memory_size(handle, size_work_stack, size_tempvect, - size_tempgemm, size_tmpz, size_splits_map, + return rocblas_set_optimal_device_memory_size(handle, size_tempvect, size_workSvec, + size_workStmp, size_workSz, size_workInt, size_workArr); // memory workspace allocation - void *work_stack, *tempvect, *tempgemm, *tmpz, *splits_map, *workArr; - rocblas_device_malloc mem(handle, size_work_stack, size_tempvect, size_tempgemm, size_tmpz, - size_splits_map, size_workArr); + void *tempvect, *workSvec, *workStmp, *workSz, *workInt, *workArr; + rocblas_device_malloc mem(handle, size_tempvect, size_workSvec, size_workStmp, size_workSz, + size_workInt, size_workArr); if(!mem) return rocblas_status_memory_error; - work_stack = mem[0]; - tempvect = mem[1]; - tempgemm = mem[2]; - tmpz = mem[3]; - splits_map = mem[4]; + tempvect = mem[0]; + workSvec = mem[1]; + workStmp = mem[2]; + workSz = mem[3]; + workInt = mem[4]; workArr = mem[5]; // execution return rocsolver_stedc_template( handle, evect, n, D, shiftD, strideD, E, shiftE, strideE, C, shiftC, ldc, strideC, info, - batch_count, work_stack, (S*)tempvect, (S*)tempgemm, (S*)tmpz, (rocblas_int*)splits_map, + batch_count, (S*)tempvect, workSvec, (S*)workStmp, (S*)workSz, (rocblas_int*)workInt, (S**)workArr); } diff --git a/projects/rocsolver/library/src/auxiliary/rocauxiliary_stedc.hpp b/projects/rocsolver/library/src/auxiliary/rocauxiliary_stedc.hpp index 2a4c30045d4..0c423259611 100644 --- a/projects/rocsolver/library/src/auxiliary/rocauxiliary_stedc.hpp +++ b/projects/rocsolver/library/src/auxiliary/rocauxiliary_stedc.hpp @@ -39,213 +39,27 @@ #include "rocsolver/rocsolver.h" #include -#include -#include ROCSOLVER_BEGIN_NAMESPACE #define STEDC_BDIM 512 // Number of threads per thread-block used in main stedc kernels -#define STEDC_SOLVE_BDIM 4 // Number of threads per thread-block used in solver kernel +#define STEDC_BDIM_VALUES 4 // Number of therads per thread-block used in mergeValues kernel +#define STEDC_BDIM_SOLVE 64 // Number of threads per thread-block used in the QR eigensolver -// bit indicating base deflation candidate -#define L_F_BCAND_BIT 0 -// bit indicating top deflation candidate -#define L_F_TCAND_BIT 1 - -// TODO: using macro STEDC_EXTERNAL_GEMM = true for now. In the future we can pass -// STEDC_EXTERNAL_GEMM at run time to switch between internal vector updates and -// external gemm-based updates. -#define STEDC_EXTERNAL_GEMM true - -__host__ __device__ inline rocblas_int get_splits_size(const rocblas_int n) -{ - // splits_map layout: - // n - number of eigenvalues (matrix size) - // m - number of merges on a current level - // struct { - // 0 rocblas_int msz[m], _[n-m]; // size of each merge - // 1 rocblas_int mps[m], _[n-m]; // starting position of each merge - // 2 rocblas_int bsz[2*m], _[n-2*m]; // size of each block - // 3 rocblas_int bps[2*m], _[n-2*m]; // starting position of each block - // 4 rocblas_int em[n]; // id of a corresponding merge (per each eigenvalue) - // 5 rocblas_int nsz[n]; // size of a corresponding merge (per each eigenvalue) - // 6 rocblas_int nps[n]; // starting position of a corresponding merge (per each eigenvalue) - // 7 rocblas_int ndd[n]; // degrees of secular equation (per each eigenvalue) - // 8 rocblas_int mask[n]; // if mask[i] = 0, the value in position i has been deflated - // 9 rocblas_int dcount[n]; // number of deflations - // 10 rocblas_int map[n]; // original indices of a sorted values - // 11 rocblas_int cand[n]; // deflation candidate flags - // 12 rocblas_int dbg[n]; // - // }; - return 13 * n; -} - -template -__host__ __device__ inline S* ptr_msz(rocblas_int n, S* splits) -{ - return splits + 0 * n; -} -template -__host__ __device__ inline S* ptr_mps(rocblas_int n, S* splits) -{ - return splits + 1 * n; -} -template -__host__ __device__ inline S* ptr_bsz(rocblas_int n, S* splits) -{ - return splits + 2 * n; -} -template -__host__ __device__ inline S* ptr_bps(rocblas_int n, S* splits) -{ - return splits + 3 * n; -} -template -__host__ __device__ inline S* ptr_em(rocblas_int n, S* splits) -{ - return splits + 4 * n; -} -template -__host__ __device__ inline S* ptr_nsz(rocblas_int n, S* splits) -{ - return splits + 5 * n; -} -template -__host__ __device__ inline S* ptr_nps(rocblas_int n, S* splits) -{ - return splits + 6 * n; -} -template -__host__ __device__ inline S* ptr_ndd(rocblas_int n, S* splits) -{ - return splits + 7 * n; -} -template -__host__ __device__ inline S* ptr_mask(rocblas_int n, S* splits) -{ - return splits + 8 * n; -} -template -__host__ __device__ inline S* ptr_dcount(rocblas_int n, S* splits) -{ - return splits + 9 * n; -} -template -__host__ __device__ inline S* ptr_map(rocblas_int n, S* splits) -{ - return splits + 10 * n; -} -template -__host__ __device__ inline S* ptr_cand(rocblas_int n, S* splits) -{ - return splits + 11 * n; -} -template -__host__ __device__ inline S* ptr_dbg(rocblas_int n, S* splits) -{ - return splits + 12 * n; -} - -__host__ __device__ inline rocblas_int get_tmpz_size(const rocblas_int n) -{ - // tmpz layout: - // n - number of eigenvalues (matrix size) - // m - number of merges on a current level - // struct { - // 0 S z[n]; // the rank-1 modification vectors in the merges - // 1 S evs[n]; // roots of secular equations - // 2 S cc[n]; // c value of rotation of corresponding deflation - // 3 S ss[n]; // s value of rotation of corresponding deflation - // 4 S tolsD[n]; // tollerance for deflation of repeaded values in D - // 5 S tolsZ[n]; // tollerance for deflation of zero values in z - // 6 S md[n]; // sorted d values - // 7 S cd[n]; // sorted and compacted d values - // 8 S cz[n]; // sorted and compacted z values - // 9 S r1p[n]; // p component of the rank-1 modification - // }; - return 10 * n; -} - -template -__host__ __device__ inline S* ptr_z(rocblas_int n, S* tmpz) -{ - return tmpz + 0 * n; -} -template -__host__ __device__ inline S* ptr_evs(rocblas_int n, S* tmpz) -{ - return tmpz + 1 * n; -} -template -__host__ __device__ inline S* ptr_cc(rocblas_int n, S* tmpz) -{ - return tmpz + 2 * n; -} -template -__host__ __device__ inline S* ptr_ss(rocblas_int n, S* tmpz) -{ - return tmpz + 3 * n; -} -template -__host__ __device__ inline S* ptr_tolsD(rocblas_int n, S* tmpz) -{ - return tmpz + 4 * n; -} -template -__host__ __device__ inline S* ptr_tolsZ(rocblas_int n, S* tmpz) -{ - return tmpz + 5 * n; -} -template -__host__ __device__ inline S* ptr_md(rocblas_int n, S* tmpz) -{ - return tmpz + 6 * n; -} -template -__host__ __device__ inline S* ptr_cd(rocblas_int n, S* tmpz) -{ - return tmpz + 7 * n; -} -template -__host__ __device__ inline S* ptr_cz(rocblas_int n, S* tmpz) -{ - return tmpz + 8 * n; -} -template -__host__ __device__ inline S* ptr_r1p(rocblas_int n, S* tmpz) -{ - return tmpz + 9 * n; -} - -__host__ __device__ inline rocblas_int get_tempgemm_size(const rocblas_int n) -{ - // tempgemm layout: - // struct { - // 0 S vecs[n*n]; // temp vectors - // 1 S etmpd[n*n]; // temp deltas used in solving secular equations, also used for temp vectors - // }; - return 2 * n * n; -} - -template -__host__ __device__ inline S* ptr_vecs(rocblas_int n, S* tempgemm) -{ - return tempgemm + 0 * n * n; -} -template -__host__ __device__ inline S* ptr_etmpd(rocblas_int n, S* tempgemm) -{ - return tempgemm + 1 * n * n; -} +// STEDC_USE_EXTERNAL_UPDATE=true forces the use of external gemms for the vector updates. +// STEDC_WITH_STRIDED_BATCHED=true forces the use of strided_batched gemms when possible. +#define STEDC_USE_EXTERNAL_UPDATE true +#define STEDC_WITH_STRIDED_BATCHED_GEMM false /*************** Main kernels *********************************************************/ /**************************************************************************************/ //--------------------------------------------------------------------------------------// /** STEDC_DIVIDE_KERNEL implements the divide phase of the DC algorithm. It - divides the input matrix into a 'blks' sub-blocks. - - This kernel is to be called with as many groups in x as needed to cover all - the batch_count problems. Each thread will work with a matrix in the batch. + divides the input matrix into 'blks' sub-blocks. + - This kernel is to be called with as many sroups in x as needed to cover all + the batch_count problems. + - Each thread will work with a matrix in the batch. - Size of groups is set to STEDC_BDIM. **/ template ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) @@ -257,7 +71,7 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) S* EE, const rocblas_stride strideE, const rocblas_int batch_count, - rocblas_int* splitsA) + rocblas_int* workInt) { // threads and groups indices rocblas_int bid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; @@ -270,29 +84,51 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) S* E = EE + bid * strideE; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* msz = ptr_msz(n, splits); - rocblas_int* mps = ptr_mps(n, splits); + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* ns = ps + blks; // find sizes of sub-blocks - msz[0] = n; - for(int i = 0; i < levs; ++i) + if(STEDC_USE_EXTERNAL_UPDATE && STEDC_WITH_STRIDED_BATCHED_GEMM && batch_count == 1) { - for(int j = (1 << i); j > 0; --j) + // division schema when using strided_batched_gemms for updates + rocblas_int sz = n / blks; + rocblas_int res = n - sz * blks; + if(res < blks / 2) + { + res = blks - res; + for(auto i = 0; i < blks; ++i) + ns[i] = i < res ? sz : sz + 1; + } + else { - rocblas_int t = msz[j - 1]; - msz[j * 2 - 1] = t / 2 + (t & 1); - msz[j * 2 - 2] = t / 2; + for(auto i = 0; i < blks; ++i) + ns[i] = i < res ? sz + 1 : sz; + } + } + else + { + // normal division schema + ns[0] = n; + rocblas_int t, t2; + for(auto i = 0; i < levs; ++i) + { + for(auto j = (1 << i); j > 0; --j) + { + t = ns[j - 1]; + t2 = t / 2; + ns[j * 2 - 1] = (2 * t2 < t) ? t2 + 1 : t2; + ns[j * 2 - 2] = t2; + } } } // find beginning of sub-blocks and update elements in D rocblas_int p2 = 0; - mps[0] = p2; - for(int i = 1; i < blks; ++i) + ps[0] = p2; + for(auto i = 1; i < blks; ++i) { - p2 += msz[i - 1]; - mps[i] = p2; + p2 += ns[i - 1]; + ps[i] = p2; // perform sub-block division S p = E[p2 - 1]; @@ -306,31 +142,31 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) /** STEDC_SOLVE_KERNEL implements the solver phase of the DC algorithm to compute the eigenvalues/eigenvectors of the 'blks' different sub-blocks of a matrix. - Call this kernel with batch_count groups in y, and 'blks' groups in x. - Groups are single-thread **/ + - Each group will solve a sub-block. + - Groups contain a single wavefront **/ template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) stedc_solve_kernel(const rocblas_int levs, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* EE, - const rocblas_stride strideE, - S* CC, - const rocblas_int shiftC, - const rocblas_int ldc, - const rocblas_stride strideC, - rocblas_int* iinfo, - S* WA, - rocblas_int* splitsA, - const S eps, - const S ssfmin, - const S ssfmax) +ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM_SOLVE) + stedc_solve_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int n, + S* DD, + const rocblas_stride strideD, + S* EE, + const rocblas_stride strideE, + S* CC, + const rocblas_int shiftC, + const rocblas_int ldc, + const rocblas_stride strideC, + rocblas_int* iinfo, + S* WA, + rocblas_int* workInt, + const S eps, + const S ssfmin, + const S ssfmax) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; - // sub-block id rocblas_int sid = hipBlockIdx_x; - // thread index rocblas_int tidb = hipThreadIdx_x; rocblas_int tidb_inc = hipBlockDim_x; @@ -341,562 +177,594 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) stedc_solve_kernel(const roc rocblas_int* info = iinfo + bid; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* msz = ptr_msz(n, splits); - rocblas_int* mps = ptr_mps(n, splits); - // workspace for solvers - S* W = WA + bid * 2 * n; + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + S* W = WA + bid * (2 * n); // Solve the blks sub-blocks in parallel (using classic QR iteration). + if(sid < blks) { - rocblas_int sz = msz[sid]; // size of sub-block - rocblas_int p2 = mps[sid]; // start position of sub-block + rocblas_int tmp = sid + 1; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + rocblas_int pin = ps[sid]; // start position of sub-block + rocblas_int sbs = pout - pin; // size of sub-block - run_steqr(tidb, tidb_inc, sz, D + p2, E + p2, C + p2 + p2 * ldc, ldc, info, W + p2 * 2, - 30 * sz, eps, ssfmin, ssfmax, false); + run_steqr(tidb, tidb_inc, sbs, D + pin, E + pin, C + pin + pin * ldc, ldc, info, + W + pin * 2, 30 * sbs, eps, ssfmin, ssfmax, true); } } //--------------------------------------------------------------------------------------// -/** STEDC_UPDATE_SPLITS updates merge block related parts of a splits struct for each - level of merge. - - This kernel is to be called with 1 group in x and batch_count groups in y. - - Size of groups is set to STEDC_BDIM. **/ -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) stedc_update_splits(const rocblas_int levs, - const rocblas_int k, - const rocblas_int n, - rocblas_int* splitsA) +/** STEDC_MERGESORT_KERNEL combines the two sorted arrays containing the eigenvalues of + every pair of sub-blocks that need to be merged, and gets its corresponding vector z. + - Call this kernel with batch_count groups in y, and as many groups in x as needed + to cover the n values of the matrix. + - Each thread will deal with one value. + - Size of groups is set to STEDC_BDIM.**/ +template +ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) + stedc_mergeSort_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* DD, + const rocblas_stride strideD, + S* CC, + const rocblas_int shiftC, + const rocblas_int ldc, + const rocblas_stride strideC, + S* workSvec, + rocblas_int* workInt) { + // threads and groups indices rocblas_int bid = hipBlockIdx_y; - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* em = ptr_em(n, splits); - rocblas_int* msz = ptr_msz(n, splits); - rocblas_int* mps = ptr_mps(n, splits); - rocblas_int* nsz = ptr_nsz(n, splits); - rocblas_int* nps = ptr_nps(n, splits); - rocblas_int* bsz = ptr_bsz(n, splits); - rocblas_int* bps = ptr_bps(n, splits); - rocblas_int* map = ptr_map(n, splits); - rocblas_int* dcount = ptr_dcount(n, splits); - - rocblas_int n_blocks = 1 << levs; - rocblas_int n_merges = 1 << (levs - k - 1); - - // init em array - if(k == 0) - { - for(int i = hipThreadIdx_x; i < n_blocks; i += hipBlockDim_x) - { - rocblas_int sz = msz[i]; - rocblas_int p1 = mps[i]; - for(int j = 0; j < sz; j++) - { - em[p1 + j] = i; - } - } - } - - // previous merges becomes blocks on a current level - for(int i = hipThreadIdx_x; i < n_merges * 2; i += hipBlockDim_x) - { - bsz[i] = msz[i]; - bps[i] = mps[i]; - } - __syncthreads(); + rocblas_int gid = hipBlockIdx_x; + rocblas_int nofg = hipGridDim_x; + rocblas_int dim = hipBlockDim_x; + rocblas_int totdim = nofg * dim; + rocblas_int tid = gid * dim + hipThreadIdx_x; - // update sizes and initial positions - for(int i = hipThreadIdx_x; i < n_merges; i += hipBlockDim_x) - { - rocblas_int sz1 = bsz[i * 2 + 0]; - rocblas_int sz2 = bsz[i * 2 + 1]; - rocblas_int p1 = bps[i * 2]; - msz[i] = sz1 + sz2; - mps[i] = p1; - } - __syncthreads(); - for(int i = hipThreadIdx_x; i < n; i += hipBlockDim_x) - { - rocblas_int m = em[i] / 2; - nsz[i] = msz[m]; - nps[i] = mps[m]; - map[i] = 0; - dcount[i] = 0; - } - __syncthreads(); + // select batch instance to work with + S* C = load_ptr_batch(CC, bid, shiftC, strideC); + S* D = DD + bid * strideD; - for(int i = hipThreadIdx_x; i < n; i += hipBlockDim_x) + // temporary arrays in global memory + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* idd1 = ps + 2 * blks; + rocblas_int* bp = idd1 + 4 * n; + S* z1 = workSvec + bid * (std::max(7, n) * n); + S* ev1 = z1 + 2 * n; + + // work with all the values (items) in parallel + for(auto tx = tid; tx < n; tx += totdim) { - em[i] /= 2; + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // item 'tx' belongs to sub-block 'bx' and thus participates + // in the merge to create the new sub-block 'nbx' + rocblas_int bx = bisearch(tx, ps, blks, false, false) - 1; + rocblas_int nbx = bx / dm2; + bp[tx] = bx; + + // the new sub-block starts at 'pin', the middle point is 'pmid', and + // it ends at 'pout' + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + rocblas_int pmid = ps[tmp + dm]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + + // the position where the item 'tx' will end up in the ordered array is 'pos' + S val = D[tx]; + rocblas_int pos1 = tx < pmid ? bisearch(val, D + pmid, pout - pmid, true, false) + : bisearch(val, D + pin, pmid - pin, false, false); + rocblas_int pos2 = tx < pmid ? tx - pin : tx - pmid; + rocblas_int pos = pos1 + pos2; + + // get merged ordered array 'ev' and permutation map 'per' + rocblas_int* idd = idd1 + pin; + S* ev = ev1 + pin; + ev[pos] = val; + idd[pos] = tx; + + // get vector Z + const S inv_sqrt2 = 1 / std::sqrt(2); + val = tx < pmid ? C[pmid - 1 + tx * ldc] : C[pmid + tx * ldc]; + z1[tx] = val * inv_sqrt2; } } //--------------------------------------------------------------------------------------// -/** STEDC_MERGEPREPARE_DEFLATEZERO_KERNEL finds and stores tolerances and performs - deflation of zero values - - Call this kernel with batch_count groups in y, and as many groups as half of the - unmerged sub-blocks in current level in x. Each group works with a merge of a pair - of sub-blocks. Groups are size STEDC_BDIM **/ +/** STEDC_MERGESEQUENCES_KERNEL finds the sequences of repeated eigenvalues for the + relative tolerance on every pair of sub-blocks that need to be merged. + - Call this kernel with batch_count groups in y, and as many groups as pairs of + sub-blocks to be merged in x. Each group will deal with one merge. + - Size of groups is set to STEDC_BDIM **/ template ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergePrepare_DeflateZero_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* EE, - const rocblas_stride strideE, - S* CC, - const rocblas_int shiftC, - const rocblas_int ldc, - const rocblas_stride strideC, - S* tmpzA, - rocblas_int* splitsA, - const S eps) + stedc_mergeSequences_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* EE, + const rocblas_stride strideE, + S* workSvec, + rocblas_int* workInt, + const S eps) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; - // merge sub-block id - rocblas_int sid = hipBlockIdx_x; + rocblas_int nbx = hipBlockIdx_x; + rocblas_int tid = hipThreadIdx_x; + rocblas_int dim = hipBlockDim_x; // select batch instance to work with - S* C = load_ptr_batch(CC, bid, shiftC, strideC); - S* D = DD + bid * strideD; S* E = EE + bid * strideE; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* mask = ptr_mask(n, splits); - rocblas_int* bsz = ptr_bsz(n, splits); - rocblas_int* bps = ptr_bps(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* z = ptr_z(n, tmpz); - S* r1p = ptr_r1p(n, tmpz); - S* tolsD = ptr_tolsD(n, tmpz); - S* tolsZ = ptr_tolsZ(n, tmpz); - - // Work with merges on level k. A thread-group works with two leaves in the merge tree. - { - // 1. find rank-1 modification components (z and p) for this merge - // ---------------------------------------------------------------- - rocblas_int sz1 = bsz[sid * 2 + 0]; - rocblas_int sz2 = bsz[sid * 2 + 1]; - rocblas_int p1 = bps[sid * 2 + 0]; - rocblas_int p2 = bps[sid * 2 + 1]; - - // Find off-diagonal element of the merge - // rank-1 modification component p correspond to the last element in the first sub-block - S p = 2 * E[p2 - 1]; - for(int i = hipThreadIdx_x; i < sz1 + sz2; i += hipBlockDim_x) - { - r1p[p1 + i] = p; - } + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* nrs = ps + blks; + rocblas_int* idd1 = nrs + blks; + rocblas_int* idd2 = idd1 + n; + rocblas_int* bp = idd1 + 4 * n; + rocblas_int* dcount = idd2 + n; + rocblas_int* rmap = dcount + n; + S* z1 = workSvec + bid * (std::max(7, n) * n); + S* z2 = z1 + n; + S* ev1 = z2 + n; + S* ev2 = ev1 + n; + S* ev3 = ev2 + n; + S* c = ev3 + n; + S* s = c + n; - S maxd = 0; - S maxz = 0; - // copy z values from first sub-block - // copy last line of the first sub-block - for(int i = hipThreadIdx_x; i < sz1; i += hipBlockDim_x) - { - S val = C[p2 - 1 + (p1 + i) * ldc] / sqrt(2); - z[p1 + i] = val; - maxz = std::max(maxz, std::abs(val)); - } - // copy first line of the second sub-block - for(int i = hipThreadIdx_x; i < sz2; i += hipBlockDim_x) - { - S val = C[p2 - 0 + (p2 + i) * ldc] / sqrt(2); - z[p2 + i] = val; - maxz = std::max(maxz, std::abs(val)); - } + // temporary arrays in shared memory + // used to store temp values during the different reductions + extern __shared__ rocblas_int shmem[]; + rocblas_int* posi = shmem; + rocblas_int* posf = posi + (1 << (k + 1)); + S* shmaxz = reinterpret_cast(posf + (1 << (k + 1))); + S shmaxd; + + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // the new sub-block starts at 'pin', the middle point is 'pmid', and + // it ends at 'pout'. Its size is 'sz'. Element 'p' is found at middle point + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + rocblas_int pmid = ps[tmp + dm]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + S p = 2 * E[pmid - 1]; + rocblas_int sz = pout - pin; + + // find max values of evs and z in the sub-blocks + S valz, vald, maxz = 0, maxd = 0; + for(auto ii = tid; ii < sz; ii += dim) + { + rocblas_int i = ii + pin; + valz = std::abs(z1[i]); + maxz = (valz > maxz) ? valz : maxz; + } + shmaxz[tid] = maxz; + if(tid == 0) + { + maxd = abs(ev1[pin]); + vald = abs(ev1[pout - 1]); + maxd = (vald > maxd) ? vald : maxd; + } + __syncthreads(); - // 2. calculate deflation tolerance - // ---------------------------------------------------------------- - // compute maximum of diagonal and z in each merge block - rocblas_int sz = sz1 + sz2; - for(int i = hipThreadIdx_x; i < sz; i += hipBlockDim_x) + // reduction + for(auto r = dim / 2; r > 0; r /= 2) + { + if(tid < r) { - maxd = std::max(maxd, abs(D[p1 + i])); + valz = shmaxz[tid + r]; + maxz = (valz > maxz) ? valz : maxz; + shmaxz[tid] = maxz; } - - // temporary arrays in shared memory - // used to store temp values during reduction - __shared__ S lmaxz[STEDC_BDIM]; - __shared__ S lmaxd[STEDC_BDIM]; - lmaxd[hipThreadIdx_x] = maxd; - lmaxz[hipThreadIdx_x] = maxz; __syncthreads(); + } + if(tid == 0) + shmaxz[0] = (maxz > maxd) ? maxz : maxd; + __syncthreads(); + maxd = shmaxz[0]; - rocblas_int dim2 = hipBlockDim_x / 2; - while(dim2 > 0) + // tol should be 8 * eps * (max diagonal or z element participating in the merge) + S tol = 8 * eps * maxd; + + // Mark deflated values in each sub-block + for(auto tx = tid; tx < dm2; tx += dim) + { + rocblas_int miposi = -1; + rocblas_int miposf = -1; + + // each sub-block 'bx' starts and ends at 'in' and 'out', respectively + rocblas_int bx = nbx * dm2 + tx; + rocblas_int in = ps[bx]; + tmp = bx + 1; + rocblas_int out = tmp < blks ? ps[tmp] : n; + + // find sequences of repeated values + rocblas_int i = in; + while(i < out) { - if(hipThreadIdx_x < dim2) + rocblas_int count = 0; + rocblas_int map = idd1[i]; + vald = ev1[i]; + valz = z1[map]; + + if(std::abs(p * valz) <= tol) { - S vald = lmaxd[hipThreadIdx_x + dim2]; - S valz = lmaxz[hipThreadIdx_x + dim2]; - maxd = std::max(maxd, vald); - maxz = std::max(maxz, valz); - lmaxd[hipThreadIdx_x] = maxd; - lmaxz[hipThreadIdx_x] = maxz; + // if element in z is zero, i cannot be the base of a new sequence + dcount[i] = 0; + i++; + } + else + { + // otherwise, take base and search for sequence + miposf = i; + miposi = (miposi == -1) ? i : miposi; + rocblas_int oldi = i; + count = 1; + i++; + while(i < out) + { + S valdt = ev1[i]; + if(abs(vald - valdt) <= tol) + { + // value repeated for given tolerance. It is part of the sequence + count++; + dcount[i] = 0; + i++; + } + else + break; + } + dcount[oldi] = count; } - dim2 /= 2; - __syncthreads(); } - // tol should be 8 * eps * (max diagonal or z element participating in merge) - maxd = lmaxd[0]; - maxz = lmaxz[0]; + // posi and posf contains the base of the first and last sequences of values + // in each sub-block + posi[tx] = miposi; + posf[tx] = miposf; + } + __syncthreads(); - S tolD = 8 * eps * std::max(maxd, maxz); - S tolZ = 8 * eps * std::max(maxd, maxz); - // store tolerances in global memory - for(int i = hipThreadIdx_x; i < sz; i += hipBlockDim_x) + // now reduce the results of all the sub-blocks + // merging the different sequences when required. + dm = 1; + for(auto kk = 0; kk <= k; ++kk) + { + dm *= 2; + for(auto tx = tid; tx < dm2 / dm; tx += dim) { - tolsD[p1 + i] = tolD; - tolsZ[p1 + i] = tolZ; - } + rocblas_int sh = nbx * dm2; + rocblas_int bin = sh + tx * dm; + rocblas_int in = posf[bin - sh]; + rocblas_int bj = bin + dm / 2; + rocblas_int j = posi[bj - sh]; + rocblas_int bout = bin + dm; + rocblas_int out = bout < blks ? ps[bout] : n; + bool go = (j >= 0 && in >= 0); + + // find sequences to merge + while(go && j < out) + { + go = false; - // 3. deflate eigenvalues - // ---------------------------------------------------------------- - // deflate zero components - for(int i = hipThreadIdx_x; i < sz; i += hipBlockDim_x) - { - S g = z[p1 + i]; - // deflated ev if component in z is zero - mask[p1 + i] = (abs(p * g) <= tolZ) ? 0 : 1; - } - } -} + // the begining of base sequence is 'vald' and there are 'i' + // elements of the sequence in front that are within tolerance + vald = ev1[in]; + rocblas_int count = dcount[j]; + rocblas_int i = bisearch(vald + tol, ev1 + j, count, false, false); -//--------------------------------------------------------------------------------------// -/** STEDC_MERGEPREPARE_SORTD_KERNEL sorts D array and construct map of original positions - - Call this kernel with n groups in x and batch_count groups in y. - Groups are size STEDC_BDIM **/ -template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergePrepare_SortD_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* tmpzA, - rocblas_int* splitsA) -{ - // threads and groups indices - // batch instance id - rocblas_int bid = hipBlockIdx_y; - // group id - rocblas_int gid = hipBlockIdx_x; - - // select batch instance to work with - S* D = DD + bid * strideD; + // if i == 0, there is nothing to merge; we are done + if(i > 0) + { + // otherwise there is a sequence to merge; i elements will be merged + dcount[in] = i + j - in; + dcount[j] = 0; + j += i; - // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* mask = ptr_mask(n, splits); - rocblas_int* ndd = ptr_ndd(n, splits); - rocblas_int* map = ptr_map(n, splits); - rocblas_int* nsz = ptr_nsz(n, splits); - rocblas_int* nps = ptr_nps(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* md = ptr_md(n, tmpz); - - S d = D[gid]; - rocblas_int sz = nsz[gid]; - rocblas_int p1 = nps[gid]; - rocblas_int def = mask[gid]; - rocblas_int dd = 0; - - constexpr int regs = 8; - const int chunk_width = regs * hipBlockDim_x; - const int n_chunks = (sz - 1) / chunk_width + 1; - S bval[regs]; - int maskval[regs]; + // if i == count, everything was merged; we are done + if(i < count) + { + // otherwise not everything was merged; see if there is a new sequence + // in the 'count - i' not merged elements + count -= i; + i = 0; + while(i < count && std::abs(p * z1[idd1[j + i]]) <= tol) + i++; + + // if i == count, there is no new sequence; we are done + if(i < count) + { + // otherwise we have a new base sequence; see if there are more + // sequences in front + in = j + i; + dcount[in] = count - i; + rocblas_int inj = j + count; + count = out - inj; + i = 0; + while(i < count && dcount[inj + i] == 0) + i++; + + // if i == count, no more sequences; we are done + if(i < count) + { + // otherwise there are more sequences in front + // that need to be analyzed + go = true; + j = inj + i; + } + } + } + } + } - int nan = 0; - int lt = 0; - int eq = 0; + // merges are done... + // update position of first and last sequences in each merged block + rocblas_int tmp1 = posf[bj - sh]; + rocblas_int tmp2 = j > tmp1 ? in : tmp1; + tmp1 = posi[bin - sh]; + tmp = std::max(tmp1, tmp2); + go = (tmp1 > 0 && tmp2 < 0) || (tmp1 < 0 && tmp2 > 0); + posf[bin - sh] = go ? tmp : tmp2; + posi[bin - sh] = go ? tmp : tmp1; + } + __syncthreads(); + } - for(int chunk = 0; chunk < n_chunks; chunk++) + // compute final number of non-deflated elementes in each sub-block + for(auto tx = tid; tx < dm2; tx += dim) { - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < sz) - { - bval[i] = D[p1 + x]; - maskval[i] = mask[p1 + x]; - } - } - for(int i = 0; i < regs; i++) + // each sub-block 'bx' starts and ends at 'in' and 'out', respectively + rocblas_int bx = nbx * dm2 + tx; + rocblas_int in = ps[bx]; + tmp = bx + 1; + rocblas_int out = tmp < blks ? ps[tmp] : n; + + rocblas_int count = 0; + rocblas_int j = in; + while(j < out) { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < sz) + tmp = dcount[j]; + rocblas_int jinc = 1; + if(tmp > 0) { - nan += std::isnan(bval[i]); - dd += maskval[i] > 0; - // lt - how many values are less then the current - // eq - how many values are equal to the current - // all zero deflated values have to be grouped at the end - // so we order any deflated value > any non-deflated value - // def == 0 - current is deflated, maskval[i] == 1 - other value is not deflated - lt += (def < maskval[i]) || (def == maskval[i] && bval[i] < d); - eq += (def == maskval[i]) && (bval[i] == d && (p1 + x) < gid); + jinc = tmp; + count++; + for(auto i = 1; i < tmp; ++i) + { + rocblas_int map = idd1[j + i]; + valz = z1[map]; + if(std::abs(p * valz) <= tol) + dcount[j + i] = -1; + } } + j += jinc; } + posf[tx] = count; } + __syncthreads(); - int pos = lt + eq; - __shared__ int lpos[STEDC_BDIM]; - __shared__ int ldd[STEDC_BDIM]; - - // reduction - lpos[hipThreadIdx_x] = pos; - ldd[hipThreadIdx_x] = dd; - for(int r = hipBlockDim_x / 2; r > 0; r /= 2) + // reduce the results of all the sub-blocks + for(auto kk = 0; kk <= k; ++kk) { - __syncthreads(); - if(hipThreadIdx_x < r) + for(auto tx = tid; tx < dm2 / 2; tx += dim) { - pos += lpos[hipThreadIdx_x + r]; - dd += ldd[hipThreadIdx_x + r]; - lpos[hipThreadIdx_x] = pos; - ldd[hipThreadIdx_x] = dd; + dm = 1 << kk; + rocblas_int bx = (tx / dm) * dm * 2 + dm - 1; + rocblas_int bxx = bx + tx % dm + 1; + posf[bxx] += posf[bx]; } + __syncthreads(); } - if(hipThreadIdx_x == 0) - { - ndd[pos + p1] = dd; - map[pos + p1] = gid; - md[pos + p1] = d; - } - - __syncthreads(); - // The NAN fp value is unordered, so it is possible that with computed - // new positions it would be silently overwriten with non NAN value. - // Make sure we propagate NAN. It is likely to have more NANs in the output - // than in the input, but the following computations are doomed anyway. - if(nan) - { - md[gid] = NAN; - } + // store final number of non-deflated elements in 'nrs' + for(auto tx = tid; tx < dm2; tx += dim) + nrs[nbx * dm2 + tx] = posf[tx]; } //--------------------------------------------------------------------------------------// -/** STEDC_MERGEPREPARE_SETCANDFLAGS_KERNEL fills cand[] array with deflation candidate flags - - Call this kernel with ((n - 1)/STEDC_BDIM+1) groups in x and batch_count groups in y. - Groups are size STEDC_BDIM **/ +/** STEDC_MERGEDEFLATE_KERNEL performs deflation in the sequences of repeated values of + every pair of sub-blocks that need to be merged. + - Call this kernel with batch_count groups in y, and as many groups in x as needed + to cover the n values of the matrix. + - Each thread will deal with one value. + - Size of groups is set to STEDC_BDIM_VALUES.**/ template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergePrepare_SetCandFlags_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* tmpzA, - rocblas_int* splitsA) +ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) stedc_mergeDeflate_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* workSvec, + rocblas_int* workInt, + const S eps) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; - - // select batch instance to work with - S* D = DD + bid * strideD; + rocblas_int gid = hipBlockIdx_x; + rocblas_int nofg = hipGridDim_x; + rocblas_int dim = hipBlockDim_x; + rocblas_int totdim = nofg * dim; + rocblas_int tid = gid * dim + hipThreadIdx_x; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* nps = ptr_nps(n, splits); - rocblas_int* ndd = ptr_ndd(n, splits); - rocblas_int* cand = ptr_cand(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* md = ptr_md(n, tmpz); - S* tolsD = ptr_tolsD(n, tmpz); - - constexpr int F_BCAND = 1 << L_F_BCAND_BIT; - constexpr int F_TCAND = 1 << L_F_TCAND_BIT; - - // find deflate candidates - int i = hipThreadIdx_x + hipBlockDim_x * hipBlockIdx_x; - if(i < n) + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* nrs = ps + blks; + rocblas_int* idd1 = nrs + blks; + rocblas_int* idd2 = idd1 + n; + rocblas_int* bp = idd1 + 4 * n; + rocblas_int* dcount = idd2 + n; + rocblas_int* rmap = dcount + n; + S* z1 = workSvec + bid * (std::max(7, n) * n); + S* z2 = z1 + n; + S* ev1 = z2 + n; + S* ev2 = ev1 + n; + S* ev3 = ev2 + n; + S* c = ev3 + n; + S* s = c + n; + + // work with all the values (items) in parallel + for(auto tx = tid; tx < n; tx += totdim) { - int next = (i + 1) < n ? (i + 1) : i; - int prev = (i > 0) ? (i - 1) : 0; - S tol = tolsD[i]; - S d = md[i]; - S dn = md[next]; - S dp = md[prev]; - rocblas_int dd = ndd[i]; - rocblas_int p1 = nps[i]; - rocblas_int pn = nps[next]; - rocblas_int pp = nps[prev]; - - int bcandidate = (i - p1 < dd - 1) // current and next are not yet deflated - && p1 == pn // in the same merge block - && std::abs(d - dn) <= tol // within tolerance - && i != (n - 1) // isn't last - ; - int tcandidate = (i - p1 < dd) // current and prev are not yet deflated - && p1 == pp // in the same merge block - && std::abs(d - dp) <= tol // within tolerance - && i > 0 // isn't first - ; - cand[i] = (bcandidate << L_F_BCAND_BIT) + (tcandidate << L_F_TCAND_BIT); - } -} - -//--------------------------------------------------------------------------------------// -/** STEDC_MERGEPREPARE_DEFLATECOUNT_KERNEL fills dcount[] array with number of deflations for each base point - - Call this kernel with ((n - 1)/STEDC_BDIM+1) groups in x and batch_count groups in y. - Groups are size STEDC_BDIM **/ -template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergePrepare_DeflateCount_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* tmpzA, - rocblas_int* splitsA) -{ - // threads and groups indices - // batch instance id - rocblas_int bid = hipBlockIdx_y; + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // value 'tx' belongs to sub-block 'bx' and thus participates + // in the merge to create the new sub-block 'nbx' + rocblas_int bx = bp[tx]; + rocblas_int nbx = bx / dm2; + + // the sub-block 'bx' starts at 'in', and ends at 'out' + // the new sub-block 'nbx' starts at 'pin', and ends at 'pout' + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + rocblas_int in = ps[bx]; + tmp = bx + 1; + rocblas_int out = tmp < blks ? ps[tmp] : n; + + // 'count' is the number of non-deflated values until value 'tx' + rocblas_int count = (bx % dm2 == 0) ? 0 : nrs[bx - 1]; + rocblas_int pj = tx - pin; + rocblas_int j = tx - in; + for(auto i = 0; i < j; ++i) + { + if(dcount[in + i] > 0) + count++; + } - // select batch instance to work with - S* D = DD + bid * strideD; + rocblas_int map = idd1[tx]; + S vald = ev1[tx]; + S valz = z1[map]; + rocblas_int dcnt = dcount[tx]; - // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* dcount = ptr_dcount(n, splits); - rocblas_int* cand = ptr_cand(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* md = ptr_md(n, tmpz); - S* tolsD = ptr_tolsD(n, tmpz); - - constexpr rocblas_int max_len = 4096; - __shared__ S ldsD[max_len]; - __shared__ int lcand[max_len]; - constexpr int F_BCAND = 1 << L_F_BCAND_BIT; - constexpr int F_TCAND = 1 << L_F_TCAND_BIT; - - int start = hipBlockDim_x * hipBlockIdx_x; - int base = start + hipThreadIdx_x; - int prev = (base > 0) ? (base - 1) : 0; - int candp = (prev < n) ? cand[prev] : 0; - int candb = (base < n) ? cand[base] : 0; - S bval = (base < n) ? md[base] : 0; - S tol = (base < n) ? tolsD[base] : 0; - - // cache max_len values of D[] and cand[] - for(int i = hipThreadIdx_x; i < max_len; i += hipBlockDim_x) - { - int x = start + i; - ldsD[i] = (x < n) ? md[x] : 0; - lcand[i] = (x < n) ? cand[x] : 0; - } - __syncthreads(); + // if value ev is marked as repeated, move it to the deflated list and finish + if(dcnt < 1) + { + rocblas_int idx = pout - 1 - pj + count; + ev3[idx] = vald; + idd2[idx] = map; + } - if((candb & F_BCAND) && (base == 0 || !(candp & F_BCAND))) - { - int top = base + 2; - int candt = lcand[top - start]; - while(top < n && (candt & F_TCAND)) + // otehrwise move it to the non-deflated list, and compute rotations to zero out + // the corresponding z element when required + else { - // first max_len values are prefetched into lds, - // access global memory only if need to go beyond that - // which is very unlikely - S tval = (top - start) < max_len ? ldsD[top - start] : md[top]; + dcnt--; + dcount[tx] = dcnt; + rocblas_int idx = pin + count; + ev2[idx] = vald; + idd2[idx] = -(map + 1); - if((tval - bval) > tol) + for(auto i = 0; i < dcnt; ++i) { - dcount[base] = top - base - 1; - base = top; - bval = tval; + rocblas_int idx2 = tx + 1 + i; + rocblas_int mapt = dcount[idx2]; + if(mapt == 0) + { + // a rotation is needed + mapt = idd1[idx2]; + S valzt = z1[mapt]; + S cc, ss, rr; + lartg(valz, valzt, cc, ss, rr); + valz = rr; + + // save the rotation encoded for mergeRotate + rmap[idx2] = mapt; + c[mapt] = cc; + s[mapt] = ss; + } + else + { + // no rotation required + rmap[idx2] = -1; + } } - top++; - - // first max_len values are prefetched into lds, - // access global memory only if need to go beyond that - // which is very unlikely - candt = (top - start) < max_len ? lcand[top - start] : cand[top]; + z2[idx] = valz; } - dcount[base] = top - base - 1; } } //--------------------------------------------------------------------------------------// -/** STEDC_MERGEPREPARE_DEFLATEAPPLY_KERNEL applies deflations and saves c/s values used to rotate C vectors - - Call this kernel with ((n - 1)/STEDC_BDIM+1) groups in x and batch_count groups in y. - Groups are size STEDC_BDIM **/ +/** STEDC_MERGEPREPARE_KERNEL prepares the components for the secular equations of every + pair of sub-blocks that need to be merged. + - Call this kernel with batch_count groups in y, and n groups in x. + - Size of groups is set to STEDC_BDIM **/ template ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergePrepare_DeflateApply_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* tmpzA, - rocblas_int* splitsA) + stedc_mergePrepare_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* EE, + const rocblas_stride strideE, + S* workSvec, + S* workStmp, + rocblas_int* workInt) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; + rocblas_int jj = hipBlockIdx_x; + rocblas_int dimr = hipBlockDim_x; + rocblas_int rid = hipThreadIdx_x; // select batch instance to work with - S* D = DD + bid * strideD; + S* E = EE + bid * strideE; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* mask = ptr_mask(n, splits); - rocblas_int* map = ptr_map(n, splits); - rocblas_int* dcount = ptr_dcount(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* z = ptr_z(n, tmpz); - S* cc = ptr_cc(n, tmpz); - S* ss = ptr_ss(n, tmpz); - - constexpr rocblas_int max_len = 4096; - __shared__ S lz[max_len]; - __shared__ int lmap[max_len]; - - int start = hipBlockDim_x * hipBlockIdx_x; - int base = start + hipThreadIdx_x; - int cnt = (base < n) ? dcount[base] : 0; - - // cache max_len values of map[] and appropriate z[] - for(int i = hipThreadIdx_x; i < max_len; i += hipBlockDim_x) + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* nrs = ps + blks; + rocblas_int* bp = ps + 2 * blks + 4 * n; + S* z1 = workSvec + bid * (std::max(7, n) * n); + S* z2 = z1 + n; + S* ev2 = z2 + 2 * n; + S* temps = workStmp + bid * (n * n); + + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // column 'jj' belongs to sub-block 'bx' and thus forms part of + // the new sub-block 'nbx' + rocblas_int bx = bp[jj]; + rocblas_int nbx = bx / dm2; + + // the new sub-block starts at 'pin', the middle point is 'pmid', and + // it ends at 'pout'. Element 'p' is found at middle point + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + rocblas_int pmid = ps[tmp + dm]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + S p = 2 * E[pmid - 1]; + rocblas_int nr = nrs[(nbx + 1) * dm2 - 1]; // number of non-deflated values in sub-block + rocblas_int j = jj - pin; + + if(j < nr) { - int x = start + i; - lmap[i] = (x < n) ? map[x] : 0; - lz[i] = (x < n) ? z[map[x]] : 0; - } - __syncthreads(); + S* tmpd = temps + pin * n; + S* ev = ev2 + pin; + S* Z = z1 + pin; - if(cnt) - { - S baseval = lz[hipThreadIdx_x]; - for(int j = 0; j < cnt; j++) + // if 'p' is negative, the values are copied as negative in reverse order + // as required by the secular equation solvers + bool pneg = (p < 0); + rocblas_int sig = pneg ? -1 : 1; + rocblas_int start = pneg ? nr - 1 : 0; + + for(auto i = rid; i < nr; i += dimr) { - int top = base + j + 1; - - // first max_len values are prefetched into lds, - // access global memory only if need to go beyond that - // which is very unlikely - int idx = (top - start) < max_len ? lmap[top - start] : map[top]; - S g = (top - start) < max_len ? lz[top - start] : z[idx]; - - S f = baseval; - S c, s, rr; - lartg(f, g, c, s, rr); - baseval = rr; - - mask[idx] = 0; - z[idx] = 0; - cc[idx] = c; - ss[idx] = s; + int id = start + sig * i; + tmpd[i + j * n] = sig * ev[id]; + if(j == 0) + Z[i] = z2[id + pin]; } - z[lmap[hipThreadIdx_x]] = baseval; } } @@ -904,31 +772,35 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) /** STEDC_MERGEROTATE_KERNEL performs rotation of vectors corresponding to deflations - Call this kernel with batch_count groups in y, and n (matrix size) groups in x. - Each group will deal with one deflation group, groups that don't correspond to - a deflation group will do nothing **/ + a deflation group will do nothing. + - Size of groups is set to STEDC_BDIM **/ template ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergeRotate_kernel(const rocblas_int k, + stedc_mergeRotate_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, const rocblas_int n, S* CC, const rocblas_int shiftC, const rocblas_int ldc, const rocblas_stride strideC, - S* tmpzA, - rocblas_int* splitsA) + S* workSvec, + rocblas_int* workInt) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; + // select batch instance to work with S* C = load_ptr_batch(CC, bid, shiftC, strideC); - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* map = ptr_map(n, splits); - rocblas_int* dcounts = ptr_dcount(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* cc = ptr_cc(n, tmpz); - S* ss = ptr_ss(n, tmpz); + // temporary arrays in global memory + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* idd1 = ps + 2 * blks; + rocblas_int* dcount = idd1 + 2 * n; + rocblas_int* rmap = dcount + n; + S* z1 = workSvec + bid * (std::max(7, n) * n); + S* cc = z1 + 5 * n; + S* ss = cc + n; constexpr int regs = 16; const int chunk_width = regs * hipBlockDim_x; @@ -937,835 +809,603 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) S tval[regs]; rocblas_int dgs = hipBlockIdx_x; - rocblas_int dcnt = dcounts[dgs]; + rocblas_int dcnt = dcount[dgs]; if(dcnt) { - rocblas_int base = map[dgs]; + rocblas_int base = idd1[dgs]; S* Cbase = C + base * ldc; - for(int chunk = 0; chunk < n_chunks; chunk++) + for(auto chunk = 0; chunk < n_chunks; chunk++) { - for(int i = 0; i < regs; i++) + for(auto i = 0; i < regs; i++) { int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; if(x < n) - { bval[i] = Cbase[x]; - } } - for(int dn = 0; dn < dcnt; dn++) + for(auto dn = 0; dn < dcnt; dn++) { - rocblas_int top = map[dgs + dn + 1]; - S c = cc[top]; - S s = ss[top]; - S* Ctop = C + top * ldc; - - for(int i = 0; i < regs; i++) + rocblas_int top = rmap[dgs + dn + 1]; + if(top > -1) { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) + S c = cc[top]; + S s = ss[top]; + S* Ctop = C + top * ldc; + + for(auto i = 0; i < regs; i++) { - tval[i] = Ctop[x]; + int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; + if(x < n) + tval[i] = Ctop[x]; } - } - for(int i = 0; i < regs; i++) - { - S valf = bval[i]; - S valg = tval[i]; - bval[i] = valf * c - valg * s; - tval[i] = valf * s + valg * c; - } + for(auto i = 0; i < regs; i++) + { + S valf = bval[i]; + S valg = tval[i]; + bval[i] = valf * c - valg * s; + tval[i] = valf * s + valg * c; + } - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) + for(auto i = 0; i < regs; i++) { - Ctop[x] = tval[i]; + int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; + if(x < n) + Ctop[x] = tval[i]; } } __syncthreads(); } - for(int i = 0; i < regs; i++) + for(auto i = 0; i < regs; i++) { int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; if(x < n) - { Cbase[x] = bval[i]; - } } } } } -template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergeValues_SortDZ_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* tmpzA, - rocblas_int* splitsA) -{ - // threads and groups indices - // batch instance id - rocblas_int bid = hipBlockIdx_y; - // group id - rocblas_int gid = hipBlockIdx_x; - - // select batch instance to work with - S* D = DD + bid * strideD; - - // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* mask = ptr_mask(n, splits); - rocblas_int* map = ptr_map(n, splits); - rocblas_int* nsz = ptr_nsz(n, splits); - rocblas_int* nps = ptr_nps(n, splits); - rocblas_int* ndd = ptr_ndd(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* z = ptr_z(n, tmpz); - S* cd = ptr_cd(n, tmpz); - S* cz = ptr_cz(n, tmpz); - S* r1p = ptr_r1p(n, tmpz); - S* evs = ptr_evs(n, tmpz); - - S sig = (r1p[gid] < 0) ? -1 : 1; - S d_ = D[gid]; - S sd_ = sig * d_; - S z_ = z[gid]; - rocblas_int sz = nsz[gid]; - rocblas_int p1 = nps[gid]; - rocblas_int def = mask[gid]; - - constexpr int regs = 8; - const int chunk_width = regs * hipBlockDim_x; - const int n_chunks = (sz - 1) / chunk_width + 1; - S bval[regs]; - int maskval[regs]; - - int nan = 0; - int lt = 0; - int eq = 0; - - rocblas_int dd = 0; - - for(int chunk = 0; chunk < n_chunks; chunk++) - { - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < sz) - { - bval[i] = sig * D[p1 + x]; - maskval[i] = mask[p1 + x]; - } - } - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < sz) - { - nan += std::isnan(bval[i]); - dd += maskval[i] > 0; - // lt - how many values are less then the current - // eq - how many values are equal to the current - // all zero deflated values have to be grouped at the end - // so we order any deflated value > any non-deflated value - // def == 0 - current is deflated, maskval[i] == 1 - other value is not deflated - lt += (def < maskval[i]) || (def == maskval[i] && bval[i] < sd_); - eq += (def == maskval[i]) && (bval[i] == sd_ && (p1 + x) < gid); - } - } - } - - int pos = lt + eq; - - // Reduction of pos and dd across workgroup - __shared__ int lpos[STEDC_BDIM]; - __shared__ int ldd[STEDC_BDIM]; - lpos[hipThreadIdx_x] = pos; - ldd[hipThreadIdx_x] = dd; - __syncthreads(); - rocblas_int dim2 = hipBlockDim_x / 2; - while(dim2 > 0) - { - if(hipThreadIdx_x < dim2) - { - pos += lpos[hipThreadIdx_x + dim2]; - dd += ldd[hipThreadIdx_x + dim2]; - lpos[hipThreadIdx_x] = pos; - ldd[hipThreadIdx_x] = dd; - } - dim2 /= 2; - __syncthreads(); - } - - if(hipThreadIdx_x == 0) - { - ndd[pos + p1] = dd; - map[pos + p1] = gid; - cd[pos + p1] = sd_; - cz[pos + p1] = z_; - // copy over all diagonal elements in ev. ev will be overwritten - // by the new computed eigenvalues of the merged block - evs[pos + p1] = d_; - } - - __syncthreads(); - // The NAN fp value is unordered, so it is possible that with computed - // new positions it would be silently overwriten with non NAN value. - // Make sure we propagate NAN. It is likely to have more NANs in the output - // than in the input, but the following computations are doomed anyway. - if(nan) - { - cd[gid] = NAN; - } -} - -template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergeValues_copyD_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* tmpzA, - S* tempgemmA, - rocblas_int* splitsA) -{ - // threads and groups indices - // batch instance id - rocblas_int bid = hipBlockIdx_y; - rocblas_int eid = hipBlockIdx_x; - - // select batch instance to work with - S* D = DD + bid * strideD; - - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* ndd = ptr_ndd(n, splits); - rocblas_int* nps = ptr_nps(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* evs = ptr_evs(n, tmpz); - S* cd = ptr_cd(n, tmpz); - - S* tempgemm = tempgemmA + bid * get_tempgemm_size(n); - S* etmpd = ptr_etmpd(n, tempgemm); - - rocblas_int dd = ndd[eid]; - rocblas_int p1 = nps[eid]; - - // copy sorted values back to D - int x = hipThreadIdx_x + hipBlockDim_x * hipBlockIdx_x; - if(x < n) - { - D[x] = evs[x]; - } - - // make copies of the non-deflated ordered diagonal elements - // (i.e. the poles of the secular eqn) so that the distances to the - // eigenvalues (D - lambda_i) are updated while computing each eigenvalue. - // This will prevent collapses and division by zero when an eigenvalue - // is too close to a pole. - for(int i = hipThreadIdx_x; i < dd; i += hipBlockDim_x) - { - etmpd[eid * n + i] = cd[p1 + i]; - } -} - //--------------------------------------------------------------------------------------// -/** STEDC_MERGEVALUES_KERNEL solves the secular equation for every pair of sub-blocks - that need to be merged. +/** STEDC_MERGEVALUES_KERNEL solves the secular equation for every value of every pair of + sub-blocks that need to be merged, and re-scales vector z accordingly. - Call this kernel with batch_count groups in y, and as many groups in x as needed - to cover n (i.e. n_groups_x * groups_size_x >= n). Groups are size STEDC_SOLVE_BDIM **/ - + to cover the n values of the matrix. + - Each thread will deal with one value. + - Size of groups is set to STEDC_BDIM_VALUES.**/ template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_SOLVE_BDIM) - stedc_mergeValues_Solve_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* EE, - const rocblas_stride strideE, - S* tmpzA, - S* tempgemmA, - rocblas_int* splitsA, - const S eps, - const S ssfmin, - const S ssfmax) +ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM_VALUES) + stedc_mergeValues_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* EE, + const rocblas_stride strideE, + S* workSvec, + S* workStmp, + rocblas_int* workInt, + const S eps, + const S ssfmin, + const S ssfmax) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; + rocblas_int gid = hipBlockIdx_x; + rocblas_int nofg = hipGridDim_x; + rocblas_int dim = hipBlockDim_x; + rocblas_int totdim = nofg * dim; + rocblas_int tid = gid * dim + hipThreadIdx_x; + + // select batch instance to work with + S* E = EE + bid * strideE; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* ndd = ptr_ndd(n, splits); - rocblas_int* nps = ptr_nps(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* z = ptr_cz(n, tmpz); - S* r1p = ptr_r1p(n, tmpz); - S* evs = ptr_evs(n, tmpz); - // updated eigenvectors after merges - S* tempgemm = tempgemmA + bid * get_tempgemm_size(n); - S* etmpd = ptr_etmpd(n, tempgemm); - - int i = hipThreadIdx_x + hipBlockDim_x * hipBlockIdx_x; - if(i < n) + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* nrs = ps + blks; + rocblas_int* idd2 = nrs + blks + n; + rocblas_int* bp = ps + 2 * blks + 4 * n; + S* z1 = workSvec + bid * (std::max(7, n) * n); + S* ev3 = z1 + 4 * n; + S* temps = workStmp + bid * (n * n); + + // work with all the values (items) in parallel + for(auto tx = tid; tx < n; tx += totdim) { - S p = r1p[i]; - rocblas_int p1 = nps[i]; - rocblas_int dd = ndd[i]; - - /* ----------------------------------------------------------------- */ - - // 2. Solve secular eqns, i.e. find the dd zeros - // corresponding to non-deflated new eigenvalues of the merged block - /* ----------------------------------------------------------------- */ - // each thread will find a different zero in parallel - if((i - p1) < dd) + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // item 'tx' belongs to sub-block 'bx' and thus participates + // in the merge to create the new sub-block 'nbx' + rocblas_int bx = bp[tx]; + rocblas_int nbx = bx / dm2; + + // the new sub-block starts at 'pin', the middle point is 'pmid', and + // it ends at 'pout'. Element 'p' is found at middle point + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + rocblas_int pmid = ps[tmp + dm]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + S p = 2 * E[pmid - 1]; + rocblas_int nr = nrs[(nbx + 1) * dm2 - 1]; // number of non-deflated values in sub-block + + // solve secular equation for every non-deflated value + rocblas_int linfo; + + if(idd2[tx] < 0) { - int cc = i - p1; - - // computed zero will overwrite 'ev' at the corresponding position. - // 'etmpd' will be updated with the distances D - lambda_i. - // deflated values are not changed. - rocblas_int linfo; - #if defined(ROCSOLVER_USE_REFERENCE_SECULAR_EQUATIONS_SOLVER) - linfo = slaed4(dd, cc, etmpd + i * n, z + p1, std::abs(p), evs[i]); + linfo = slaed4(nr, tx - pin, temps + tx * n, z1 + pin, std::abs(p), ev3[tx]); #else - if(cc == dd - 1) - linfo = seq_solve_ext(dd, etmpd + i * n, z + p1, std::abs(p), evs + i, eps, ssfmin, - ssfmax); + if(tx - pin == nr - 1) + linfo = seq_solve_ext(nr, temps + tx * n, z1 + pin, std::abs(p), ev3[tx], eps, + ssfmin, ssfmax); else - linfo = seq_solve(dd, etmpd + i * n, z + p1, std::abs(p), cc, evs + i, eps, ssfmin, + linfo = seq_solve(nr, temps + tx * n, z1 + pin, std::abs(p), ev3[tx], eps, ssfmin, ssfmax); #endif - if(p < 0) - evs[i] *= -1; + ev3[tx] *= -1; } } } +//--------------------------------------------------------------------------------------// +/** STEDC_MERGEREINSERT_KERNEL combines and sort the new eigenvalues with the deflated values + - Call this kernel with batch_count groups in y, and as many groups in x as needed + to cover the n values of the matrix. + - Each thread will deal with one value. + - Size of groups is set to STEDC_BDIM.**/ template ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergeValues_Rescale_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* EE, - const rocblas_stride strideE, - S* tmpzA, - S* tempgemmA, - rocblas_int* splitsA, - const S eps, - const S ssfmin, - const S ssfmax) + stedc_mergeReinsert_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* DD, + const rocblas_stride strideD, + S* EE, + const rocblas_stride strideE, + S* workSvec, + rocblas_int* workInt) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; - // value id - rocblas_int eid = hipBlockIdx_x; + rocblas_int gid = hipBlockIdx_x; + rocblas_int nofg = hipGridDim_x; + rocblas_int dim = hipBlockDim_x; + rocblas_int totdim = nofg * dim; + rocblas_int tid = gid * dim + hipThreadIdx_x; // select batch instance to work with S* D = DD + bid * strideD; + S* E = EE + bid * strideE; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* nps = ptr_nps(n, splits); - rocblas_int* nsz = ptr_nsz(n, splits); - rocblas_int* ndd = ptr_ndd(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* z = ptr_cz(n, tmpz); - // updated eigenvectors after merges - S* tempgemm = tempgemmA + bid * get_tempgemm_size(n); - S* etmpd = ptr_etmpd(n, tempgemm); - - rocblas_int sz = nsz[eid]; - rocblas_int p1 = nps[eid]; - rocblas_int dd = ndd[eid]; - - // Re-scale vector Z to avoid bad numerics when an eigenvalue - // is too close to a pole - rocblas_int i = eid - p1; - if(i < dd) + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* nrs = ps + blks; + rocblas_int* idd1 = nrs + blks; + rocblas_int* idd2 = idd1 + n; + rocblas_int* bp = ps + 2 * blks + 4 * n; + S* z1 = workSvec + bid * (std::max(7, n) * n); + S* ev3 = z1 + 4 * n; + + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // work with all the values (items) in parallel + for(auto j = tid; j < n; j += totdim) { - S valf = 1; - for(int j = hipThreadIdx_x; j < dd; j += hipBlockDim_x) - { - S valg = etmpd[(p1 + j) * n + i]; - valf *= ((p1 + i) == (p1 + j)) ? valg : valg / (D[p1 + i] - D[p1 + j]); - } - __shared__ S lds[STEDC_BDIM]; - rocblas_int dim2 = hipBlockDim_x / 2; - - lds[hipThreadIdx_x] = valf; - __syncthreads(); - while(dim2 > 0) - { - if(hipThreadIdx_x < dim2) - { - valf *= lds[hipThreadIdx_x + dim2]; - lds[hipThreadIdx_x] = valf; - } - dim2 /= 2; - __syncthreads(); - } - - if(hipThreadIdx_x == 0) - { - valf = sqrt(std::abs(valf)); - z[eid] = z[eid] < 0 ? -valf : valf; - } + // item 'j' belongs to sub-block 'bx' and thus form vector of + // the new sub-block 'nbx' + rocblas_int bx = bp[j]; + rocblas_int nbx = bx / dm2; + + // the new sub-block starts at 'pin', the middle point is 'pmid', and + // it ends at 'pout'. Element 'p' is found at middle point + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + rocblas_int pmid = ps[tmp + dm]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + S p = 2 * E[pmid - 1]; + rocblas_int nr = nrs[(nbx + 1) * dm2 - 1]; // number of non-deflated values in sub-block + + // re-insert deflated values to keep new sub-blocks ordered + rocblas_int nf = pout - pin - nr; // number of deflated values + + // the position where the item 'j' will end up in the ordered array is 'pos' + S val = ev3[j]; + rocblas_int pos1 = (j < nr + pin) ? bisearch(val, ev3 + nr + pin, nf, true, true) + : bisearch(val, ev3 + pin, nr, false, (p < 0)); + rocblas_int pos2 = (j < nr + pin) ? (p < 0 ? pin + nr - 1 - j : j - pin) : pout - j - 1; + rocblas_int pos = pos1 + pos2 + pin; + + // get merged ordered array 'ev' and permutation map 'ord' + D[pos] = val; + + rocblas_int ind = idd2[j]; + if(ind < 0) + idd1[pos] = -(j + 1); + else + idd1[pos] = ind; } } //--------------------------------------------------------------------------------------// -/** STEDC_MERGEVECTORS_KERNEL prepares vectors from the secular equation for - every pair of sub-blocks that need to be merged. +/** STEDC_MERGERESCALE_KERNEL reconstructs perturbed vector Z of the rank-1 system. - Call this kernel with batch_count groups in y, and n groups in x. - Each group works with a column. Groups are size STEDC_BDIM. - - If a group has an id larger than the actual number of columns it will do nothing. **/ -template + - Each group will deal with one row of Z corresponding to each merge. + - Size of groups is set to STEDC_BDIM.**/ +template ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergeVectors_kernel(const rocblas_int k, + stedc_mergeRescale_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, const rocblas_int n, - S* CC, - const rocblas_int shiftC, - const rocblas_int ldc, - const rocblas_stride strideC, - S* tmpzA, - S* tempgemmA, - rocblas_int* splitsA) + S* EE, + const rocblas_stride strideE, + S* workSvec, + S* workStmp, + S* workSz, + rocblas_int* workInt) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; - // eigenvalue id - rocblas_int eid = hipBlockIdx_x; - // thread id + rocblas_int ii = hipBlockIdx_x; rocblas_int tidb = hipThreadIdx_x; rocblas_int dim = hipBlockDim_x; // select batch instance to work with - S* C = load_ptr_batch(CC, bid, shiftC, strideC); + S* E = EE + bid * strideE; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* nsz = ptr_nsz(n, splits); - rocblas_int* nps = ptr_nps(n, splits); - rocblas_int* ndd = ptr_ndd(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* z = ptr_cz(n, tmpz); - // updated eigenvectors after merges - S* tempgemm = tempgemmA + bid * get_tempgemm_size(n); - S* vecs = ptr_vecs(n, tempgemm); - S* etmpd = ptr_etmpd(n, tempgemm); - - // Work with merges on level k. Each thread-group works with one vector. - // determine boundaries of what would be the new merged sub-block - rocblas_int p1 = nps[eid]; - rocblas_int sz = nsz[eid]; - rocblas_int dd = ndd[eid]; - - __syncthreads(); + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* nrs = ps + blks; + rocblas_int* bp = ps + 2 * blks + 4 * n; + S* z1 = workSvec + bid * (std::max(7, n) * n); + S* ev2 = z1 + 3 * n; + S* temps = workStmp + bid * (n * n); + S* zf = workSz + bid * n; // temporary arrays in shared memory // used to store temp values during the different reductions __shared__ S inrms[STEDC_BDIM]; - // Prepare vectors corresponding to non-deflated values - S nrm; - S* putvec = USEGEMM ? vecs : etmpd; + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // 'ii' belongs to sub-block 'bx' and thus form vector of + // the new sub-block 'nbx' + rocblas_int bx = bp[ii]; + rocblas_int nbx = bx / dm2; - if(eid - p1 < dd) + // the new sub-block starts at 'pin', the middle point is 'pmid', and + // it ends at 'pout'. Element 'p' is found at middle point + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + rocblas_int pmid = ps[tmp + dm]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + S p = 2 * E[pmid - 1]; + rocblas_int nr = nrs[(nbx + 1) * dm2 - 1]; // number of non-deflated values in sub-block + + S* evd = ev2 + pin; + rocblas_int start = (p < 0) ? nr - 1 : 0; + rocblas_int inc = (p < 0) ? -1 : 1; + + rocblas_int i = ii - pin; + + // compute re-scaled vector Z of rank-1 perturbed system + if(i < nr) { - // compute vectors of rank-1 perturbed system and their norms - nrm = 0; - for(int i = tidb; i < dd; i += dim) - { - S valf = z[p1 + i] / etmpd[i + eid * n]; - nrm += valf * valf; - putvec[i + eid * n] = valf; - } + rocblas_int sgnz = (z1[i + pin] < 0) ? -1 : 1; + S dd = evd[start + inc * i]; + S mul = 1; - // reduction (for the norms) - inrms[tidb] = nrm; - for(int r = dim / 2; r > 0; r /= 2) + for(auto j = tidb; j < nr; j += dim) { - __syncthreads(); - if(tidb < r) - { - nrm += inrms[tidb + r]; - inrms[tidb] = nrm; - } + S num = std::abs(temps[i + (j + pin) * n]); + S den = (j == i) ? 1 : std::abs(dd - evd[start + inc * j]); + mul *= num / den; } + inrms[tidb] = mul; __syncthreads(); - nrm = sqrt(inrms[0]); - } - if(USEGEMM) - { - // when using external gemms for the update, we need to - // put vectors in padded matrix 'etmpd' - // (this is to compute 'vecs = C * etmpd' using external gemm call) - for(int i = tidb; i < p1 + sz; i += dim) + // reduction (for the norms) + for(auto r = dim / 2; r > 0; r /= 2) { - if(i >= p1 && (eid - p1) < dd && (i - p1) < dd) + if(tidb < r) { - etmpd[i + eid * n] = vecs[i - p1 + eid * n] / nrm; + mul *= inrms[tidb + r]; + inrms[tidb] = mul; } - else - etmpd[i + eid * n] = 0; + __syncthreads(); } - } - else - { - // otherwise, use internal gemm-like procedure to - // multiply by C (row by row) - if(eid - p1 < dd) - { - for(int ii = 0; ii < sz; ++ii) - { - rocblas_int i = p1 + ii; - - // inner products - S temp = 0; - for(int kk = tidb; kk < dd; kk += dim) - temp += C[i + (p1 + kk) * ldc] * etmpd[kk + eid * n]; - - // reduction - inrms[tidb] = temp; - for(int r = dim / 2; r > 0; r /= 2) - { - __syncthreads(); - if(tidb < r) - { - temp += inrms[tidb + r]; - inrms[tidb] = temp; - } - } - __syncthreads(); - // result - if(tidb == 0) - vecs[i + eid * n] = temp / nrm; - __syncthreads(); - } - } + if(tidb == 0) + zf[i + pin] = sgnz * std::sqrt(mul); } } //--------------------------------------------------------------------------------------// -/** STEDC_MERGEUPDATE_KERNEL updates vectors and values after a merge is done. +/** STEDC_MERGEVECTORS_KERNEL computes vectors of the rank-1 system for + every pair of sub-blocks that need to be merged. - Call this kernel with batch_count groups in y, and n groups in x. - Each group works with a column. Groups are size STEDC_BDIM. **/ + - Each group works with a column/vector. + - Groups are size STEDC_BDIM **/ template ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) - stedc_mergeUpdate_kernel(const rocblas_int k, - const rocblas_int n, - S* DD, - const rocblas_stride strideD, - S* CC, - const rocblas_int shiftC, - const rocblas_int ldc, - const rocblas_stride strideC, - S* tmpzA, - S* tempgemmA, - rocblas_int* splitsA) + stedc_mergeVectors_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* EE, + const rocblas_stride strideE, + S* workSvec, + S* workStmp, + S* workSz, + rocblas_int* workInt) { // threads and groups indices - // batch instance id rocblas_int bid = hipBlockIdx_y; - // eigenvalue id - rocblas_int eid = hipBlockIdx_x; + rocblas_int j = hipBlockIdx_x; + rocblas_int tidb = hipThreadIdx_x; + rocblas_int dim = hipBlockDim_x; // select batch instance to work with - S* C = load_ptr_batch(CC, bid, shiftC, strideC); - S* D = DD + bid * strideD; + S* E = EE + bid * strideE; // temporary arrays in global memory - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* ndd = ptr_ndd(n, splits); - rocblas_int* nsz = ptr_nsz(n, splits); - rocblas_int* nps = ptr_nps(n, splits); - - S* tmpz = tmpzA + bid * get_tmpz_size(n); - S* evs = ptr_evs(n, tmpz); - // updated eigenvectors after merges - S* tempgemm = tempgemmA + bid * get_tempgemm_size(n); - S* vecs = ptr_vecs(n, tempgemm); - - rocblas_int p1 = nps[eid]; - rocblas_int sz = nsz[eid]; - rocblas_int dd = ndd[eid]; - - // update D and C with computed values and vectors - if(eid - p1 < dd) - { - if(hipThreadIdx_x == 0) - D[eid] = evs[eid]; - for(int i = p1 + hipThreadIdx_x; i < p1 + sz; i += hipBlockDim_x) - C[i + eid * ldc] = vecs[i + eid * n]; - } -} - -template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) stedc_copyD(const rocblas_int n, - S* DDin, - const rocblas_stride strideDin, - S* DDout, - const rocblas_stride strideDout) -{ - // batch instance id - rocblas_int bid = hipBlockIdx_y; - - S* Din = DDin + bid * strideDin; - S* Dout = DDout + bid * strideDout; + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* nrs = ps + blks; + rocblas_int* idd2 = nrs + blks + n; + rocblas_int* bp = ps + 2 * blks + 4 * n; + S* vecs = workSvec + bid * (std::max(7, n) * n); + S* temps = workStmp + bid * (n * n); + S* zf = workSz + bid * n; - int tid = hipThreadIdx_x; - - constexpr int regs = 16; - const int chunk_width = regs * hipBlockDim_x; - const int n_chunks = (n - 1) / chunk_width + 1; - S bval[regs]; + // temporary arrays in shared memory + // used to store temp values during the different reductions + __shared__ S inrms[STEDC_BDIM]; - for(int chunk = 0; chunk < n_chunks; chunk++) + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // column 'j' belongs to sub-block 'bx' and thus form vector of + // the new sub-block 'nbx' + rocblas_int bx = bp[j]; + rocblas_int nbx = bx / dm2; + + // the new sub-block starts at 'pin', the middle point is 'pmid', and + // it ends at 'pout'. Element 'p' is found at middle point + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + rocblas_int pmid = ps[tmp + dm]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + S p = 2 * E[pmid - 1]; + rocblas_int nr = nrs[(nbx + 1) * dm2 - 1]; // number of non-deflated values in sub-block + rocblas_int start = (p < 0) ? nr - 1 : 0; + rocblas_int inc = (p < 0) ? -1 : 1; + + // compute vectors of rank-1 perturbed system and their norms + if(idd2[j] < 0 && j < n) { - for(int i = 0; i < regs; i++) + S tm, nrm = 0; + for(auto i = tidb; i < nr; i += dim) { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - bval[i] = Din[x]; + S tot = zf[i + pin] / temps[i + j * n]; + vecs[i + j * n] = tot; + nrm += tot * tot; } - for(int i = 0; i < regs; i++) + inrms[tidb] = nrm; + __syncthreads(); + + // reduction (for the norms) + for(auto r = dim / 2; r > 0; r /= 2) { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - Dout[x] = bval[i]; + if(tidb < r) + { + nrm += inrms[tidb + r]; + inrms[tidb] = nrm; + } + __syncthreads(); } - } -} - -template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) stedc_copyC(const rocblas_int n, - U1 CCin, - const rocblas_int shiftCin, - const rocblas_int ldcin, - const rocblas_stride strideCin, - U2 CCout, - const rocblas_int shiftCout, - const rocblas_int ldcout, - const rocblas_stride strideCout) -{ - // batch instance id - rocblas_int bid = hipBlockIdx_y; - // group id - rocblas_int gid = hipBlockIdx_x; - - T* Cin = load_ptr_batch(CCin, bid, shiftCin, strideCin); - T* Cout = load_ptr_batch(CCout, bid, shiftCout, strideCout); - - T* src = Cin + ldcin * gid; - T* dst = Cout + ldcout * gid; + nrm = std::sqrt(inrms[0]); - constexpr int regs = 16; - const int chunk_width = regs * hipBlockDim_x; - const int n_chunks = (n - 1) / chunk_width + 1; - T bval[regs]; + // normalize + for(auto i = tidb; i < nr; i += dim) + vecs[i + j * n] /= nrm; + } - for(int chunk = 0; chunk < n_chunks; chunk++) + /*if(!STEDC_USE_EXTERNAL_UPDATE) { - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - bval[i] = src[x]; - } - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - dst[x] = bval[i]; - } - } + // Vectors should be updated at this point when not using external gemm update. + // TODO: the code needs to be revisited and adapted for this new implementation + // of stedc. Performance of the internal gemm needs to be re-evaluated. + }*/ } -template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) stedc_reshuffleC(const rocblas_int n, - U1 CCin, - const rocblas_int shiftCin, - const rocblas_int ldcin, - const rocblas_stride strideCin, - U2 CCout, - const rocblas_int shiftCout, - const rocblas_int ldcout, - const rocblas_stride strideCout, - rocblas_int* splitsA) +//--------------------------------------------------------------------------------------// +/** STEDC_MERGEPREPGEMM1_KERNEL prepares the matrix of vectors of the rank-1 system for + the gemm to update eigenvectors (pad with zeros and insert 1 for deflated values). + - Call this kernel with batch_count groups in y, and n groups in x. + - Groups are size STEDC_BDIM **/ +template +__launch_bounds__(STEDC_BDIM) ROCSOLVER_KERNEL + void stedc_mergePrepgemm1_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* workStmp, + rocblas_int* workInt) { - // batch instance id + // threads and groups indices rocblas_int bid = hipBlockIdx_y; - // group id - rocblas_int gid = hipBlockIdx_x; - - rocblas_int* splits = splitsA + bid * get_splits_size(n); - rocblas_int* map = ptr_map(n, splits); - - rocblas_int dst_row = gid; - rocblas_int src_row = map[gid]; - - T* Cin = load_ptr_batch(CCin, bid, shiftCin, strideCin); - T* Cout = load_ptr_batch(CCout, bid, shiftCout, strideCout); - - T* src = Cin + ldcin * src_row; - T* dst = Cout + ldcout * dst_row; - - constexpr int regs = 16; - const int chunk_width = regs * hipBlockDim_x; - const int n_chunks = (n - 1) / chunk_width + 1; - T bval[regs]; + rocblas_int j = hipBlockIdx_x; + rocblas_int dim = hipBlockDim_x; + rocblas_int tid = hipThreadIdx_x; - for(int chunk = 0; chunk < n_chunks; chunk++) + // temporary arrays in global memory + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* idd1 = ps + 2 * blks; + rocblas_int* bp = ps + 2 * blks + 4 * n; + S* temps = workStmp + bid * (n * n); + + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // column 'j' belongs to sub-block 'bx' and thus form vector of + // the new sub-block 'nbx' + rocblas_int bx = bp[j]; + rocblas_int nbx = bx / dm2; + + // the new sub-block starts at 'pin', and + // it ends at 'pout'. Its size is 'sz' + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + rocblas_int sz = pout - pin; + + rocblas_int t = idd1[j]; + + for(auto ii = tid; ii < sz; ii += dim) { - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - bval[i] = src[x]; - } - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - dst[x] = bval[i]; - } + rocblas_int i = ii + pin; + temps[i + j * n] = (i == t) ? 1 : 0; } } -/** STEDC_SORT sorts computed eigenvalues and eigenvectors in increasing order **/ -template -ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) stedc_sort(const rocblas_int n, - S* DDin, - const rocblas_stride strideDin, - S* DDout, - const rocblas_stride strideDout, - U1 CCin, - const rocblas_int shiftCin, - const rocblas_int ldcin, - const rocblas_stride strideCin, - U2 CCout, - const rocblas_int shiftCout, - const rocblas_int ldcout, - const rocblas_stride strideCout - -) +//--------------------------------------------------------------------------------------// +/** STEDC_MERGEPREPGEMM_KERNEL prepares the matrix of vectors of the rank-1 system for + the gemm to update eigenvectors (pad with zeros and permutate rows and columns). + - Call this kernel with batch_count groups in y, and n groups in x + - Groups are size STEDC_BDIM **/ +template +__launch_bounds__(STEDC_BDIM) ROCSOLVER_KERNEL + void stedc_mergePrepgemm_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* EE, + const rocblas_stride strideE, + S* workSvec, + S* workStmp, + rocblas_int* workInt) { - // batch instance id + // threads and groups indices rocblas_int bid = hipBlockIdx_y; - // group id - rocblas_int gid = hipBlockIdx_x; - - S* Din = DDin + bid * strideDin; - S* Dout = DDout + bid * strideDout; - - int tid = hipThreadIdx_x; - - S d = Din[gid]; - - constexpr int regs = 16; - const int chunk_width = regs * hipBlockDim_x; - const int n_chunks = (n - 1) / chunk_width + 1; - T bvalT[regs]; - S* bvalS = reinterpret_cast(bvalT); + rocblas_int j = hipBlockIdx_x; + rocblas_int tid = hipThreadIdx_x; + rocblas_int dim = hipBlockDim_x; - int nan = 0; - int lt = 0; - int eq = 0; + // select batch instance to work with + S* E = EE + bid * strideE; - for(int chunk = 0; chunk < n_chunks; chunk++) + // temporary arrays in global memory + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* nrs = ps + blks; + rocblas_int* idd1 = nrs + blks; + rocblas_int* idd2 = idd1 + n; + rocblas_int* bp = ps + 2 * blks + 4 * n; + S* vecs = workSvec + bid * (std::max(7, n) * n); + S* temps = workStmp + bid * (n * n); + + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // column 'j' belongs to sub-block 'bx' and thus form vector of + // the new sub-block 'nbx' + rocblas_int bx = bp[j]; + rocblas_int nbx = bx / dm2; + + // the new sub-block starts at 'pin', the middle point is 'pmid', and + // it ends at 'pout'. Element 'p' is found at middle point + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + rocblas_int pmid = ps[tmp + dm]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + S p = 2 * E[pmid - 1]; + rocblas_int nr = nrs[(nbx + 1) * dm2 - 1]; // number of non-deflated values in sub-block + + rocblas_int start = (p < 0) ? pin + nr - 1 : pin; + rocblas_int inc = (p < 0) ? -1 : 1; + + // put vectors in padded matrix 'temps' to use external gemm for the update + rocblas_int ind = idd1[j]; + + if(ind < 0) { - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - { - bvalS[i] = Din[x]; - } - } - for(int i = 0; i < regs; i++) + for(auto i = tid; i < nr; i += dim) { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - { - nan += std::isnan(bvalS[i]); - // lt - how many values are less then the current - // eq - how many values are equal to the current - lt += (bvalS[i] < d); - eq += (bvalS[i] == d && x < gid); - } - } - } + // read rank-1 vector value from 'vecs' (this permutates columns) + rocblas_int jv = -(ind + 1); + S val = vecs[i + jv * n]; - int pos = lt + eq; - // reduction - __shared__ int lpos[STEDC_BDIM]; - lpos[hipThreadIdx_x] = pos; - for(int r = hipBlockDim_x / 2; r > 0; r /= 2) - { - __syncthreads(); - if(hipThreadIdx_x < r) - { - pos += lpos[hipThreadIdx_x + r]; - lpos[hipThreadIdx_x] = pos; + // write in final position in 'temps' (this permutates rows) + rocblas_int it = -(idd2[start + inc * i] + 1); + temps[it + j * n] = val; } } - __syncthreads(); - pos = lpos[0]; - - if(hipThreadIdx_x == 0) - { - Dout[pos] = d; - } - - // The NAN fp value is unordered, so it is possible that with computed - // new positions it would be silently overwriten with non NAN value. - // Make sure we propagate NAN. It is likely to have more NANs in the output - // than in the input, but the following computations are doomed anyway. - if(nan) - { - Dout[gid] = NAN; - } +} - T* Cin = load_ptr_batch(CCin, bid, shiftCin, strideCin); - T* Cout = load_ptr_batch(CCout, bid, shiftCout, strideCout); +//--------------------------------------------------------------------------------------// +/** STEDC_MERGEUPDATE_KERNEL updates vectors after a merge is done. + (simply copy results from temporary arrays into V) + - Call this kernel with batch_count groups in y, and n groups in x + - Groups are size STEDC_BDIM **/ +template +ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM) + stedc_mergeUpdate_kernel(const rocblas_int levs, + const rocblas_int blks, + const rocblas_int k, + const rocblas_int n, + S* CC, + const rocblas_int shiftC, + const rocblas_int ldc, + const rocblas_stride strideC, + S* workSvec, + rocblas_int* workInt) +{ + // threads and groups indices + rocblas_int bid = hipBlockIdx_y; + rocblas_int j = hipBlockIdx_x; + rocblas_int dim = hipBlockDim_x; + rocblas_int tid = hipThreadIdx_x; - T* src = Cin + ldcin * gid; - T* dst = Cout + ldcout * pos; + // select batch instance to work with + S* C = load_ptr_batch(CC, bid, shiftC, strideC); + S* W = workSvec + bid * (n * n); - for(int chunk = 0; chunk < n_chunks; chunk++) + // temporary arrays in global memory + rocblas_int* ps = workInt + bid * (5 * n + 2 * blks); + rocblas_int* bp = ps + 2 * blks + 4 * n; + + rocblas_int dm = 1 << k; + rocblas_int dm2 = dm << 1; + + // column 'j' belongs to sub-block 'bx' and thus form vector of + // the new sub-block 'nbx' + rocblas_int bx = bp[j]; + rocblas_int nbx = bx / dm2; + + // the new sub-block starts at 'pin', and + // it ends at 'pout'. Its size is 'sz' + rocblas_int tmp = nbx * dm2; + rocblas_int pin = ps[tmp]; + tmp += dm2; + rocblas_int pout = tmp < blks ? ps[tmp] : n; + rocblas_int sz = pout - pin; + + for(auto ii = tid; ii < sz; ii += dim) { - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - bvalT[i] = src[x]; - } - for(int i = 0; i < regs; i++) - { - int x = chunk * chunk_width + i * hipBlockDim_x + hipThreadIdx_x; - if(x < n) - dst[x] = bvalT[i]; - } + rocblas_int i = ii + pin; + C[i + j * ldc] = W[i + j * n]; } } @@ -1794,11 +1434,11 @@ template void rocsolver_stedc_getMemorySize(const rocblas_evect evect, const rocblas_int n, const rocblas_int batch_count, - size_t* size_work_stack, size_t* size_tempvect, - size_t* size_tempgemm, - size_t* size_tmpz, - size_t* size_splits_map, + size_t* size_workSvec, + size_t* size_workStmp, + size_t* size_workSz, + size_t* size_workInt, size_t* size_workArr) { constexpr bool COMPLEX = rocblas_is_complex; @@ -1806,12 +1446,12 @@ void rocsolver_stedc_getMemorySize(const rocblas_evect evect, // if quick return no workspace needed if(n <= 1 || !batch_count) { - *size_work_stack = 0; *size_tempvect = 0; - *size_tempgemm = 0; + *size_workSvec = 0; + *size_workStmp = 0; *size_workArr = 0; - *size_splits_map = 0; - *size_tmpz = 0; + *size_workInt = 0; + *size_workSz = 0; return; } @@ -1819,45 +1459,66 @@ void rocsolver_stedc_getMemorySize(const rocblas_evect evect, if(evect == rocblas_evect_none) { *size_tempvect = 0; - *size_tempgemm = 0; + *size_workStmp = 0; *size_workArr = 0; - *size_splits_map = 0; - *size_tmpz = 0; - rocsolver_sterf_getMemorySize(n, batch_count, size_work_stack); + *size_workInt = 0; + *size_workSz = 0; + rocsolver_sterf_getMemorySize(n, batch_count, size_workSvec); } // if size is too small with classic solver else if(n < STEDC_MIN_DC_SIZE) { *size_tempvect = 0; - *size_tempgemm = 0; + *size_workStmp = 0; *size_workArr = 0; - *size_splits_map = 0; - *size_tmpz = 0; - rocsolver_steqr_getMemorySize(evect, n, batch_count, size_work_stack); + *size_workInt = 0; + *size_workSz = 0; + rocsolver_steqr_getMemorySize(evect, n, batch_count, size_workSvec); } // otherwise use divide and conquer algorithm: else { + // find number of sub-blocks + rocblas_int levs = stedc_num_levels(n); + rocblas_int blks = 1 << levs; + + // requirements for batched operations + *size_workArr = 0; + if(batch_count > 1) + { + if(BATCHED && !COMPLEX) + *size_workArr = sizeof(S*) * batch_count; + } + else + { + if(STEDC_USE_EXTERNAL_UPDATE && !STEDC_WITH_STRIDED_BATCHED_GEMM) + { + rocblas_int max_n_merges = 1 << (levs - 1); + *size_workArr = sizeof(S*) * max_n_merges * 3; + } + } + // requirements for solver of small independent blocks - rocsolver_steqr_getMemorySize(evect, n, batch_count, size_work_stack); + size_t vec1; + rocsolver_steqr_getMemorySize(evect, n, batch_count, &vec1); - // extra requirements for original eigenvectors of small independent blocks + // extra requirements for original eigenvectors when needed if(evect != rocblas_evect_tridiagonal) *size_tempvect = sizeof(S) * (n * n) * batch_count; else *size_tempvect = 0; - *size_tempgemm = sizeof(S) * get_tempgemm_size(n) * batch_count; - // blocks for batched GEMM are at least 8 x 8 - auto max_n_merges = 1 << (stedc_num_levels(n) - 1); - *size_workArr = sizeof(S*) * std::max(max_n_merges * 3, batch_count); - // size for split blocks and sub-blocks positions - *size_splits_map = sizeof(rocblas_int) * get_splits_size(n) * batch_count; + // extra requirements for divde and conquer process + size_t vec2 = sizeof(S) * (std::max(7, n) * n) * batch_count; + *size_workSvec = std::max(vec1, vec2); + + *size_workStmp = sizeof(S) * (n * n) * batch_count; - // size for temporary diagonal and rank-1 modif vector - *size_tmpz = sizeof(S) * get_tmpz_size(n) * batch_count; + *size_workInt = sizeof(rocblas_int) * (5 * n + 2 * blks) * batch_count; + + *size_workSz = sizeof(S) * (n)*batch_count; } } @@ -1915,11 +1576,11 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle, const rocblas_stride strideC, rocblas_int* info, const rocblas_int batch_count, - void* work_stack, S* tempvect, - S* tempgemm, - S* tmpz, - rocblas_int* splits, + void* workSvec, + S* workStmp, + S* workSz, + rocblas_int* workInt, S** workArr) { ROCSOLVER_ENTER("stedc", "evect:", evect, "n:", n, "shiftD:", shiftD, "shiftE:", shiftE, @@ -1929,8 +1590,6 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle, if(batch_count == 0) return rocblas_status_success; - auto const splits_map = splits; - hipStream_t stream; rocblas_get_stream(handle, &stream); @@ -1952,22 +1611,20 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle, if(evect == rocblas_evect_none) { rocsolver_sterf_template(handle, n, D, shiftD, strideD, E, shiftE, strideE, info, - batch_count, static_cast(work_stack)); + batch_count, static_cast(workSvec)); } // if size is too small with classic solver, use steqr else if(n < STEDC_MIN_DC_SIZE) { rocsolver_steqr_template(handle, evect, n, D, shiftD, strideD, E, shiftE, strideE, C, - shiftC, ldc, strideC, info, batch_count, work_stack); + shiftC, ldc, strideC, info, batch_count, workSvec); } // otherwise use divide and conquer algorithm: else { - // initialize temporary array for vector updates - size_t size_tempgemm = sizeof(S) * get_tempgemm_size(n) * batch_count; - HIP_CHECK(hipMemsetAsync((void*)tempgemm, 0, size_tempgemm, stream)); + S* workSvecs = (S*)workSvec; // everything must be executed with scalars on the host rocblas_pointer_mode old_mode; @@ -1976,7 +1633,7 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle, S one = 1.0; S zero = 0.0; - // constants + // numerics constants S eps = get_epsilon(); S ssfmin = get_safemin(); S ssfmax = S(1.0) / ssfmin; @@ -2008,118 +1665,193 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle, rocblas_int groups = (batch_count - 1) / STEDC_BDIM + 1; ROCSOLVER_LAUNCH_KERNEL((stedc_divide_kernel), dim3(groups), dim3(STEDC_BDIM), 0, stream, levs, blks, n, D + shiftD, strideD, E + shiftE, strideE, - batch_count, splits); + batch_count, workInt); // 2. solve phase //----------------------------- - ROCSOLVER_LAUNCH_KERNEL((stedc_solve_kernel), dim3(blks, batch_count), dim3(64), 0, - stream, levs, n, D + shiftD, strideD, E + shiftE, strideE, V, 0, - ldv, strideV, info, (S*)work_stack, splits, eps, ssfmin, ssfmax); + ROCSOLVER_LAUNCH_KERNEL((stedc_solve_kernel), dim3(blks, batch_count), + dim3(STEDC_BDIM_SOLVE), 0, stream, levs, blks, n, D + shiftD, + strideD, E + shiftE, strideE, V, 0, ldv, strideV, info, workSvecs, + workInt, eps, ssfmin, ssfmax); // 3. merge phase //---------------- - // launch merge for level k - for(rocblas_int k = 0; k < levs; ++k) + rocblas_int numgrps = (n - 1) / STEDC_BDIM + 1; + rocblas_int nng = (n - 1) / STEDC_BDIM_VALUES + 1; + + // prepare for batched gemms if necessary + // using this approach only in the non-batch syevd calls + bool use_batched_gemm = (batch_count == 1 && !STEDC_WITH_STRIDED_BATCHED_GEMM); + bool use_strided_batched_gemm = (batch_count == 1 && STEDC_WITH_STRIDED_BATCHED_GEMM); + std::vector ns(blks); + rocblas_int res, bc, dm, nn, shv, shw, stv, stw; + rocblas_int dm2 = 1; + rocblas_int bbs = blks; + + if(STEDC_USE_EXTERNAL_UPDATE && use_strided_batched_gemm) + { + rocblas_int sz = n / blks; + res = n - sz * blks; + if(res < blks / 2) + { + res = blks - res; + for(auto i = 0; i < blks; ++i) + ns[i] = i < res ? sz : sz + 1; + } + else + { + for(auto i = 0; i < blks; ++i) + ns[i] = i < res ? sz + 1 : sz; + } + } + + // ****************** launch merge for level k **********************// + // ------------------------------------------------------------------// + for(auto k = 0; k < levs; ++k) { - rocblas_int n_merges = 1 << (levs - k - 1); - ROCSOLVER_LAUNCH_KERNEL(stedc_update_splits, dim3(1, batch_count), dim3(STEDC_BDIM), 0, - stream, levs, k, n, splits); - - // a. prepare secular equations - ROCSOLVER_LAUNCH_KERNEL((stedc_mergePrepare_DeflateZero_kernel), - dim3(n_merges, batch_count), dim3(STEDC_BDIM), 0, stream, k, n, - D + shiftD, strideD, E + shiftE, strideE, V, 0, ldv, strideV, - tmpz, splits, eps); - - ROCSOLVER_LAUNCH_KERNEL((stedc_mergePrepare_SortD_kernel), dim3(n, batch_count), - dim3(STEDC_BDIM), 0, stream, k, n, D + shiftD, strideD, tmpz, - splits); - rocblas_int numgrps_deflate = (n - 1) / STEDC_BDIM + 1; - ROCSOLVER_LAUNCH_KERNEL((stedc_mergePrepare_SetCandFlags_kernel), - dim3(numgrps_deflate, batch_count), dim3(STEDC_BDIM), 0, stream, - k, n, D + shiftD, strideD, tmpz, splits); - ROCSOLVER_LAUNCH_KERNEL((stedc_mergePrepare_DeflateCount_kernel), - dim3(numgrps_deflate, batch_count), dim3(STEDC_BDIM), 0, stream, - k, n, D + shiftD, strideD, tmpz, splits); - ROCSOLVER_LAUNCH_KERNEL((stedc_mergePrepare_DeflateApply_kernel), - dim3(numgrps_deflate, batch_count), dim3(STEDC_BDIM), 0, stream, - k, n, D + shiftD, strideD, tmpz, splits); + // a. merge sort and deflation + ROCSOLVER_LAUNCH_KERNEL((stedc_mergeSort_kernel), dim3(numgrps, batch_count), + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, D + shiftD, + strideD, V, 0, ldv, strideV, workSvecs, workInt); + + rocblas_int ngps = blks / (1 << (k + 1)); + size_t lmemsize = sizeof(S) * blks + sizeof(rocblas_int) * 2 * (1 << (k + 1)); + ROCSOLVER_LAUNCH_KERNEL((stedc_mergeSequences_kernel), dim3(ngps, batch_count), + dim3(STEDC_BDIM), lmemsize, stream, levs, blks, k, n, + E + shiftE, strideE, workSvecs, workInt, eps); + + ROCSOLVER_LAUNCH_KERNEL((stedc_mergeDeflate_kernel), dim3(numgrps, batch_count), + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, workSvecs, + workInt, eps); + + ROCSOLVER_LAUNCH_KERNEL((stedc_mergePrepare_kernel), dim3(n, batch_count), + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, E + shiftE, + strideE, workSvecs, workStmp, workInt); ROCSOLVER_LAUNCH_KERNEL((stedc_mergeRotate_kernel), dim3(n, batch_count), - dim3(STEDC_BDIM), 0, stream, k, n, V, 0, ldv, strideV, tmpz, - splits); - - ROCSOLVER_LAUNCH_KERNEL((stedc_mergeValues_SortDZ_kernel), dim3(n, batch_count), - dim3(STEDC_BDIM), 0, stream, k, n, D + shiftD, strideD, tmpz, - splits); - ROCSOLVER_LAUNCH_KERNEL((stedc_mergeValues_copyD_kernel), dim3(n, batch_count), - dim3(STEDC_BDIM), 0, stream, k, n, D + shiftD, strideD, tmpz, - tempgemm, splits); - - ROCSOLVER_LAUNCH_KERNEL(stedc_copyC, dim3(n, batch_count), dim3(STEDC_BDIM), 0, - stream, n, V, 0, ldv, strideV, ptr_vecs(n, tempgemm), 0, n, - get_tempgemm_size(n)); - - ROCSOLVER_LAUNCH_KERNEL(stedc_reshuffleC, dim3(n, batch_count), dim3(STEDC_BDIM), 0, - stream, n, ptr_vecs(n, tempgemm), 0, n, get_tempgemm_size(n), V, - 0, ldv, strideV, splits); - - rocblas_int numgrps_solve = (n - 1) / STEDC_SOLVE_BDIM + 1; - ROCSOLVER_LAUNCH_KERNEL((stedc_mergeValues_Solve_kernel), - dim3(numgrps_solve, batch_count), dim3(STEDC_SOLVE_BDIM), 0, - stream, k, n, D + shiftD, strideD, E + shiftE, strideE, tmpz, - tempgemm, splits, eps, ssfmin, ssfmax); - - ROCSOLVER_LAUNCH_KERNEL((stedc_mergeValues_Rescale_kernel), dim3(n, batch_count), - dim3(STEDC_BDIM), 0, stream, k, n, D + shiftD, strideD, - E + shiftE, strideE, tmpz, tempgemm, splits, eps, ssfmin, ssfmax); - - // c. find merged eigenvectors - ROCSOLVER_LAUNCH_KERNEL((stedc_mergeVectors_kernel), - dim3(n, batch_count), dim3(STEDC_BDIM), 0, stream, k, n, V, 0, - ldv, strideV, tmpz, tempgemm, splits); - - if(STEDC_EXTERNAL_GEMM) + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, V, 0, ldv, + strideV, workSvecs, workInt); + + // b. compute new values + ROCSOLVER_LAUNCH_KERNEL((stedc_mergeValues_kernel), dim3(nng, batch_count), + dim3(STEDC_BDIM_VALUES), 0, stream, levs, blks, k, n, E + shiftE, + strideE, workSvecs, workStmp, workInt, eps, ssfmin, ssfmax); + + ROCSOLVER_LAUNCH_KERNEL((stedc_mergeReinsert_kernel), dim3(numgrps, batch_count), + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, D + shiftD, + strideD, E + shiftE, strideE, workSvecs, workInt); + + // c. compute new vectors + ROCSOLVER_LAUNCH_KERNEL((stedc_mergeRescale_kernel), dim3(n, batch_count), + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, E + shiftE, + strideE, workSvecs, workStmp, workSz, workInt); + + ROCSOLVER_LAUNCH_KERNEL((stedc_mergeVectors_kernel), dim3(n, batch_count), + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, E + shiftE, + strideE, workSvecs, workStmp, workSz, workInt); + + // d. vector updates + if(STEDC_USE_EXTERNAL_UPDATE) { - // using external gemms with padded matrices to do the vector update - // One single full gemm of size n x n x n merges all the blocks in the level - // TODO: using macro STEDC_EXTERNAL_GEMM = true for now. In the future we can pass - // STEDC_EXTERNAL_GEMM at run time to switch between internal vector updates and - // external gemm based updates. - if(n <= 1024 || batch_count > 1) + ROCSOLVER_LAUNCH_KERNEL(stedc_mergePrepgemm1_kernel, dim3(n, batch_count), + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, workStmp, + workInt); + + ROCSOLVER_LAUNCH_KERNEL(stedc_mergePrepgemm_kernel, dim3(n, batch_count), + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, E + shiftE, + strideE, workSvecs, workStmp, workInt); + + if(use_strided_batched_gemm) { - rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, n, n, n, - &one, V, 0, ldv, strideV, ptr_etmpd(n, tempgemm), 0, n, - get_tempgemm_size(n), &zero, ptr_vecs(n, tempgemm), 0, n, - get_tempgemm_size(n), batch_count, workArr); + HIP_CHECK(hipMemsetAsync((void*)(workSvecs), 0, sizeof(S) * (n * n), stream)); + + dm = dm2; + dm2 *= 2; + res /= 2; + bbs /= 2; + rocblas_int idx = res; + for(auto kk = 0; kk < blks; kk += dm2) + ns[kk] = ns[kk] + ns[kk + dm]; + + // first batch call + shv = 0; + shw = 0; + bc = idx; + if(bc > 0) + { + nn = ns[(idx - 1) * dm2]; + stv = nn * (ldv + 1); + stw = nn * (n + 1); + rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, nn, + nn, nn, &one, V, shv, ldv, stv, workStmp, shw, n, stw, &zero, + workSvecs, shw, n, stw, bc, workArr); + + shv = bc * nn * (ldv + 1); + shw = bc * nn * (n + 1); + } + + // middle batch call + if(idx < bbs - 1) + { + nn = ns[idx * dm2]; + bc = (nn == ns[(idx + 1) * dm2]) ? 0 : 1; + if(bc > 0) + { + stv = nn * (ldv + 1); + stw = nn * (n + 1); + rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, + nn, nn, nn, &one, V, shv, ldv, stv, workStmp, shw, n, + stw, &zero, workSvecs, shw, n, stw, bc, workArr); + + shv += bc * nn * (ldv + 1); + shw += bc * nn * (n + 1); + } + idx += bc; + } + + // last batch call + bc = bbs - idx; + if(bc > 0) + { + nn = ns[idx * dm2]; + stv = nn * (ldv + 1); + stw = nn * (n + 1); + rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, nn, + nn, nn, &one, V, shv, ldv, stv, workStmp, shw, n, stw, &zero, + workSvecs, shw, n, stw, bc, workArr); + } } - else + + else if(use_batched_gemm) { - HIP_CHECK(hipMemsetAsync((void*)tempgemm, 0, n * n * sizeof(S), stream)); + HIP_CHECK(hipMemsetAsync((void*)(workSvecs), 0, sizeof(S) * (n * n), stream)); + rocblas_int n_merges = 1 << (levs - k - 1); if(n % n_merges == 0) { - int sz = n / n_merges; - rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, sz, - sz, sz, &one, V, 0, ldv, sz * ldv + sz, - ptr_etmpd(n, tempgemm), 0, n, sz * n + sz, &zero, - ptr_vecs(n, tempgemm), 0, n, sz * n + sz, n_merges, workArr); + // if all sub-blocks are of same size, only one uniform batch call is required + nn = n / n_merges; + rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, nn, + nn, nn, &one, V, 0, ldv, nn * ldv + nn, workStmp, 0, n, + nn * n + nn, &zero, workSvecs, 0, n, nn * n + nn, n_merges, + workArr); } else { - rocblas_int lvl = levs - k - 1; - std::vector ns(n_merges); + // otherwise 2 batched calls, with sizes ns[0] and ns[0] + 1, are required ns[0] = n; - for(int i = 0; i < lvl; ++i) + rocblas_int t, t2; + for(auto i = 0; i < levs - k - 1; ++i) { - for(int j = (1 << i); j > 0; --j) + for(auto j = (1 << i); j > 0; --j) { - auto t = ns[j - 1]; - ns[j * 2 - 1] = t / 2 + (t & 1); - ns[j * 2 - 2] = t / 2; + t = ns[j - 1]; + t2 = t / 2; + ns[j * 2 - 1] = (2 * t2 < t) ? t2 + 1 : t2; + ns[j * 2 - 2] = t2; } } - // there can only be 2 block sizes: ns[0] and ns[0]+1 + std::array, 2> uniform_batch; uniform_batch[0].reserve(n_merges); uniform_batch[1].reserve(n_merges); @@ -2134,8 +1866,8 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle, { auto ps = b[j]; hABC[j + 0 * nbb] = ps + ps * ldv + V; - hABC[j + 1 * nbb] = ps + ps * n + ptr_etmpd(n, tempgemm); - hABC[j + 2 * nbb] = ps + ps * n + ptr_vecs(n, tempgemm); + hABC[j + 1 * nbb] = ps + ps * n + workStmp; + hABC[j + 2 * nbb] = ps + ps * n + workSvecs; } HIP_CHECK(hipMemcpyAsync(workArr, hABC.data(), 3 * nbb * sizeof(S*), hipMemcpyHostToDevice, stream)); @@ -2146,29 +1878,35 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle, } } } + + else + { + rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, n, n, n, + &one, V, 0, ldv, strideV, workStmp, 0, n, n * n, &zero, + workSvecs, 0, n, n * n, batch_count, workArr); + } } - // d. update level + // e. update for next level ROCSOLVER_LAUNCH_KERNEL((stedc_mergeUpdate_kernel), dim3(n, batch_count), - dim3(STEDC_BDIM), 0, stream, k, n, D + shiftD, strideD, V, 0, - ldv, strideV, tmpz, tempgemm, splits); + dim3(STEDC_BDIM), 0, stream, levs, blks, k, n, V, 0, ldv, + strideV, workSvecs, workInt); } - // 4. update and sort + // 4. Final update //---------------------- if(evect != rocblas_evect_tridiagonal) { // eigenvectors C <- C*V - local_gemm(handle, n, C, shiftC, ldc, strideC, V, tempgemm, - tempgemm + strideV, 0, ldv, strideV, batch_count, - workArr); + local_gemm(handle, n, C, shiftC, ldc, strideC, V, workSvecs, + workStmp, 0, ldv, strideV, batch_count, workArr); } else if constexpr(rocblas_is_complex) { // V is stored in C but is of type S; need to convert to type T // tempgemm = V ROCSOLVER_LAUNCH_KERNEL(copy_mat, dim3(groupsn, groupsn, batch_count), dim3(BS2, BS2), - 0, stream, copymat_to_buffer, n, n, V, 0, ldv, strideV, tempgemm); + 0, stream, copymat_to_buffer, n, n, V, 0, ldv, strideV, workStmp); // imag(C) = zeros ROCSOLVER_LAUNCH_KERNEL(set_zero, dim3(groupsn, groupsn, batch_count), @@ -2177,19 +1915,9 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle, // real(C) = tempgemm ROCSOLVER_LAUNCH_KERNEL((copy_mat), dim3(groupsn, groupsn, batch_count), dim3(BS2, BS2), 0, stream, copymat_from_buffer, n, n, C, shiftC, - ldc, strideC, tempgemm); + ldc, strideC, workStmp); } - ROCSOLVER_LAUNCH_KERNEL(stedc_copyD, dim3(1, batch_count), dim3(STEDC_BDIM), 0, stream, n, - D + shiftD, strideD, tmpz, n); - - ROCSOLVER_LAUNCH_KERNEL(stedc_copyC, dim3(n, batch_count), dim3(STEDC_BDIM), 0, stream, - n, C, shiftC, ldc, strideC, (T*)tempgemm, 0, n, n * n); - - ROCSOLVER_LAUNCH_KERNEL(stedc_sort, dim3(n, batch_count), dim3(STEDC_BDIM), 0, stream, n, - tmpz, n, D + shiftD, strideD, (T*)tempgemm, 0, n, n * n, C, shiftC, - ldc, strideC); - rocblas_set_pointer_mode(handle, old_mode); } diff --git a/projects/rocsolver/library/src/include/lib_device_helpers.hpp b/projects/rocsolver/library/src/include/lib_device_helpers.hpp index 100184e4a64..cee82ed9a30 100644 --- a/projects/rocsolver/library/src/include/lib_device_helpers.hpp +++ b/projects/rocsolver/library/src/include/lib_device_helpers.hpp @@ -1375,4 +1375,105 @@ ROCSOLVER_KERNEL void swap_kernel(I const n, T* const x, I const incx, T* const } } +/** BISEARCH implements a binary search to find the position of 'val' in a sorted array 'X'. + If STRICT = true, it returns the number of elements in 'X' that are strictly smaller than 'val' + If STRICT = false, it returns the number of elements in 'X' that are smaller than or + equal to 'val' **/ +template +__device__ __host__ rocblas_int bisearch(T val, T* X, rocblas_int n, bool STRICT, bool REVERSE) +{ + rocblas_int d = 1; + rocblas_int u = n; + rocblas_int m; + T test; + + // quick return + if(n == 0) + return 0; + + if(REVERSE) + { + if(STRICT) + { + // while there is still an interval to search + while(d != u) + { + // find middle point in the interval [d, u] + m = (u - d - 1) / 2 + 1 + d; + test = X[n - m]; + + // correct interval accordingly + if(test >= val) + u = m - 1; + else + d = m; + } + // return result + test = X[n - d]; + return test >= val ? 0 : d; + } + else + { + // while there is still an interval to search + while(d != u) + { + // find middle point in the interval [d, u] + m = (u - d - 1) / 2 + 1 + d; + test = X[n - m]; + + // correct interval accordingly + if(test > val) + u = m - 1; + else + d = m; + } + // return result + test = X[n - d]; + return test > val ? 0 : d; + } + } + + else + { + if(STRICT) + { + // while there is still an interval to search + while(d != u) + { + // find middle point in the interval [d, u] + m = (u - d - 1) / 2 + 1 + d; + test = X[m - 1]; + + // correct interval accordingly + if(test >= val) + u = m - 1; + else + d = m; + } + // return result + test = X[d - 1]; + return test >= val ? 0 : d; + } + else + { + // while there is still an interval to search + while(d != u) + { + // find middle point in the interval [d, u] + m = (u - d - 1) / 2 + 1 + d; + test = X[m - 1]; + + // correct interval accordingly + if(test > val) + u = m - 1; + else + d = m; + } + // return result + test = X[d - 1]; + return test > val ? 0 : d; + } + } +} + ROCSOLVER_END_NAMESPACE diff --git a/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd.cpp b/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd.cpp index a3171775dba..0b6c08481a2 100644 --- a/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd.cpp +++ b/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd.cpp @@ -24,7 +24,7 @@ rocblas_status rocsolver_syevd_heevd_impl(rocblas_handle handle, return rocblas_status_invalid_handle; // argument checking - rocblas_status st = rocsolver_syevd_heevd_argCheck(handle, evect, uplo, n, A, lda, D, E, info); + rocblas_status st = rocsolver_syev_heev_argCheck(handle, evect, uplo, n, A, lda, D, E, info); if(st != rocblas_status_continue) return st; diff --git a/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd.hpp b/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd.hpp index f75f8ac00b4..83aa71a8544 100644 --- a/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd.hpp +++ b/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd.hpp @@ -35,8 +35,8 @@ #include "auxiliary/rocauxiliary_ormtr_unmtr.hpp" #include "auxiliary/rocauxiliary_stedc.hpp" #include "auxiliary/rocauxiliary_sterf.hpp" -#include "lib_device_helpers.hpp" #include "rocblas.hpp" +#include "roclapack_syev_heev.hpp" #include "roclapack_sytrd_hetrd.hpp" #include "rocsolver/rocsolver.h" @@ -124,41 +124,6 @@ void rocsolver_syevd_heevd_getMemorySize(rocblas_handle handle, *size_workArr = std::max(*size_workArr, 2 * sizeof(T*) * batch_count); } -/** Argument checking **/ -template -rocblas_status rocsolver_syevd_heevd_argCheck(rocblas_handle handle, - const rocblas_evect evect, - const rocblas_fill uplo, - const rocblas_int n, - T A, - const rocblas_int lda, - S* D, - S* E, - rocblas_int* info, - const rocblas_int batch_count = 1) -{ - // order is important for unit tests: - - // 1. invalid/non-supported values - if((evect != rocblas_evect_original && evect != rocblas_evect_none) - || (uplo != rocblas_fill_lower && uplo != rocblas_fill_upper)) - return rocblas_status_invalid_value; - - // 2. invalid size - if(n < 0 || lda < n || batch_count < 0) - return rocblas_status_invalid_size; - - // skip pointer check if querying memory size - if(rocblas_is_device_memory_size_query(handle)) - return rocblas_status_continue; - - // 3. invalid pointers - if((n && !A) || (n && !E) || (n && !D) || (batch_count && !info)) - return rocblas_status_invalid_pointer; - - return rocblas_status_continue; -} - template void rocsolver_syevd_heevd_getMemorySize(rocblas_handle handle, const rocblas_evect evect, @@ -353,7 +318,7 @@ rocblas_status rocsolver_syevd_heevd_template(rocblas_handle handle, rocsolver_stedc_template( handle, rocblas_evect_tridiagonal, n, D, 0, strideD, E, 0, strideE, tmptau_W, 0, ldw, - strideW, info, batch_count, work3, (S*)work2, (S*)work1, tmpz, splits, (S**)workArr); + strideW, info, batch_count, (S*)work3, work2, (S*)work1, tmpz, splits, (S**)workArr); // update the eigenvectors (if applicable) if(evect == rocblas_evect_original) @@ -488,7 +453,7 @@ rocblas_status rocsolver_syevd_heevd_template(rocblas_handle handle, rocsolver_stedc_template( handle, rocblas_evect_tridiagonal, n, D, 0, strideD, E, 0, strideE, tmptau_W, 0, ldw, - strideW, info, batch_count, work3, (S*)work2, (S*)work1, tmpz, splits, (S**)workArr); + strideW, info, batch_count, (S*)work3, work2, (S*)work1, tmpz, splits, (S**)workArr); // update the eigenvectors (if applicable) if(evect == rocblas_evect_original) diff --git a/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd_batched.cpp b/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd_batched.cpp index c7b4c66890c..aa8d5519057 100644 --- a/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd_batched.cpp +++ b/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd_batched.cpp @@ -52,7 +52,7 @@ rocblas_status rocsolver_syevd_heevd_batched_impl(rocblas_handle handle, // argument checking rocblas_status st - = rocsolver_syevd_heevd_argCheck(handle, evect, uplo, n, A, lda, D, E, info, batch_count); + = rocsolver_syev_heev_argCheck(handle, evect, uplo, n, A, lda, D, E, info, batch_count); if(st != rocblas_status_continue) return st; diff --git a/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd_strided_batched.cpp b/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd_strided_batched.cpp index cafade5ac8d..6d01054c7db 100644 --- a/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd_strided_batched.cpp +++ b/projects/rocsolver/library/src/lapack/roclapack_syevd_heevd_strided_batched.cpp @@ -54,7 +54,7 @@ rocblas_status rocsolver_syevd_heevd_strided_batched_impl(rocblas_handle handle, // argument checking rocblas_status st - = rocsolver_syevd_heevd_argCheck(handle, evect, uplo, n, A, lda, D, E, info, batch_count); + = rocsolver_syev_heev_argCheck(handle, evect, uplo, n, A, lda, D, E, info, batch_count); if(st != rocblas_status_continue) return st;