-
Notifications
You must be signed in to change notification settings - Fork 33
/
matrix_functions_types.py
126 lines (83 loc) · 4.31 KB
/
matrix_functions_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
from dataclasses import dataclass
from commons import AbstractDataclass
@dataclass(init=False)
class MatrixFunctionConfig(AbstractDataclass):
"""Base dataclass for matrix function configurations."""
@dataclass(kw_only=True)
class EigenvalueDecompositionConfig(MatrixFunctionConfig):
"""Configuration for eigenvalue decomposition.
Args:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
"""
retry_double_precision: bool = True
@dataclass(init=False)
class RootInvConfig(MatrixFunctionConfig):
"""Base dataclass for matrix root inverse method configurations."""
@dataclass(kw_only=True)
class EigenConfig(RootInvConfig, EigenvalueDecompositionConfig):
"""Configuration for matrix root inverse via an eigendecomposition.
Args:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True)
exponent_multiplier (float): Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: 1.0)
"""
make_positive_semidefinite: bool = True
exponent_multiplier: float = 1.0
DefaultEigenConfig = EigenConfig()
@dataclass(kw_only=True)
class CoupledNewtonConfig(RootInvConfig):
"""Configuration for matrix root inverse via coupled Newton method.
Args:
max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100)
tolerance (float): Tolerance for computing root inverse using coupled Newton iteration. (Default: 1e-6)
"""
max_iterations: int = 100
tolerance: float = 1e-6
@dataclass(kw_only=True)
class CoupledHigherOrderConfig(RootInvConfig):
"""Configuration for matrix root inverse via coupled higher-order method.
Args:
rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix
before taking matrix root, where lambda_max is an upper bound on maximum eigenvalue. (Default: 0.0)
max_iterations (int): Maximum number of iterations for coupled higher order method. (Default: 100)
tolerance (float): Tolerance for computing root inverse using coupled higher order method. (Default: 1e-8)
order (int): Order of the method. Order must be >= 2. Higher order methods accelerate convergence (fewer iterations),
but can take more matmuls per iteration. order=2 represents Newton's method. (Default: 3)
disable_tf32 (bool): Whether to disable tf32 matmuls or not internally. Highly recommend keeping True,
since tf32 is challenging numerically here. (Default: True)
"""
rel_epsilon: float = 0.0
max_iterations: int = 100
tolerance: float = 1e-8
order: int = 3
disable_tf32: bool = True
@dataclass(init=False)
class EigenvectorConfig(MatrixFunctionConfig):
"""Base dataclass for matrix eigenvector method configurations."""
@dataclass(kw_only=True)
class EighEigenvectorConfig(EigenvectorConfig, EigenvalueDecompositionConfig):
"""Configuration for eigenvectors via an eigendecomposition.
Args:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
"""
DefaultEighEigenvectorConfig = EighEigenvectorConfig()
@dataclass(kw_only=True)
class QRConfig(EigenvectorConfig):
"""Configuration for eigenvectors via orthogonal/simultaneous iterations/QR algorithm.
Args:
max_iterations (int): The maximum number of iterations to perform. (Default: 1)
tolerance (float): The tolerance for determining convergence in terms of the relative change of the eigenvectors estimate.
(Default: 1e-5)
"""
max_iterations: int = 1
tolerance: float = 1e-5