Skip to content

Commit

Permalink
Fix definitions of == and isless
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Mar 13, 2024
1 parent d9251a7 commit d52f6a3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
12 changes: 6 additions & 6 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,17 @@ end
Base.convert(::Type{Dual}, z::Dual) = z
Base.convert(::Type{Dual}, x::Number) = Dual(x)

Base.:(==)(z::Dual, w::Dual) = value(z) == value(w)
Base.:(==)(z::Dual, x::Number) = value(z) == x
Base.:(==)(x::Number, z::Dual) = value(z) == x
Base.:(==)(z::Dual, w::Dual) = value(z) == value(w) && epsilon(z) == epsilon(w)
Base.:(==)(z::Dual, x::Number) = value(z) == x && iszero(epsilon(z))
Base.:(==)(x::Number, z::Dual) = z == x

Base.isequal(z::Dual, w::Dual) = isequal(value(z),value(w)) && isequal(epsilon(z), epsilon(w))
Base.isequal(z::Dual, x::Number) = isequal(value(z), x) && isequal(epsilon(z), zero(x))
Base.isequal(x::Number, z::Dual) = isequal(z, x)

Base.isless(z::Dual{<:Real},w::Dual{<:Real}) = value(z) < value(w)
Base.isless(z::Real,w::Dual{<:Real}) = z < value(w)
Base.isless(z::Dual{<:Real},w::Real) = value(z) < w
Base.isless(z::Dual{<:Real},w::Dual{<:Real}) = isless(value(z), value(w)) || (isequal(value(z), value(w)) && isless(epsilon(z), epsilon(w)))
Base.isless(z::Real,w::Dual{<:Real}) = isless(z, value(w)) || (isequal(z, value(w)) && isless(zero(epsilon(w)), epsilon(w)))
Base.isless(z::Dual{<:Real},w::Real) = isless(value(z), w) || (isequal(value(z), w) && isless(epsilon(z), zero(epsilon(z))))

Check warning on line 183 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L183

Added line #L183 was not covered by tests

Base.hash(z::Dual) = (x = hash(value(z)); epsilon(z)==0 ? x : bitmix(x,hash(epsilon(z))))

Expand Down
18 changes: 12 additions & 6 deletions test/automatic_differentiation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,26 @@ powwrap(z, n, epspart=0) = Dual(z, epspart)^n
@test powwrap(1, -1) == powwrap(1.0, -1) # special case is handled
@test powwrap(1, -2) == powwrap(1.0, -2) # special case is handled
@test powwrap(1, -123) == powwrap(1.0, -123) # special case is handled
@test powwrap(1, 0) == Dual(1, 1)
@test powwrap(123, 0) == Dual(1, 1)
@test powwrap(1, 0) == Dual(1, 0)
@test powwrap(1, 0) != Dual(1, 1)
@test powwrap(123, 0) == Dual(1, 0)
@test powwrap(123, 0) != Dual(1, 1)
for i -3:3
@test powwrap(1, i) == Dual(1, i)
@test powwrap(1, i) == Dual(1, 0)
@test i == 0 || (powwrap(1, i) != Dual(1, i))
end

# this no longer throws 1/0 DomainError
@test powwrap(0, Dual(0, 1)) == Dual(1, 0)
@test powwrap(0, Dual(0, 1)) == Dual(1, -Inf)
@test powwrap(0, Dual(0, 1)) != Dual(1, 0)
# this never did DomainError because it starts off with a float
@test 0.0^Dual(0, 1) == Dual(1.0, NaN)
@test 0.0^Dual(0, 1) == Dual(1.0, -Inf)
@test 0.0^Dual(0, 1) != Dual(1.0, NaN)
# and Dual^Dual uses a log and is now type stable
# because the log promotes ints to floats for all values
@test typeof(value(powwrap(0, Dual(0, 1)))) == Float64
@test Dual(0, 1)^Dual(0, 1) == Dual(1, 0)
@test Dual(0, 1)^Dual(0, 1) == Dual(1, -Inf)
@test Dual(0, 1)^Dual(0, 1) != Dual(1, 0)

y = Dual(2.0, 1)^UInt64(0)
@test !isnan(epsilon(y))
Expand Down

0 comments on commit d52f6a3

Please sign in to comment.