Skip to content

Commit

Permalink
added promotions for SymTri and Tri (#48536)
Browse files Browse the repository at this point in the history
Co-authored-by: zzeuuus <[email protected]>
Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2023
1 parent 4d677a5 commit 24e43ad
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
23 changes: 19 additions & 4 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,14 @@ julia> A[2,1]
SymTridiagonal(dv::V, ev::V) where {T,V<:AbstractVector{T}} = SymTridiagonal{T}(dv, ev)
SymTridiagonal{T}(dv::V, ev::V) where {T,V<:AbstractVector{T}} = SymTridiagonal{T,V}(dv, ev)
function SymTridiagonal{T}(dv::AbstractVector, ev::AbstractVector) where {T}
SymTridiagonal(convert(AbstractVector{T}, dv)::AbstractVector{T},
convert(AbstractVector{T}, ev)::AbstractVector{T})
d = convert(AbstractVector{T}, dv)::AbstractVector{T}
e = convert(AbstractVector{T}, ev)::AbstractVector{T}
typeof(d) == typeof(e) ?
SymTridiagonal{T}(d, e) :
throw(ArgumentError("diagonal vectors needed to be convertible to same type"))
end
SymTridiagonal(d::AbstractVector{T}, e::AbstractVector{S}) where {T,S} =
SymTridiagonal{promote_type(T, S)}(d, e)

"""
SymTridiagonal(A::AbstractMatrix)
Expand Down Expand Up @@ -513,11 +518,21 @@ julia> Tridiagonal(dl, d, du)
"""
Tridiagonal(dl::V, d::V, du::V) where {T,V<:AbstractVector{T}} = Tridiagonal{T,V}(dl, d, du)
Tridiagonal(dl::V, d::V, du::V, du2::V) where {T,V<:AbstractVector{T}} = Tridiagonal{T,V}(dl, d, du, du2)
Tridiagonal(dl::AbstractVector{T}, d::AbstractVector{S}, du::AbstractVector{U}) where {T,S,U} =
Tridiagonal{promote_type(T, S, U)}(dl, d, du)
Tridiagonal(dl::AbstractVector{T}, d::AbstractVector{S}, du::AbstractVector{U}, du2::AbstractVector{V}) where {T,S,U,V} =
Tridiagonal{promote_type(T, S, U, V)}(dl, d, du, du2)
function Tridiagonal{T}(dl::AbstractVector, d::AbstractVector, du::AbstractVector) where {T}
Tridiagonal(map(x->convert(AbstractVector{T}, x), (dl, d, du))...)
l, d, u = map(x->convert(AbstractVector{T}, x), (dl, d, du))
typeof(l) == typeof(d) == typeof(u) ?
Tridiagonal(l, d, u) :
throw(ArgumentError("diagonal vectors needed to be convertible to same type"))
end
function Tridiagonal{T}(dl::AbstractVector, d::AbstractVector, du::AbstractVector, du2::AbstractVector) where {T}
Tridiagonal(map(x->convert(AbstractVector{T}, x), (dl, d, du, du2))...)
l, d, u, u2 = map(x->convert(AbstractVector{T}, x), (dl, d, du, du2))
typeof(l) == typeof(d) == typeof(u) == typeof(u2) ?
Tridiagonal(l, d, u, u2) :
throw(ArgumentError("diagonal vectors needed to be convertible to same type"))
end

"""
Expand Down
17 changes: 11 additions & 6 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ end
@test isa(ST, SymTridiagonal{elty,Vector{elty}})
TT = Tridiagonal{elty,Vector{elty}}(GenericArray(dl), d, GenericArray(dl))
@test isa(TT, Tridiagonal{elty,Vector{elty}})
@test_throws MethodError SymTridiagonal(d, GenericArray(dl))
@test_throws MethodError SymTridiagonal(GenericArray(d), dl)
@test_throws MethodError Tridiagonal(GenericArray(dl), d, GenericArray(dl))
@test_throws MethodError Tridiagonal(dl, GenericArray(d), dl)
@test_throws MethodError SymTridiagonal{elty}(d, GenericArray(dl))
@test_throws MethodError Tridiagonal{elty}(GenericArray(dl), d,GenericArray(dl))
@test_throws ArgumentError SymTridiagonal(d, GenericArray(dl))
@test_throws ArgumentError SymTridiagonal(GenericArray(d), dl)
@test_throws ArgumentError Tridiagonal(GenericArray(dl), d, GenericArray(dl))
@test_throws ArgumentError Tridiagonal(dl, GenericArray(d), dl)
@test_throws ArgumentError SymTridiagonal{elty}(d, GenericArray(dl))
@test_throws ArgumentError Tridiagonal{elty}(GenericArray(dl), d,GenericArray(dl))
STI = SymTridiagonal([1,2,3,4], [1,2,3])
TTI = Tridiagonal([1,2,3], [1,2,3,4], [1,2,3])
TTI2 = Tridiagonal([1,2,3], [1,2,3,4], [1,2,3], [1,2])
Expand Down Expand Up @@ -505,6 +505,11 @@ end
@test SymTridiagonal([1, 2], [0])^3 == [1 0; 0 8]
end

@testset "Issue #48505" begin
@test SymTridiagonal([1,2,3],[4,5.0]) == [1.0 4.0 0.0; 4.0 2.0 5.0; 0.0 5.0 3.0]
@test Tridiagonal([1, 2], [4, 5, 1], [6.0, 7]) == [4.0 6.0 0.0; 1.0 5.0 7.0; 0.0 2.0 1.0]
end

@testset "convert for SymTridiagonal" begin
STF32 = SymTridiagonal{Float32}(fill(1f0, 5), fill(1f0, 4))
@test convert(SymTridiagonal{Float64}, STF32)::SymTridiagonal{Float64} == STF32
Expand Down

0 comments on commit 24e43ad

Please sign in to comment.