-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Store info in Cholesky type #21976
Store info in Cholesky type #21976
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,11 @@ | |
# through the Hermitian and Symmetric views or exact symmetric or Hermitian elements which | ||
# is checked for and an error is thrown if the check fails. | ||
|
||
# The internal structure is as follows | ||
# - _chol! return the factor and info without checking positive definiteness | ||
# - chol/chol! return the factor and checks for positive definiteness | ||
# - cholfact/cholfact! return Cholesky with checking positive definiteness | ||
|
||
# FixMe? The dispatch below seems overly complicated. One simplification could be to | ||
# merge the two Cholesky types into one. It would remove the need for Val completely but | ||
# the cost would be extra unnecessary/unused fields for the unpivoted Cholesky and runtime | ||
|
@@ -27,9 +32,12 @@ | |
struct Cholesky{T,S<:AbstractMatrix} <: Factorization{T} | ||
factors::S | ||
uplo::Char | ||
info::BlasInt | ||
end | ||
Cholesky{T}(A::AbstractMatrix{T}, uplo::Symbol) = Cholesky{T,typeof(A)}(A, char_uplo(uplo)) | ||
Cholesky{T}(A::AbstractMatrix{T}, uplo::Char) = Cholesky{T,typeof(A)}(A, uplo) | ||
Cholesky{T}(A::AbstractMatrix{T}, uplo::Symbol, info::BlasInt) = | ||
Cholesky{T,typeof(A)}(A, char_uplo(uplo), info) | ||
Cholesky{T}(A::AbstractMatrix{T}, uplo::Char, info::BlasInt) = | ||
Cholesky{T,typeof(A)}(A, uplo, info) | ||
|
||
struct CholeskyPivoted{T,S<:AbstractMatrix} <: Factorization{T} | ||
factors::S | ||
|
@@ -49,11 +57,11 @@ end | |
## BLAS/LAPACK element types | ||
function _chol!(A::StridedMatrix{<:BlasFloat}, ::Type{UpperTriangular}) | ||
C, info = LAPACK.potrf!('U', A) | ||
return @assertposdef UpperTriangular(C) info | ||
return UpperTriangular(C), info | ||
end | ||
function _chol!(A::StridedMatrix{<:BlasFloat}, ::Type{LowerTriangular}) | ||
C, info = LAPACK.potrf!('L', A) | ||
return @assertposdef LowerTriangular(C) info | ||
return LowerTriangular(C), info | ||
end | ||
|
||
## Non BLAS/LAPACK element types (generic) | ||
|
@@ -64,7 +72,10 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular}) | |
for i = 1:k - 1 | ||
A[k,k] -= A[i,k]'A[i,k] | ||
end | ||
Akk = _chol!(A[k,k], UpperTriangular) | ||
Akk, info = _chol!(A[k,k], UpperTriangular) | ||
if info != 0 | ||
return UpperTriangular(A), info | ||
end | ||
A[k,k] = Akk | ||
AkkInv = inv(Akk') | ||
for j = k + 1:n | ||
|
@@ -75,7 +86,7 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular}) | |
end | ||
end | ||
end | ||
return UpperTriangular(A) | ||
return UpperTriangular(A), convert(BlasInt, 0) # TODO: If we get here, do we know A is pos. def? | ||
end | ||
function _chol!(A::AbstractMatrix, ::Type{LowerTriangular}) | ||
n = checksquare(A) | ||
|
@@ -84,7 +95,10 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular}) | |
for i = 1:k - 1 | ||
A[k,k] -= A[k,i]*A[k,i]' | ||
end | ||
Akk = _chol!(A[k,k], LowerTriangular) | ||
Akk, info = _chol!(A[k,k], LowerTriangular) | ||
if info != 0 | ||
return LowerTriangular(A), info | ||
end | ||
A[k,k] = Akk | ||
AkkInv = inv(Akk) | ||
for j = 1:k | ||
|
@@ -99,30 +113,33 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular}) | |
end | ||
end | ||
end | ||
return LowerTriangular(A) | ||
return LowerTriangular(A), convert(BlasInt, 0) # TODO: If we get here, do we know A is pos. def? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd think so. What could be wrong? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay thanks. Just wanted to make sure, will remove the comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was paranoid so confirmed with experiments: for N in (10, 100, 1000), i in 1:1000
A = rand(N, N); B = A'A
for M in (A, B)
C, info = invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(M), LowerTriangular)
infoblas = cholfact(Hermitian(M, :L)).info
info == infoblas == 0 || (info > 0 && infoblas > 0) || error()
end
for M in (A, B)
C, info = invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(M), UpperTriangular)
infoblas = cholfact(Hermitian(M, :U)).info
info == infoblas == 0 || (info > 0 && infoblas > 0) || error()
end
end |
||
end | ||
|
||
## Numbers | ||
function _chol!(x::Number, uplo) | ||
rx = real(x) | ||
if rx != abs(x) | ||
throw(ArgumentError("x must be positive semidefinite")) | ||
end | ||
rxr = sqrt(rx) | ||
convert(promote_type(typeof(x), typeof(rxr)), rxr) | ||
rxr = sqrt(abs(rx)) | ||
rval = convert(promote_type(typeof(x), typeof(rxr)), rxr) | ||
rx == abs(x) ? (rval, convert(BlasInt, 0)) : (rval, convert(BlasInt, 1)) | ||
end | ||
|
||
chol!(x::Number, uplo) = ((C, info) = _chol!(x, uplo); @assertposdef C info) | ||
|
||
non_hermitian_error(f) = throw(ArgumentError("matrix is not symmetric/" * | ||
"Hermitian. This error can be avoided by calling $f(Hermitian(A)) " * | ||
"which will ignore either the upper or lower triangle of the matrix.")) | ||
|
||
# chol!. Destructive methods for computing Cholesky factor of real symmetric or Hermitian | ||
# matrix | ||
chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) = | ||
_chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular) | ||
function chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) | ||
C, info = _chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular) | ||
@assertposdef C info | ||
end | ||
function chol!(A::StridedMatrix) | ||
ishermitian(A) || non_hermitian_error("chol!") | ||
return _chol!(A, UpperTriangular) | ||
C, info = _chol!(A, UpperTriangular) | ||
@assertposdef C info | ||
end | ||
|
||
|
||
|
@@ -184,7 +201,7 @@ julia> chol(16) | |
4.0 | ||
``` | ||
""" | ||
chol(x::Number, args...) = _chol!(x, nothing) | ||
chol(x::Number, args...) = ((C, info) = _chol!(x, nothing); @assertposdef C info) | ||
|
||
|
||
|
||
|
@@ -193,9 +210,11 @@ chol(x::Number, args...) = _chol!(x, nothing) | |
## No pivoting | ||
function cholfact!(A::RealHermSymComplexHerm, ::Type{Val{false}}) | ||
if A.uplo == 'U' | ||
Cholesky(_chol!(A.data, UpperTriangular).data, 'U') | ||
CU, info = _chol!(A.data, UpperTriangular) | ||
Cholesky(CU.data, 'U', info) | ||
else | ||
Cholesky(_chol!(A.data, LowerTriangular).data, 'L') | ||
CL, info = _chol!(A.data, LowerTriangular) | ||
Cholesky(CL.data, 'L', info) | ||
end | ||
end | ||
|
||
|
@@ -354,14 +373,15 @@ end | |
|
||
## Number | ||
function cholfact(x::Number, uplo::Symbol=:U) | ||
xf = fill(chol(x), 1, 1) | ||
Cholesky(xf, uplo) | ||
C, info = _chol!(x, uplo) | ||
xf = fill(C, 1, 1) | ||
Cholesky(xf, uplo, info) | ||
end | ||
|
||
|
||
function convert(::Type{Cholesky{T}}, C::Cholesky) where T | ||
Cnew = convert(AbstractMatrix{T}, C.factors) | ||
Cholesky{T, typeof(Cnew)}(Cnew, C.uplo) | ||
Cholesky{T, typeof(Cnew)}(Cnew, C.uplo, C.info) | ||
end | ||
convert(::Type{Factorization{T}}, C::Cholesky{T}) where {T} = C | ||
convert(::Type{Factorization{T}}, C::Cholesky) where {T} = convert(Cholesky{T}, C) | ||
|
@@ -386,7 +406,7 @@ convert(::Type{Matrix}, F::CholeskyPivoted) = convert(Array, convert(AbstractArr | |
convert(::Type{Array}, F::CholeskyPivoted) = convert(Matrix, F) | ||
full(F::CholeskyPivoted) = convert(AbstractArray, F) | ||
|
||
copy(C::Cholesky) = Cholesky(copy(C.factors), C.uplo) | ||
copy(C::Cholesky) = Cholesky(copy(C.factors), C.uplo, C.info) | ||
copy(C::CholeskyPivoted) = CholeskyPivoted(copy(C.factors), C.uplo, C.piv, C.rank, C.tol, C.info) | ||
|
||
size(C::Union{Cholesky, CholeskyPivoted}) = size(C.factors) | ||
|
@@ -417,7 +437,7 @@ show(io::IO, C::Cholesky{<:Any,<:AbstractMatrix}) = | |
(println(io, "$(typeof(C)) with factor:");show(io,C[:UL])) | ||
|
||
A_ldiv_B!(C::Cholesky{T,<:AbstractMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} = | ||
LAPACK.potrs!(C.uplo, C.factors, B) | ||
@assertposdef LAPACK.potrs!(C.uplo, C.factors, B) C.info | ||
|
||
function A_ldiv_B!(C::Cholesky{<:Any,<:AbstractMatrix}, B::StridedVecOrMat) | ||
if C.uplo == 'L' | ||
|
@@ -465,16 +485,18 @@ function A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix) | |
end | ||
|
||
function det(C::Cholesky) | ||
C.info == 0 || throw(PosDefException(C.info)) | ||
dd = one(real(eltype(C))) | ||
for i in 1:size(C.factors,1) | ||
@inbounds for i in 1:size(C.factors,1) | ||
dd *= real(C.factors[i,i])^2 | ||
end | ||
dd | ||
end | ||
|
||
function logdet(C::Cholesky) | ||
C.info == 0 || throw(PosDefException(C.info)) | ||
dd = zero(real(eltype(C))) | ||
for i in 1:size(C.factors,1) | ||
@inbounds for i in 1:size(C.factors,1) | ||
dd += log(real(C.factors[i,i])) | ||
end | ||
dd + dd # instead of 2.0dd which can change the type | ||
|
@@ -505,10 +527,9 @@ function logdet(C::CholeskyPivoted) | |
end | ||
|
||
inv!(C::Cholesky{<:BlasFloat,<:StridedMatrix}) = | ||
copytri!(LAPACK.potri!(C.uplo, C.factors), C.uplo, true) | ||
@assertposdef copytri!(LAPACK.potri!(C.uplo, C.factors), C.uplo, true) C.info | ||
|
||
inv(C::Cholesky{<:BlasFloat,<:StridedMatrix}) = | ||
inv!(copy(C)) | ||
inv(C::Cholesky{<:BlasFloat,<:StridedMatrix}) = inv!(copy(C)) | ||
|
||
function inv(C::CholeskyPivoted) | ||
chkfullrank(C) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ debug = false | |
|
||
using Base.Test | ||
|
||
using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted | ||
using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted, PosDefException | ||
|
||
n = 10 | ||
|
||
|
@@ -60,7 +60,7 @@ for eltya in (Float32, Float64, Complex64, Complex128, BigFloat, Int) | |
|
||
apos = apd[1,1] # test chol(x::Number), needs x>0 | ||
@test all(x -> x ≈ √apos, cholfact(apos).factors) | ||
@test_throws ArgumentError chol(-one(eltya)) | ||
@test_throws PosDefException chol(-one(eltya)) | ||
|
||
if eltya <: Real | ||
capds = cholfact(apds) | ||
|
@@ -194,10 +194,9 @@ end | |
|
||
begin | ||
# Cholesky factor of Matrix with non-commutative elements, here 2x2-matrices | ||
|
||
X = Matrix{Float64}[0.1*rand(2,2) for i in 1:3, j = 1:3] | ||
L = full(Base.LinAlg._chol!(X*X', LowerTriangular)) | ||
U = full(Base.LinAlg._chol!(X*X', UpperTriangular)) | ||
L = full(Base.LinAlg._chol!(X*X', LowerTriangular)[1]) | ||
U = full(Base.LinAlg._chol!(X*X', UpperTriangular)[1]) | ||
XX = full(X*X') | ||
|
||
@test sum(sum(norm, L*L' - XX)) < eps() | ||
|
@@ -212,8 +211,8 @@ for elty in (Float32, Float64, Complex{Float32}, Complex{Float64}) | |
A = randn(5,5) | ||
end | ||
A = convert(Matrix{elty}, A'A) | ||
@test full(cholfact(A)[:L]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(A), LowerTriangular)) | ||
@test full(cholfact(A)[:U]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(A), UpperTriangular)) | ||
@test full(cholfact(A)[:L]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(A), LowerTriangular)[1]) | ||
@test full(cholfact(A)[:U]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(A), UpperTriangular)[1]) | ||
end | ||
|
||
# Test up- and downdates | ||
|
@@ -272,3 +271,14 @@ end | |
|
||
# Fail for non-BLAS element types | ||
@test_throws ArgumentError cholfact!(Hermitian(rand(Float16, 5,5)), Val{true}) | ||
|
||
@testset "throw for non positive matrix" begin | ||
for T in (Float32, Float64, Complex64, Complex128) | ||
A = T[1 2; 2 1]; B = T[1, 1] | ||
C = cholfact(A) | ||
@show typeof(A), typeof(B), typeof(C.factors) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. debugging output left in? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, sry |
||
@test_throws PosDefException C\B | ||
@test_throws PosDefException det(C) | ||
@test_throws PosDefException logdet(C) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
returns ?