From bd9ddc841787000f4ae71bb914dbe399ee6f2b8c Mon Sep 17 00:00:00 2001 From: Sacha Verweij Date: Sat, 28 Oct 2017 11:57:16 -0700 Subject: [PATCH] Fix and specialize equality comparison with UniformScaling. --- base/linalg/uniformscaling.jl | 18 ++++++++++++++++++ test/linalg/uniformscaling.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/base/linalg/uniformscaling.jl b/base/linalg/uniformscaling.jl index cc42919650a05..2b0fe631475e6 100644 --- a/base/linalg/uniformscaling.jl +++ b/base/linalg/uniformscaling.jl @@ -211,6 +211,24 @@ broadcast(::typeof(/), J::UniformScaling,x::Number) = UniformScaling(J.λ/x) ==(J1::UniformScaling,J2::UniformScaling) = (J1.λ == J2.λ) +## equality comparison with UniformScaling +==(J::UniformScaling, A::AbstractMatrix) = A == J +function ==(A::AbstractMatrix, J::UniformScaling) + size(A, 1) == size(A, 2) || return false + iszero(J.λ) && return iszero(A) + isone(J.λ) && return isone(A) + return A == J.λ*one(A) +end +function ==(A::StridedMatrix, J::UniformScaling) + size(A, 1) == size(A, 2) || return false + iszero(J.λ) && return iszero(A) + isone(J.λ) && return isone(A) + for j in indices(A, 2), i in indices(A, 1) + ifelse(i == j, A[i, j] == J.λ, iszero(A[i, j])) || return false + end + return true +end + function isapprox(J1::UniformScaling{T}, J2::UniformScaling{S}; atol::Real=0, rtol::Real=Base.rtoldefault(T,S,atol), nans::Bool=false) where {T<:Number,S<:Number} isapprox(J1.λ, J2.λ, rtol=rtol, atol=atol, nans=nans) diff --git a/test/linalg/uniformscaling.jl b/test/linalg/uniformscaling.jl index d3ce6677b4be9..1d92a57e38bcf 100644 --- a/test/linalg/uniformscaling.jl +++ b/test/linalg/uniformscaling.jl @@ -188,3 +188,30 @@ end @test_throws LinAlg.PosDefException chol(-λ*I) end end + +@testset "equality comparison of matrices with UniformScaling" begin + # AbstractMatrix methods + diagI = Diagonal(fill(1, 3)) + rdiagI = view(diagI, 1:2, 1:3) + bidiag = Bidiagonal(fill(2, 3), fill(2, 2), :U) + @test diagI == I == diagI # test isone(I) path / equality + @test 2diagI != I != 2diagI # test isone(I) path / inequality + @test 0diagI == 0I == 0diagI # test iszero(I) path / equality + @test 2diagI != 0I != 2diagI # test iszero(I) path / inequality + @test 2diagI == 2I == 2diagI # test generic path / equality + @test 0diagI != 2I != 0diagI # test generic path / inequality on diag + @test bidiag != 2I != bidiag # test generic path / inequality off diag + @test rdiagI != I != rdiagI # test square matrix check + # StridedMatrix specialization + denseI = eye(3) + rdenseI = eye(3, 4) + alltwos = fill(2, (3, 3)) + @test denseI == I == denseI # test isone(I) path / equality + @test 2denseI != I != 2denseI # test isone(I) path / inequality + @test 0denseI == 0I == 0denseI # test iszero(I) path / equality + @test 2denseI != 0I != 2denseI # test iszero(I) path / inequality + @test 2denseI == 2I == 2denseI # test generic path / equality + @test 0denseI != 2I != 0denseI # test generic path / inequality on diag + @test alltwos != 2I != alltwos # test generic path / inequality off diag + @test rdenseI != I != rdenseI # test square matrix check +end