Skip to content

Commit

Permalink
Merge pull request #21184 from iamnapo/anj/powm
Browse files Browse the repository at this point in the history
Fixed the algorithm for powers of a matrix.
  • Loading branch information
tkelman authored Apr 24, 2017
2 parents 8df5fbe + 2fbeba3 commit f56147d
Show file tree
Hide file tree
Showing 7 changed files with 414 additions and 29 deletions.
76 changes: 64 additions & 12 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function scale!(X::Array{T}, s::Real) where T<:BlasComplex
X
end

#Test whether a matrix is positive-definite
# Test whether a matrix is positive-definite
isposdef!(A::StridedMatrix{<:BlasFloat}, UL::Symbol) = LAPACK.potrf!(char_uplo(UL), A)[2] == 0

"""
Expand Down Expand Up @@ -323,18 +323,67 @@ kron(a::AbstractVector, b::AbstractVector)=vec(kron(reshape(a,length(a),1),resha
kron(a::AbstractMatrix, b::AbstractVector)=kron(a,reshape(b,length(b),1))
kron(a::AbstractVector, b::AbstractMatrix)=kron(reshape(a,length(a),1),b)

^(A::AbstractMatrix, p::Integer) = p < 0 ? inv(A^-p) : Base.power_by_squaring(A,p)

function ^(A::AbstractMatrix, p::Number)
# Matrix power
^{T}(A::AbstractMatrix{T}, p::Integer) = p < 0 ? Base.power_by_squaring(inv(A), -p) : Base.power_by_squaring(A, p)
function ^{T}(A::AbstractMatrix{T}, p::Real)
# For integer powers, use repeated squaring
if isinteger(p)
return A^Integer(real(p))
TT = Base.promote_op(^, eltype(A), typeof(p))
return (TT == eltype(A) ? A : copy!(similar(A, TT), A))^Integer(p)
end

# If possible, use diagonalization
if T <: Real && issymmetric(A)
return (Symmetric(A)^p)
end
if ishermitian(A)
return (Hermitian(A)^p)
end

n = checksquare(A)

# Quicker return if A is diagonal
if isdiag(A)
retmat = copy(A)
for i in 1:n
retmat[i, i] = retmat[i, i] ^ p
end
return retmat
end

# Otherwise, use Schur decomposition
if istriu(A)
# Integer part
retmat = A ^ floor(p)
# Real part
if p - floor(p) == 0.5
# special case: A^0.5 === sqrtm(A)
retmat = retmat * sqrtm(A)
else
retmat = retmat * powm!(UpperTriangular(float.(A)), real(p - floor(p)))
end
else
S,Q,d = schur(complex(A))
# Integer part
R = S ^ floor(p)
# Real part
if p - floor(p) == 0.5
# special case: A^0.5 === sqrtm(A)
R = R * sqrtm(S)
else
R = R * powm!(UpperTriangular(float.(S)), real(p - floor(p)))
end
retmat = Q * R * Q'
end

# if A has nonpositive real eigenvalues, retmat is a nonprincipal matrix power.
if isreal(retmat)
return real(retmat)
else
return retmat
end
checksquare(A)
v, X = eig(A)
any(v.<0) && (v = complex(v))
Xinv = ishermitian(A) ? X' : inv(X)
(X * Diagonal(v.^p)) * Xinv
end
^(A::AbstractMatrix, p::Number) = expm(p*logm(A))

# Matrix exponential

Expand Down Expand Up @@ -466,7 +515,7 @@ function rcswap!(i::Integer, j::Integer, X::StridedMatrix{<:Number})
end

"""
logm(A::StridedMatrix)
logm(A{T}::StridedMatrix{T})
If `A` has no negative real eigenvalue, compute the principal matrix logarithm of `A`, i.e.
the unique matrix ``X`` such that ``e^X = A`` and ``-\\pi < Im(\\lambda) < \\pi`` for all
Expand Down Expand Up @@ -497,8 +546,11 @@ julia> logm(A)
0.0 1.0
```
"""
function logm(A::StridedMatrix)
function logm{T}(A::StridedMatrix{T})
# If possible, use diagonalization
if issymmetric(A) && T <: Real
return full(logm(Symmetric(A)))
end
if ishermitian(A)
return full(logm(Hermitian(A)))
end
Expand Down
1 change: 1 addition & 0 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ end
# identity matrices via eye(Diagonal{type},n)
eye(::Type{Diagonal{T}}, n::Int) where {T} = Diagonal(ones(T,n))

# Matrix functions
expm(D::Diagonal) = Diagonal(exp.(D.diag))
expm(D::Diagonal{<:AbstractMatrix}) = Diagonal(expm.(D.diag))
logm(D::Diagonal) = Diagonal(log.(D.diag))
Expand Down
3 changes: 2 additions & 1 deletion base/linalg/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ include("hessenberg.jl")
include("lq.jl")
include("eigen.jl")
include("svd.jl")
include("schur.jl")
include("symmetric.jl")
include("cholesky.jl")
include("lu.jl")
Expand All @@ -274,6 +273,8 @@ include("givens.jl")
include("special.jl")
include("bitarray.jl")
include("ldlt.jl")
include("schur.jl")


include("arpack.jl")
include("arnoldi.jl")
Expand Down
6 changes: 6 additions & 0 deletions base/linalg/schur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ function schur(A::StridedMatrix)
SchurF = schurfact(A)
SchurF[:T], SchurF[:Z], SchurF[:values]
end
schur(A::Symmetric) = schur(full(A))
schur(A::Hermitian) = schur(full(A))
schur(A::UpperTriangular) = schur(full(A))
schur(A::LowerTriangular) = schur(full(A))
schur(A::Tridiagonal) = schur(full(A))


"""
ordschur!(F::Schur, select::Union{Vector{Bool},BitVector}) -> F::Schur
Expand Down
61 changes: 52 additions & 9 deletions base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

#Symmetric and Hermitian matrices
# Symmetric and Hermitian matrices
struct Symmetric{T,S<:AbstractMatrix} <: AbstractMatrix{T}
data::S
uplo::Char
Expand Down Expand Up @@ -181,7 +181,7 @@ trace(A::Hermitian) = real(trace(A.data))
Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo)
Base.conj!(A::HermOrSym) = typeof(A)(conj!(A.data), A.uplo)

#tril/triu
# tril/triu
function tril(A::Hermitian, k::Integer=0)
if A.uplo == 'U' && k <= 0
return tril!(A.data',k)
Expand Down Expand Up @@ -235,7 +235,7 @@ end
## Matvec
A_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::Symmetric{T,<:StridedMatrix}, x::StridedVector{T}) = BLAS.symv!(A.uplo, one(T), A.data, x, zero(T), y)
A_mul_B!{T<:BlasComplex}(y::StridedVector{T}, A::Hermitian{T,<:StridedMatrix}, x::StridedVector{T}) = BLAS.hemv!(A.uplo, one(T), A.data, x, zero(T), y)
##Matmat
## Matmat
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::Symmetric{T,<:StridedMatrix}, B::StridedMatrix{T}) = BLAS.symm!('L', A.uplo, one(T), A.data, B, zero(T), C)
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Symmetric{T,<:StridedMatrix}) = BLAS.symm!('R', B.uplo, one(T), B.data, A, zero(T), C)
A_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::Hermitian{T,<:StridedMatrix}, B::StridedMatrix{T}) = BLAS.hemm!('L', A.uplo, one(T), A.data, B, zero(T), C)
Expand Down Expand Up @@ -403,7 +403,54 @@ function svdvals!{T<:Real,S}(A::Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{
return sort!(vals, rev = true)
end

#Matrix-valued functions
# Matrix functions
function ^{T<:Real}(A::Symmetric{T}, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
function ^{T<:Real}(A::Symmetric{T}, p::Real)
F = eigfact(A)
if all-> λ 0, F.values)
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
else
retmat = (F.vectors * Diagonal((complex(F.values)).^p)) * F.vectors'
end
return Symmetric(retmat)
end
function ^(A::Hermitian, p::Integer)
n = checksquare(A)
if p < 0
retmat = Base.power_by_squaring(inv(A), -p)
else
retmat = Base.power_by_squaring(A, p)
end
for i = 1:n
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
function ^{T}(A::Hermitian{T}, p::Real)
n = checksquare(A)
F = eigfact(A)
if all-> λ 0, F.values)
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
if T <: Real
return Hermitian(retmat)
else
for i = 1:n
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
else
retmat = (F.vectors * Diagonal((complex(F.values).^p))) * F.vectors'
return retmat
end
end

function expm(A::Symmetric)
F = eigfact(A)
return Symmetric((F.vectors * Diagonal(exp.(F.values))) * F.vectors')
Expand All @@ -423,10 +470,8 @@ function expm(A::Hermitian{T}) where T
end

for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt])

@eval begin

function ($funm)(A::Symmetric)
function ($funm){T<:Real}(A::Symmetric{T})
F = eigfact(A)
if isposdef(F)
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
Expand Down Expand Up @@ -454,7 +499,5 @@ for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt])
return retmat
end
end

end

end
Loading

0 comments on commit f56147d

Please sign in to comment.