Skip to content

Commit

Permalink
Merge pull request #48 from simpeg/consistent_refactor
Browse files Browse the repository at this point in the history
Refactor base class
  • Loading branch information
jcapriot authored Oct 10, 2024
2 parents cf3f2e0 + 5047329 commit 3297054
Show file tree
Hide file tree
Showing 13 changed files with 1,356 additions and 342 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Testing
on:
push:
branches:
- '*'
- 'main'
tags:
- 'v*'
pull_request:
Expand Down Expand Up @@ -75,6 +75,8 @@ jobs:
uses: codecov/codecov-action@v4
with:
verbose: true # optional (default = false)
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

distribute:
name: Distributing from 3.8
Expand Down
7 changes: 4 additions & 3 deletions pymatsolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
.. autosummary::
:toctree: generated/
Triangle
Forward
Backward
Expand Down Expand Up @@ -60,9 +61,9 @@
}

# Simple solvers
from .solvers import Diagonal, Forward, Backward
from .wrappers import WrapDirect
from .wrappers import WrapIterative
from .solvers import Diagonal, Triangle, Forward, Backward
from .wrappers import wrap_direct, WrapDirect
from .wrappers import wrap_iterative, WrapIterative

# Scipy Iterative solvers
from .iterative import SolverCG
Expand Down
92 changes: 67 additions & 25 deletions pymatsolver/direct/mumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,106 @@
from mumps import Context

class Mumps(Base):
"""
Mumps solver
"""The MUMPS direct solver.
This solver uses the python-mumps wrappers to factorize a sparse matrix, and use that factorization for solving.
Parameters
----------
A
Matrix to solve with.
ordering : str, default 'metis'
Which ordering algorithm to use. See the `python-mumps` documentation for more details.
is_symmetric : bool, optional
Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and
default to ``False`` if those fail.
is_positive_definite : bool, optional
Whether the matrix is positive definite.
check_accuracy : bool, optional
Whether to check the accuracy of the solution.
check_rtol : float, optional
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments. If there are any left here a warning will be raised.
"""
_transposed = False
ordering = ''

def __init__(self, A, **kwargs):
self.set_kwargs(**kwargs)
def __init__(self, A, ordering=None, is_symmetric=None, is_positive_definite=False, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
is_hermitian = kwargs.pop('is_hermitian', False)
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
if ordering is None:
ordering = "metis"
self.ordering = ordering
self.solver = Context()
self._set_A(A)
self.A = A
self._set_A(self.A)

def _set_A(self, A):
self.solver.set_matrix(
A,
symmetric=self.is_symmetric,
# positive_definite=self.is_positive_definite # doesn't (yet) support setting positive definiteness
)

@property
def ordering(self):
return getattr(self, '_ordering', "metis")
"""The ordering algorithm to use.
Returns
-------
str
"""
return self._ordering

@ordering.setter
def ordering(self, value):
self._ordering = value
self._ordering = str(value)

@property
def _factored(self):
return self.solver.factored

@property
def get_attributes(self):
attrs = super().get_attributes()
attrs['ordering'] = self.ordering
return attrs

def transpose(self):
trans_obj = Mumps.__new__(Mumps)
trans_obj.A = self.A
trans_obj._A = self.A
for attr, value in self.get_attributes().items():
setattr(trans_obj, attr, value)
trans_obj.solver = self.solver
trans_obj.is_symmetric = self.is_symmetric
trans_obj.is_positive_definite = self.is_positive_definite
trans_obj.ordering = self.ordering
trans_obj._transposed = not self._transposed
return trans_obj

T = transpose

def factor(self, A=None):
reuse_analysis = False
if A is not None:
self._set_A(A)
self.A = A
"""(Re)factor the A matrix.
Parameters
----------
A : scipy.sparse.spmatrix
The matrix to be factorized. If a previous factorization has been performed, this will
reuse the previous factorization's analysis.
"""
reuse_analysis = self._factored
do_factor = not self._factored
if A is not None and A is not self.A:
# if it was previously factored then re-use the analysis.
reuse_analysis = self._factored
if not self._factored:
self._set_A(A)
self._A = A
do_factor = True
if do_factor:
pivot_tol = 0.0 if self.is_positive_definite else 0.01
self.solver.factor(
ordering=self.ordering, reuse_analysis=reuse_analysis, pivot_tol=pivot_tol
)

def _solveM(self, rhs):
def _solve_multiple(self, rhs):
self.factor()
if self._transposed:
self.solver.mumps_instance.icntl[9] = 0
Expand All @@ -68,4 +110,4 @@ def _solveM(self, rhs):
sol = self.solver.solve(rhs)
return sol

_solve1 = _solveM
_solve_single = _solve_multiple
90 changes: 65 additions & 25 deletions pymatsolver/direct/pardiso.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,50 @@
from pydiso.mkl_solver import set_mkl_pardiso_threads, get_mkl_pardiso_max_threads

class Pardiso(Base):
"""The Pardiso direct solver.
This solver uses the `pydiso` Intel MKL wrapper to factorize a sparse matrix, and use that
factorization for solving.
Parameters
----------
A : scipy.sparse.spmatrix
Matrix to solve with.
n_threads : int, optional
Number of threads to use for the `Pardiso` routine in Intel's MKL.
is_symmetric : bool, optional
Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and
default to ``False`` if those fail.
is_positive_definite : bool, optional
Whether the matrix is positive definite.
is_hermitian : bool, optional
Whether the matrix is hermitian. By default, it will perform some simple tests to check, and default to
``False`` if those fail.
check_accuracy : bool, optional
Whether to check the accuracy of the solution.
check_rtol : float, optional
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments. If there are any left here a warning will be raised.
"""
Pardiso Solver

https://github.com/simpeg/pydiso
_transposed = False

documentation::
http://www.pardiso-project.org/
"""

_factored = False

def __init__(self, A, **kwargs):
self.A = A
self.set_kwargs(**kwargs)
def __init__(self, A, n_threads=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
self.solver = MKLPardisoSolver(
self.A,
matrix_type=self._matrixType(),
factor=False
)
if n_threads is not None:
self.n_threads = n_threads

def _matrixType(self):
"""
Expand Down Expand Up @@ -65,28 +88,45 @@ def _matrixType(self):
return 13

def factor(self, A=None):
if A is not None:
self._factored = False
self.A = A
if not self._factored:
"""(Re)factor the A matrix.
Parameters
----------
A : scipy.sparse.spmatrix
The matrix to be factorized. If a previous factorization has been performed, this will
reuse the previous factorization's analysis.
"""
if A is not None and self.A is not A:
self._A = A
self.solver.refactor(self.A)
self._factored = True

def _solveM(self, rhs):
self.factor()
sol = self.solver.solve(rhs)
def _solve_multiple(self, rhs):
sol = self.solver.solve(rhs, transpose=self._transposed)
return sol

def transpose(self):
trans_obj = Pardiso.__new__(Pardiso)
trans_obj._A = self.A
for attr, value in self.get_attributes().items():
setattr(trans_obj, attr, value)
trans_obj.solver = self.solver
trans_obj._transposed = not self._transposed
return trans_obj

@property
def n_threads(self):
"""
Number of threads to use for the Pardiso solver routine. This property
is global to all Pardiso solver objects for a single python process.
"""Number of threads to use for the Pardiso solver routine.
This property is global to all Pardiso solver objects for a single python process.
Returns
-------
int
"""
return get_mkl_pardiso_max_threads()

@n_threads.setter
def n_threads(self, n_threads):
set_mkl_pardiso_threads(n_threads)

_solve1 = _solveM
_solve_single = _solve_multiple
Loading

0 comments on commit 3297054

Please sign in to comment.