From 786ba2f69b3f7274a4fa4a27ba8fa9929aa523bd Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 21 Oct 2024 17:55:08 +0530 Subject: [PATCH] Fix multiplying a triangular matrix and a Diagonal --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 2 + stdlib/LinearAlgebra/src/diagonal.jl | 190 +++++++++++++++------- stdlib/LinearAlgebra/test/addmul.jl | 22 +++ stdlib/LinearAlgebra/test/diagonal.jl | 40 ++++- 4 files changed, 190 insertions(+), 64 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 15354603943c22..88fc3476c9d7fd 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -655,6 +655,8 @@ matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS) matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS) matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS) _matprod_dest_diag(A, TS) = similar(A, TS) +_matprod_dest_diag(A::UnitUpperTriangular, TS) = UpperTriangular(similar(parent(A), TS)) +_matprod_dest_diag(A::UnitLowerTriangular, TS) = LowerTriangular(similar(parent(A), TS)) function _matprod_dest_diag(A::SymTridiagonal, TS) n = size(A, 1) ev = similar(A, TS, max(0, n-1)) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 6e8ce96259fc1a..260098aa88fef3 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -396,82 +396,156 @@ function lmul!(D::Diagonal, T::Tridiagonal) return T end -function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} +@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul) + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + end + end + out +end +_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true +_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true +_has_matching_storage(out, A) = false +function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col) + isunit = B isa UnitUpperTriangular + 1:min(col-isunit, size(B,1)) +end +function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col) + isunit = B isa UnitLowerTriangular + col+isunit:size(B,1) +end +_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) +_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1 +function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul) + isunit = B isa UnitUpperOrUnitLowerTriangular + out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B) + for j in axes(B, 2) + # store the diagonal separately for unit triangular matrices + if isunit + @inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) + end + # indices of out corresponding to the stored indices of B + rowrange = _rowrange_tri_stored(B, j) + @inbounds @simd for i in rowrange + _modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) + end + # indices of out corresponding to the zeros of B + # we only fill these if out and B don't have matching zeros + if !_has_matching_storage(out, B) + rowrange = _rowrange_tri_nonstored(B, j) + if haszero(eltype(out)) + _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) + else + @inbounds @simd for i in rowrange + _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + end + end + end + end + out +end +function __muldiag!(out, D::Diagonal, B, _add::MulAddMul) require_one_based_indexing(out, B) alpha, beta = _add.alpha, _add.beta if iszero(alpha) _rmul_or_fill!(out, beta) else - if bis0 - @inbounds for j in axes(B, 2) - @simd for i in axes(B, 1) - out[i,j] = D.diag[i] * B[i,j] * alpha - end - end - else - @inbounds for j in axes(B, 2) - @simd for i in axes(B, 1) - out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta + __muldiag_nonzeroalpha!(out, D, B, _add) + end + return out +end + +@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} + beta = _add.beta + _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) + @inbounds for j in axes(A, 2) + dja = _add(D.diag[j]) + @simd for i in axes(A, 1) + _modify!(_add_aisone, A[i,j] * dja, out, (i,j)) + end + end + out +end +function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} + isunit = A isa UnitUpperOrUnitLowerTriangular + beta = _add.beta + # since alpha is multiplied to the diagonal element of D, + # we may skip alpha in the second multiplication by setting ais1 to true + _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) + # if both A and out have the same upper/lower triangular structure, + # we may directly read and write from the parents + out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A) + for j in axes(A, 2) + dja = _add(@inbounds D.diag[j]) + # store the diagonal separately for unit triangular matrices + if isunit + @inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j)) + end + # indices of out corresponding to the stored indices of A + rowrange = _rowrange_tri_stored(A, j) + @inbounds @simd for i in rowrange + _modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) + end + # indices of out corresponding to the zeros of A + # we only fill these if out and A don't have matching zeros + if !_has_matching_storage(out, A) + rowrange = _rowrange_tri_nonstored(A, j) + if haszero(eltype(out)) + _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) + else + @inbounds @simd for i in rowrange + _modify!(_add, A[i,j] * dja, out, (i,j)) end end end end - return out + out end -function __muldiag!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} +function __muldiag!(out, A, D::Diagonal, _add::MulAddMul) require_one_based_indexing(out, A) alpha, beta = _add.alpha, _add.beta if iszero(alpha) _rmul_or_fill!(out, beta) else - if bis0 - @inbounds for j in axes(A, 2) - dja = D.diag[j] * alpha - @simd for i in axes(A, 1) - out[i,j] = A[i,j] * dja - end - end - else - @inbounds for j in axes(A, 2) - dja = D.diag[j] * alpha - @simd for i in axes(A, 1) - out[i,j] = A[i,j] * dja + out[i,j] * beta - end - end - end + __muldiag_nonzeroalpha!(out, A, D, _add) end return out end -function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} + +@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul) d1 = D1.diag d2 = D2.diag + outd = out.diag + @inbounds @simd for i in eachindex(d1, d2, outd) + _modify!(_add, d1[i] * d2[i], outd, i) + end + out +end +function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul) alpha, beta = _add.alpha, _add.beta if iszero(alpha) _rmul_or_fill!(out.diag, beta) else - if bis0 - @inbounds @simd for i in eachindex(out.diag) - out.diag[i] = d1[i] * d2[i] * alpha - end - else - @inbounds @simd for i in eachindex(out.diag) - out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta - end - end + __muldiag_nonzeroalpha!(out, D1, D2, _add) end return out end -function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - require_one_based_indexing(out) - alpha, beta = _add.alpha, _add.beta - mA = size(D1, 1) +@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul) d1 = D1.diag d2 = D2.diag + @inbounds @simd for i in eachindex(d1, d2) + _modify!(_add, d1[i] * d2[i], out, (i,i)) + end + out +end +function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1}) where {ais1} + require_one_based_indexing(out) + alpha, beta = _add.alpha, _add.beta _rmul_or_fill!(out, beta) if !iszero(alpha) - @inbounds @simd for i in 1:mA - out[i,i] += d1[i] * d2[i] * alpha - end + _add_bis1 = MulAddMul{ais1,false,typeof(alpha),Bool}(alpha,true) + __muldiag_nonzeroalpha!(out, D1, D2, _add_bis1) end return out end @@ -658,31 +732,21 @@ for Tri in (:UpperTriangular, :LowerTriangular) @eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D)) @eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag)) end + @eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = + @invoke *(A::AbstractMatrix, D::Diagonal) + @eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = + @invoke *(A::AbstractMatrix, D::Diagonal) for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv)) @eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data)) @eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag)) end + @eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) = + @invoke *(D::Diagonal, A::AbstractMatrix) + @eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) = + @invoke *(D::Diagonal, A::AbstractMatrix) # 3-arg ldiv! @eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data)) @eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag)) - # 3-arg mul! is disambiguated in special.jl - # 5-arg mul! - @eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta)) - @eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - α, β = _add.alpha, _add.beta - iszero(α) && return _rmul_or_fill!(C, β) - diag′ = bis0 ? nothing : diag(C) - data = mul!(C.data, D, A.data, α, β) - $Tri(_setdiag!(data, _add, D.diag, diag′)) - end - @eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta)) - @eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - α, β = _add.alpha, _add.beta - iszero(α) && return _rmul_or_fill!(C, β) - diag′ = bis0 ? nothing : diag(C) - data = mul!(C.data, A.data, D, α, β) - $Tri(_setdiag!(data, _add, D.diag, diag′)) - end end @inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal) diff --git a/stdlib/LinearAlgebra/test/addmul.jl b/stdlib/LinearAlgebra/test/addmul.jl index 208fa930e8ee14..fcd0b51b2e4c0b 100644 --- a/stdlib/LinearAlgebra/test/addmul.jl +++ b/stdlib/LinearAlgebra/test/addmul.jl @@ -239,4 +239,26 @@ end end end +@testset "Diagonal scaling of a triangular matrix with a non-triangular destination" begin + for MT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular) + U = MT(reshape([1:9;],3,3)) + M = Array(U) + D = Diagonal(1:3) + A = reshape([1:9;],3,3) + @test mul!(copy(A), U, D, 2, 3) == M * D * 2 + A * 3 + @test mul!(copy(A), D, U, 2, 3) == D * M * 2 + A * 3 + + # nan values with iszero(alpha) + D = Diagonal(fill(NaN,3)) + @test mul!(copy(A), U, D, 0, 3) == A * 3 + @test mul!(copy(A), D, U, 0, 3) == A * 3 + + # nan values with iszero(beta) + A = fill(NaN,3,3) + D = Diagonal(1:3) + @test mul!(copy(A), U, D, 2, 0) == M * D * 2 + @test mul!(copy(A), D, U, 2, 0) == D * M * 2 + end +end + end # module diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 1c3a9dfa676acc..380a0465028d17 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1188,7 +1188,7 @@ end @test oneunit(D3) isa typeof(D3) end -@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular)) +@testset "$Tri" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular)) A = randn(4, 4) TriA = Tri(A) UTriA = UTri(A) @@ -1218,6 +1218,44 @@ end @test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1) @test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1) @test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1) + + # we may write to a Unit triangular if the diagonal is preserved + ID = Diagonal(ones(size(UTriA,2))) + @test mul!(copy(UTriA), UTriA, ID) == UTriA + @test mul!(copy(UTriA), ID, UTriA) == UTriA + + @testset "partly filled parents" begin + M = Matrix{BigFloat}(undef, 2, 2) + M[1,1] = M[2,2] = 3 + isupper = Tri == UpperTriangular + M[1+!isupper, 1+isupper] = 3 + D = Diagonal(1:2) + T = Tri(M) + TA = Array(T) + @test T * D == TA * D + @test D * T == D * TA + @test mul!(copy(T), T, D, 2, 3) == 2T * D + 3T + @test mul!(copy(T), D, T, 2, 3) == 2D * T + 3T + + U = UTri(M) + UA = Array(U) + @test U * D == UA * D + @test D * U == D * UA + @test mul!(copy(T), U, D, 2, 3) == 2 * UA * D + 3TA + @test mul!(copy(T), D, U, 2, 3) == 2 * D * UA + 3TA + + M2 = Matrix{BigFloat}(undef, 2, 2) + M2[1+!isupper, 1+isupper] = 3 + U = UTri(M2) + UA = Array(U) + @test U * D == UA * D + @test D * U == D * UA + ID = Diagonal(ones(size(U,2))) + @test mul!(copy(U), U, ID) == U + @test mul!(copy(U), ID, U) == U + @test mul!(copy(U), U, ID, 2, -1) == U + @test mul!(copy(U), ID, U, 2, -1) == U + end end struct SMatrix1{T} <: AbstractArray{T,2}