Skip to content

Commit

Permalink
Using a strong zero in Dual(0.0, 1)^0 to avoid NaN (#84)
Browse files Browse the repository at this point in the history
* test and fix for nan epsilon of dual^UInt64(0)

* "0.6.3" -> "0.6.4"

* removed NaNs, codified behaviour with Ints

* improve testing, remove type instability

* handle Dual(Integer(1), n)^Integer

* version semver bump 0.6.4 -> 0.6.5
  • Loading branch information
jwscook committed Apr 3, 2021
1 parent a1128cf commit 4603cc1
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DualNumbers"
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
version = "0.6.4"
version = "0.6.5"

[deps]
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
Expand Down
37 changes: 24 additions & 13 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,29 +245,40 @@ Base.:/(z::Number, w::Dual) = Dual(z/value(w), -z*epsilon(w)/value(w)^2)
Base.:/(z::Dual, x::Number) = Dual(value(z)/x, epsilon(z)/x)

for f in [:(Base.:^), :(NaNMath.pow)]
@eval function ($f)(z::Dual, w::Dual)
if epsilon(w) == 0.0
return $f(z, value(w))
end
@eval function ($f)(z::Dual{T1}, w::Dual{T2}) where {T1, T2}
T = promote_type(T1, T2) # for type stability in ? : statements
val = $f(value(z), value(w))

du = epsilon(z) * value(w) * $f(value(z), value(w) - 1) +
epsilon(w) * $f(value(z), value(w)) * log(value(z))
ezvw = epsilon(z) * value(w) # for using in ? : statement
du1 = iszero(ezvw) ? zero(T) : ezvw * $f(value(z), value(w) - 1)
ew = epsilon(w) # for using in ? : statement
# the float is for type stability because log promotes to floats
du2 = iszero(ew) ? zero(float(T)) : ew * val * log(value(z))
du = du1 + du2

Dual(val, du)
end
end

Base.mod(z::Dual, n::Number) = Dual(mod(value(z), n), epsilon(z))

# these two definitions are needed to fix ambiguity warnings
Base.:^(z::Dual, n::Unsigned) = z^Signed(n)
Base.:^(z::Dual, n::Integer) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
Base.:^(z::Dual, n::Rational) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
# introduce a boolean !iszero(n) for hard zero behaviour to combat NaNs
function pow(z::Dual, n::AbstractFloat)
return Dual(value(z)^n, !iszero(n) * (epsilon(z) * n * value(z)^(n - 1)))
end
function pow(z::Dual{T}, n::Integer) where T
iszero(n) && return Dual(one(T), zero(T)) # avoid DomainError Int^(negative Int)
isone(z) && return Dual(one(T), epsilon(z) * n)
return Dual(value(z)^n, epsilon(z) * n * value(z)^(n - 1))
end
# these first two definitions are needed to fix ambiguity warnings
for T1 (:Integer, :Rational, :Number)
@eval Base.:^(z::Dual{T}, n::$T1) where T = pow(z, n)
end


Base.:^(z::Dual, n::Number) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
NaNMath.pow(z::Dual, n::Number) = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
NaNMath.pow(z::Number, w::Dual) = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))
NaNMath.pow(z::Dual{T}, n::Number) where T = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
NaNMath.pow(z::Number, w::Dual{T}) where T = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))

Base.inv(z::Dual) = dual(inv(value(z)),-epsilon(z)/value(z)^2)

Expand Down
44 changes: 44 additions & 0 deletions test/automatic_differentiation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,52 @@ y = x^3.0
@test value(y) 2.0^3
@test epsilon(y) 3.0*2^2

# taking care with divides by zero where there shouldn't be any on paper
for (y, n) Iterators.product((float(x), Dual(0.0, 1)), (0, 0.0))
z = y^n
@test value(z) == 1
@test !isnan(epsilon(z))
@test epsilon(z) == 0
end

# acting on floats works as expected
for (y, n) ((float(x), Dual(0.0, 1)), -1:1)
@test float(y)^n == float(y)^float(n)
end

@test !isnan(epsilon(Dual(0, 1)^1))
@test Dual(0, 1)^1 == Dual(0, 1)

# power_by_squaring error for integers
# needs to be wrapped to make n a literal
powwrap(z, n, epspart=0) = Dual(z, epspart)^n
@test_throws DomainError powwrap(0, -1)
@test_throws DomainError powwrap(2, -1)
@test_throws DomainError powwrap(123, -1) # etc
# these ones don't DomainError
@test powwrap(0, 0, 0) == Dual(1, 0) # special case is handled
@test powwrap(0, 0, 1) == Dual(1, 0) # special case is handled
@test powwrap(1, -1) == powwrap(1.0, -1) # special case is handled
@test powwrap(1, -2) == powwrap(1.0, -2) # special case is handled
@test powwrap(1, -123) == powwrap(1.0, -123) # special case is handled
@test powwrap(1, 0) == Dual(1, 1)
@test powwrap(123, 0) == Dual(1, 1)
for i -3:3
@test powwrap(1, i) == Dual(1, i)
end

# this no longer throws 1/0 DomainError
@test powwrap(0, Dual(0, 1)) == Dual(1, 0)
# this never did DomainError because it starts off with a float
@test 0.0^Dual(0, 1) == Dual(1.0, NaN)
# and Dual^Dual uses a log and is now type stable
# because the log promotes ints to floats for all values
@test typeof(value(powwrap(0, Dual(0, 1)))) == Float64
@test Dual(0, 1)^Dual(0, 1) == Dual(1, 0)

y = Dual(2.0, 1)^UInt64(0)
@test !isnan(epsilon(y))
@test epsilon(y) == 0

y = sin(x)+exp(x)
@test value(y) sin(2)+exp(2)
Expand Down

2 comments on commit 4603cc1

@dlfivefifty
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33459

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.5 -m "<description of version>" 4603cc1e07dc6363d394875046bca74cd7ce6fc6
git push origin v0.6.5

Please sign in to comment.