Skip to content

Commit

Permalink
Merge pull request #30 from Uditgulati/linalg
Browse files Browse the repository at this point in the history
Implement NumRuby::Linalg
  • Loading branch information
prasunanand authored Dec 2, 2019
2 parents 5428e46 + 4e4daf2 commit ba6ab0c
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 15 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
Gemfile.lock
/.yardoc
.rake_tasks~
*.so
*.so
.vscode

120 changes: 106 additions & 14 deletions lib/nmatrix/lapack.rb
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
module NumRuby::Linalg
def self.inv(obj)
if obj.is_a?(NMatrix)
return obj.invert
def self.inv(matrix)
if not matrix.is_a?(NMatrix)
raise("Invalid matrix. Not of type NMatrix.")
end
if matrix.dim != 2
raise("Invalid shape of matrix. Should be 2.")
end
if matrix.shape[0] != matrix.shape[1]
raise("Invalid shape. Expected square matrix.")
end
m, n = matrix.shape

lu, ipiv = NumRuby::Lapack.getrf(matrix)
inv_a = NumRuby::Lapack.getri(lu, ipiv)

return inv_a
end

def self.dot(lha, rha)
Expand All @@ -13,12 +25,22 @@ def self.norm

end

def self.solve

def self.solve(a, b, sym_pos: False, lower: False, assume_a: "gen", transposed: False)
# TODO: implement this and remove NMatrix.solve
end

def self.det
def self.det(matrix)
if not matrix.is_a?(NMatrix)
raise("Invalid matrix. Not of type NMatrix.")
end
if matrix.dim != 2
raise("Invalid shape of matrix. Should be 2.")
end
if matrix.shape[0] != matrix.shape[1]
raise("Invalid shape. Expected square matrix.")
end

return matrix.det
end

def self.least_square
Expand Down Expand Up @@ -48,23 +70,47 @@ def self.eigvalsh
# Matrix Decomposition


def self.lu(matrix)
def self.lu(matrix, permute_l: False)
if not matrix.is_a?(NMatrix)
raise("Invalid matrix. Not of type NMatrix.")
end
if matrix.dim != 2
raise("Invalid shape of matrix. Should be 2.")
end

lu, ipiv = NumRuby::Linalg.getrf(matrix)

# TODO: calulate p, l, u
end

def self.lu_factor(matrix)
if not matrix.is_a?(NMatrix)
raise("Invalid matrix. Not of type NMatrix.")
end
if matrix.dim != 2
raise("Invalid shape of matrix. Should be 2.")
end
if matrix.shape[0] != matrix.shape[1]
raise("Invalid shape. Expected square matrix.")
end

lu, ipiv = NumRuby::Linalg.getrf(matrix)

return [lu, ipiv]
end

def self.lu_solve(matrix, rhs_val)
def self.lu_solve(lu, ipiv, b, trans: 0)
if lu.shape[0] != b.shape[0]
raise("Incompatibel dimensions.")
end

x = NumRuby::Lapack.getrs(lu, ipiv, b, trans)
return x
end

# Computes the QR decomposition of matrix.
# Computes the SVD decomposition of matrix.
# Args:
# - input matrix, type: NMatrix
# - mode, type: String
# - pivoting, type: Boolean
def self.svd(matrix)

end
Expand All @@ -89,11 +135,24 @@ def self.cholesky_solve(matrix)

end

# Computes the QR decomposition of matrix.
# Computes QR decomposition of a matrix.
#
# Calculates the decomposition A = Q*R where Q is unitary/orthogonal and R is upper triangular.
#
# Args:
# - input matrix, type: NMatrix
# - matrix, type: NMatrix
# Matrix to be decomposed
# - mode, type: String
# Determines what information is to be returned: either both Q and R
# ('full', default), only R ('r') or both Q and R but computed in
# economy-size ('economic', see Notes). The final option 'raw'
# (added in Scipy 0.11) makes the function return two matrices
# (Q, TAU) in the internal format used by LAPACK.
# - pivoting, type: Boolean
# Whether or not factorization should include pivoting for rank-revealing
# qr decomposition. If pivoting, compute the decomposition
# A*P = Q*R as above, but where P is chosen such that the diagonal
# of R is non-increasing.
def self.qr(matrix, mode: "full", pivoting: false)
if not ['full', 'r', 'economic', 'raw'].include?(mode.downcase)
raise("Invalid mode. Should be one of ['full', 'r', 'economic', 'raw']")
Expand All @@ -106,9 +165,42 @@ def self.qr(matrix, mode: "full", pivoting: false)
end
m, n = matrix.shape

if pivoting == false
if pivoting == true
qr, tau, jpvt = NumRuby::Lapack.geqp3(matrix)
jpvt -= 1
else
qr, tau = NumRuby::Lapack.geqrf(matrix)
end

# calculate R here for both pivot true & false

if ['economic', 'raw'].include?(mode.downcase) or m < n
r = NumRuby.triu(matrix)
else
r = NumRuby.triu(matrix[0...n, 0...n])
end

if pivoting == true
rj = r, jpvt
else
rj = r
end

if mode == 'r'
return rj
elsif mode == 'raw'
return [qr, tau]
end

if m < n
q = NumRuby::Lapack.orgqr(qr[0...m, 0...m], tau)
elsif mode == 'economic'
q = NumRuby::Lapack.orgqr(qr, tau)
else
# TODO: Implement slice view and set slice
q = NumRuby::Lapack.orgqr(qr, tau)
end

return q, rj
end
end

0 comments on commit ba6ab0c

Please sign in to comment.