diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 8ba4c3d457e838..2f02ec77b80b3b 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -404,22 +404,44 @@ end end out end -_maybe_unwrap_tri(out, A) = out, A -_maybe_unwrap_tri(out::UpperTriangular, A::UpperOrUnitUpperTriangular) = parent(out), parent(A) -_maybe_unwrap_tri(out::LowerTriangular, A::LowerOrUnitLowerTriangular) = parent(out), parent(A) +_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true +_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true +_has_matching_storage(out, A) = false +function _rowrange_tri_stored(B::UpperOrUpperTriangular, col) + isunit = B isa UnitUpperTriangular + 1:min(col-isunit, size(B,1)) +end +function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col) + isunit = B isa UnitLowerTriangular + col+isunit:size(B,1) +end +_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) +_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1 @inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul) - isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular} - # if both B and out have the same upper/lower triangular structure, - # we may directly read and write from the parents - out_maybeparent, B_maybeparent = _maybe_unwrap_tri(out, B) + isunit = B isa UnitUpperOrUnitLowerTriangular + out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B) for j in axes(B, 2) + # store the diagonal separately for unit triangular matrices if isunit - _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) + @inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) end - rowrange = B isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(B,1))) : (j+isunit:size(B,1)) + # indices of out corresponding to the stored indices of B + rowrange = _rowrange_tri_stored(B, j) @inbounds @simd for i in rowrange _modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) end + # indices of out corresponding to the zeros of B + # we only fill these if out and B don't have matching zeros + if !_has_matching_storage(out, B) + rowrange = _rowrange_tri_nonstored(B, j) + if haszero(eltype(out)) + _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) + else + @inbounds @simd for i in rowrange + _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + end + end + end end out end @@ -446,23 +468,37 @@ end out end @inline function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular} + isunit = A isa UnitUpperOrUnitLowerTriangular beta = _add.beta # since alpha is multiplied to the diagonal element of D, # we may skip alpha in the second multiplication by setting ais1 to true _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) # if both A and out have the same upper/lower triangular structure, # we may directly read and write from the parents - out_maybeparent, A_maybeparent = _maybe_unwrap_tri(out, A) - @inbounds for j in axes(A, 2) - dja = _add(D.diag[j]) + out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A) + for j in axes(A, 2) + dja = _add(@inbounds D.diag[j]) + # store the diagonal separately for unit triangular matrices if isunit - _modify!(_add_aisone, A[j,j] * dja, out, (j,j)) + @inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j)) end - rowrange = A isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(A,1))) : (j+isunit:size(A,1)) - @simd for i in rowrange + # indices of out corresponding to the stored indices of A + rowrange = _rowrange_tri_stored(A, j) + @inbounds @simd for i in rowrange _modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) end + # indices of out corresponding to the zeros of A + # we only fill these if out and A don't have matching zeros + if !_has_matching_storage(out, A) + rowrange = _rowrange_tri_nonstored(A, j) + if haszero(eltype(out)) + _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) + else + @inbounds @simd for i in rowrange + _modify!(_add, A[i,j] * dja, out, (i,j)) + end + end + end end out end diff --git a/stdlib/LinearAlgebra/test/addmul.jl b/stdlib/LinearAlgebra/test/addmul.jl index 208fa930e8ee14..fcd0b51b2e4c0b 100644 --- a/stdlib/LinearAlgebra/test/addmul.jl +++ b/stdlib/LinearAlgebra/test/addmul.jl @@ -239,4 +239,26 @@ end end end +@testset "Diagonal scaling of a triangular matrix with a non-triangular destination" begin + for MT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular) + U = MT(reshape([1:9;],3,3)) + M = Array(U) + D = Diagonal(1:3) + A = reshape([1:9;],3,3) + @test mul!(copy(A), U, D, 2, 3) == M * D * 2 + A * 3 + @test mul!(copy(A), D, U, 2, 3) == D * M * 2 + A * 3 + + # nan values with iszero(alpha) + D = Diagonal(fill(NaN,3)) + @test mul!(copy(A), U, D, 0, 3) == A * 3 + @test mul!(copy(A), D, U, 0, 3) == A * 3 + + # nan values with iszero(beta) + A = fill(NaN,3,3) + D = Diagonal(1:3) + @test mul!(copy(A), U, D, 2, 0) == M * D * 2 + @test mul!(copy(A), D, U, 2, 0) == D * M * 2 + end +end + end # module