Skip to content

Commit

Permalink
Merge pull request #11518 from JuliaLang/ksh/prettify
Browse files Browse the repository at this point in the history
RFC: prettify single line throws/returns and add more descriptive errors for linalg
  • Loading branch information
StefanKarpinski committed Jun 1, 2015
2 parents 0623698 + bbc47a8 commit 8ef420c
Show file tree
Hide file tree
Showing 17 changed files with 590 additions and 216 deletions.
36 changes: 28 additions & 8 deletions base/linalg/arnoldi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@ function eigs(A, B;
isgeneral = B !== I
sym = issym(A) && !iscmplx
nevmax=sym ? n-1 : n-2
nevmax > 0 || throw(ArgumentError("Input matrix A is too small. Use eigfact instead."))
if nevmax <= 0
throw(ArgumentError("Input matrix A is too small. Use eigfact instead."))
end
if nev > nevmax
warn("Adjusting nev from $nev to $nevmax")
nev = nevmax
end
nev > 0 || throw(ArgumentError("requested number of eigenvalues (nev) must be ≥ 1, got $nev"))
if nev <= 0
throw(ArgumentError("requested number of eigenvalues (nev) must be ≥ 1, got $nev"))
end
ncvmin = nev + (sym ? 1 : 2)
if ncv < ncvmin
warn("Adjusting ncv from $ncv to $ncvmin")
ncv = ncvmin
end
ncv = blas_int(min(ncv, n))
isgeneral && !isposdef(B) && throw(PosDefException(0))
if isgeneral && !isposdef(B)
throw(PosDefException(0))
end
bmat = isgeneral ? "G" : "I"
isshift = sigma !== nothing

Expand All @@ -42,7 +48,9 @@ function eigs(A, B;
which != :LI && which != :SI && which != :BE)
throw(ArgumentError("which must be :LM, :SM, :LR, :SR, :LI, :SI, or :BE, got $(repr(which))"))
end
which != :BE || sym || throw(ArgumentError("which=:BE only possible for real symmetric problem"))
if which == :BE && !sym
throw(ArgumentError("which=:BE only possible for real symmetric problem"))
end
isshift && which == :SM && warn("use of :SM in shift-and-invert mode is not recommended, use :LM to find eigenvalues closest to sigma")

if which==:SM && !isshift # transform into shift-and-invert method with sigma = 0
Expand All @@ -57,8 +65,12 @@ function eigs(A, B;
sigma = isshift ? convert(T,sigma) : zero(T)

if !isempty(v0)
length(v0)==n || throw(DimensionMismatch())
eltype(v0)==T || throw(ArgumentError("starting vector must have element type $T, got $(eltype(v0))"))
if length(v0) != n
throw(DimensionMismatch())
end
if eltype(v0) != T
throw(ArgumentError("starting vector must have element type $T, got $(eltype(v0))"))
end
end

whichstr = "LM"
Expand All @@ -72,10 +84,18 @@ function eigs(A, B;
whichstr = (!sym ? "SR" : "SA")
end
if which == :LI
whichstr = (!sym ? "LI" : throw(ArgumentError("largest imaginary is meaningless for symmetric eigenvalue problems")))
if !sym
whichstr = "LI"
else
throw(ArgumentError("largest imaginary is meaningless for symmetric eigenvalue problems"))
end
end
if which == :SI
whichstr = (!sym ? "SI" : throw(ArgumentError("smallest imaginary is meaningless for symmetric eigenvalue problems")))
if !sym
whichstr = "SI"
else
throw(ArgumentError("smallest imaginary is meaningless for symmetric eigenvalue problems"))
end
end

# Refer to ex-*.doc files in ARPACK/DOCUMENTS for calling sequence
Expand Down
17 changes: 11 additions & 6 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ type Bidiagonal{T} <: AbstractMatrix{T}
ev::Vector{T} # sub/super diagonal
isupper::Bool # is upper bidiagonal (true) or lower (false)
function Bidiagonal{T}(dv::Vector{T}, ev::Vector{T}, isupper::Bool)
length(ev)==length(dv)-1 ? new(dv, ev, isupper) : throw(DimensionMismatch())
if length(ev)==length(dv)-1
new(dv, ev, isupper)
else
throw(DimensionMismatch("Length of diagonal vector is $(length(dv)), length of off-diagonal vector is $(length(ev))"))
end
end
end
Bidiagonal{T}(dv::AbstractVector{T}, ev::AbstractVector{T}, isupper::Bool)=Bidiagonal{T}(copy(dv), copy(ev), isupper)
Expand Down Expand Up @@ -103,7 +107,7 @@ function diag{T}(M::Bidiagonal{T}, n::Integer=0)
elseif -size(M,1)<n<size(M,1)
return zeros(T, size(M,1)-abs(n))
else
throw(BoundsError())
throw(BoundsError("Matrix size is $(size(M)), n is $n"))
end
end

Expand Down Expand Up @@ -146,7 +150,7 @@ function A_ldiv_B!(A::Union(Bidiagonal, AbstractTriangular), B::AbstractMatrix)
tmp = similar(B,size(B,1))
n = size(B, 1)
if nA != n
throw(DimensionMismatch())
throw(DimensionMismatch("Size of A is ($nA,$mA), corresponding dimension of B is $n"))
end
for i = 1:size(B,2)
copy!(tmp, 1, B, (i - 1)*n + 1, n)
Expand All @@ -163,7 +167,7 @@ for func in (:Ac_ldiv_B!, :At_ldiv_B!)
tmp = similar(B,size(B,1))
n = size(B, 1)
if mA != n
throw(DimensionMismatch())
throw(DimensionMismatch("Size of A' is ($mA,$nA), corresponding dimension of B is $n"))
end
for i = 1:size(B,2)
copy!(tmp, 1, B, (i - 1)*n + 1, n)
Expand All @@ -179,8 +183,9 @@ At_ldiv_B(A::Union(Bidiagonal, AbstractTriangular), B::AbstractMatrix) = At_ldiv
#Generic solver using naive substitution
function naivesub!{T}(A::Bidiagonal{T}, b::AbstractVector, x::AbstractVector = b)
N = size(A, 2)
N == length(b) == length(x) || throw(DimensionMismatch())

if N != length(b) || N != length(x)
throw(DimensionMismatch())
end
if !A.isupper #do forward substitution
for j = 1:N
x[j] = b[j]
Expand Down
94 changes: 68 additions & 26 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,23 @@ for (fname, elty) in ((:cblas_zdotu_sub,:Complex128),
end
function dot{T<:BlasReal}(DX::Union(DenseArray{T},StridedVector{T}), DY::Union(DenseArray{T},StridedVector{T}))
n = length(DX)
n == length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
dot(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end
function dotc{T<:BlasComplex}(DX::Union(DenseArray{T},StridedVector{T}), DY::Union(DenseArray{T},StridedVector{T}))
n = length(DX)
n == length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
dotc(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end
function dotu{T<:BlasComplex}(DX::Union(DenseArray{T},StridedVector{T}), DY::Union(DenseArray{T},StridedVector{T}))
n = length(DX)
n == length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
dotu(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end

Expand Down Expand Up @@ -219,18 +225,24 @@ for (fname, elty) in ((:daxpy_,:Float64),
end
end
function axpy!{T<:BlasFloat,Ta<:Number}(alpha::Ta, x::Union(DenseArray{T},StridedVector{T}), y::Union(DenseArray{T},StridedVector{T}))
length(x) == length(y) || throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
axpy!(length(x), convert(T,alpha), pointer(x), stride(x, 1), pointer(y), stride(y, 1))
y
end

function axpy!{T<:BlasFloat,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(UnitRange{Ti},Range{Ti}),
y::Array{T}, ry::Union(UnitRange{Ti},Range{Ti}))

length(rx)==length(ry) || throw(DimensionMismatch())

if minimum(rx) < 1 || maximum(rx) > length(x) || minimum(ry) < 1 || maximum(ry) > length(y)
throw(BoundsError())
if length(rx) != length(ry)
throw(DimensionMismatch("Ranges of differing lengths"))
end
if minimum(rx) < 1 || maximum(rx) > length(x)
throw(BoundsError("Range out of bounds for x, of length $(length(x))"))
end
if minimum(ry) < 1 || maximum(ry) > length(y)
throw(BoundsError("Range out of bounds for y, of length $(length(y))"))
end
axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
y
Expand Down Expand Up @@ -269,7 +281,13 @@ for (fname, elty) in ((:dgemv_,:Float64),
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function gemv!(trans::Char, alpha::($elty), A::StridedVecOrMat{$elty}, X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
m,n = size(A,1),size(A,2)
length(X) == (trans == 'N' ? n : m) && length(Y) == (trans == 'N' ? m : n) || throw(DimensionMismatch())
if trans == 'N' && (length(X) != n || length(Y) != m)
throw(DimensionMismatch("A has dimensions $(size(A)), X has length $(length(X)) and Y has length $(length(Y))"))
elseif trans == 'C' && (length(X) != m || length(Y) != n)
throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
elseif trans == 'T' && (length(X) != m || length(Y) != n)
throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
end
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Expand Down Expand Up @@ -340,7 +358,9 @@ for (fname, elty) in ((:dsymv_,:Float64),
function symv!(uplo::Char, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty},beta::($elty), y::StridedVector{$elty})
m, n = size(A)
if m != n throw(DimensionMismatch("Matrix A is $m by $n but must be square")) end
if m != length(x) || m != length(y) throw(DimensionMismatch()) end
if m != length(x) || m != length(y)
throw(DimensionMismatch("A has size ($m,$n), x has length $(length(x)), y has length $(length(y))"))
end
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$elty},
Expand All @@ -365,8 +385,12 @@ for (fname, elty) in ((:zhemv_,:Complex128),
@eval begin
function hemv!(uplo::Char, α::$elty, A::StridedMatrix{$elty}, x::StridedVector{$elty}, β::$elty, y::StridedVector{$elty})
n = size(A, 2)
n == length(x) || throw(DimensionMismatch())
size(A, 1) == length(y) || throw(DimensionMismatch())
if n != length(x)
throw(DimensionMismatch("A has size $(size(A)), and x has length $(length(x))"))
end
if size(A, 1) != length(y)
throw(DimensionMismatch("A has size $(size(A)), and y has length $(length(x))"))
end
lda = max(1, stride(A, 2))
incx = stride(x, 1)
incy = stride(y, 1)
Expand Down Expand Up @@ -464,7 +488,9 @@ for (fname, elty) in ((:dtrsv_,:Float64),
# DOUBLE PRECISION A(LDA,*),X(*)
function trsv!(uplo::Char, trans::Char, diag::Char, A::StridedMatrix{$elty}, x::StridedVector{$elty})
n = chksquare(A)
n==length(x) || throw(DimensionMismatch("size of A is $n != length(x) = $(length(x))"))
if n != length(x)
throw(DimensionMismatch("size of A is $n != length(x) = $(length(x))"))
end
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}),
Expand All @@ -486,8 +512,9 @@ for (fname, elty) in ((:dger_,:Float64),
@eval begin
function ger!::$elty, x::StridedVector{$elty}, y::StridedVector{$elty}, A::StridedMatrix{$elty})
m, n = size(A)
m == length(x) || throw(DimensionMismatch())
n == length(y) || throw(DimensionMismatch())
if m != length(x) || n != length(y)
throw(DimensionMismatch("A has size ($m,$n), x has length $(length(x)), y has length $(length(y))"))
end
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{BlasInt}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty},
Expand All @@ -508,7 +535,9 @@ for (fname, elty) in ((:dsyr_,:Float64),
@eval begin
function syr!(uplo::Char, α::$elty, x::StridedVector{$elty}, A::StridedMatrix{$elty})
n = chksquare(A)
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
if length(x) != n
throw(DimensionMismatch("A has size ($n,$n), x has length $(length(x))"))
end
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}),
Expand All @@ -525,7 +554,9 @@ for (fname, elty) in ((:zher_,:Complex128),
@eval begin
function her!(uplo::Char, α::$elty, x::StridedVector{$elty}, A::StridedMatrix{$elty})
n = chksquare(A)
length(x) == A || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
if length(x) != A
throw(DimensionMismatch("A has size ($n,$n), x has length $(length(x))"))
end
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}),
Expand Down Expand Up @@ -559,7 +590,7 @@ for (gemm, elty) in
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
if m != size(C,1) || n != size(C,2)
throw(DimensionMismatch())
throw(DimensionMismatch("A has size ($m,$k), B has size ($k,$n), C has size $(size(C))"))
end
ccall(($(blasfunc(gemm)), libblas), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Expand Down Expand Up @@ -667,7 +698,7 @@ for (fname, elty) in ((:dsyrk_,:Float64),
beta::($elty), C::StridedMatrix{$elty})
n = chksquare(C)
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n throw(DimensionMismatch("syrk!")) end
if nn != n throw(DimensionMismatch("C has size ($n,$n), corresponding dimension of A is $nn")) end
k = size(A, trans == 'N' ? 2 : 1)
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Expand Down Expand Up @@ -701,7 +732,10 @@ for (fname, elty, relty) in ((:zherk_, :Complex128, :Float64),
function herk!(uplo::Char, trans::Char, α::$relty, A::StridedVecOrMat{$elty},
β::$relty, C::StridedMatrix{$elty})
n = chksquare(C)
n == size(A, trans == 'N' ? 1 : 2) || throw(DimensionMismatch("the matrix to update has dimension $n but the implied dimension of the update is $(size(A, trans == 'N' ? 1 : 2))"))
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n
throw(DimensionMismatch("the matrix to update has dimension $n but the implied dimension of the update is $(size(A, trans == 'N' ? 1 : 2))"))
end
k = size(A, trans == 'N' ? 2 : 1)
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Expand Down Expand Up @@ -740,7 +774,7 @@ for (fname, elty) in ((:dsyr2k_,:Float64),
beta::($elty), C::StridedMatrix{$elty})
n = chksquare(C)
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n throw(DimensionMismatch("syr2k!")) end
if nn != n throw(DimensionMismatch("C has size ($n,$n), corresponding dimension of A is $nn")) end
k = size(A, trans == 'N' ? 2 : 1)
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Expand Down Expand Up @@ -776,7 +810,8 @@ for (fname, elty1, elty2) in ((:zher2k_,:Complex128,:Float64), (:cher2k_,:Comple
A::StridedVecOrMat{$elty1}, B::StridedVecOrMat{$elty1},
beta::($elty2), C::StridedMatrix{$elty1})
n = chksquare(C)
n == size(A, trans == 'N' ? 1 : 2) || throw(DimensionMismatch("her2k!"))
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n throw(DimensionMismatch("C has size ($n,$n), corresponding dimension of A is $nn")) end
k = size(A, trans == 'N' ? 2 : 1)
ccall(($(blasfunc(fname)), libblas), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Expand Down Expand Up @@ -836,7 +871,9 @@ for (mmname, smname, elty) in
alpha::$elty, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
m, n = size(B)
k = chksquare(A)
k==(side == 'L' ? m : n) || throw(DimensionMismatch("size of A is $n, size(B)=($m,$n) and transa='$transa'"))
if k != (side == 'L' ? m : n)
throw(DimensionMismatch("size of A is $n, size(B)=($m,$n) and transa='$transa'"))
end
ccall(($(blasfunc(smname)), libblas), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{UInt8},
Ref{BlasInt}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Expand All @@ -856,10 +893,15 @@ end # module

function copy!{T<:BlasFloat,Ti<:Integer}(dest::Array{T}, rdest::Union(UnitRange{Ti},Range{Ti}),
src::Array{T}, rsrc::Union(UnitRange{Ti},Range{Ti}))
if minimum(rdest) < 1 || maximum(rdest) > length(dest) || minimum(rsrc) < 1 || maximum(rsrc) > length(src)
throw(BoundsError())
if minimum(rdest) < 1 || maximum(rdest) > length(dest)
throw(BoundsError("Range out of bounds for dest, of length $(length(dest))"))
end
if minimum(rsrc) < 1 || maximum(rsrc) > length(src)
throw(BoundsError("Range out of bounds for src, of length $(length(src))"))
end
if length(rdest) != length(rsrc)
throw(DimensionMismatch("Ranges must be of the same length"))
end
length(rdest)==length(rsrc) || throw(DimensionMismatch("Ranges must be of the same length"))
BLAS.blascopy!(length(rsrc), pointer(src)+(first(rsrc)-1)*sizeof(T), step(rsrc),
pointer(dest)+(first(rdest)-1)*sizeof(T), step(rdest))
dest
Expand Down
4 changes: 3 additions & 1 deletion base/linalg/bunchkaufman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ end
BunchKaufman{T}(LD::AbstractMatrix{T}, ipiv::Vector{BlasInt}, uplo::Char, symmetric::Bool) = BunchKaufman{T,typeof(LD)}(LD, ipiv, uplo, symmetric)

function bkfact!{T<:BlasReal}(A::StridedMatrix{T}, uplo::Symbol=:U, symmetric::Bool=issym(A))
symmetric || throw(ArgumentError("Bunch-Kaufman decomposition is only valid for symmetric matrices"))
if !symmetric
throw(ArgumentError("Bunch-Kaufman decomposition is only valid for symmetric matrices"))
end
LD, ipiv = LAPACK.sytrf!(char_uplo(uplo) , A)
BunchKaufman(LD, ipiv, char_uplo(uplo), symmetric)
end
Expand Down
4 changes: 3 additions & 1 deletion base/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ function chol{T}(A::AbstractMatrix{T}, uplo::Union(Type{Val{:L}}, Type{Val{:U}})
end
function chol!(x::Number, uplo)
rx = real(x)
rx == abs(x) || throw(DomainError())
if rx != abs(x)
throw(DomainError("x must be positive semidefinite"))
end
rxr = sqrt(rx)
convert(promote_type(typeof(x), typeof(rxr)), rxr)
end
Expand Down
Loading

0 comments on commit 8ef420c

Please sign in to comment.