Skip to content

Commit

Permalink
Redesign of triangular matrix types. Avoid Upper/lower and unit diago…
Browse files Browse the repository at this point in the history
…nal specification by parameters. This fixes #9191, type instability of \ for BlasFloat types when no pormotion is necessary.
  • Loading branch information
andreasnoack committed Jan 15, 2015
1 parent 1763fef commit b3c1c82
Show file tree
Hide file tree
Showing 18 changed files with 785 additions and 393 deletions.
3 changes: 2 additions & 1 deletion base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export
IOBuffer,
IOStream,
LocalProcess,
LowerTriangular,
MathConst,
Matrix,
MergeSort,
Expand Down Expand Up @@ -108,9 +109,9 @@ export
SymTridiagonal,
Timer,
TmStruct,
Triangular,
Tridiagonal,
UnitRange,
UpperTriangular,
UTF16String,
UTF32String,
Val,
Expand Down
3 changes: 2 additions & 1 deletion base/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ export
SVD,
Hermitian,
Symmetric,
Triangular,
LowerTriangular,
UpperTriangular,
Diagonal,
UniformScaling,

Expand Down
13 changes: 6 additions & 7 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,23 @@ end
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.isupper)
==(A::Bidiagonal, B::Bidiagonal) = (A.dv==B.dv) && (A.ev==B.ev) && (A.isupper==B.isupper)

SpecialMatrix = Union(Diagonal, Bidiagonal, SymTridiagonal, Tridiagonal, Triangular)
SpecialMatrix = Union(Diagonal, Bidiagonal, SymTridiagonal, Tridiagonal, TriangularUnion)
*(A::SpecialMatrix, B::SpecialMatrix)=full(A)*full(B)

#Generic multiplication
for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc)
@eval begin
($func){T}(A::Bidiagonal{T}, B::AbstractVector{T}) = ($func)(full(A), B)
#($func){T}(A::AbstractArray{T}, B::Triangular{T}) = ($func)(full(A), B)
end
end


#Linear solvers
A_ldiv_B!(A::Union(Bidiagonal, Triangular), b::AbstractVector) = naivesub!(A, b)
At_ldiv_B!(A::Union(Bidiagonal, Triangular), b::AbstractVector) = naivesub!(transpose(A), b)
Ac_ldiv_B!(A::Union(Bidiagonal, Triangular), b::AbstractVector) = naivesub!(ctranspose(A), b)
A_ldiv_B!(A::Union(Bidiagonal, TriangularUnion), b::AbstractVector) = naivesub!(A, b)
At_ldiv_B!(A::Union(Bidiagonal, TriangularUnion), b::AbstractVector) = naivesub!(transpose(A), b)
Ac_ldiv_B!(A::Union(Bidiagonal, TriangularUnion), b::AbstractVector) = naivesub!(ctranspose(A), b)
for func in (:A_ldiv_B!, :Ac_ldiv_B!, :At_ldiv_B!) @eval begin
function ($func)(A::Union(Bidiagonal, Triangular), B::AbstractMatrix)
function ($func)(A::Union(Bidiagonal, TriangularUnion), B::AbstractMatrix)
tmp = similar(B[:,1])
n = size(B, 1)
for i = 1:size(B,2)
Expand All @@ -143,7 +142,7 @@ for func in (:A_ldiv_B!, :Ac_ldiv_B!, :At_ldiv_B!) @eval begin
end
end end
for func in (:A_ldiv_Bt!, :Ac_ldiv_Bt!, :At_ldiv_Bt!) @eval begin
function ($func)(A::Union(Bidiagonal, Triangular), B::AbstractMatrix)
function ($func)(A::Union(Bidiagonal, TriangularUnion), B::AbstractMatrix)
tmp = similar(B[:, 2])
m, n = size(B)
nm = n*m
Expand Down
49 changes: 36 additions & 13 deletions base/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,35 @@ immutable CholeskyPivoted{T,S<:AbstractMatrix} <: Factorization{T}
end
CholeskyPivoted{T}(UL::AbstractMatrix{T}, uplo::Char, piv::Vector{BlasInt}, rank::BlasInt, tol::Real, info::BlasInt) = CholeskyPivoted{T,typeof(UL)}(UL, uplo, piv, rank, tol, info)

function chol!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U)
function chol!{T<:BlasFloat}(A::StridedMatrix{T})
C, info = LAPACK.potrf!('U', A)
return @assertposdef UpperTriangular(C) info
end
function chol!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol)
C, info = LAPACK.potrf!(char_uplo(uplo), A)
return @assertposdef Triangular{eltype(C),typeof(C),uplo,false}(C) info
return @assertposdef uplo == :U ? UpperTriangular(C) : LowerTriangular(C) info
end

function chol!{T}(A::AbstractMatrix{T}, uplo::Symbol=:U)
function chol!{T}(A::AbstractMatrix{T})
n = chksquare(A)
@inbounds begin
for k = 1:n
for i = 1:k - 1
A[k,k] -= A[i,k]'A[i,k]
end
A[k,k] = chol!(A[k,k], uplo)
AkkInv = inv(A[k,k])
for j = k + 1:n
for i = 1:k - 1
A[k,j] -= A[i,k]'A[i,j]
end
A[k,j] = A[k,k]'\A[k,j]
end
end
end
return UpperTriangular(A)
end
function chol!{T}(A::AbstractMatrix{T}, uplo::Symbol)
n = chksquare(A)
@inbounds begin
if uplo == :L
Expand Down Expand Up @@ -58,7 +81,7 @@ function chol!{T}(A::AbstractMatrix{T}, uplo::Symbol=:U)
throw(ArgumentError("uplo must be either :U or :L but was $(uplo)"))
end
end
return Triangular(A, uplo, false)
return uplo == :U ? UpperTriangular(A) : LowerTriangular(A)
end

function cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0)
Expand Down Expand Up @@ -118,14 +141,14 @@ size(C::Union(Cholesky, CholeskyPivoted)) = size(C.UL)
size(C::Union(Cholesky, CholeskyPivoted), d::Integer) = size(C.UL,d)

function getindex{T,S,UpLo}(C::Cholesky{T,S,UpLo}, d::Symbol)
d == :U && return Triangular(UpLo == d ? C.UL : C.UL',:U)
d == :L && return Triangular(UpLo == d ? C.UL : C.UL',:L)
d == :UL && return Triangular(C.UL, UpLo)
d == :U && return UpperTriangular(UpLo == d ? C.UL : C.UL')
d == :L && return LowerTriangular(UpLo == d ? C.UL : C.UL')
d == :UL && return UpLo == :U ? UpperTriangular(C.UL) : LowerTriangular(C.UL)
throw(KeyError(d))
end
function getindex{T<:BlasFloat}(C::CholeskyPivoted{T}, d::Symbol)
d == :U && return Triangular(symbol(C.uplo) == d ? C.UL : C.UL', :U)
d == :L && return Triangular(symbol(C.uplo) == d ? C.UL : C.UL', :L)
d == :U && return UpperTriangular(symbol(C.uplo) == d ? C.UL : C.UL')
d == :L && return LowerTriangular(symbol(C.uplo) == d ? C.UL : C.UL')
d == :p && return C.piv
if d == :P
n = size(C, 1)
Expand All @@ -142,8 +165,8 @@ show{T,S<:AbstractMatrix,UpLo}(io::IO, C::Cholesky{T,S,UpLo}) = (println("$(type

A_ldiv_B!{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:U}, B::StridedVecOrMat{T}) = LAPACK.potrs!('U', C.UL, B)
A_ldiv_B!{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:L}, B::StridedVecOrMat{T}) = LAPACK.potrs!('L', C.UL, B)
A_ldiv_B!{T,S<:AbstractMatrix}(C::Cholesky{T,S,:L}, B::StridedVecOrMat) = Ac_ldiv_B!(Triangular(C.UL, :L, false), A_ldiv_B!(Triangular(C.UL, :L, false), B))
A_ldiv_B!{T,S<:AbstractMatrix}(C::Cholesky{T,S,:U}, B::StridedVecOrMat) = A_ldiv_B!(Triangular(C.UL, :U, false), Ac_ldiv_B!(Triangular(C.UL, :U, false), B))
A_ldiv_B!{T,S<:AbstractMatrix}(C::Cholesky{T,S,:L}, B::StridedVecOrMat) = Ac_ldiv_B!(LowerTriangular(C.UL), A_ldiv_B!(LowerTriangular(C.UL), B))
A_ldiv_B!{T,S<:AbstractMatrix}(C::Cholesky{T,S,:U}, B::StridedVecOrMat) = A_ldiv_B!(UpperTriangular(C.UL), Ac_ldiv_B!(UpperTriangular(C.UL), B))

function A_ldiv_B!{T<:BlasFloat}(C::CholeskyPivoted{T}, B::StridedVector{T})
chkfullrank(C)
Expand All @@ -161,8 +184,8 @@ function A_ldiv_B!{T<:BlasFloat}(C::CholeskyPivoted{T}, B::StridedMatrix{T})
end
B
end
A_ldiv_B!(C::CholeskyPivoted, B::StridedVector) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv]))[invperm(C.piv)] : A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv]))[invperm(C.piv)]
A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv,:]))[invperm(C.piv),:] : A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv,:]))[invperm(C.piv),:]
A_ldiv_B!(C::CholeskyPivoted, B::StridedVector) = C.uplo=='L' ? Ac_ldiv_B!(LowerTriangular(C.UL), A_ldiv_B!(LowerTriangular(C.UL), B[C.piv]))[invperm(C.piv)] : A_ldiv_B!(UpperTriangular(C.UL), Ac_ldiv_B!(UpperTriangular(C.UL), B[C.piv]))[invperm(C.piv)]
A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix) = C.uplo=='L' ? Ac_ldiv_B!(LowerTriangular(C.UL), A_ldiv_B!(LowerTriangular(C.UL), B[C.piv,:]))[invperm(C.piv),:] : A_ldiv_B!(UpperTriangular(C.UL), Ac_ldiv_B!(UpperTriangular(C.UL), B[C.piv,:]))[invperm(C.piv),:]

function det{T,S,UpLo}(C::Cholesky{T,S,UpLo})
dd = one(T)
Expand Down
16 changes: 8 additions & 8 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,15 @@ function sqrtm{T<:Real}(A::StridedMatrix{T})
issym(A) && return sqrtm(Symmetric(A))
n = chksquare(A)
SchurF = schurfact(complex(A))
R = full(sqrtm(Triangular(SchurF[:T], :U, false)))
R = full(sqrtm(UpperTriangular(SchurF[:T])))
retmat = SchurF[:vectors]*R*SchurF[:vectors]'
all(imag(retmat) .== 0) ? real(retmat) : retmat
end
function sqrtm{T<:Complex}(A::StridedMatrix{T})
ishermitian(A) && return sqrtm(Hermitian(A))
n = chksquare(A)
SchurF = schurfact(A)
R = full(sqrtm(Triangular(SchurF[:T], :U, false)))
R = full(sqrtm(UpperTriangular(SchurF[:T])))
SchurF[:vectors]*R*SchurF[:vectors]'
end
sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b)
Expand All @@ -325,9 +325,9 @@ function inv{S}(A::StridedMatrix{S})
T = typeof(one(S)/one(S))
Ac = convert(AbstractMatrix{T}, A)
if istriu(Ac)
Ai = inv(Triangular(A, :U, false))
Ai = inv(UpperTriangular(A))
elseif istril(Ac)
Ai = inv(Triangular(A, :L, false))
Ai = inv(LowerTriangular(A))
else
Ai = inv(lufact(Ac))
end
Expand Down Expand Up @@ -377,7 +377,7 @@ function factorize{T}(A::Matrix{T})
if utri1
return Bidiagonal(diag(A), diag(A, -1), false)
end
return Triangular(A, :L)
return LowerTriangular(A)
end
if utri
return Bidiagonal(diag(A), diag(A, 1), true)
Expand All @@ -392,7 +392,7 @@ function factorize{T}(A::Matrix{T})
end
end
if utri
return Triangular(A, :U)
return UpperTriangular(A)
end
if herm
try
Expand All @@ -413,9 +413,9 @@ function (\)(A::StridedMatrix, B::StridedVecOrMat)
m, n = size(A)
if m == n
if istril(A)
return istriu(A) ? \(Diagonal(A),B) : \(Triangular(A, :L),B)
return istriu(A) ? \(Diagonal(A),B) : \(LowerTriangular(A),B)
end
istriu(A) && return \(Triangular(A, :U),B)
istriu(A) && return \(UpperTriangular(A),B)
return \(lufact(A),B)
end
return qrfact(A,pivot=eltype(A)<:BlasFloat)\B
Expand Down
3 changes: 2 additions & 1 deletion base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ Diagonal(A::Matrix) = Diagonal(diag(A))
convert{T}(::Type{Diagonal{T}}, D::Diagonal{T}) = D
convert{T}(::Type{Diagonal{T}}, D::Diagonal) = Diagonal{T}(convert(Vector{T}, D.diag))
convert{T}(::Type{AbstractMatrix{T}}, D::Diagonal) = convert(Diagonal{T}, D)
convert{T}(::Type{Triangular}, A::Diagonal{T}) = Triangular{T, Diagonal{T}, :U, false}(A)
convert{T}(::Type{UpperTriangular}, A::Diagonal{T}) = UpperTriangular(A)
convert{T}(::Type{LowerTriangular}, A::Diagonal{T}) = LowerTriangular(A)

function similar{T}(D::Diagonal, ::Type{T}, d::(Int,Int))
d[1] == d[2] || throw(ArgumentError("Diagonal matrix must be square"))
Expand Down
8 changes: 4 additions & 4 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ function A_mul_Bc!{T}(A::AbstractMatrix{T},Q::QRPackedQ{T})
end
A
end
A_mul_Bc(A::Triangular, B::Union(QRCompactWYQ,QRPackedQ)) = A_mul_Bc(full(A), B)
A_mul_Bc(A::TriangularUnion, B::Union(QRCompactWYQ,QRPackedQ)) = A_mul_Bc(full(A), B)
function A_mul_Bc{TA,TB}(A::AbstractArray{TA}, B::Union(QRCompactWYQ{TB},QRPackedQ{TB}))
TAB = promote_type(TA,TB)
A_mul_Bc!(size(A,2)==size(B.factors,1) ? (TA == TAB ? copy(A) : convert(AbstractMatrix{TAB}, A)) :
Expand All @@ -284,8 +284,8 @@ function A_mul_Bc{TA,TB}(A::AbstractArray{TA}, B::Union(QRCompactWYQ{TB},QRPacke
convert(AbstractMatrix{TAB}, B))
end

A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, b::StridedVector{T}) = (A_ldiv_B!(Triangular(A[:R], :U), sub(Ac_mul_B!(A[:Q], b), 1:size(A, 2))); b)
A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, B::StridedMatrix{T}) = (A_ldiv_B!(Triangular(A[:R], :U), sub(Ac_mul_B!(A[:Q], B), 1:size(A, 2), 1:size(B, 2))); B)
A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, b::StridedVector{T}) = (A_ldiv_B!(UpperTriangular(A[:R]), sub(Ac_mul_B!(A[:Q], b), 1:size(A, 2))); b)
A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, B::StridedMatrix{T}) = (A_ldiv_B!(UpperTriangular(A[:R]), sub(Ac_mul_B!(A[:Q], B), 1:size(A, 2), 1:size(B, 2))); B)

# Julia implementation similarly to xgelsy
function A_ldiv_B!{T<:BlasFloat}(A::QRPivoted{T}, B::StridedMatrix{T}, rcond::Real)
Expand Down Expand Up @@ -313,7 +313,7 @@ function A_ldiv_B!{T<:BlasFloat}(A::QRPivoted{T}, B::StridedMatrix{T}, rcond::Re
# if cond(r[1:rnk, 1:rnk])*rcond < 1 break end
end
C, τ = LAPACK.tzrzf!(A.factors[1:rnk,:])
A_ldiv_B!(Triangular(C[1:rnk,1:rnk],:U),sub(Ac_mul_B!(getq(A),sub(B, 1:mA, 1:nrhs)),1:rnk,1:nrhs))
A_ldiv_B!(UpperTriangular(C[1:rnk,1:rnk]),sub(Ac_mul_B!(getq(A),sub(B, 1:mA, 1:nrhs)),1:rnk,1:nrhs))
B[rnk+1:end,:] = zero(T)
LAPACK.ormrz!('L', iseltype(B, Complex) ? 'C' : 'T', C, τ, sub(B,1:nA,1:nrhs))
B[1:nA,:] = sub(B, 1:nA, :)[invperm(A[:p]::Vector{BlasInt}),:]
Expand Down
6 changes: 5 additions & 1 deletion base/linalg/generic.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## linalg.jl: Some generic Linear Algebra definitions

# Fall back arithmetic
+(A::AbstractMatrix, B::AbstractMatrix) = full(A) + full(B)
-(A::AbstractMatrix, B::AbstractMatrix) = full(A) - full(B)

scale(X::AbstractArray, s::Number) = scale!(copy(X), s)
scale(s::Number, X::AbstractArray) = scale!(copy(X), s)

Expand Down Expand Up @@ -424,7 +428,7 @@ function elementaryRightTrapezoid!(A::AbstractMatrix, row::Integer)
end

function det(A::AbstractMatrix)
(istriu(A) || istril(A)) && return det(Triangular(A, :U, false))
(istriu(A) || istril(A)) && return det(UpperTriangular(A))
return det(lufact(A))
end
det(x::Number) = x
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ function getindex{T,S<:StridedMatrix}(A::LU{T,S}, d::Symbol)
end

A_ldiv_B!{T<:BlasFloat, S<:StridedMatrix}(A::LU{T, S}, B::StridedVecOrMat{T}) = @assertnonsingular LAPACK.getrs!('N', A.factors, A.ipiv, B) A.info
A_ldiv_B!{T,S<:StridedMatrix}(A::LU{T,S}, b::StridedVector) = A_ldiv_B!(Triangular(A.factors, :U, false), A_ldiv_B!(Triangular(A.factors, :L, true), b[ipiv2perm(A.ipiv, length(b))]))
A_ldiv_B!{T,S<:StridedMatrix}(A::LU{T,S}, B::StridedMatrix) = A_ldiv_B!(Triangular(A.factors, :U, false), A_ldiv_B!(Triangular(A.factors, :L, true), B[ipiv2perm(A.ipiv, size(B, 1)),:]))
A_ldiv_B!{T,S<:StridedMatrix}(A::LU{T,S}, b::StridedVector) = A_ldiv_B!(UpperTriangular(A.factors), A_ldiv_B!(LowerTriangularUnit(A.factors), b[ipiv2perm(A.ipiv, length(b))]))
A_ldiv_B!{T,S<:StridedMatrix}(A::LU{T,S}, B::StridedMatrix) = A_ldiv_B!(UpperTriangular(A.factors), A_ldiv_B!(LowerTriangularUnit(A.factors), B[ipiv2perm(A.ipiv, size(B, 1)),:]))
At_ldiv_B{T<:BlasFloat,S<:StridedMatrix}(A::LU{T,S}, B::StridedVecOrMat{T}) = @assertnonsingular LAPACK.getrs!('T', A.factors, A.ipiv, copy(B)) A.info
Ac_ldiv_B{T<:BlasComplex,S<:StridedMatrix}(A::LU{T,S}, B::StridedVecOrMat{T}) = @assertnonsingular LAPACK.getrs!('C', A.factors, A.ipiv, copy(B)) A.info
At_ldiv_Bt{T<:BlasFloat,S<:StridedMatrix}(A::LU{T,S}, B::StridedVecOrMat{T}) = @assertnonsingular LAPACK.getrs!('T', A.factors, A.ipiv, transpose(B)) A.info
Expand Down
2 changes: 1 addition & 1 deletion base/linalg/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ end
*{TvA,TiA}(X::BitArray{2}, A::SparseMatrixCSC{TvA,TiA}) = invoke(*, (AbstractMatrix, SparseMatrixCSC), X, A)
# TODO: Tridiagonal * Sparse should be implemented more efficiently
*{TX,TvA,TiA}(X::Tridiagonal{TX}, A::SparseMatrixCSC{TvA,TiA}) = invoke(*, (Tridiagonal, AbstractMatrix), X, A)
*{TvA,TiA}(X::Triangular, A::SparseMatrixCSC{TvA,TiA}) = full(X)*A
*{TvA,TiA}(X::TriangularUnion, A::SparseMatrixCSC{TvA,TiA}) = full(X)*A
function *{TX,TvA,TiA}(X::AbstractMatrix{TX}, A::SparseMatrixCSC{TvA,TiA})
mX, nX = size(X)
nX == A.m || throw(DimensionMismatch())
Expand Down
26 changes: 15 additions & 11 deletions base/linalg/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
convert{T}(::Type{Bidiagonal}, A::Diagonal{T})=Bidiagonal(A.diag, zeros(T, size(A.diag,1)-1), true)
convert{T}(::Type{SymTridiagonal}, A::Diagonal{T})=SymTridiagonal(A.diag, zeros(T, size(A.diag,1)-1))
convert{T}(::Type{Tridiagonal}, A::Diagonal{T})=Tridiagonal(zeros(T, size(A.diag,1)-1), A.diag, zeros(T, size(A.diag,1)-1))
convert(::Type{Triangular}, A::Diagonal) = Triangular(full(A), :L)
convert(::Type{Triangular}, A::Bidiagonal) = Triangular(full(A), A.isupper ? :U : :L)
convert(::Type{UpperTriangular}, A::Diagonal) = UpperTriangular(full(A), :L)
convert(::Type{UpperTriangularUnit}, A::Diagonal) = UpperTriangularUnit(full(A), :L)
convert(::Type{LowerTriangular}, A::Diagonal) = LowerTriangular(full(A), :L)
convert(::Type{LowerTriangularUnit}, A::Diagonal) = LowerTriangularUnit(full(A), :L)
convert(::Type{LowerTriangular}, A::Bidiagonal) = !A.isupper ? LowerTriangular(full(A)) : throw(ArgumentError("Bidiagonal matrix must have lower off diagonal to be converted to LowerTriangular"))
convert(::Type{UpperTriangular}, A::Bidiagonal) = A.isupper ? UpperTriangular(full(A)) : throw(ArgumentError("Bidiagonal matrix must have upper off diagonal to be converted to UpperTriangular"))
convert(::Type{Matrix}, D::Diagonal) = diagm(D.diag)

function convert(::Type{Diagonal}, A::Union(Bidiagonal, SymTridiagonal))
Expand Down Expand Up @@ -42,12 +46,12 @@ function convert(::Type{SymTridiagonal}, A::Tridiagonal)
SymTridiagonal(A.d, A.dl)
end

function convert(::Type{Diagonal}, A::Triangular)
function convert(::Type{Diagonal}, A::TriangularUnion)
full(A) == diagm(diag(A)) || throw(ArgumentError("Matrix cannot be represented as Diagonal"))
Diagonal(diag(A))
end

function convert(::Type{Bidiagonal}, A::Triangular)
function convert(::Type{Bidiagonal}, A::TriangularUnion)
fA = full(A)
if fA == diagm(diag(A)) + diagm(diag(fA, 1), 1)
return Bidiagonal(diag(A), diag(fA,1), true)
Expand All @@ -58,9 +62,9 @@ function convert(::Type{Bidiagonal}, A::Triangular)
end
end

convert(::Type{SymTridiagonal}, A::Triangular) = convert(SymTridiagonal, convert(Tridiagonal, A))
convert(::Type{SymTridiagonal}, A::TriangularUnion) = convert(SymTridiagonal, convert(Tridiagonal, A))

function convert(::Type{Tridiagonal}, A::Triangular)
function convert(::Type{Tridiagonal}, A::TriangularUnion)
fA = full(A)
if fA == diagm(diag(A)) + diagm(diag(fA, 1), 1) + diagm(diag(fA, -1), -1)
return Tridiagonal(diag(fA, -1), diag(A), diag(fA,1))
Expand All @@ -82,7 +86,7 @@ macro commutative(myexpr)
end

for op in (:+, :-)
SpecialMatrices = [:Diagonal, :Bidiagonal, :Tridiagonal, :Triangular, :Matrix]
SpecialMatrices = [:Diagonal, :Bidiagonal, :Tridiagonal, :Matrix]
for (idx, matrixtype1) in enumerate(SpecialMatrices) #matrixtype1 is the sparser matrix type
for matrixtype2 in SpecialMatrices[idx+1:end] #matrixtype2 is the denser matrix type
@eval begin #TODO quite a few of these conversions are NOT defined...
Expand All @@ -93,7 +97,7 @@ for op in (:+, :-)
end

for matrixtype1 in (:SymTridiagonal,) #matrixtype1 is the sparser matrix type
for matrixtype2 in (:Tridiagonal, :Triangular, :Matrix) #matrixtype2 is the denser matrix type
for matrixtype2 in (:Tridiagonal, :Matrix) #matrixtype2 is the denser matrix type
@eval begin
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B))
Expand All @@ -111,7 +115,7 @@ for op in (:+, :-)
end
end

A_mul_Bc!(A::Triangular, B::QRCompactWYQ) = A_mul_Bc!(full!(A),B)
A_mul_Bc!(A::Triangular, B::QRPackedQ) = A_mul_Bc!(full!(A),B)
A_mul_Bc(A::Triangular, B::Union(QRCompactWYQ,QRPackedQ)) = A_mul_Bc(full(A), B)
A_mul_Bc!(A::TriangularUnion, B::QRCompactWYQ) = A_mul_Bc!(full!(A),B)
A_mul_Bc!(A::TriangularUnion, B::QRPackedQ) = A_mul_Bc!(full!(A),B)
A_mul_Bc(A::TriangularUnion, B::Union(QRCompactWYQ,QRPackedQ)) = A_mul_Bc(full(A), B)

Loading

0 comments on commit b3c1c82

Please sign in to comment.