Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conjugate solve #50

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion pymatsolver/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class Base(ABC):
__numpy_ufunc__ = True
__array_ufunc__ = None

_is_conjugate = False

def __init__(
self, A, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs
):
Expand Down Expand Up @@ -251,7 +253,13 @@ def _transpose_class(self):
return self.__class__

def transpose(self):
"""Return the transposed solve operator."""
"""Return the transposed solve operator.

Returns
-------
pymatsolver.solvers.Base
"""

if self.is_symmetric:
return self
if self._transpose_class is None:
Expand All @@ -274,6 +282,23 @@ def T(self):
"""
return self.transpose()

def conjugate(self):
"""Return the complex conjugate version of this solver.

Returns
-------
pymatsolver.solvers.Base
"""
if self.is_real:
return self
else:
# make a shallow copy of myself
conjugated = copy.copy(self)
conjugated._is_conjugate = not self._is_conjugate
return conjugated

conj = conjugate

def _compute_accuracy(self, rhs, x):
resid_norm = np.linalg.norm(rhs - self.A @ x)
rhs_norm = np.linalg.norm(rhs)
Expand Down Expand Up @@ -308,6 +333,8 @@ def solve(self, rhs):
if ndim == 1:
if len(rhs) != n:
raise ValueError(f'Expected a vector of length {n}, got {len(rhs)}')
if self._is_conjugate:
rhs = rhs.conjugate()
x = self._solve_single(rhs)
else:
if ndim == 2 and rhs.shape[-1] == 1:
Expand All @@ -331,6 +358,8 @@ def solve(self, rhs):
# (which is more common for direct solvers).
rhs = rhs.transpose()
# should end up with shape (n, -1)
if self._is_conjugate:
rhs = rhs.conjugate()
x = self._solve_multiple(rhs)
if do_broadcast:
# undo the reshaping above
Expand All @@ -347,6 +376,9 @@ def solve(self, rhs):
#TODO remove this in v0.4.0.
if x.size == n:
x = x.reshape(-1)

if self._is_conjugate:
x = x.conjugate()
return x

@abstractmethod
Expand Down
46 changes: 46 additions & 0 deletions tests/test_conjugate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
import pymatsolver
import numpy as np
import scipy.sparse as sp
import numpy.testing as npt


@pytest.mark.parametrize('solver_class', [pymatsolver.Solver, pymatsolver.SolverLU, pymatsolver.Pardiso, pymatsolver.Mumps])
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
@pytest.mark.parametrize('n_rhs', [1, 4])
def test_conjugate_solve(solver_class, dtype, n_rhs):
if solver_class is pymatsolver.Pardiso and not pymatsolver.AvailableSolvers['Pardiso']:
pytest.skip("pydiso not installed.")
if solver_class is pymatsolver.Mumps and not pymatsolver.AvailableSolvers['Mumps']:
pytest.skip("python-mumps not installed.")

n = 10
D = sp.diags(np.linspace(1, 10, n))
if dtype == np.float64:
L = sp.diags([1, -1], [0, -1], shape=(n, n))

sol = np.linspace(0.9, 1.1, n)
# non-symmetric real matrix
else:
# non-symmetric
L = sp.diags([1, -1j], [0, -1], shape=(n, n))
sol = np.linspace(0.9, 1.1, n) - 1j * np.linspace(0.9, 1.1, n)[::-1]

if n_rhs > 1:
sol = np.pad(sol[:, None], [(0, 0), (0, n_rhs - 1)], mode='constant')

A = D @ L @ D @ L.T

# double check it solves
rhs = A @ sol
Ainv = solver_class(A)
npt.assert_allclose(Ainv @ rhs, sol)

# is conjugate solve correct?
rhs_conj = A.conjugate() @ sol
Ainv_conj = Ainv.conjugate()
npt.assert_allclose(Ainv_conj @ rhs_conj, sol)

# is conjugate -> conjugate solve correct?
Ainv2 = Ainv_conj.conjugate()
npt.assert_allclose(Ainv2 @ rhs, sol)
Loading