|
8 | 8 | import scipy.linalg as scipy_linalg |
9 | 9 | from numpy.exceptions import ComplexWarning |
10 | 10 | from scipy.linalg import get_lapack_funcs |
11 | | -from scipy.linalg._misc import LinAlgWarning |
| 11 | +from scipy.linalg._misc import LinAlgError, LinAlgWarning |
12 | 12 |
|
13 | 13 | import pytensor |
14 | 14 | from pytensor import ifelse |
@@ -897,15 +897,51 @@ def __init__(self, *, unit_diagonal=False, **kwargs): |
897 | 897 |
|
898 | 898 | def perform(self, node, inputs, outputs): |
899 | 899 | A, b = inputs |
900 | | - outputs[0][0] = scipy_linalg.solve_triangular( |
901 | | - A, |
902 | | - b, |
903 | | - lower=self.lower, |
904 | | - trans=0, |
905 | | - unit_diagonal=self.unit_diagonal, |
906 | | - check_finite=self.check_finite, |
907 | | - overwrite_b=self.overwrite_b, |
908 | | - ) |
| 900 | + |
| 901 | + if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()): |
| 902 | + raise ValueError("array must not contain infs or NaNs") |
| 903 | + |
| 904 | + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: |
| 905 | + raise ValueError("expected square matrix") |
| 906 | + |
| 907 | + if A.shape[0] != b.shape[0]: |
| 908 | + raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible") |
| 909 | + |
| 910 | + (trtrs,) = get_lapack_funcs(("trtrs",), (A, b)) |
| 911 | + |
| 912 | + # Quick return for empty arrays |
| 913 | + if b.size == 0: |
| 914 | + outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype) |
| 915 | + return |
| 916 | + |
| 917 | + if A.flags["F_CONTIGUOUS"]: |
| 918 | + x, info = trtrs( |
| 919 | + A, |
| 920 | + b, |
| 921 | + overwrite_b=self.overwrite_b, |
| 922 | + lower=self.lower, |
| 923 | + trans=0, |
| 924 | + unitdiag=self.unit_diagonal, |
| 925 | + ) |
| 926 | + else: |
| 927 | + # transposed system is solved since trtrs expects Fortran ordering |
| 928 | + x, info = trtrs( |
| 929 | + A.T, |
| 930 | + b, |
| 931 | + overwrite_b=self.overwrite_b, |
| 932 | + lower=not self.lower, |
| 933 | + trans=1, |
| 934 | + unitdiag=self.unit_diagonal, |
| 935 | + ) |
| 936 | + |
| 937 | + if info > 0: |
| 938 | + raise LinAlgError( |
| 939 | + f"singular matrix: resolution failed at diagonal {info-1}" |
| 940 | + ) |
| 941 | + elif info < 0: |
| 942 | + raise ValueError(f"illegal value in {-info}-th argument of internal trtrs") |
| 943 | + |
| 944 | + outputs[0][0] = x |
909 | 945 |
|
910 | 946 | def L_op(self, inputs, outputs, output_gradients): |
911 | 947 | res = super().L_op(inputs, outputs, output_gradients) |
|
0 commit comments