From 168f7e76c211a50070f28cc9cb50eeb44b87253b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 21 Oct 2024 12:38:11 +0530 Subject: [PATCH] Specialize adding/subtracting mixed Upper/LowerTriangular (#56149) --- stdlib/LinearAlgebra/src/triangular.jl | 20 +++++++++- stdlib/LinearAlgebra/test/triangular.jl | 52 +++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index b1de141b109c0..e3c96533a9864 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -152,6 +152,7 @@ UnitUpperTriangular const UpperOrUnitUpperTriangular{T,S} = Union{UpperTriangular{T,S}, UnitUpperTriangular{T,S}} const LowerOrUnitLowerTriangular{T,S} = Union{LowerTriangular{T,S}, UnitLowerTriangular{T,S}} const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}} +const UnitUpperOrUnitLowerTriangular{T,S} = Union{UnitUpperTriangular{T,S}, UnitLowerTriangular{T,S}} uppertriangular(M) = UpperTriangular(M) lowertriangular(M) = LowerTriangular(M) @@ -221,6 +222,16 @@ function Matrix{T}(A::UnitUpperTriangular) where T B end +function full(A::Union{UpperTriangular,LowerTriangular}) + return _triangularize(A)(parent(A)) +end +function full(A::UnitUpperOrUnitLowerTriangular) + isupper = A isa UnitUpperTriangular + Ap = _triangularize(A)(parent(A), isupper ? 1 : -1) + Ap[diagind(Ap, IndexStyle(Ap))] = @view A[diagind(A, IndexStyle(A))] + return Ap +end + function full!(A::LowerTriangular) B = A.data tril!(B) @@ -553,6 +564,9 @@ function copyto!(A::T, B::T) where {T<:Union{LowerTriangular,UnitLowerTriangular return A end +_triangularize(::UpperOrUnitUpperTriangular) = triu +_triangularize(::LowerOrUnitLowerTriangular) = tril + @inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) = _triscale!(A, B, C, MulAddMul(alpha, beta)) @inline _lscale_add!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) = @@ -812,7 +826,8 @@ function +(A::UnitLowerTriangular, B::UnitLowerTriangular) (parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I) end -+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) + copyto!(similar(parent(B)), B) ++(A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = full(A) + full(B) ++(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A), size(A)), A) + copyto!(similar(parent(B), size(B)), B) function -(A::UpperTriangular, B::UpperTriangular) (parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B @@ -846,7 +861,8 @@ function -(A::UnitLowerTriangular, B::UnitLowerTriangular) (parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B LowerTriangular(tril(A.data, -1) - tril(B.data, -1)) end --(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B) +-(A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = full(A) - full(B) +-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A), size(A)), A) - copyto!(similar(parent(B), size(B)), B) # use broadcasting if the parents are strided, where we loop only over the triangular part for op in (:+, :-) diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index f37bf05650452..f3e6e4a744088 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -25,6 +25,12 @@ debug && println("Test basic type functionality") @test_throws DimensionMismatch LowerTriangular(randn(5, 4)) @test LowerTriangular(randn(3, 3)) |> t -> [size(t, i) for i = 1:3] == [size(Matrix(t), i) for i = 1:3] +struct MyTriangular{T, A<:LinearAlgebra.AbstractTriangular{T}} <: LinearAlgebra.AbstractTriangular{T} + data :: A +end +Base.size(A::MyTriangular) = size(A.data) +Base.getindex(A::MyTriangular, i::Int, j::Int) = A.data[i,j] + # The following test block tries to call all methods in base/linalg/triangular.jl in order for a combination of input element types. Keep the ordering when adding code. @testset for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}, Int) # Begin loop for first Triangular matrix @@ -1078,4 +1084,50 @@ end end end +@testset "addition/subtraction of mixed triangular" begin + for A in (Hermitian(rand(4, 4)), Diagonal(rand(5))) + for T in (UpperTriangular, LowerTriangular, + UnitUpperTriangular, UnitLowerTriangular) + B = T(A) + M = Matrix(B) + R = B - B' + if A isa Diagonal + @test R isa Diagonal + end + @test R == M - M' + R = B + B' + if A isa Diagonal + @test R isa Diagonal + end + @test R == M + M' + C = MyTriangular(B) + @test C - C' == M - M' + @test C + C' == M + M' + end + end + @testset "unfilled parent" begin + @testset for T in (UpperTriangular, LowerTriangular, + UnitUpperTriangular, UnitLowerTriangular) + F = Matrix{BigFloat}(undef, 2, 2) + B = T(F) + isupper = B isa Union{UpperTriangular, UnitUpperTriangular} + B[1+!isupper, 1+isupper] = 2 + if !(B isa Union{UnitUpperTriangular, UnitLowerTriangular}) + B[1,1] = B[2,2] = 3 + end + M = Matrix(B) + # These are broken, as triu/tril don't work with + # unfilled adjoint matrices + # See https://github.com/JuliaLang/julia/pull/55312 + @test_broken B - B' == M - M' + @test_broken B + B' == M + M' + @test B - copy(B') == M - M' + @test B + copy(B') == M + M' + C = MyTriangular(B) + @test C - C' == M - M' + @test C + C' == M + M' + end + end +end + end # module TestTriangular