Skip to content

Commit

Permalink
type-stable inner loop for sqrtm (#20214)
Browse files Browse the repository at this point in the history
* type-stable inner loop for sqrtm

As suggested by Ralph_Smith on [discourse](https://discourse.julialang.org/t/review-schur-pade-matrix-powers-speedup/1650/6)

On my machine: speedup x15

* dispatch sqrtm on real-or-not bool

As suggested by @stevengj

* sylvester for numbers
  • Loading branch information
felixrehren authored and andreasnoack committed Feb 3, 2017
1 parent 6cfd329 commit 9c067b6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
2 changes: 2 additions & 0 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@ function sylvester{T<:BlasFloat}(A::StridedMatrix{T},B::StridedMatrix{T},C::Stri
end
sylvester{T<:Integer}(A::StridedMatrix{T},B::StridedMatrix{T},C::StridedMatrix{T}) = sylvester(float(A), float(B), float(C))

sylvester(a::Union{Real,Complex},b::Union{Real,Complex},c::Union{Real,Complex}) = -c / (a + b)

# AX + XA' + C = 0

"""
Expand Down
47 changes: 25 additions & 22 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1867,48 +1867,51 @@ function logm{T<:Union{Float64,Complex{Float64}}}(A0::UpperTriangular{T})
end
logm(A::LowerTriangular) = logm(A.').'

function sqrtm{T}(A::UpperTriangular{T})
n = checksquare(A)
function sqrtm(A::UpperTriangular)
realmatrix = false
if isreal(A)
realmatrix = true
for i = 1:n
for i = 1:checksquare(A)
if real(A[i,i]) < 0
realmatrix = false
break
end
end
end
if realmatrix
TT = typeof(sqrt(zero(T)))
else
TT = typeof(sqrt(complex(-one(T))))
end
R = zeros(TT, n, n)
for j = 1:n
R[j,j] = realmatrix?sqrt(A[j,j]):sqrt(complex(A[j,j]))
sqrtm(A,Val{realmatrix})
end
function sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}})
B = A.data
n = checksquare(B)
t = realmatrix ? typeof(sqrt(zero(T))) : typeof(sqrt(complex(zero(T))))
R = zeros(t, n, n)
tt = typeof(zero(t)*zero(t))
@inbounds for j = 1:n
R[j,j] = realmatrix ? sqrt(B[j,j]) : sqrt(complex(B[j,j]))
for i = j-1:-1:1
r = A[i,j]
for k = i+1:j-1
r::tt = B[i,j]
@simd for k = i+1:j-1
r -= R[i,k]*R[k,j]
end
r==0 || (R[i,j] = r / (R[i,i] + R[j,j]))
r==0 || (R[i,j] = sylvester(R[i,i],R[j,j],-r))
end
end
return UpperTriangular(R)
end
function sqrtm{T}(A::UnitUpperTriangular{T})
n = checksquare(A)
TT = typeof(sqrt(zero(T)))
R = zeros(TT, n, n)
for j = 1:n
R[j,j] = one(T)
B = A.data
n = checksquare(B)
t = typeof(sqrt(zero(T)))
R = eye(t, n, n)
tt = typeof(zero(t)*zero(t))
half = inv(R[1,1]+R[1,1]) # for general, algebraic cases. PR#20214
@inbounds for j = 1:n
for i = j-1:-1:1
r = A[i,j]
for k = i+1:j-1
r::tt = B[i,j]
@simd for k = i+1:j-1
r -= R[i,k]*R[k,j]
end
r==0 || (R[i,j] = r / (R[i,i] + R[j,j]))
r==0 || (R[i,j] = half*r)
end
end
return UnitUpperTriangular(R)
Expand Down

0 comments on commit 9c067b6

Please sign in to comment.