-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathpardiso.py
More file actions
134 lines (114 loc) · 4.42 KB
/
pardiso.py
File metadata and controls
134 lines (114 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from pymatsolver.solvers import Base
try:
from pydiso.mkl_solver import MKLPardisoSolver
from pydiso.mkl_solver import set_mkl_pardiso_threads, get_mkl_pardiso_max_threads
_available = True
except ImportError:
_available = False
class Pardiso(Base):
"""The Pardiso direct solver.
This solver uses the `pydiso` Intel MKL wrapper to factorize a sparse matrix, and use that
factorization for solving.
Parameters
----------
A : scipy.sparse.spmatrix
Matrix to solve with.
n_threads : int, optional
Number of threads to use for the `Pardiso` routine in Intel's MKL.
is_symmetric : bool, optional
Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and
default to ``False`` if those fail.
is_positive_definite : bool, optional
Whether the matrix is positive definite.
is_hermitian : bool, optional
Whether the matrix is hermitian. By default, it will perform some simple tests to check, and default to
``False`` if those fail.
check_accuracy : bool, optional
Whether to check the accuracy of the solution.
check_rtol : float, optional
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
**kwargs
Extra keyword arguments. If there are any left here a warning will be raised.
"""
_transposed = False
def __init__(self, A, n_threads=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs):
if not _available:
raise ImportError("Pardiso solver requires the pydiso package to be installed.")
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs)
self.solver = MKLPardisoSolver(
self.A,
matrix_type=self._matrixType(),
factor=False
)
if n_threads is not None:
self.n_threads = n_threads
def _matrixType(self):
"""
Set basic matrix type:
Real::
1: structurally symmetric
2: symmetric positive definite
-2: symmetric indefinite
11: nonsymmetric
Complex::
6: symmetric
4: hermitian positive definite
-4: hermitian indefinite
3: structurally symmetric
13: nonsymmetric
"""
if self.is_real:
if self.is_symmetric:
if self.is_positive_definite:
return 2
else:
return -2
else:
return 11
else:
if self.is_symmetric:
return 6
elif self.is_hermitian:
if self.is_positive_definite:
return 4
else:
return -4
else:
return 13
def factor(self, A=None):
"""(Re)factor the A matrix.
Parameters
----------
A : scipy.sparse.spmatrix
The matrix to be factorized. If a previous factorization has been performed, this will
reuse the previous factorization's analysis.
"""
if A is not None and self.A is not A:
self._A = A
self.solver.refactor(self.A)
def _solve_multiple(self, rhs):
sol = self.solver.solve(rhs, transpose=self._transposed)
return sol
def transpose(self):
trans_obj = Pardiso.__new__(Pardiso)
trans_obj._A = self.A
for attr, value in self.get_attributes().items():
setattr(trans_obj, attr, value)
trans_obj.solver = self.solver
trans_obj._transposed = not self._transposed
return trans_obj
@property
def n_threads(self):
"""Number of threads to use for the Pardiso solver routine.
This property is global to all Pardiso solver objects for a single python process.
Returns
-------
int
"""
return get_mkl_pardiso_max_threads()
@n_threads.setter
def n_threads(self, n_threads):
set_mkl_pardiso_threads(n_threads)
_solve_single = _solve_multiple