diff --git a/paddle/phi/kernels/funcs/matrix_solve.cu b/paddle/phi/kernels/funcs/matrix_solve.cu index b8c20aa166fadb..9253be6d1fd87b 100644 --- a/paddle/phi/kernels/funcs/matrix_solve.cu +++ b/paddle/phi/kernels/funcs/matrix_solve.cu @@ -17,10 +17,131 @@ limitations under the License. */ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" namespace phi { namespace funcs { +#ifndef PADDLE_WITH_HIP +/** + * Transform pivot array to permutation by swapping perm[i] and perm[pivot[i]] + * from 0 to n-1, where pivot and perm have shape [batch_size, n]. + * Example: + * Input pivot = [[6, 7, 4, 5, 5, 7, 8, 8]] + * Output perm = [[5, 6, 3, 4, 2, 1, 7, 0]] + */ +__global__ void UnpackPivot(const int* __restrict__ pivot, + int* __restrict__ perm, + int64_t batch_size, + int64_t n) { + constexpr int warp_size = 32; + int warps_per_block = blockDim.x / warp_size; + int warp_id = threadIdx.x / warp_size; + int warp_offset = threadIdx.x % warp_size; + int64_t offset = static_cast(blockIdx.x) * warps_per_block + warp_id; + int64_t stride = static_cast(gridDim.x) * warps_per_block; + + for (; offset < batch_size; offset += stride) { + // init perm[*, n] with 0...n-1 + for (int64_t i = warp_offset; i < n; i += warp_size) { + perm[offset * n + i] = offset * n + i; + } + __syncwarp(); + + // Since the swapping makes entirely discrete access, we only use the first + // thread in each warp to avoid warp divergence. + if (warp_offset > 0) continue; + + // Swap perm[i] and perm[pivot[i]] for i in 0...n-1 + for (int64_t i = offset * n; i < offset * n + n; ++i) { + int64_t j = pivot[i] - 1 + offset * n; // cublas use 1-index + int tmp = perm[i]; + perm[i] = perm[j]; + perm[j] = tmp; + } + } +} + +/** + * Eliminate the L and U in equation: + * (U^T @ L^T @ P) @ X = B (the U^T @ L^T @ P is stored in A) + * by solving the inversion of L^T and U^T respectively. The result is: + * P @ X = L^T^-1 @ U^T^-1 @ B + * and is stored in B. + */ +template +void SolveLU(const phi::funcs::BlasT& blas, + int m, + int n, + const T* A, + T* B, + int batch_size) { + constexpr T alpha = 1.0; + for (int64_t i = 0; i < batch_size; ++i) { + // Before: U^T @ L^T @ P @ X = B + blas.TRSM(CblasRight, + CblasLower, + CblasTrans, + CblasNonUnit, + m, + n, + alpha, + A + i * n * n, + n, + B + i * m * n, + n); + // After: L^T @ P @ X = U^T^-1 @ B + blas.TRSM(CblasRight, + CblasUpper, + CblasTrans, + CblasUnit, + m, + n, + alpha, + A + i * n * n, + n, + B + i * m * n, + n); + // After: P @ X = L^T^-1 @ U^T^-1 @ B + } +} + +// Batched version of SolveLU. +template +void BatchedSolveLU(const phi::funcs::BlasT& blas, + int m, + int n, + const T** A, + T** B, + int batch_size) { + constexpr T alpha = 1.0; + blas.BatchedTRSM(CblasRight, + CblasLower, + CblasTrans, + CblasNonUnit, + m, + n, + alpha, + A, + n, + B, + n, + batch_size); + blas.BatchedTRSM(CblasRight, + CblasUpper, + CblasTrans, + CblasUnit, + m, + n, + alpha, + A, + n, + B, + n, + batch_size); +} +#endif + template void MatrixSolveFunctor::operator()(const Context& context, const DenseTensor& a, @@ -39,47 +160,38 @@ void MatrixSolveFunctor::operator()(const Context& context, const int a_rank = a_dims.size(); int n = a_dims[a_rank - 1]; int lda = n; - int batch_size = a_rank > 2 ? a.numel() / (n * n) : 1; + int64_t batch_size = a_rank > 2 ? a.numel() / (n * n) : 1; const auto& b_dims = b.dims(); const int b_rank = b_dims.size(); int nrhs = b_dims[b_rank - 1]; - int ldb = b_dims[b_rank - 2]; - - // make sure the out dims is right - out->Resize(b_dims); + int ldb = n; - context.template Alloc(out); - - // copy input A to a temporary tensor tmp_a, - // LU factorization, written back to original matrix A, so in the beginning, - // it's necessary to create a temporary tensor tmp_a. + // 1. Copy input A to a temporary tensor tmp_a for LU factorization. DenseTensor tmp_a(a.dtype()); tmp_a.Resize(a.dims()); - context.template Alloc(&tmp_a); phi::Copy(context, a, context.GetPlace(), false, &tmp_a); - // copy input B to a temporary tensor tmp_b, and transpose tmp_b, - // because cuBlas assumes column-major while Paddle uses row-majar. - DenseTensor tmp_b(b.type()); - const auto& new_dims_vec = getNewDimsVec(b_dims); - tmp_b.Resize(common::make_ddim(new_dims_vec)); - context.template Alloc(&tmp_b); + // 2. Transpose B and save it in out, because cuBlas assumes column-major + // while Paddle uses row-majar. + const auto& new_b_dims = getNewDimsVec(b_dims); + out->Resize(common::make_ddim(new_b_dims)); + context.template Alloc(out); phi::funcs::TransposeNormal trans; std::vector new_axis = getNewAxis(b_rank); - trans(context, b, &tmp_b, new_axis); + trans(context, b, out, new_axis); const T* a_data_in_gpu = tmp_a.data(); - const T* b_data_in_gpu = tmp_b.data(); + T* b_data_in_gpu = out->data(); std::vector cpu_ptrs(batch_size * 2); - for (int i = 0; i < batch_size; ++i) { + for (int64_t i = 0; i < batch_size; ++i) { cpu_ptrs[i] = a_data_in_gpu + i * n * n; cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs; } - // Copy the addresses of A and tmp_b from host to device. + // 3. Copy the addresses of A and B from host to device. phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc( context.GetPlace(), cpu_ptrs.size() * sizeof(T*), @@ -94,8 +206,8 @@ void MatrixSolveFunctor::operator()(const Context& context, T** gpu_tmp_b_ptrs = reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; - // Allocate device memory for BatchedGETRF's info and pivots. - int num_ints = n < 32 ? batch_size : batch_size * (n + 1); + // 4. Allocate device memory for BatchedGETRF's info and pivots. + int64_t num_ints = batch_size * (n + 1); phi::Allocator::AllocationPtr tmp_gpu_info_data = phi::memory_utils::Alloc( context.GetPlace(), num_ints * sizeof(int), @@ -111,14 +223,13 @@ void MatrixSolveFunctor::operator()(const Context& context, int* gpu_pivot_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()) + batch_size; - // This function performs the LU factorization of each matrix A by the - // equation A = L * U. L and U are written back to original matrix A, - // and diagonal elements of L are discarded. + // 5. Performs LU factorization on A. blas.BatchedGETRF(n, reinterpret_cast(tmp_gpu_ptrs_data->ptr()), gpu_pivot_ptr, gpu_info_ptr, batch_size); + // After: P @ A^T = L @ U // check whether BatchedGETRF is executed successfully or not memory_utils::Copy(phi::CPUPlace(), @@ -139,33 +250,47 @@ void MatrixSolveFunctor::operator()(const Context& context, info[i])); } - // hold the result code from BatchedGETRS - int host_info = 0; + // 6. Solve L and U in equation Ax = B where A = U^T @ L^T @ P. + // The batched version is advantageous for small shapes, but has error for + // large shapes. In this case, we call the non-batched version for batch_size + // times instead. + // Ref: https://docs.nvidia.com/cuda/cublas/#cublas-t-trsmbatched + constexpr int max_batch_nrhs = 65535 * 8; // max(gridDim.y) * 8 + if (batch_size > 1 && nrhs <= max_batch_nrhs) { + BatchedSolveLU(blas, + nrhs, + n, + reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_tmp_b_ptrs, + batch_size); + } else { + SolveLU(blas, nrhs, n, a_data_in_gpu, b_data_in_gpu, batch_size); + } + + // 7. Transpose B back to row-major form. + DenseTensor tmp_b(b.type()); + tmp_b.Resize(b_dims); + context.template Alloc(&tmp_b); + phi::funcs::TransposeNormal trans2; + trans2(context, *out, &tmp_b, new_axis); - // to solve the equation after LU factorization - CBLAS_TRANSPOSE transA = CblasTrans; - blas.BatchedGETRS(transA, - n, - nrhs, - reinterpret_cast(tmp_gpu_ptrs_data->ptr()), - lda, - gpu_pivot_ptr, - gpu_tmp_b_ptrs, - ldb, - &host_info, - batch_size); + // 8. Permute B according to pivots to get the final result. + DenseTensor perm; + perm.Resize({batch_size * n}); + context.template Alloc(&perm); - // check whether BatchedGETRS is executed successfully or not - PADDLE_ENFORCE_EQ(host_info, - 0, - common::errors::InvalidArgument( - "The [%d]'th argument to cublas*getrsBatched had " - "an illegal value.", - -host_info)); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(context, batch_size * 32); + auto stream = context.stream(); + UnpackPivot<<>>( + gpu_pivot_ptr, perm.data(), batch_size, n); - // transpose tmp_b to get the final result in row-major form. - phi::funcs::TransposeNormal trans2; - trans2(context, tmp_b, out, new_axis); + // fuse dims 0...n-2 because scatter only supports one index dim + tmp_b.Resize({batch_size * n, nrhs}); + out->Resize({batch_size * n, nrhs}); + GPUScatterAssign(context, tmp_b, perm, out); + out->Resize(b_dims); + // After: X = P^T @ L^T^-1 @ U^T^-1 @ B #else compute_solve_eigen(context, a, b, out);