From 613d4b97ddcda4bb42deee41a43db49f5a236346 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 19 Aug 2024 19:29:44 +0530 Subject: [PATCH] Fix tr for Symmetric/Hermitian block matrices (#55522) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since `Symmetric` and `Hermitian` symmetrize the diagonal elements of the parent, we can't forward `tr` to the parent unless it is already symmetric. This limits the existing `tr` methods to matrices of `Number`s, which is the common use-case. `tr` for `Symmetric` block matrices would now use the fallback implementation that explicitly computes the `diag`. This resolves the following discrepancy: ```julia julia> S = Symmetric(fill([1 2; 3 4], 3, 3)) 3×3 Symmetric{AbstractMatrix, Matrix{Matrix{Int64}}}: [1 2; 2 4] [1 2; 3 4] [1 2; 3 4] [1 3; 2 4] [1 2; 2 4] [1 2; 3 4] [1 3; 2 4] [1 3; 2 4] [1 2; 2 4] julia> tr(S) 2×2 Matrix{Int64}: 3 6 9 12 julia> sum(diag(S)) 2×2 Symmetric{Int64, Matrix{Int64}}: 3 6 6 12 ``` --- stdlib/LinearAlgebra/src/symmetric.jl | 4 ++-- stdlib/LinearAlgebra/test/symmetric.jl | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 55630595f6fb2f..c3367857925889 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -449,8 +449,8 @@ Base.copy(A::Adjoint{<:Any,<:Symmetric}) = Base.copy(A::Transpose{<:Any,<:Hermitian}) = Hermitian(copy(transpose(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U)) -tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations) -tr(A::Hermitian) = real(tr(A.data)) +tr(A::Symmetric{<:Number}) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations) +tr(A::Hermitian{<:Number}) = real(tr(A.data)) Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo)) Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo)) diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 89e9ca0d6a51d6..5f1293ab2cdd7d 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -1116,4 +1116,15 @@ end end end +@testset "tr for block matrices" begin + m = [1 2; 3 4] + for b in (m, m * (1 + im)) + M = fill(b, 3, 3) + for ST in (Symmetric, Hermitian) + S = ST(M) + @test tr(S) == sum(diag(S)) + end + end +end + end # module TestSymmetric