Skip to content

Commit

Permalink
Make svdvals(::Matrix{<:Complex}) type inferrable (#22443)
Browse files Browse the repository at this point in the history
* Make svdvals(::Matrix{<:Complex}) type inferrable

Ensure that svdvals(zeros(Complex128,0,0)) returns a complex real matrix
to avoid type instability.  Also add some simplistic but explicit tests
for svdvals and svdfact, including ensuring this case is inferred.

* Add unitarity test for SVD U and Vt parts.

* Use \approxeq in tests instead of clobbering \approx
  • Loading branch information
c42f authored and andreasnoack committed Jun 21, 2017
1 parent 236e486 commit 2de5dab
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion base/linalg/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ end
Returns the singular values of `A`, saving space by overwriting the input.
See also [`svdvals`](@ref).
"""
svdvals!(A::StridedMatrix{T}) where {T<:BlasFloat} = findfirst(size(A), 0) > 0 ? zeros(T, 0) : LAPACK.gesdd!('N', A)[2]
svdvals!(A::StridedMatrix{T}) where {T<:BlasFloat} = isempty(A) ? zeros(real(T), 0) : LAPACK.gesdd!('N', A)[2]
svdvals(A::AbstractMatrix{<:BlasFloat}) = svdvals!(copy(A))

"""
Expand Down
25 changes: 25 additions & 0 deletions test/linalg/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@ using Base.Test

using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted

@testset "Simple svdvals / svdfact tests" begin
(x,y) = isapprox(x,y,rtol=1e-15)

m1 = [2 0; 0 0]
m2 = [2 -2; 1 1]/sqrt(2)
m2c = Complex.([2 -2; 1 1]/sqrt(2))
@test @inferred(svdvals(m1)) [2, 0]
@test @inferred(svdvals(m2)) [2, 1]
@test @inferred(svdvals(m2c)) [2, 1]

sf1 = svdfact(m1)
sf2 = svdfact(m2)
@test sf1.S [2, 0]
@test sf2.S [2, 1]
# U & Vt are unitary
@test sf1.U*sf1.U' eye(2)
@test sf1.Vt*sf1.Vt' eye(2)
@test sf2.U*sf2.U' eye(2)
@test sf2.Vt*sf2.Vt' eye(2)
# SVD not uniquely determined, so just test we can reconstruct the
# matrices from the factorization as expected.
@test sf1.U*Diagonal(sf1.S)*sf1.Vt' m1
@test sf2.U*Diagonal(sf2.S)*sf2.Vt' m2
end

n = 10

# Split n into 2 parts for tests needing two matrices
Expand Down

0 comments on commit 2de5dab

Please sign in to comment.