From a9cbdaaa1f3a1dcf167d3ba8f0384333eae8ee25 Mon Sep 17 00:00:00 2001 From: Andreas Noack Jensen Date: Sun, 4 Nov 2012 10:31:48 +0100 Subject: [PATCH] Fix of det for singular matrices by changing solvers from Lapack to return info parameter for singular systems. Added info parameter to LUDense. Updated code that use the Lapack solvers. Added fester det solutions for triangular matrices. Changes Lapack LU routine getrf to allow rectangular matrices. --- base/lapack.jl | 22 ++++++++--------- base/linalg_dense.jl | 59 +++++++++++++++++++++++++++++++++----------- 2 files changed, 55 insertions(+), 26 deletions(-) diff --git a/base/lapack.jl b/base/lapack.jl index 1599f37fd488e..90ff2cfd3deae 100644 --- a/base/lapack.jl +++ b/base/lapack.jl @@ -327,13 +327,13 @@ for (gebrd, gelqf, geqlf, geqrf, geqp3, gerqf, getrf, elty) in info = Array(Int32, 1) m, n = size(A) lda = stride(A, 2) - ipiv = Array(Int32, n) + ipiv = Array(Int32, min(m,n)) ccall(dlsym(Base.liblapack, $(string(getrf))), Void, (Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), &m, &n, A, &lda, ipiv, info) - if info[1] != 0 throw(LapackException(info[1])) end - A, ipiv + if info[1] < 0 throw(LapackException(info[1])) end + A, ipiv, info[1] end end end @@ -392,8 +392,8 @@ for (gels, gesv, getrs, getri, elty) in (Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}), &n, &size(B,2), A, &stride(A,2), ipiv, B, &stride(B,2), info) - if info[1] != 0 throw(LapackException(info[1])) end - A, ipiv, B + if info[1] < 0 throw(LapackException(info[1])) end + B, A, ipiv, info[1] end # SUBROUTINE DGETRS( TRANS, N, NRHS, A, LDA, IPIV, B, LDB, INFO ) #* .. Scalar Arguments .. @@ -1096,8 +1096,8 @@ for (trtri, trtrs, elty) in (Ptr{Uint8}, Ptr{Uint8}, Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}), &uplo, &diag, &n, A, &lda, info) - if info[1] != 0 error("trtri!: error $(info[1])") end - A + if info[1] < 0 error("trtri!: error $(info[1])") end + A, info[1] end # SUBROUTINE DTRTRS( UPLO, TRANS, DIAG, N, NRHS, A, LDA, B, LDB, INFO ) # * .. Scalar Arguments .. @@ -1117,8 +1117,8 @@ for (trtri, trtrs, elty) in Ptr{$elty}, Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}), &uplo, &trans, &diag, &n, &size(B,2), A, &stride(A,2), B, &stride(B,2), info) - if info[1] != 0 throw(LapackException(info[1])) end - B + if info[1] < 0 throw(LapackException(info[1])) end + B, info[1] end end end @@ -1237,13 +1237,13 @@ for (syconv, syev, sysv, sytrf, sytri, sytrs, elty) in Ptr{$elty}, Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}), &uplo, &n, &size(B,2), A, &stride(A,2), ipiv, B, &stride(B,2), work, &lwork, info) - if info[1] != 0 throw(LapackException(info[1])) end + if info[1] < 0 throw(LapackException(info[1])) end if lwork < 0 lwork = int32(real(work[1])) work = Array($elty, lwork) end end - B, A, ipiv + B, A, ipiv, info[1] end # SUBROUTINE DSYTRF( UPLO, N, A, LDA, IPIV, WORK, LWORK, INFO ) # * .. Scalar Arguments .. diff --git a/base/linalg_dense.jl b/base/linalg_dense.jl index a42169ca409f5..1e9629cb76e2d 100644 --- a/base/linalg_dense.jl +++ b/base/linalg_dense.jl @@ -455,9 +455,10 @@ cholpd{T<:LapackType}(A::Matrix{T}) = cholpd!(copy(A), true, -1.) type LUDense{T} <: Factorization{T} lu::Matrix{T} ipiv::Vector{Int32} - function LUDense(lu::Matrix{T}, ipiv::Vector{Int32}) + info::Int32 + function LUDense(lu::Matrix{T}, ipiv::Vector{Int32}, info::Int32) m, n = size(lu) - m == numel(ipiv) ? new(lu, ipiv) : error("LUDense: dimension mismatch") + m == n ? new(lu, ipiv, info) : error("LUDense only defined for square matrices") end end @@ -480,8 +481,8 @@ function factors{T<:LapackType}(lu::LUDense{T}) end function lud!{T<:LapackType}(A::Matrix{T}) - lu, ipiv = Lapack.getrf!(A) - LUDense{T}(lu, ipiv) + lu, ipiv, info = Lapack.getrf!(A) + LUDense{T}(lu, ipiv, info) end lud{T<:LapackType}(A::Matrix{T}) = lud!(copy(A)) @@ -491,17 +492,28 @@ lud{T<:Number}(A::Matrix{T}) = lud(float64(A)) lu{T<:Number}(A::Matrix{T}) = factors(lud(A)) function det(lu::LUDense) - m, n = size(lu.lu) - if m != n error("det only defined for square matrices") end + if lu.info > 0; return zero(typeof(lu.lu[1])); end prod(diag(lu.lu)) * (bool(sum(lu.ipiv .!= 1:n) % 2) ? -1 : 1) end -det(A::Matrix) = det(lud(A)) +function det(A::Matrix) + m, n = size(A) + if m != n; error("det only defined for square matrices"); end + if istriu(A) | istril(A); return prod(diag(A)); end + return det(lud(A)) +end -(\){T<:LapackType}(lu::LUDense{T}, B::StridedVecOrMat{T}) = +function (\){T<:LapackType}(lu::LUDense{T}, B::StridedVecOrMat{T}) + if lu.info > 0; error("Singular system"); end Lapack.getrs!('N', lu.lu, lu.ipiv, copy(B)) +end -inv{T<:LapackType}(lu::LUDense{T}) = Lapack.getri!(copy(lu.lu), lu.ipiv) +function inv{T<:LapackType}(lu::LUDense{T}) + m, n = size(lu.lu) + if m != n; error("inv only defined for square matrices"); end + if lu.info > 0; return error("Singular system"); end + Lapack.getri!(copy(lu.lu), lu.ipiv) +end ## QR decomposition without column pivots type QRDense{T} <: Factorization{T} @@ -537,7 +549,9 @@ Ac_mul_B{T<:LapackType}(A::QRDense{T}, B::StridedVecOrMat{T}) = ## Least squares solution. Should be more careful about cases with m < n function (\){T<:LapackType}(A::QRDense{T}, B::StridedVecOrMat{T}) n = length(A.tau) - Lapack.trtrs!('U','N','N',A.hh[1:n,:],(A'*B)[1:n,:]) + ans, info = Lapack.trtrs!('U','N','N',A.hh[1:n,:],(A'*B)[1:n,:]) + if info > 0; error("Singular system"); end + return ans end type QRPDense{T} <: Factorization{T} @@ -576,7 +590,8 @@ qrp{T<:Real}(x::StridedMatrix{T}) = qrp(float64(x)) function (\){T<:LapackType}(A::QRPDense{T}, B::StridedVecOrMat{T}) n = length(A.tau) - x = Lapack.trtrs!('U','N','N',A.hh[1:n,:],(A'*B)[1:n,:]) + x, info = Lapack.trtrs!('U','N','N',A.hh[1:n,:],(A'*B)[1:n,:]) + if info > 0; error("Singular system"); end isa(B, Vector) ? x[invperm(A.jpvt)] : x[:,invperm(A.jpvt)] end @@ -659,10 +674,24 @@ function (\){T<:LapackType}(A::StridedMatrix{T}, B::StridedVecOrMat{T}) X = copy(B) if m == n # Square - if istriu(A) return Lapack.trtrs!('U', 'N', 'N', Acopy, X) end - if istril(A) return Lapack.trtrs!('L', 'N', 'N', Acopy, X) end - if ishermitian(A) return Lapack.sysv!('U', Acopy, X)[1] end - return Lapack.gesv!(Acopy, X)[3] + if istriu(A) + ans, info = Lapack.trtrs!('U', 'N', 'N', Acopy, X) + if info > 0; error("Singular system"); end + return ans + end + if istril(A) + ans, info = Lapack.trtrs!('L', 'N', 'N', Acopy, X) + if info > 0; error("Singular system"); end + return ans + end + if ishermitian(A) + ans, ~, ~, info = Lapack.sysv!('U', Acopy, X) + if info > 0; error("Singular system"); end + return ans + end + ans, ~, ~, info = Lapack.gesv!(Acopy, X) + if info > 0; error("Singular system"); end + return ans end Lapack.gelsd!(Acopy, X)[1] end