Skip to content

Commit d1d24f7

Browse files
dev10110dkarrasch
authored andcommitted
Add a fast method for diag of Cholesky matrices (JuliaLang#53767)
Co-authored-by: Daniel Karrasch <[email protected]>
1 parent 753e737 commit d1d24f7

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

stdlib/LinearAlgebra/src/cholesky.jl

+27
Original file line numberDiff line numberDiff line change
@@ -860,3 +860,30 @@ then `CC = cholesky(C.U'C.U - v*v')` but the computation of `CC` only uses
860860
`O(n^2)` operations.
861861
"""
862862
lowrankdowndate(C::Cholesky, v::AbstractVector) = lowrankdowndate!(copy(C), copy(v))
863+
864+
function diag(C::Cholesky{T}, k::Int = 0) where {T}
865+
N = size(C, 1)
866+
absk = abs(k)
867+
iabsk = N - absk
868+
z = Vector{T}(undef, iabsk)
869+
UL = C.factors
870+
if C.uplo == 'U'
871+
for i in 1:iabsk
872+
z[i] = zero(T)
873+
for j in 1:min(i, i+absk)
874+
z[i] += UL[j, i]'UL[j, i+absk]
875+
end
876+
end
877+
else
878+
for i in 1:iabsk
879+
z[i] = zero(T)
880+
for j in 1:min(i, i+absk)
881+
z[i] += UL[i, j]*UL[i+absk, j]'
882+
end
883+
end
884+
end
885+
if !(T <: Real) && k < 0
886+
z .= adjoint.(z)
887+
end
888+
return z
889+
end

stdlib/LinearAlgebra/src/generic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,7 @@ julia> tr(A)
10861086
5
10871087
```
10881088
"""
1089-
function tr(A::AbstractMatrix)
1089+
function tr(A)
10901090
checksquare(A)
10911091
sum(diag(A))
10921092
end

stdlib/LinearAlgebra/test/cholesky.jl

+10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ function unary_ops_tests(a, ca, tol; n=size(a, 1))
2020
@test_throws ErrorException ca.Z
2121
@test size(ca) == size(a)
2222
@test Array(copy(ca)) a
23+
@test tr(ca) tr(a) skip=ca isa CholeskyPivoted
2324
end
2425

2526
function factor_recreation_tests(a_U, a_L)
@@ -561,4 +562,13 @@ end
561562
end
562563
end
563564

565+
@testset "diag" begin
566+
for T in (Float64, ComplexF64), k in (0, 1, -3), uplo in (:U, :L)
567+
A = randn(T, 100, 100)
568+
P = Hermitian(A' * A, uplo)
569+
C = cholesky(P)
570+
@test diag(P, k) diag(C, k)
571+
end
572+
end
573+
564574
end # module TestCholesky

0 commit comments

Comments
 (0)