Skip to content

Commit

Permalink
Merge pull request #24380 from Sacha0/equalsI
Browse files Browse the repository at this point in the history
fix and specialize equality comparison with UniformScaling
  • Loading branch information
fredrikekre authored Oct 30, 2017
2 parents fb46fe8 + bd9ddc8 commit 75d95f9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
18 changes: 18 additions & 0 deletions base/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,24 @@ broadcast(::typeof(/), J::UniformScaling,x::Number) = UniformScaling(J.λ/x)

==(J1::UniformScaling,J2::UniformScaling) = (J1.λ == J2.λ)

## equality comparison with UniformScaling
==(J::UniformScaling, A::AbstractMatrix) = A == J
function ==(A::AbstractMatrix, J::UniformScaling)
size(A, 1) == size(A, 2) || return false
iszero(J.λ) && return iszero(A)
isone(J.λ) && return isone(A)
return A == J.λ*one(A)
end
function ==(A::StridedMatrix, J::UniformScaling)
size(A, 1) == size(A, 2) || return false
iszero(J.λ) && return iszero(A)
isone(J.λ) && return isone(A)
for j in indices(A, 2), i in indices(A, 1)
ifelse(i == j, A[i, j] == J.λ, iszero(A[i, j])) || return false
end
return true
end

function isapprox(J1::UniformScaling{T}, J2::UniformScaling{S};
atol::Real=0, rtol::Real=Base.rtoldefault(T,S,atol), nans::Bool=false) where {T<:Number,S<:Number}
isapprox(J1.λ, J2.λ, rtol=rtol, atol=atol, nans=nans)
Expand Down
27 changes: 27 additions & 0 deletions test/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,30 @@ end
@test_throws LinAlg.PosDefException chol(-λ*I)
end
end

@testset "equality comparison of matrices with UniformScaling" begin
# AbstractMatrix methods
diagI = Diagonal(fill(1, 3))
rdiagI = view(diagI, 1:2, 1:3)
bidiag = Bidiagonal(fill(2, 3), fill(2, 2), :U)
@test diagI == I == diagI # test isone(I) path / equality
@test 2diagI != I != 2diagI # test isone(I) path / inequality
@test 0diagI == 0I == 0diagI # test iszero(I) path / equality
@test 2diagI != 0I != 2diagI # test iszero(I) path / inequality
@test 2diagI == 2I == 2diagI # test generic path / equality
@test 0diagI != 2I != 0diagI # test generic path / inequality on diag
@test bidiag != 2I != bidiag # test generic path / inequality off diag
@test rdiagI != I != rdiagI # test square matrix check
# StridedMatrix specialization
denseI = eye(3)
rdenseI = eye(3, 4)
alltwos = fill(2, (3, 3))
@test denseI == I == denseI # test isone(I) path / equality
@test 2denseI != I != 2denseI # test isone(I) path / inequality
@test 0denseI == 0I == 0denseI # test iszero(I) path / equality
@test 2denseI != 0I != 2denseI # test iszero(I) path / inequality
@test 2denseI == 2I == 2denseI # test generic path / equality
@test 0denseI != 2I != 0denseI # test generic path / inequality on diag
@test alltwos != 2I != alltwos # test generic path / inequality off diag
@test rdenseI != I != rdenseI # test square matrix check
end

0 comments on commit 75d95f9

Please sign in to comment.