Skip to content

Commit

Permalink
Rename _has_matching_storage to _has_matching_zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 25, 2024
1 parent 3ee4650 commit 16a2b0d
Showing 1 changed file with 18 additions and 26 deletions.
44 changes: 18 additions & 26 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,9 @@ end
end
out
end
_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
_has_matching_storage(out, A) = false
_has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
_has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
_has_matching_zeros(out, A) = false
function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col)
isunit = B isa UnitUpperTriangular
1:min(col-isunit, size(B,1))
Expand All @@ -416,31 +416,27 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
isunit = B isa UnitLowerTriangular
col+isunit:size(B,1)
end
_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1
_rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
_rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
isunit = B isa UnitUpperOrUnitLowerTriangular
out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B)
out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B)
for j in axes(B, 2)
# store the diagonal separately for unit triangular matrices
if isunit
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
end
# indices of out corresponding to the stored indices of B
# The indices of out corresponding to the stored indices of B
rowrange = _rowrange_tri_stored(B, j)
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
end
# indices of out corresponding to the zeros of B
# Fill the indices of out corresponding to the zeros of B
# we only fill these if out and B don't have matching zeros
if !_has_matching_storage(out, B)
rowrange = _rowrange_tri_nonstored(B, j)
if haszero(eltype(out))
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
else
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
end
if !_has_matching_zeros(out, B)
rowrange = _rowrange_tri_zeros(B, j)
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
end
end
end
Expand Down Expand Up @@ -476,7 +472,7 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
# if both A and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A)
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
for j in axes(A, 2)
dja = _add(@inbounds D.diag[j])
# store the diagonal separately for unit triangular matrices
Expand All @@ -488,16 +484,12 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a
@inbounds @simd for i in rowrange
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
end
# indices of out corresponding to the zeros of A
# Fill the indices of out corresponding to the zeros of A
# we only fill these if out and A don't have matching zeros
if !_has_matching_storage(out, A)
rowrange = _rowrange_tri_nonstored(A, j)
if haszero(eltype(out))
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
else
@inbounds @simd for i in rowrange
_modify!(_add, A[i,j] * dja, out, (i,j))
end
if !_has_matching_zeros(out, A)
rowrange = _rowrange_tri_zeros(A, j)
@inbounds @simd for i in rowrange
_modify!(_add, A[i,j] * dja, out, (i,j))
end
end
end
Expand Down

0 comments on commit 16a2b0d

Please sign in to comment.