From 8cb762b223e8f9c0aaa5498acf4f78b237febc3b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 16 May 2024 23:20:34 +0530 Subject: [PATCH] Specialize `copyto!` for triangular matrices (#52730) This provides a performance boost in copying a triangular matrix to a `StridedMatrix`, which is a common operation (e.g. in broadcasting or in `Matrix(::UpperTriangular)`). The main improvement is improved cache locality for strided triangular matrices by fusing the loops. On master ```julia julia> U = UpperTriangular(rand(4000,4000)); julia> @btime Matrix($U); 64.649 ms (3 allocations: 122.07 MiB) ``` This PR ```julia julia> @btime Matrix($U); 48.332 ms (3 allocations: 122.07 MiB) ``` --- stdlib/LinearAlgebra/src/triangular.jl | 83 +++++++++++++++---------- stdlib/LinearAlgebra/test/triangular.jl | 15 +++-- 2 files changed, 62 insertions(+), 36 deletions(-) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 153472040f6e4..9d7e7c8d4eb5d 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -185,7 +185,6 @@ function imag(A::UnitUpperTriangular) return Uim end -Array(A::AbstractTriangular) = Matrix(A) parent(A::UpperOrLowerTriangular) = A.data # For strided matrices, we may only loop over the filled triangle @@ -193,37 +192,6 @@ copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copyto!(si # then handle all methods that requires specific handling of upper/lower and unit diagonal -function Matrix{T}(A::LowerTriangular) where T - B = Matrix{T}(undef, size(A, 1), size(A, 1)) - copyto!(B, A.data) - tril!(B) - B -end -function Matrix{T}(A::UnitLowerTriangular) where T - B = Matrix{T}(undef, size(A, 1), size(A, 1)) - copyto!(B, A.data) - tril!(B) - for i = 1:size(B,1) - B[i,i] = oneunit(T) - end - B -end -function Matrix{T}(A::UpperTriangular) where T - B = Matrix{T}(undef, size(A, 1), size(A, 1)) - copyto!(B, A.data) - triu!(B) - B -end -function Matrix{T}(A::UnitUpperTriangular) where T - B = Matrix{T}(undef, size(A, 1), size(A, 1)) - copyto!(B, A.data) - triu!(B) - for i = 1:size(B,1) - B[i,i] = oneunit(T) - end - B -end - function full!(A::LowerTriangular) B = A.data tril!(B) @@ -544,6 +512,57 @@ function copyto!(A::T, B::T) where {T<:Union{LowerTriangular,UnitLowerTriangular return A end +_triangularize!(::UpperOrUnitUpperTriangular) = triu! +_triangularize!(::LowerOrUnitLowerTriangular) = tril! + +function copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular) + if axes(dest) != axes(U) + @invoke copyto!(dest::StridedMatrix, U::AbstractArray) + else + _copyto!(dest, U) + end + return dest +end +function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular) + copytrito!(dest, parent(U), U isa UpperOrUnitUpperTriangular ? 'U' : 'L') + _triangularize!(U)(dest) + if U isa Union{UnitUpperTriangular, UnitLowerTriangular} + dest[diagind(dest)] .= @view U[diagind(U, IndexCartesian())] + end + return dest +end +function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix}) + U2 = Base.unalias(dest, U) + copyto_unaliased!(dest, U2) + return dest +end +# for strided matrices, we explicitly loop over the arrays to improve cache locality +# This fuses the copytrito! and triu/l operations +function copyto_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix}) + isunit = U isa UnitUpperTriangular + for col in axes(dest,2) + for row in 1:col-isunit + @inbounds dest[row,col] = U.data[row,col] + end + for row in col+!isunit:size(U,1) + @inbounds dest[row,col] = U[row,col] + end + end + return dest +end +function copyto_unaliased!(dest::StridedMatrix, L::LowerOrUnitLowerTriangular{<:Any, <:StridedMatrix}) + isunit = L isa UnitLowerTriangular + for col in axes(dest,2) + for row in 1:col-!isunit + @inbounds dest[row,col] = L[row,col] + end + for row in col+isunit:size(L,1) + @inbounds dest[row,col] = L.data[row,col] + end + end + return dest +end + @inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) = @stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta)) @inline _lscale_add!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) = diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index c793c5a3e9924..cf08d23e8cf92 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -28,7 +28,7 @@ debug && println("Test basic type functionality") # The following test block tries to call all methods in base/linalg/triangular.jl in order for a combination of input element types. Keep the ordering when adding code. @testset for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}, Int) # Begin loop for first Triangular matrix - for (t1, uplo1) in ((UpperTriangular, :U), + @testset for (t1, uplo1) in ((UpperTriangular, :U), (UnitUpperTriangular, :U), (LowerTriangular, :L), (UnitLowerTriangular, :L)) @@ -339,8 +339,8 @@ debug && println("Test basic type functionality") @test ((A1\A1)::t1) ≈ M1 \ M1 # Begin loop for second Triangular matrix - for elty2 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}, Int) - for (t2, uplo2) in ((UpperTriangular, :U), + @testset for elty2 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}, Int) + @testset for (t2, uplo2) in ((UpperTriangular, :U), (UnitUpperTriangular, :U), (LowerTriangular, :L), (UnitLowerTriangular, :L)) @@ -970,7 +970,7 @@ end end end -@testset "arithmetic with an immutable parent" begin +@testset "immutable and non-strided parent" begin F = FillArrays.Fill(2, (4,4)) for UT in (UnitUpperTriangular, UnitLowerTriangular) U = UT(F) @@ -981,6 +981,13 @@ end for U in (UnitUpperTriangular(F), UnitLowerTriangular(F)) @test imag(F) == imag(collect(F)) end + + @testset "copyto!" begin + for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @test Matrix(T(F)) == T(F) + end + @test copyto!(zeros(eltype(F), length(F)), UpperTriangular(F)) == vec(UpperTriangular(F)) + end end @testset "error paths" begin