diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 8d6f60c3d42f4..5c32058a4b6e9 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -405,9 +405,9 @@ end end out end -_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true -_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true -_has_matching_storage(out, A) = false +_has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true +_has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true +_has_matching_zeros(out, A) = false function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col) isunit = B isa UnitUpperTriangular 1:min(col-isunit, size(B,1)) @@ -416,31 +416,27 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col) isunit = B isa UnitLowerTriangular col+isunit:size(B,1) end -_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) -_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1 +_rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) +_rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1 function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul) isunit = B isa UnitUpperOrUnitLowerTriangular - out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B) + out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B) for j in axes(B, 2) # store the diagonal separately for unit triangular matrices if isunit @inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) end - # indices of out corresponding to the stored indices of B + # The indices of out corresponding to the stored indices of B rowrange = _rowrange_tri_stored(B, j) @inbounds @simd for i in rowrange _modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) end - # indices of out corresponding to the zeros of B + # Fill the indices of out corresponding to the zeros of B # we only fill these if out and B don't have matching zeros - if !_has_matching_storage(out, B) - rowrange = _rowrange_tri_nonstored(B, j) - if haszero(eltype(out)) - _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) - else - @inbounds @simd for i in rowrange - _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) - end + if !_has_matching_zeros(out, B) + rowrange = _rowrange_tri_zeros(B, j) + @inbounds @simd for i in rowrange + _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) end end end @@ -476,7 +472,7 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) # if both A and out have the same upper/lower triangular structure, # we may directly read and write from the parents - out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A) + out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A) for j in axes(A, 2) dja = _add(@inbounds D.diag[j]) # store the diagonal separately for unit triangular matrices @@ -488,16 +484,12 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a @inbounds @simd for i in rowrange _modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) end - # indices of out corresponding to the zeros of A + # Fill the indices of out corresponding to the zeros of A # we only fill these if out and A don't have matching zeros - if !_has_matching_storage(out, A) - rowrange = _rowrange_tri_nonstored(A, j) - if haszero(eltype(out)) - _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) - else - @inbounds @simd for i in rowrange - _modify!(_add, A[i,j] * dja, out, (i,j)) - end + if !_has_matching_zeros(out, A) + rowrange = _rowrange_tri_zeros(A, j) + @inbounds @simd for i in rowrange + _modify!(_add, A[i,j] * dja, out, (i,j)) end end end