From e33c6a8551e070e7936a2ac95180a6c834f56549 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 18 Oct 2024 17:04:44 +0530 Subject: [PATCH] Specialize adding/subtracting mixed Upper/LowerTriangular (#56149) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes https://github.com/JuliaLang/julia/issues/56134 After this, ```julia julia> using LinearAlgebra julia> A = hermitianpart(rand(4, 4)) 4×4 Hermitian{Float64, Matrix{Float64}}: 0.387617 0.277226 0.67629 0.60678 0.277226 0.894101 0.388416 0.489141 0.67629 0.388416 0.100907 0.619955 0.60678 0.489141 0.619955 0.452605 julia> B = UpperTriangular(A) 4×4 UpperTriangular{Float64, Hermitian{Float64, Matrix{Float64}}}: 0.387617 0.277226 0.67629 0.60678 ⋅ 0.894101 0.388416 0.489141 ⋅ ⋅ 0.100907 0.619955 ⋅ ⋅ ⋅ 0.452605 julia> B - B' 4×4 Matrix{Float64}: 0.0 0.277226 0.67629 0.60678 -0.277226 0.0 0.388416 0.489141 -0.67629 -0.388416 0.0 0.619955 -0.60678 -0.489141 -0.619955 0.0 ``` This preserves the band structure of the parent, if any: ```julia julia> U = UpperTriangular(Diagonal(ones(4))) 4×4 UpperTriangular{Float64, Diagonal{Float64, Vector{Float64}}}: 1.0 0.0 0.0 0.0 ⋅ 1.0 0.0 0.0 ⋅ ⋅ 1.0 0.0 ⋅ ⋅ ⋅ 1.0 julia> U - U' 4×4 Diagonal{Float64, Vector{Float64}}: 0.0 ⋅ ⋅ ⋅ ⋅ 0.0 ⋅ ⋅ ⋅ ⋅ 0.0 ⋅ ⋅ ⋅ ⋅ 0.0 ``` This doesn't fully work with partly initialized matrices, and would need https://github.com/JuliaLang/julia/pull/55312 for that. The abstract triangular methods now construct matrices using `similar(parent(U), size(U))` so that the destinations are fully mutable. ```julia julia> @invoke B::LinearAlgebra.AbstractTriangular - B'::LinearAlgebra.AbstractTriangular 4×4 Matrix{Float64}: 0.0 0.277226 0.67629 0.60678 -0.277226 0.0 0.388416 0.489141 -0.67629 -0.388416 0.0 0.619955 -0.60678 -0.489141 -0.619955 0.0 ``` --------- Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/src/triangular.jl | 19 +++++++++-- stdlib/LinearAlgebra/test/triangular.jl | 43 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 71660bc5ca28c..83ef221329d33 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -142,6 +142,7 @@ UnitUpperTriangular const UpperOrUnitUpperTriangular{T,S} = Union{UpperTriangular{T,S}, UnitUpperTriangular{T,S}} const LowerOrUnitLowerTriangular{T,S} = Union{LowerTriangular{T,S}, UnitLowerTriangular{T,S}} const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}} +const UnitUpperOrUnitLowerTriangular{T,S} = Union{UnitUpperTriangular{T,S}, UnitLowerTriangular{T,S}} uppertriangular(M) = UpperTriangular(M) lowertriangular(M) = LowerTriangular(M) @@ -181,6 +182,16 @@ copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copyto!(si # then handle all methods that requires specific handling of upper/lower and unit diagonal +function full(A::Union{UpperTriangular,LowerTriangular}) + return _triangularize(A)(parent(A)) +end +function full(A::UnitUpperOrUnitLowerTriangular) + isupper = A isa UnitUpperTriangular + Ap = _triangularize(A)(parent(A), isupper ? 1 : -1) + Ap[diagind(Ap, IndexStyle(Ap))] = @view A[diagind(A, IndexStyle(A))] + return Ap +end + function full!(A::LowerTriangular) B = A.data tril!(B) @@ -571,6 +582,8 @@ end return A end +_triangularize(::UpperOrUnitUpperTriangular) = triu +_triangularize(::LowerOrUnitLowerTriangular) = tril _triangularize!(::UpperOrUnitUpperTriangular) = triu! _triangularize!(::LowerOrUnitLowerTriangular) = tril! @@ -880,7 +893,8 @@ function +(A::UnitLowerTriangular, B::UnitLowerTriangular) (parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I) end -+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) + copyto!(similar(parent(B)), B) ++(A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = full(A) + full(B) ++(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A), size(A)), A) + copyto!(similar(parent(B), size(B)), B) function -(A::UpperTriangular, B::UpperTriangular) (parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B @@ -914,7 +928,8 @@ function -(A::UnitLowerTriangular, B::UnitLowerTriangular) (parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B LowerTriangular(tril(A.data, -1) - tril(B.data, -1)) end --(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B) +-(A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = full(A) - full(B) +-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A), size(A)), A) - copyto!(similar(parent(B), size(B)), B) function kron(A::UpperTriangular{T,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{S,<:StridedMaybeAdjOrTransMat}) where {T,S} C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B))) diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index ec9a3079e2643..7acb3cbfc0c57 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -1322,4 +1322,47 @@ end end end +@testset "addition/subtraction of mixed triangular" begin + for A in (Hermitian(rand(4, 4)), Diagonal(rand(5))) + for T in (UpperTriangular, LowerTriangular, + UnitUpperTriangular, UnitLowerTriangular) + B = T(A) + M = Matrix(B) + R = B - B' + if A isa Diagonal + @test R isa Diagonal + end + @test R == M - M' + R = B + B' + if A isa Diagonal + @test R isa Diagonal + end + @test R == M + M' + C = MyTriangular(B) + @test C - C' == M - M' + @test C + C' == M + M' + end + end + @testset "unfilled parent" begin + @testset for T in (UpperTriangular, LowerTriangular, + UnitUpperTriangular, UnitLowerTriangular) + F = Matrix{BigFloat}(undef, 2, 2) + B = T(F) + isupper = B isa Union{UpperTriangular, UnitUpperTriangular} + B[1+!isupper, 1+isupper] = 2 + if !(B isa Union{UnitUpperTriangular, UnitLowerTriangular}) + B[1,1] = B[2,2] = 3 + end + M = Matrix(B) + @test B - B' == M - M' + @test B + B' == M + M' + @test B - copy(B') == M - M' + @test B + copy(B') == M + M' + C = MyTriangular(B) + @test C - C' == M - M' + @test C + C' == M + M' + end + end +end + end # module TestTriangular