Skip to content

Commit

Permalink
Fix multiplying a triangular matrix and a Diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 21, 2024
1 parent 04259da commit 5229375
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 16 deletions.
68 changes: 52 additions & 16 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,22 +404,44 @@ end
end
out
end
_maybe_unwrap_tri(out, A) = out, A
_maybe_unwrap_tri(out::UpperTriangular, A::UpperOrUnitUpperTriangular) = parent(out), parent(A)
_maybe_unwrap_tri(out::LowerTriangular, A::LowerOrUnitLowerTriangular) = parent(out), parent(A)
_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
_has_matching_storage(out, A) = false
function _rowrange_tri_stored(B::UpperOrUpperTriangular, col)
isunit = B isa UnitUpperTriangular
1:min(col-isunit, size(B,1))
end
function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
isunit = B isa UnitLowerTriangular
col+isunit:size(B,1)
end
_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
# if both B and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, B_maybeparent = _maybe_unwrap_tri(out, B)
isunit = B isa UnitUpperOrUnitLowerTriangular
out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B)
for j in axes(B, 2)
# store the diagonal separately for unit triangular matrices
if isunit
_modify!(_add, D.diag[j] * B[j,j], out, (j,j))
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
end
rowrange = B isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(B,1))) : (j+isunit:size(B,1))
# indices of out corresponding to the stored indices of B
rowrange = _rowrange_tri_stored(B, j)
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
end
# indices of out corresponding to the zeros of B
# we only fill these if out and B don't have matching zeros
if !_has_matching_storage(out, B)
rowrange = _rowrange_tri_nonstored(B, j)
if haszero(eltype(out))
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
else
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
end
end
end
end
out
end
Expand All @@ -446,23 +468,37 @@ end
out
end
@inline function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
isunit = A isa UnitUpperOrUnitLowerTriangular
beta = _add.beta
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
# if both A and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, A_maybeparent = _maybe_unwrap_tri(out, A)
@inbounds for j in axes(A, 2)
dja = _add(D.diag[j])
out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A)
for j in axes(A, 2)
dja = _add(@inbounds D.diag[j])
# store the diagonal separately for unit triangular matrices
if isunit
_modify!(_add_aisone, A[j,j] * dja, out, (j,j))
@inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j))
end
rowrange = A isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(A,1))) : (j+isunit:size(A,1))
@simd for i in rowrange
# indices of out corresponding to the stored indices of A
rowrange = _rowrange_tri_stored(A, j)
@inbounds @simd for i in rowrange
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
end
# indices of out corresponding to the zeros of A
# we only fill these if out and A don't have matching zeros
if !_has_matching_storage(out, A)
rowrange = _rowrange_tri_nonstored(A, j)
if haszero(eltype(out))
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
else
@inbounds @simd for i in rowrange
_modify!(_add, A[i,j] * dja, out, (i,j))
end
end
end
end
out
end
Expand Down
22 changes: 22 additions & 0 deletions stdlib/LinearAlgebra/test/addmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,26 @@ end
end
end

@testset "Diagonal scaling of a triangular matrix with a non-triangular destination" begin
for MT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
U = MT(reshape([1:9;],3,3))
M = Array(U)
D = Diagonal(1:3)
A = reshape([1:9;],3,3)
@test mul!(copy(A), U, D, 2, 3) == M * D * 2 + A * 3
@test mul!(copy(A), D, U, 2, 3) == D * M * 2 + A * 3

# nan values with iszero(alpha)
D = Diagonal(fill(NaN,3))
@test mul!(copy(A), U, D, 0, 3) == A * 3
@test mul!(copy(A), D, U, 0, 3) == A * 3

# nan values with iszero(beta)
A = fill(NaN,3,3)
D = Diagonal(1:3)
@test mul!(copy(A), U, D, 2, 0) == M * D * 2
@test mul!(copy(A), D, U, 2, 0) == D * M * 2
end
end

end # module

0 comments on commit 5229375

Please sign in to comment.