Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor base class #48

Merged
merged 25 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0e57ffd
validate kwargs for wrapped functions
jcapriot Sep 27, 2024
a6690a3
Refactor many internals to be more consistent between classes.
jcapriot Sep 28, 2024
7876f77
fix mumps transpose operation.
jcapriot Sep 28, 2024
76c9c34
add n_threads as a kwarg on Pardiso
jcapriot Sep 28, 2024
58eb8d1
add base kwargs to subclasses
jcapriot Oct 1, 2024
2a58618
update wrapper to not gobble `is_symmetric` and `is_hermitian`
jcapriot Oct 1, 2024
dae993c
update deprecation messages
jcapriot Oct 1, 2024
f0dee04
add reminder note to remove piece in next version.
jcapriot Oct 1, 2024
7feff87
remove unused
jcapriot Oct 1, 2024
3b889b0
upload to codecov with CODECOV_TOKEN
jcapriot Oct 3, 2024
f4c0b1e
Add minimal docstrings describing parameters and returns
jcapriot Oct 4, 2024
178d619
Merge branch 'docstrings' into signature_checking
jcapriot Oct 4, 2024
6d56478
fix mumps re-factor logic
jcapriot Oct 9, 2024
5e3663c
Merge remote-tracking branch 'origin/consistent_refactor' into consis…
jcapriot Oct 9, 2024
a6ac4d9
add tests for bicg warnings and coverage
jcapriot Oct 9, 2024
f082535
add some coverage tests for the basic solver
jcapriot Oct 9, 2024
4efc207
run BiCG through the matrix test suite
jcapriot Oct 10, 2024
9350dba
more tests for coverage.
jcapriot Oct 10, 2024
9305f05
Coverage for basic tests
jcapriot Oct 10, 2024
4e9ebed
triangle tests
jcapriot Oct 10, 2024
eefd9fb
wrapper tests
jcapriot Oct 10, 2024
a558b37
more wrapper tests.
jcapriot Oct 10, 2024
e6d3dad
add test for non-scipy iterative solver
jcapriot Oct 10, 2024
46683d6
Do not run on every branch.
jcapriot Oct 10, 2024
5047329
add pardiso tests for PD matrices.
jcapriot Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Testing
on:
push:
branches:
- '*'
- 'main'
tags:
- 'v*'
pull_request:
Expand Down Expand Up @@ -75,6 +75,8 @@ jobs:
uses: codecov/codecov-action@v4
with:
verbose: true # optional (default = false)
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

distribute:
name: Distributing from 3.8
Expand Down
7 changes: 4 additions & 3 deletions pymatsolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
.. autosummary::
:toctree: generated/

Triangle
Forward
Backward

Expand Down Expand Up @@ -60,9 +61,9 @@
}

# Simple solvers
from .solvers import Diagonal, Forward, Backward
from .wrappers import WrapDirect
from .wrappers import WrapIterative
from .solvers import Diagonal, Triangle, Forward, Backward
from .wrappers import wrap_direct, WrapDirect
from .wrappers import wrap_iterative, WrapIterative

# Scipy Iterative solvers
from .iterative import SolverCG
Expand Down
92 changes: 67 additions & 25 deletions pymatsolver/direct/mumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,106 @@
from mumps import Context

class Mumps(Base):
"""
Mumps solver
"""The MUMPS direct solver.

This solver uses the python-mumps wrappers to factorize a sparse matrix, and use that factorization for solving.

Parameters
----------
A
Matrix to solve with.
ordering : str, default 'metis'
Which ordering algorithm to use. See the `python-mumps` documentation for more details.
is_symmetric : bool, optional
Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and
default to ``False`` if those fail.
is_positive_definite : bool, optional
Whether the matrix is positive definite.
check_accuracy : bool, optional
Whether to check the accuracy of the solution.
check_rtol : float, optional
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments. If there are any left here a warning will be raised.
"""
_transposed = False
ordering = ''

def __init__(self, A, **kwargs):
self.set_kwargs(**kwargs)
def __init__(self, A, ordering=None, is_symmetric=None, is_positive_definite=False, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
is_hermitian = kwargs.pop('is_hermitian', False)
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
if ordering is None:
ordering = "metis"
self.ordering = ordering
self.solver = Context()
self._set_A(A)
self.A = A
self._set_A(self.A)

def _set_A(self, A):
self.solver.set_matrix(
A,
symmetric=self.is_symmetric,
# positive_definite=self.is_positive_definite # doesn't (yet) support setting positive definiteness
)

@property
def ordering(self):
return getattr(self, '_ordering', "metis")
"""The ordering algorithm to use.

Returns
-------
str
"""
return self._ordering

@ordering.setter
def ordering(self, value):
self._ordering = value
self._ordering = str(value)

@property
def _factored(self):
return self.solver.factored

@property
def get_attributes(self):
attrs = super().get_attributes()
attrs['ordering'] = self.ordering
return attrs

def transpose(self):
trans_obj = Mumps.__new__(Mumps)
trans_obj.A = self.A
trans_obj._A = self.A
for attr, value in self.get_attributes().items():
setattr(trans_obj, attr, value)
trans_obj.solver = self.solver
trans_obj.is_symmetric = self.is_symmetric
trans_obj.is_positive_definite = self.is_positive_definite
trans_obj.ordering = self.ordering
trans_obj._transposed = not self._transposed
return trans_obj

T = transpose

def factor(self, A=None):
reuse_analysis = False
if A is not None:
self._set_A(A)
self.A = A
"""(Re)factor the A matrix.

Parameters
----------
A : scipy.sparse.spmatrix
The matrix to be factorized. If a previous factorization has been performed, this will
reuse the previous factorization's analysis.
"""
reuse_analysis = self._factored
do_factor = not self._factored
if A is not None and A is not self.A:
# if it was previously factored then re-use the analysis.
reuse_analysis = self._factored
if not self._factored:
self._set_A(A)
self._A = A
do_factor = True
if do_factor:
pivot_tol = 0.0 if self.is_positive_definite else 0.01
self.solver.factor(
ordering=self.ordering, reuse_analysis=reuse_analysis, pivot_tol=pivot_tol
)

def _solveM(self, rhs):
def _solve_multiple(self, rhs):
self.factor()
if self._transposed:
self.solver.mumps_instance.icntl[9] = 0
Expand All @@ -68,4 +110,4 @@ def _solveM(self, rhs):
sol = self.solver.solve(rhs)
return sol

_solve1 = _solveM
_solve_single = _solve_multiple
90 changes: 65 additions & 25 deletions pymatsolver/direct/pardiso.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,50 @@
from pydiso.mkl_solver import set_mkl_pardiso_threads, get_mkl_pardiso_max_threads

class Pardiso(Base):
"""The Pardiso direct solver.

This solver uses the `pydiso` Intel MKL wrapper to factorize a sparse matrix, and use that
factorization for solving.

Parameters
----------
A : scipy.sparse.spmatrix
Matrix to solve with.
n_threads : int, optional
Number of threads to use for the `Pardiso` routine in Intel's MKL.
is_symmetric : bool, optional
Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and
default to ``False`` if those fail.
is_positive_definite : bool, optional
Whether the matrix is positive definite.
is_hermitian : bool, optional
Whether the matrix is hermitian. By default, it will perform some simple tests to check, and default to
``False`` if those fail.
check_accuracy : bool, optional
Whether to check the accuracy of the solution.
check_rtol : float, optional
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments. If there are any left here a warning will be raised.
"""
Pardiso Solver

https://github.com/simpeg/pydiso
_transposed = False


documentation::

http://www.pardiso-project.org/
"""

_factored = False

def __init__(self, A, **kwargs):
self.A = A
self.set_kwargs(**kwargs)
def __init__(self, A, n_threads=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
self.solver = MKLPardisoSolver(
self.A,
matrix_type=self._matrixType(),
factor=False
)
if n_threads is not None:
self.n_threads = n_threads

def _matrixType(self):
"""
Expand Down Expand Up @@ -65,28 +88,45 @@ def _matrixType(self):
return 13

def factor(self, A=None):
if A is not None:
self._factored = False
self.A = A
if not self._factored:
"""(Re)factor the A matrix.

Parameters
----------
A : scipy.sparse.spmatrix
The matrix to be factorized. If a previous factorization has been performed, this will
reuse the previous factorization's analysis.
"""
if A is not None and self.A is not A:
self._A = A
self.solver.refactor(self.A)
self._factored = True

def _solveM(self, rhs):
self.factor()
sol = self.solver.solve(rhs)
def _solve_multiple(self, rhs):
sol = self.solver.solve(rhs, transpose=self._transposed)
return sol

def transpose(self):
trans_obj = Pardiso.__new__(Pardiso)
trans_obj._A = self.A
for attr, value in self.get_attributes().items():
setattr(trans_obj, attr, value)
trans_obj.solver = self.solver
trans_obj._transposed = not self._transposed
return trans_obj

@property
def n_threads(self):
"""
Number of threads to use for the Pardiso solver routine. This property
is global to all Pardiso solver objects for a single python process.
"""Number of threads to use for the Pardiso solver routine.

This property is global to all Pardiso solver objects for a single python process.

Returns
-------
int
"""
return get_mkl_pardiso_max_threads()

@n_threads.setter
def n_threads(self, n_threads):
set_mkl_pardiso_threads(n_threads)

_solve1 = _solveM
_solve_single = _solve_multiple
Loading
Loading