Skip to content

Commit

Permalink
Merge pull request #2611 from jiahao/tridiag-simpleops
Browse files Browse the repository at this point in the history
Elementary operations for Tridiagonal and SymTridiagonal matrices
  • Loading branch information
ViralBShah committed Mar 19, 2013
2 parents 2e2755d + 764f89e commit 8b29497
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
52 changes: 45 additions & 7 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#### Specialized matrix types ####

import Base.conj, Base.transpose, Base.ctranspose

## Hermitian tridiagonal matrices
type SymTridiagonal{T<:BlasFloat} <: AbstractMatrix{T}
dv::Vector{T} # diagonal
Expand All @@ -23,8 +25,6 @@ end

SymTridiagonal(A::AbstractMatrix) = SymTridiagonal(diag(A), diag(A,1))

copy(S::SymTridiagonal) = SymTridiagonal(S.dv,S.ev)

function full(S::SymTridiagonal)
M = diagm(S.dv)
for i in 1:length(S.ev)
Expand All @@ -45,9 +45,30 @@ end
size(m::SymTridiagonal) = (length(m.dv), length(m.dv))
size(m::SymTridiagonal, d::Integer) = d<1 ? error("dimension out of range") : (d<2 ? length(m.dv) : 1)

eig(m::SymTridiagonal) = LAPACK.stegr!('V', copy(m.dv), copy(m.ev))
#Elementary operations
copy(S::SymTridiagonal) = SymTridiagonal(copy(S.dv), copy(S.ev))
round(M::SymTridiagonal) = SymTridiagonal(round(M.dv), round(M.ev))
iround(M::SymTridiagonal) = SymTridiagonal(iround(M.dv), iround(M.ev))

conj(M::SymTridiagonal) = SymTridiagonal(conj(M.dv), conj(M.ev))
transpose(M::SymTridiagonal) = M #Identity operation
ctranspose(M::SymTridiagonal) = conj(M)

+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, A.ev+B.ev)
-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev)
#XXX Returns dense matrix but really should be banded
*(A::SymTridiagonal, B::SymTridiagonal) = full(A)*full(B)

## Solver
function \{T<:BlasFloat}(M::SymTridiagonal{T}, rhs::StridedVecOrMat{T})
if stride(rhs, 1) == 1
return LAPACK.gtsv!(copy(M.dv), copy(M.ev), copy(M.dv), copy(rhs))
end
solve(Tridiagonal(M), rhs) # use the Julia "fallback"
end

#Wrap LAPACK DSTEBZ to compute eigenvalues
#Wrap LAPACK DSTE{GR,BZ} to compute eigenvalues
eig(m::SymTridiagonal) = LAPACK.stegr!('V', copy(m.dv), copy(m.ev))
eigvals(m::SymTridiagonal, il::Int, iu::Int) = LAPACK.stebz!('I', 'E', 0.0, 0.0, il, iu, -1.0, copy(m.dv), copy(m.ev))[1]
eigvals(m::SymTridiagonal, vl::Float64, vu::Float64) = LAPACK.stebz!('V', 'E', vl, vu, 0, 0, -1.0, copy(m.dv), copy(m.ev))[1]
eigvals(m::SymTridiagonal) = LAPACK.stebz!('A', 'E', 0.0, 0.0, 0, 0, -1.0, copy(m.dv), copy(m.ev))[1]
Expand Down Expand Up @@ -83,8 +104,6 @@ function Tridiagonal{Tl<:Number, Td<:Number, Tu<:Number}(dl::Vector{Tl}, d::Vect
Tridiagonal(convert(Vector{R}, dl), convert(Vector{R}, d), convert(Vector{R}, du))
end

copy(A::Tridiagonal) = Tridiagonal(copy(A.dl), copy(A.d), copy(A.du))

size(M::Tridiagonal) = (length(M.d), length(M.d))
function show(io::IO, M::Tridiagonal)
println(io, summary(M), ":")
Expand Down Expand Up @@ -113,12 +132,31 @@ function similar(M::Tridiagonal, T, dims::Dims)
end
return Tridiagonal{T}(dims[1])
end
copy(M::Tridiagonal) = Tridiagonal(M.dl, M.d, M.du)

# Operations on Tridiagonal matrices
copy(A::Tridiagonal) = Tridiagonal(copy(A.dl), copy(A.d), copy(A.du))
round(M::Tridiagonal) = Tridiagonal(round(M.dl), round(M.d), round(M.du))
iround(M::Tridiagonal) = Tridiagonal(iround(M.dl), iround(M.d), iround(M.du))

conj(M::Tridiagonal) = Tridiagonal(conj(M.du), conj(M.d), conj(M.dl))
transpose(M::Tridiagonal) = Tridiagonal(M.du, M.d, M.dl)
ctranspose(M::Tridiagonal) = conj(transpose(M))

+(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl+B.dl, A.d+B.d, A.du+B.du)
-(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl-B.dl, A.d-B.d, A.du+B.du)
#XXX Returns dense matrix but really should be banded
*(A::Tridiagonal, B::Tridiagonal) = full(A)*full(B)

# Elementary operations that mix Tridiagonal and SymTridiagonal matrices
Tridiagonal(A::SymTridiagonal) = Tridiagonal(A.dv, A.ev, A.dv)
+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.dv, A.d+B.ev, A.du+B.dv)
+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.dv+B.dl, A.ev+B.d, A.dv+B.du)
-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-B.dv, A.d-B.ev, A.du-B.dv)
-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.dv-B.dl, A.ev-B.d, A.dv-B.du)
#XXX Returns dense matrix but really should be banded
*(A::SymTridiagonal, B::Tridiagonal) = full(A)*full(B)
*(A::Tridiagonal, B::SymTridiagonal) = full(A)*full(B)

## Solvers

#### Tridiagonal matrix routines ####
Expand Down
5 changes: 5 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ for elty in (Float32, Float64, Complex64, Complex128)
F[i+1,i] = dl[i]
end
@test full(T) == F
# elementary operations on tridiagonals
@test conj(T) == Tridiagonal(conj(dl), conj(d), conj(du))
@test transpose(T) == Tridiagonal(du, d, du)
@test ctranspose(T) == Tridiagonal(conj(du), conj(d), conj(dl))

# tridiagonal linear algebra
v = convert(Vector{elty}, v)
@test_approx_eq T*v F*v
Expand Down

0 comments on commit 8b29497

Please sign in to comment.