Skip to content

Commit

Permalink
Make similar(S<:[Sym]Tridiagonal, [T,] shape...) return a sparse rath…
Browse files Browse the repository at this point in the history
…er than dense array.

Also fix a minor bug where lufact(A::Tridiagonal{T}) for !(T<:AbstractFloat) would dispatch
to lufact!(B, ...) for non-tridiagonal rather than tridiagonal B downstream.
  • Loading branch information
Sacha0 committed Oct 16, 2017
1 parent 41aa54c commit bdd8985
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
8 changes: 8 additions & 0 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ function lufact!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(
LU{T,Tridiagonal{T,V}}(B, ipiv, convert(BlasInt, info))
end

# Absent the following methods, if A::Tridiagonal{T} where !(T<:AbstractFloat), then the
# method lufact(A) hits sends copy!(similar(A, afloattype, size(A)), A) (which is not a
# Tridagonal given the (unnecessary?) shape argument to similar) to lufact! rather than
# copy!(similar(A, afloattype), A) (which is a Tridiagonal). The following method
# makes certain that lufact handles Tridiagonal matrices appropriately.
lufact(A::Tridiagonal{T}, pivot::Union{Val{false},Val{true}} = Val(true)) where {T} =
lufact!(copy!(similar(A, float(T)), A), pivot)

factorize(A::Tridiagonal) = lufact(A)

function getindex(F::LU{T,Tridiagonal{T,V}}, d::Symbol) where {T,V}
Expand Down
8 changes: 2 additions & 6 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1430,13 +1430,9 @@ end
## the element type doesn't have to be stable under division whereas that is
## necessary in the general triangular solve problem.

## Some Triangular-Triangular cases. We might want to write taylored methods
## Some Triangular-Triangular cases. We might want to write tailored methods
## for these cases, but I'm not sure it is worth it.
for t in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
@eval begin
(*)(A::Tridiagonal, B::$t) = A_mul_B!(Matrix(A), B)
end
end
(*)(A::Union{Tridiagonal,SymTridiagonal}, B::AbstractTriangular) = A_mul_B!(Matrix(A), B)

for (f1, f2) in ((:*, :A_mul_B!), (:\, :A_ldiv_B!))
@eval begin
Expand Down
15 changes: 11 additions & 4 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ function size(A::SymTridiagonal, d::Integer)
end
end

similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal{T}(similar(S.dv, T), similar(S.ev, T))
# For S<:SymTridiagonal, similar(S[, neweltype]) should yield a SymTridiagonal matrix.
# On the other hand, similar(S, [neweltype,] shape...) should yield a sparse matrix.
# The first method below effects the former, and the second the latter.
similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T), similar(S.ev, T))
similar(S::SymTridiagonal, ::Type{T}, dims::Dims{N}) where {T,N} = spzeros(T, dims...)

#Elementary operations
broadcast(::typeof(abs), M::SymTridiagonal) = SymTridiagonal(abs.(M.dv), abs.(M.ev))
Expand Down Expand Up @@ -499,9 +503,12 @@ end
convert(::Type{Matrix}, M::Tridiagonal{T}) where {T} = convert(Matrix{T}, M)
convert(::Type{Array}, M::Tridiagonal) = convert(Matrix, M)
full(M::Tridiagonal) = convert(Array, M)
function similar(M::Tridiagonal, ::Type{T}) where T
Tridiagonal{T}(similar(M.dl, T), similar(M.d, T), similar(M.du, T))
end

# For M<:Tridiagonal, similar(M[, neweltype]) should yield a Tridiagonal matrix.
# On the other hand, similar(M, [neweltype,] shape...) should yield a sparse matrix.
# The first method below effects the former, and the second the latter.
similar(M::Tridiagonal, ::Type{T}) where {T} = Tridiagonal(similar(M.dl, T), similar(M.d, T), similar(M.du, T))
similar(M::Tridiagonal, ::Type{T}, dims::Dims{N}) where {T,N} = spzeros(T, dims...)

# Operations on Tridiagonal matrices
copy!(dest::Tridiagonal, src::Tridiagonal) = (copy!(dest.dl, src.dl); copy!(dest.d, src.d); copy!(dest.du, src.du); dest)
Expand Down
6 changes: 4 additions & 2 deletions test/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ guardsrand(123) do
end
@test isa(similar(A), mat_type{elty})
@test isa(similar(A, Int), mat_type{Int})
@test isa(similar(A, Int, (3, 2)), Matrix{Int})
@test isa(similar(A, (3, 2)), SparseMatrixCSC)
@test isa(similar(A, Int, (3, 2)), SparseMatrixCSC{Int})
@test size(A, 3) == 1
@test size(A, 1) == n
@test size(A) == (n, n)
Expand Down Expand Up @@ -289,7 +290,8 @@ guardsrand(123) do
@testset "similar" begin
@test isa(similar(Ts), SymTridiagonal{elty})
@test isa(similar(Ts, Int), SymTridiagonal{Int})
@test isa(similar(Ts, Int, (3,2)), Matrix{Int})
@test isa(similar(Ts, (3, 2)), SparseMatrixCSC)
@test isa(similar(Ts, Int, (3, 2)), SparseMatrixCSC{Int})
end

@test first(logabsdet(Tldlt)) first(logabsdet(Fs))
Expand Down

0 comments on commit bdd8985

Please sign in to comment.