diff --git a/.gitignore b/.gitignore index 6310525..881286d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ Gemfile.lock /.yardoc .rake_tasks~ -*.so \ No newline at end of file +*.so +.vscode + diff --git a/lib/nmatrix/lapack.rb b/lib/nmatrix/lapack.rb index 09e642f..9779c63 100644 --- a/lib/nmatrix/lapack.rb +++ b/lib/nmatrix/lapack.rb @@ -1,8 +1,20 @@ module NumRuby::Linalg - def self.inv(obj) - if obj.is_a?(NMatrix) - return obj.invert + def self.inv(matrix) + if not matrix.is_a?(NMatrix) + raise("Invalid matrix. Not of type NMatrix.") + end + if matrix.dim != 2 + raise("Invalid shape of matrix. Should be 2.") + end + if matrix.shape[0] != matrix.shape[1] + raise("Invalid shape. Expected square matrix.") end + m, n = matrix.shape + + lu, ipiv = NumRuby::Lapack.getrf(matrix) + inv_a = NumRuby::Lapack.getri(lu, ipiv) + + return inv_a end def self.dot(lha, rha) @@ -13,12 +25,22 @@ def self.norm end - def self.solve - + def self.solve(a, b, sym_pos: False, lower: False, assume_a: "gen", transposed: False) + # TODO: implement this and remove NMatrix.solve end - def self.det + def self.det(matrix) + if not matrix.is_a?(NMatrix) + raise("Invalid matrix. Not of type NMatrix.") + end + if matrix.dim != 2 + raise("Invalid shape of matrix. Should be 2.") + end + if matrix.shape[0] != matrix.shape[1] + raise("Invalid shape. Expected square matrix.") + end + return matrix.det end def self.least_square @@ -48,23 +70,47 @@ def self.eigvalsh # Matrix Decomposition - def self.lu(matrix) + def self.lu(matrix, permute_l: False) + if not matrix.is_a?(NMatrix) + raise("Invalid matrix. Not of type NMatrix.") + end + if matrix.dim != 2 + raise("Invalid shape of matrix. Should be 2.") + end + + lu, ipiv = NumRuby::Linalg.getrf(matrix) + # TODO: calulate p, l, u end def self.lu_factor(matrix) + if not matrix.is_a?(NMatrix) + raise("Invalid matrix. Not of type NMatrix.") + end + if matrix.dim != 2 + raise("Invalid shape of matrix. Should be 2.") + end + if matrix.shape[0] != matrix.shape[1] + raise("Invalid shape. Expected square matrix.") + end + lu, ipiv = NumRuby::Linalg.getrf(matrix) + + return [lu, ipiv] end - def self.lu_solve(matrix, rhs_val) + def self.lu_solve(lu, ipiv, b, trans: 0) + if lu.shape[0] != b.shape[0] + raise("Incompatibel dimensions.") + end + x = NumRuby::Lapack.getrs(lu, ipiv, b, trans) + return x end - # Computes the QR decomposition of matrix. + # Computes the SVD decomposition of matrix. # Args: # - input matrix, type: NMatrix - # - mode, type: String - # - pivoting, type: Boolean def self.svd(matrix) end @@ -89,11 +135,24 @@ def self.cholesky_solve(matrix) end - # Computes the QR decomposition of matrix. + # Computes QR decomposition of a matrix. + # + # Calculates the decomposition A = Q*R where Q is unitary/orthogonal and R is upper triangular. + # # Args: - # - input matrix, type: NMatrix + # - matrix, type: NMatrix + # Matrix to be decomposed # - mode, type: String + # Determines what information is to be returned: either both Q and R + # ('full', default), only R ('r') or both Q and R but computed in + # economy-size ('economic', see Notes). The final option 'raw' + # (added in Scipy 0.11) makes the function return two matrices + # (Q, TAU) in the internal format used by LAPACK. # - pivoting, type: Boolean + # Whether or not factorization should include pivoting for rank-revealing + # qr decomposition. If pivoting, compute the decomposition + # A*P = Q*R as above, but where P is chosen such that the diagonal + # of R is non-increasing. def self.qr(matrix, mode: "full", pivoting: false) if not ['full', 'r', 'economic', 'raw'].include?(mode.downcase) raise("Invalid mode. Should be one of ['full', 'r', 'economic', 'raw']") @@ -106,9 +165,42 @@ def self.qr(matrix, mode: "full", pivoting: false) end m, n = matrix.shape - if pivoting == false + if pivoting == true + qr, tau, jpvt = NumRuby::Lapack.geqp3(matrix) + jpvt -= 1 + else qr, tau = NumRuby::Lapack.geqrf(matrix) + end + + # calculate R here for both pivot true & false + + if ['economic', 'raw'].include?(mode.downcase) or m < n + r = NumRuby.triu(matrix) + else + r = NumRuby.triu(matrix[0...n, 0...n]) + end + + if pivoting == true + rj = r, jpvt + else + rj = r + end + + if mode == 'r' + return rj + elsif mode == 'raw' return [qr, tau] end + + if m < n + q = NumRuby::Lapack.orgqr(qr[0...m, 0...m], tau) + elsif mode == 'economic' + q = NumRuby::Lapack.orgqr(qr, tau) + else + # TODO: Implement slice view and set slice + q = NumRuby::Lapack.orgqr(qr, tau) + end + + return q, rj end end \ No newline at end of file