diff --git a/pymatsolver/solvers.py b/pymatsolver/solvers.py index 4d0317c..a1ab55b 100644 --- a/pymatsolver/solvers.py +++ b/pymatsolver/solvers.py @@ -77,8 +77,10 @@ def __init__( if is_symmetric is None: if sp.issparse(A): is_symmetric = (A.T != A).nnz == 0 - else: + elif isinstance(A, np.ndarray): is_symmetric = issymmetric(A) + else: + is_symmetric = False self.is_symmetric = is_symmetric if is_hermitian is None: if self.is_real: @@ -86,8 +88,10 @@ def __init__( else: if sp.issparse(A): is_hermitian = (A.T.conjugate() != A).nnz == 0 - else: + elif isinstance(A, np.ndarray): is_hermitian = ishermitian(A) + else: + is_hermitian = False self.is_hermitian = is_hermitian diff --git a/tests/test_Scipy.py b/tests/test_Scipy.py index c07e0c9..a080031 100644 --- a/tests/test_Scipy.py +++ b/tests/test_Scipy.py @@ -1,5 +1,6 @@ from pymatsolver import Solver, Diagonal, SolverCG, SolverLU import scipy.sparse as sp +from scipy.sparse.linalg import aslinearoperator import numpy as np import numpy.testing as npt import pytest @@ -57,6 +58,17 @@ def test_solver(a_matrix, n_rhs, solver): npt.assert_allclose(x, b, atol=tol) +@pytest.mark.parametrize('dtype', [np.float64, np.complex128]) +def test_iterative_solver_linear_op(dtype): + n = 10 + A = aslinearoperator(sp.eye(n).astype(dtype)) + + Ainv = SolverCG(A) + + rhs = np.linspace(0.9, 1.1, n) + + npt.assert_allclose(Ainv @ rhs, rhs) + @pytest.mark.parametrize('n_rhs', [1, 5]) def test_diag_solver(n_rhs): n = 10 diff --git a/tests/test_Wrappers.py b/tests/test_Wrappers.py index 10b0dab..62a5809 100644 --- a/tests/test_Wrappers.py +++ b/tests/test_Wrappers.py @@ -14,12 +14,14 @@ def test_wrapper_unused_kwargs(solver_class): with pytest.warns(UnusedArgumentWarning, match="Unused keyword argument.*"): solver_class(A, not_a_keyword_arg=True) + def test_good_arg_iterative(): # Ensure this doesn't throw a warning! with warnings.catch_warnings(): warnings.simplefilter("error") SolverCG(sp.eye(10), rtol=1e-4) + def test_good_arg_direct(): # Ensure this doesn't throw a warning! with warnings.catch_warnings(): @@ -40,7 +42,6 @@ def __init__(self, A): WrappedClass(sp.eye(2)) - def test_direct_clean_function(): def direct_func(A): class Empty(): @@ -67,6 +68,7 @@ def clean(self): Ainv.clean() assert Ainv.solver.A is None + def test_iterative_deprecations(): with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"): @@ -75,6 +77,7 @@ def test_iterative_deprecations(): with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"): wrap_iterative(lambda a, x: x, accuracy_tol=1E-3) + def test_non_scipy_iterative(): def iterative_solver(A, x): return x