@@ -46,6 +46,7 @@ namespace type_utils = dpctl::tensor::type_utils;
4646typedef sycl::event (*getrf_batch_impl_fn_ptr_t )(
4747 sycl::queue &,
4848 std::int64_t ,
49+ std::int64_t ,
4950 char *,
5051 std::int64_t ,
5152 std::int64_t ,
@@ -61,6 +62,7 @@ static getrf_batch_impl_fn_ptr_t
6162
6263template <typename T>
6364static sycl::event getrf_batch_impl (sycl::queue &exec_q,
65+ std::int64_t m,
6466 std::int64_t n,
6567 char *in_a,
6668 std::int64_t lda,
@@ -77,7 +79,7 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
7779 T *a = reinterpret_cast <T *>(in_a);
7880
7981 const std::int64_t scratchpad_size =
80- mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, n , n, lda, stride_a,
82+ mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, m , n, lda, stride_a,
8183 stride_ipiv, batch_size);
8284 T *scratchpad = nullptr ;
8385
@@ -91,11 +93,11 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
9193
9294 getrf_batch_event = mkl_lapack::getrf_batch (
9395 exec_q,
94- n , // The order of each square matrix in the batch; (0 ≤ n ).
96+ m , // The number of rows in each matrix in the batch; (0 ≤ m ).
9597 // It must be a non-negative integer.
9698 n, // The number of columns in each matrix in the batch; (0 ≤ n).
9799 // It must be a non-negative integer.
98- a, // Pointer to the batch of square matrices, each of size (n x n).
100+ a, // Pointer to the batch of input matrices, each of size (m x n).
99101 lda, // The leading dimension of each matrix in the batch.
100102 stride_a, // Stride between consecutive matrices in the batch.
101103 ipiv, // Pointer to the array of pivot indices for each matrix in
@@ -179,6 +181,7 @@ std::pair<sycl::event, sycl::event>
179181 const dpctl::tensor::usm_ndarray &a_array,
180182 const dpctl::tensor::usm_ndarray &ipiv_array,
181183 py::list dev_info,
184+ std::int64_t m,
182185 std::int64_t n,
183186 std::int64_t stride_a,
184187 std::int64_t stride_ipiv,
@@ -191,21 +194,21 @@ std::pair<sycl::event, sycl::event>
191194 if (a_array_nd < 3 ) {
192195 throw py::value_error (
193196 " The input array has ndim=" + std::to_string (a_array_nd) +
194- " , but an array with ndim >= 3 is expected. " );
197+ " , but an array with ndim >= 3 is expected" );
195198 }
196199
197200 if (ipiv_array_nd != 2 ) {
198201 throw py::value_error (" The array of pivot indices has ndim=" +
199202 std::to_string (ipiv_array_nd) +
200- " , but a 2-dimensional array is expected. " );
203+ " , but a 2-dimensional array is expected" );
201204 }
202205
203206 const int dev_info_size = py::len (dev_info);
204207 if (dev_info_size != batch_size) {
205208 throw py::value_error (" The size of 'dev_info' (" +
206209 std::to_string (dev_info_size) +
207210 " ) does not match the expected batch size (" +
208- std::to_string (batch_size) + " ). " );
211+ std::to_string (batch_size) + " )" );
209212 }
210213
211214 // check compatibility of execution queue and allocation queue
@@ -221,10 +224,11 @@ std::pair<sycl::event, sycl::event>
221224 }
222225
223226 bool is_a_array_c_contig = a_array.is_c_contiguous ();
227+ bool is_a_array_f_contig = a_array.is_f_contiguous ();
224228 bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous ();
225- if (!is_a_array_c_contig) {
229+ if (!is_a_array_c_contig && !is_a_array_f_contig ) {
226230 throw py::value_error (" The input array "
227- " must be C- contiguous" );
231+ " must be must contiguous" );
228232 }
229233 if (!is_ipiv_array_c_contig) {
230234 throw py::value_error (" The array of pivot indices "
@@ -240,27 +244,34 @@ std::pair<sycl::event, sycl::event>
240244 if (getrf_batch_fn == nullptr ) {
241245 throw py::value_error (
242246 " No getrf_batch implementation defined for the provided type "
243- " of the input matrix. " );
247+ " of the input matrix" );
244248 }
245249
246250 auto ipiv_types = dpctl_td_ns::usm_ndarray_types ();
247251 int ipiv_array_type_id =
248252 ipiv_types.typenum_to_lookup_id (ipiv_array.get_typenum ());
249253
250254 if (ipiv_array_type_id != static_cast <int >(dpctl_td_ns::typenum_t ::INT64)) {
251- throw py::value_error (" The type of 'ipiv_array' must be int64." );
255+ throw py::value_error (" The type of 'ipiv_array' must be int64" );
256+ }
257+
258+ const py::ssize_t *ipiv_array_shape = ipiv_array.get_shape_raw ();
259+ if (ipiv_array_shape[0 ] != batch_size ||
260+ ipiv_array_shape[1 ] != std::min (m, n)) {
261+ throw py::value_error (
262+ " The shape of 'ipiv_array' must be (batch_size, min(m, n))" );
252263 }
253264
254265 char *a_array_data = a_array.get_data ();
255- const std::int64_t lda = std::max<size_t >(1UL , n );
266+ const std::int64_t lda = std::max<size_t >(1UL , m );
256267
257268 char *ipiv_array_data = ipiv_array.get_data ();
258269 std::int64_t *d_ipiv = reinterpret_cast <std::int64_t *>(ipiv_array_data);
259270
260271 std::vector<sycl::event> host_task_events;
261272 sycl::event getrf_batch_ev = getrf_batch_fn (
262- exec_q, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, batch_size ,
263- dev_info, host_task_events, depends);
273+ exec_q, m, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv,
274+ batch_size, dev_info, host_task_events, depends);
264275
265276 sycl::event args_ev = dpctl::utils::keep_args_alive (
266277 exec_q, {a_array, ipiv_array}, host_task_events);
0 commit comments