Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize copyto! for triangular matrices #52730

Merged
merged 6 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,45 +185,13 @@ function imag(A::UnitUpperTriangular)
return Uim
end

Array(A::AbstractTriangular) = Matrix(A)
parent(A::UpperOrLowerTriangular) = A.data

# For strided matrices, we may only loop over the filled triangle
copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copyto!(similar(A), A)

# then handle all methods that requires specific handling of upper/lower and unit diagonal

function Matrix{T}(A::LowerTriangular) where T
B = Matrix{T}(undef, size(A, 1), size(A, 1))
copyto!(B, A.data)
tril!(B)
B
end
function Matrix{T}(A::UnitLowerTriangular) where T
B = Matrix{T}(undef, size(A, 1), size(A, 1))
copyto!(B, A.data)
tril!(B)
for i = 1:size(B,1)
B[i,i] = oneunit(T)
end
B
end
function Matrix{T}(A::UpperTriangular) where T
B = Matrix{T}(undef, size(A, 1), size(A, 1))
copyto!(B, A.data)
triu!(B)
B
end
function Matrix{T}(A::UnitUpperTriangular) where T
B = Matrix{T}(undef, size(A, 1), size(A, 1))
copyto!(B, A.data)
triu!(B)
for i = 1:size(B,1)
B[i,i] = oneunit(T)
end
B
end

function full!(A::LowerTriangular)
B = A.data
tril!(B)
Expand Down Expand Up @@ -544,6 +512,57 @@ function copyto!(A::T, B::T) where {T<:Union{LowerTriangular,UnitLowerTriangular
return A
end

_triangularize!(::UpperOrUnitUpperTriangular) = triu!
_triangularize!(::LowerOrUnitLowerTriangular) = tril!

function copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular)
if axes(dest) != axes(U)
@invoke copyto!(dest::StridedMatrix, U::AbstractArray)
else
_copyto!(dest, U)
end
return dest
end
function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular)
copytrito!(dest, parent(U), U isa UpperOrUnitUpperTriangular ? 'U' : 'L')
_triangularize!(U)(dest)
if U isa Union{UnitUpperTriangular, UnitLowerTriangular}
dest[diagind(dest)] .= @view U[diagind(U, IndexCartesian())]
end
return dest
end
function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix})
U2 = Base.unalias(dest, U)
copyto_unaliased!(dest, U2)
return dest
end
# for strided matrices, we explicitly loop over the arrays to improve cache locality
# This fuses the copytrito! and triu/l operations
function copyto_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix})
isunit = U isa UnitUpperTriangular
for col in axes(dest,2)
for row in 1:col-isunit
@inbounds dest[row,col] = U.data[row,col]
end
for row in col+!isunit:size(U,1)
@inbounds dest[row,col] = U[row,col]
end
end
return dest
end
function copyto_unaliased!(dest::StridedMatrix, L::LowerOrUnitLowerTriangular{<:Any, <:StridedMatrix})
isunit = L isa UnitLowerTriangular
for col in axes(dest,2)
for row in 1:col-!isunit
@inbounds dest[row,col] = L[row,col]
end
for row in col+isunit:size(L,1)
@inbounds dest[row,col] = L.data[row,col]
end
end
return dest
end

@inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
@stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta))
@inline _lscale_add!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) =
Expand Down
15 changes: 11 additions & 4 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ debug && println("Test basic type functionality")
# The following test block tries to call all methods in base/linalg/triangular.jl in order for a combination of input element types. Keep the ordering when adding code.
@testset for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}, Int)
# Begin loop for first Triangular matrix
for (t1, uplo1) in ((UpperTriangular, :U),
@testset for (t1, uplo1) in ((UpperTriangular, :U),
(UnitUpperTriangular, :U),
(LowerTriangular, :L),
(UnitLowerTriangular, :L))
Expand Down Expand Up @@ -339,8 +339,8 @@ debug && println("Test basic type functionality")
@test ((A1\A1)::t1) ≈ M1 \ M1

# Begin loop for second Triangular matrix
for elty2 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}, Int)
for (t2, uplo2) in ((UpperTriangular, :U),
@testset for elty2 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}, Int)
@testset for (t2, uplo2) in ((UpperTriangular, :U),
(UnitUpperTriangular, :U),
(LowerTriangular, :L),
(UnitLowerTriangular, :L))
Expand Down Expand Up @@ -970,7 +970,7 @@ end
end
end

@testset "arithmetic with an immutable parent" begin
@testset "immutable and non-strided parent" begin
F = FillArrays.Fill(2, (4,4))
for UT in (UnitUpperTriangular, UnitLowerTriangular)
U = UT(F)
Expand All @@ -981,6 +981,13 @@ end
for U in (UnitUpperTriangular(F), UnitLowerTriangular(F))
@test imag(F) == imag(collect(F))
end

@testset "copyto!" begin
for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
@test Matrix(T(F)) == T(F)
end
@test copyto!(zeros(eltype(F), length(F)), UpperTriangular(F)) == vec(UpperTriangular(F))
end
end

@testset "error paths" begin
Expand Down