From efaf7a8f00dbd7638751947bd3141c43ffb53d7c Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Fri, 11 Oct 2024 23:00:39 -0600 Subject: [PATCH 1/3] Only use the issymmetric and ishermitian funcitons on numpy arrays --- pymatsolver/solvers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pymatsolver/solvers.py b/pymatsolver/solvers.py index 4d0317c..a1ab55b 100644 --- a/pymatsolver/solvers.py +++ b/pymatsolver/solvers.py @@ -77,8 +77,10 @@ def __init__( if is_symmetric is None: if sp.issparse(A): is_symmetric = (A.T != A).nnz == 0 - else: + elif isinstance(A, np.ndarray): is_symmetric = issymmetric(A) + else: + is_symmetric = False self.is_symmetric = is_symmetric if is_hermitian is None: if self.is_real: @@ -86,8 +88,10 @@ def __init__( else: if sp.issparse(A): is_hermitian = (A.T.conjugate() != A).nnz == 0 - else: + elif isinstance(A, np.ndarray): is_hermitian = ishermitian(A) + else: + is_hermitian = False self.is_hermitian = is_hermitian From bae9724f61ede99a22576645d557cc7c892645fe Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Fri, 11 Oct 2024 23:01:02 -0600 Subject: [PATCH 2/3] add test to make sure a linear operator can be passed. --- tests/test_Scipy.py | 11 +++++++++++ tests/test_Wrappers.py | 5 ++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_Scipy.py b/tests/test_Scipy.py index c07e0c9..0cab0ed 100644 --- a/tests/test_Scipy.py +++ b/tests/test_Scipy.py @@ -1,5 +1,6 @@ from pymatsolver import Solver, Diagonal, SolverCG, SolverLU import scipy.sparse as sp +from scipy.sparse.linalg import aslinearoperator import numpy as np import numpy.testing as npt import pytest @@ -57,6 +58,16 @@ def test_solver(a_matrix, n_rhs, solver): npt.assert_allclose(x, b, atol=tol) +def test_iterative_solver_linear_op(): + n = 10 + A = aslinearoperator(sp.eye(n)) + + Ainv = SolverCG(A) + + rhs = np.linspace(0.9, 1.1, n) + + npt.assert_allclose(Ainv @ rhs, rhs) + @pytest.mark.parametrize('n_rhs', [1, 5]) def test_diag_solver(n_rhs): n = 10 diff --git a/tests/test_Wrappers.py b/tests/test_Wrappers.py index 10b0dab..62a5809 100644 --- a/tests/test_Wrappers.py +++ b/tests/test_Wrappers.py @@ -14,12 +14,14 @@ def test_wrapper_unused_kwargs(solver_class): with pytest.warns(UnusedArgumentWarning, match="Unused keyword argument.*"): solver_class(A, not_a_keyword_arg=True) + def test_good_arg_iterative(): # Ensure this doesn't throw a warning! with warnings.catch_warnings(): warnings.simplefilter("error") SolverCG(sp.eye(10), rtol=1e-4) + def test_good_arg_direct(): # Ensure this doesn't throw a warning! with warnings.catch_warnings(): @@ -40,7 +42,6 @@ def __init__(self, A): WrappedClass(sp.eye(2)) - def test_direct_clean_function(): def direct_func(A): class Empty(): @@ -67,6 +68,7 @@ def clean(self): Ainv.clean() assert Ainv.solver.A is None + def test_iterative_deprecations(): with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"): @@ -75,6 +77,7 @@ def test_iterative_deprecations(): with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"): wrap_iterative(lambda a, x: x, accuracy_tol=1E-3) + def test_non_scipy_iterative(): def iterative_solver(A, x): return x From 5f17cf874ed4903eb7577a0b41330cd403095e5b Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Fri, 11 Oct 2024 23:05:17 -0600 Subject: [PATCH 3/3] test with real and complex data types. --- tests/test_Scipy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_Scipy.py b/tests/test_Scipy.py index 0cab0ed..a080031 100644 --- a/tests/test_Scipy.py +++ b/tests/test_Scipy.py @@ -58,9 +58,10 @@ def test_solver(a_matrix, n_rhs, solver): npt.assert_allclose(x, b, atol=tol) -def test_iterative_solver_linear_op(): +@pytest.mark.parametrize('dtype', [np.float64, np.complex128]) +def test_iterative_solver_linear_op(dtype): n = 10 - A = aslinearoperator(sp.eye(n)) + A = aslinearoperator(sp.eye(n).astype(dtype)) Ainv = SolverCG(A)