Prevent the recalculation of permutation
in lu_solve()
#5826
Labels
enhancement
New feature or request
permutation
in lu_solve()
#5826
Use case: Repeated solution of a linear system based on LU decomposition.
Currently, the only exposed methods to make this possible are
jax.scipy.linalg.lu_factor()
in conjuction withjax.scipy.linalg.lu_solve()
. However, to conform with SciPy's API,jax.scipy.linalg.lu_factor()
drops the third outputpermutation
provided byjax.lax.linalg.lu()
. Hence, the permuation is recalculated in every invocation ofjax.scipy.linalg.lu_solve()
(code).For reasons that should be investigated separately, lu_pivots_to_permutation() exhibits poor performance on GPU.
There are two possible solutions:
jax.scipy.linalg.lu_solve()
that allows the user to pass inpermutation
. That would cause divergence from SciPy's API.jax.lax.linalg.lu_solve()
directly to the user.The text was updated successfully, but these errors were encountered: