@@ -17,10 +17,131 @@ limitations under the License. */
1717#include " paddle/phi/core/tensor_utils.h"
1818#include " paddle/phi/kernels/funcs/blas/blas.h"
1919#include " paddle/phi/kernels/funcs/math_function.h"
20+ #include " paddle/phi/kernels/funcs/scatter.cu.h"
2021
2122namespace phi {
2223namespace funcs {
2324
25+ #ifndef PADDLE_WITH_HIP
26+ /* *
27+ * Transform pivot array to permutation by swapping perm[i] and perm[pivot[i]]
28+ * from 0 to n-1, where pivot and perm have shape [batch_size, n].
29+ * Example:
30+ * Input pivot = [[6, 7, 4, 5, 5, 7, 8, 8]]
31+ * Output perm = [[5, 6, 3, 4, 2, 1, 7, 0]]
32+ */
33+ __global__ void UnpackPivot (const int * __restrict__ pivot,
34+ int * __restrict__ perm,
35+ int64_t batch_size,
36+ int64_t n) {
37+ constexpr int warp_size = 32 ;
38+ int warps_per_block = blockDim .x / warp_size;
39+ int warp_id = threadIdx .x / warp_size;
40+ int warp_offset = threadIdx .x % warp_size;
41+ int64_t offset = static_cast <int64_t >(blockIdx .x ) * warps_per_block + warp_id;
42+ int64_t stride = static_cast <int64_t >(gridDim .x ) * warps_per_block;
43+
44+ for (; offset < batch_size; offset += stride) {
45+ // init perm[*, n] with 0...n-1
46+ for (int64_t i = warp_offset; i < n; i += warp_size) {
47+ perm[offset * n + i] = offset * n + i;
48+ }
49+ __syncwarp ();
50+
51+ // Since the swapping makes entirely discrete access, we only use the first
52+ // thread in each warp to avoid warp divergence.
53+ if (warp_offset > 0 ) continue ;
54+
55+ // Swap perm[i] and perm[pivot[i]] for i in 0...n-1
56+ for (int64_t i = offset * n; i < offset * n + n; ++i) {
57+ int64_t j = pivot[i] - 1 + offset * n; // cublas use 1-index
58+ int tmp = perm[i];
59+ perm[i] = perm[j];
60+ perm[j] = tmp;
61+ }
62+ }
63+ }
64+
65+ /* *
66+ * Eliminate the L and U in equation:
67+ * (U^T @ L^T @ P) @ X = B (the U^T @ L^T @ P is stored in A)
68+ * by solving the inversion of L^T and U^T respectively. The result is:
69+ * P @ X = L^T^-1 @ U^T^-1 @ B
70+ * and is stored in B.
71+ */
72+ template <typename Context, typename T>
73+ void SolveLU (const phi::funcs::BlasT<Context, T>& blas,
74+ int m,
75+ int n,
76+ const T* A,
77+ T* B,
78+ int batch_size) {
79+ constexpr T alpha = 1.0 ;
80+ for (int i = 0 ; i < batch_size; i++) {
81+ // Before: U^T @ L^T @ P @ X = B
82+ blas.TRSM (CblasRight,
83+ CblasLower,
84+ CblasTrans,
85+ CblasNonUnit,
86+ m,
87+ n,
88+ alpha,
89+ A + i * n * n,
90+ n,
91+ B + i * m * n,
92+ n);
93+ // After: L^T @ P @ X = U^T^-1 @ B
94+ blas.TRSM (CblasRight,
95+ CblasUpper,
96+ CblasTrans,
97+ CblasUnit,
98+ m,
99+ n,
100+ alpha,
101+ A + i * n * n,
102+ n,
103+ B + i * m * n,
104+ n);
105+ // After: P @ X = L^T^-1 @ U^T^-1 @ B
106+ }
107+ }
108+
109+ // Batched version of SolveLU.
110+ template <typename Context, typename T>
111+ void BatchedSolveLU (const phi::funcs::BlasT<Context, T>& blas,
112+ int m,
113+ int n,
114+ const T** A,
115+ T** B,
116+ int batch_size) {
117+ constexpr T alpha = 1.0 ;
118+ blas.BatchedTRSM (CblasRight,
119+ CblasLower,
120+ CblasTrans,
121+ CblasNonUnit,
122+ m,
123+ n,
124+ alpha,
125+ A,
126+ n,
127+ B,
128+ n,
129+ batch_size);
130+ blas.BatchedTRSM (CblasRight,
131+ CblasUpper,
132+ CblasTrans,
133+ CblasUnit,
134+ m,
135+ n,
136+ alpha,
137+ A,
138+ n,
139+ B,
140+ n,
141+ batch_size);
142+ }
143+ #endif
144+
24145template <typename Context, typename T>
25146void MatrixSolveFunctor<Context, T>::operator ()(const Context& context,
26147 const DenseTensor& a,
@@ -39,47 +160,38 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
39160 const int a_rank = a_dims.size ();
40161 int n = a_dims[a_rank - 1 ];
41162 int lda = n;
42- int batch_size = a_rank > 2 ? a.numel () / (n * n) : 1 ;
163+ int64_t batch_size = a_rank > 2 ? a.numel () / (n * n) : 1 ;
43164
44165 const auto & b_dims = b.dims ();
45166 const int b_rank = b_dims.size ();
46167 int nrhs = b_dims[b_rank - 1 ];
47- int ldb = b_dims[b_rank - 2 ];
48-
49- // make sure the out dims is right
50- out->Resize (b_dims);
168+ int ldb = n;
51169
52- context.template Alloc <T>(out);
53-
54- // copy input A to a temporary tensor tmp_a,
55- // LU factorization, written back to original matrix A, so in the beginning,
56- // it's necessary to create a temporary tensor tmp_a.
170+ // 1. Copy input A to a temporary tensor tmp_a for LU factorization.
57171 DenseTensor tmp_a (a.dtype ());
58172 tmp_a.Resize (a.dims ());
59-
60173 context.template Alloc <T>(&tmp_a);
61174 phi::Copy (context, a, context.GetPlace (), false , &tmp_a);
62175
63- // copy input B to a temporary tensor tmp_b, and transpose tmp_b,
64- // because cuBlas assumes column-major while Paddle uses row-majar.
65- DenseTensor tmp_b (b.type ());
66- const auto & new_dims_vec = getNewDimsVec (b_dims);
67- tmp_b.Resize (common::make_ddim (new_dims_vec));
68- context.template Alloc <T>(&tmp_b);
176+ // 2. Transpose B and save it in out, because cuBlas assumes column-major
177+ // while Paddle uses row-majar.
178+ const auto & new_b_dims = getNewDimsVec (b_dims);
179+ out->Resize (common::make_ddim (new_b_dims));
180+ context.template Alloc <T>(out);
69181 phi::funcs::TransposeNormal<Context, T> trans;
70182 std::vector<int > new_axis = getNewAxis (b_rank);
71- trans (context, b, &tmp_b , new_axis);
183+ trans (context, b, out , new_axis);
72184
73185 const T* a_data_in_gpu = tmp_a.data <T>();
74- const T* b_data_in_gpu = tmp_b. data <T>();
186+ T* b_data_in_gpu = out-> data <T>();
75187
76188 std::vector<const T*> cpu_ptrs (batch_size * 2 );
77- for (int i = 0 ; i < batch_size; ++i) {
189+ for (int64_t i = 0 ; i < batch_size; ++i) {
78190 cpu_ptrs[i] = a_data_in_gpu + i * n * n;
79191 cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs;
80192 }
81193
82- // Copy the addresses of A and tmp_b from host to device.
194+ // 3. Copy the addresses of A and B from host to device.
83195 phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc (
84196 context.GetPlace (),
85197 cpu_ptrs.size () * sizeof (T*),
@@ -94,8 +206,8 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
94206 T** gpu_tmp_b_ptrs =
95207 reinterpret_cast <T**>(tmp_gpu_ptrs_data->ptr ()) + batch_size;
96208
97- // Allocate device memory for BatchedGETRF's info and pivots.
98- int num_ints = n < 32 ? batch_size : batch_size * (n + 1 );
209+ // 4. Allocate device memory for BatchedGETRF's info and pivots.
210+ int64_t num_ints = batch_size * (n + 1 );
99211 phi::Allocator::AllocationPtr tmp_gpu_info_data = phi::memory_utils::Alloc (
100212 context.GetPlace (),
101213 num_ints * sizeof (int ),
@@ -111,14 +223,13 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
111223 int * gpu_pivot_ptr =
112224 reinterpret_cast <int *>(tmp_gpu_info_data->ptr ()) + batch_size;
113225
114- // This function performs the LU factorization of each matrix A by the
115- // equation A = L * U. L and U are written back to original matrix A,
116- // and diagonal elements of L are discarded.
226+ // 5. Performs LU factorization on A.
117227 blas.BatchedGETRF (n,
118228 reinterpret_cast <T**>(tmp_gpu_ptrs_data->ptr ()),
119229 gpu_pivot_ptr,
120230 gpu_info_ptr,
121231 batch_size);
232+ // After: P @ A^T = L @ U
122233
123234 // check whether BatchedGETRF is executed successfully or not
124235 memory_utils::Copy (phi::CPUPlace (),
@@ -139,33 +250,47 @@ void MatrixSolveFunctor<Context, T>::operator()(const Context& context,
139250 info[i]));
140251 }
141252
142- // hold the result code from BatchedGETRS
143- int host_info = 0 ;
253+ // 6. Solve L and U in equation Ax = B where A = U^T @ L^T @ P.
254+ // The batched version is advantageous for small shapes, but has error for
255+ // large shapes. In this case, we call the non-batched version for batch_size
256+ // times instead.
257+ // Ref: https://docs.nvidia.com/cuda/cublas/#cublas-t-trsmbatched
258+ constexpr int max_batch_nrhs = 65535 * 8 ; // max(gridDim.y) * 8
259+ if (batch_size > 1 && nrhs <= max_batch_nrhs) {
260+ BatchedSolveLU (blas,
261+ nrhs,
262+ n,
263+ reinterpret_cast <const T**>(tmp_gpu_ptrs_data->ptr ()),
264+ gpu_tmp_b_ptrs,
265+ batch_size);
266+ } else {
267+ SolveLU (blas, nrhs, n, a_data_in_gpu, b_data_in_gpu, batch_size);
268+ }
269+
270+ // 7. Transpose B back to row-major form.
271+ DenseTensor tmp_b (b.type ());
272+ tmp_b.Resize (b_dims);
273+ context.template Alloc <T>(&tmp_b);
274+ phi::funcs::TransposeNormal<Context, T> trans2;
275+ trans2 (context, *out, &tmp_b, new_axis);
144276
145- // to solve the equation after LU factorization
146- CBLAS_TRANSPOSE transA = CblasTrans;
147- blas.BatchedGETRS (transA,
148- n,
149- nrhs,
150- reinterpret_cast <const T**>(tmp_gpu_ptrs_data->ptr ()),
151- lda,
152- gpu_pivot_ptr,
153- gpu_tmp_b_ptrs,
154- ldb,
155- &host_info,
156- batch_size);
277+ // 8. Permute B according to pivots to get the final result.
278+ DenseTensor perm;
279+ perm.Resize ({batch_size * n});
280+ context.template Alloc <int >(&perm);
157281
158- // check whether BatchedGETRS is executed successfully or not
159- PADDLE_ENFORCE_EQ (host_info,
160- 0 ,
161- common::errors::InvalidArgument (
162- " The [%d]'th argument to cublas*getrsBatched had "
163- " an illegal value." ,
164- -host_info));
282+ auto config =
283+ phi::backends::gpu::GetGpuLaunchConfig1D (context, batch_size * 32 );
284+ auto stream = context.stream ();
285+ UnpackPivot<<<config.block_per_grid, config.thread_per_block, 0 , stream>>> (
286+ gpu_pivot_ptr, perm.data <int >(), batch_size, n);
165287
166- // transpose tmp_b to get the final result in row-major form.
167- phi::funcs::TransposeNormal<Context, T> trans2;
168- trans2 (context, tmp_b, out, new_axis);
288+ // fuse dims 0...n-2 because scatter only supports one index dim
289+ tmp_b.Resize ({batch_size * n, nrhs});
290+ out->Resize ({batch_size * n, nrhs});
291+ GPUScatterAssign<T>(context, tmp_b, perm, out);
292+ out->Resize (b_dims);
293+ // After: X = P^T @ L^T^-1 @ U^T^-1 @ B
169294
170295#else
171296 compute_solve_eigen<Context, T>(context, a, b, out);
0 commit comments