Skip to content

Commit 58a79e8

Browse files
committed
Use lapack instead of scipy_linalg.solve_triangular
1 parent 135b8d9 commit 58a79e8

File tree

2 files changed

+71
-10
lines changed

2 files changed

+71
-10
lines changed

pytensor/tensor/slinalg.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import scipy.linalg as scipy_linalg
99
from numpy.exceptions import ComplexWarning
1010
from scipy.linalg import get_lapack_funcs
11-
from scipy.linalg._misc import LinAlgWarning
11+
from scipy.linalg._misc import LinAlgError, LinAlgWarning
1212

1313
import pytensor
1414
from pytensor import ifelse
@@ -897,15 +897,51 @@ def __init__(self, *, unit_diagonal=False, **kwargs):
897897

898898
def perform(self, node, inputs, outputs):
899899
A, b = inputs
900-
outputs[0][0] = scipy_linalg.solve_triangular(
901-
A,
902-
b,
903-
lower=self.lower,
904-
trans=0,
905-
unit_diagonal=self.unit_diagonal,
906-
check_finite=self.check_finite,
907-
overwrite_b=self.overwrite_b,
908-
)
900+
901+
if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()):
902+
raise ValueError("array must not contain infs or NaNs")
903+
904+
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
905+
raise ValueError("expected square matrix")
906+
907+
if A.shape[0] != b.shape[0]:
908+
raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible")
909+
910+
(trtrs,) = get_lapack_funcs(("trtrs",), (A, b))
911+
912+
# Quick return for empty arrays
913+
if b.size == 0:
914+
outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype)
915+
return
916+
917+
if A.flags["F_CONTIGUOUS"]:
918+
x, info = trtrs(
919+
A,
920+
b,
921+
overwrite_b=self.overwrite_b,
922+
lower=self.lower,
923+
trans=0,
924+
unitdiag=self.unit_diagonal,
925+
)
926+
else:
927+
# transposed system is solved since trtrs expects Fortran ordering
928+
x, info = trtrs(
929+
A.T,
930+
b,
931+
overwrite_b=self.overwrite_b,
932+
lower=not self.lower,
933+
trans=1,
934+
unitdiag=self.unit_diagonal,
935+
)
936+
937+
if info > 0:
938+
raise LinAlgError(
939+
f"singular matrix: resolution failed at diagonal {info-1}"
940+
)
941+
elif info < 0:
942+
raise ValueError(f"illegal value in {-info}-th argument of internal trtrs")
943+
944+
outputs[0][0] = x
909945

910946
def L_op(self, inputs, outputs, output_gradients):
911947
res = super().L_op(inputs, outputs, output_gradients)

tests/tensor/test_slinalg.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,31 @@ def solve_op(A, b):
513513

514514
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
515515

516+
def test_solve_triangular_empty(self):
517+
rng = np.random.default_rng(utt.fetch_seed())
518+
A = pt.tensor("A", shape=(5, 5))
519+
b = pt.tensor("b", shape=(5, 0))
520+
521+
A_val = rng.random((5, 5)).astype(config.floatX)
522+
b_empty = np.empty([5, 0], dtype=config.floatX)
523+
524+
A_func = functools.partial(self.A_func, lower=True, unit_diagonal=True)
525+
526+
x = solve_triangular(
527+
A_func(A),
528+
b,
529+
lower=True,
530+
trans=0,
531+
unit_diagonal=True,
532+
b_ndim=len((5, 0)),
533+
)
534+
535+
f = function([A, b], x)
536+
537+
res = f(A_val, b_empty)
538+
assert res.size == 0
539+
assert res.dtype == config.floatX
540+
516541

517542
class TestCholeskySolve(utt.InferShapeTester):
518543
def setup_method(self):

0 commit comments

Comments
 (0)