Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using a strong zero in Dual(0.0, 1)^0 to avoid NaN #84

Merged
merged 8 commits into from
Apr 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DualNumbers"
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
version = "0.6.4"
version = "0.6.5"

[deps]
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
Expand Down
37 changes: 24 additions & 13 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,29 +245,40 @@ Base.:/(z::Number, w::Dual) = Dual(z/value(w), -z*epsilon(w)/value(w)^2)
Base.:/(z::Dual, x::Number) = Dual(value(z)/x, epsilon(z)/x)

for f in [:(Base.:^), :(NaNMath.pow)]
@eval function ($f)(z::Dual, w::Dual)
if epsilon(w) == 0.0
return $f(z, value(w))
end
@eval function ($f)(z::Dual{T1}, w::Dual{T2}) where {T1, T2}
T = promote_type(T1, T2) # for type stability in ? : statements
val = $f(value(z), value(w))

du = epsilon(z) * value(w) * $f(value(z), value(w) - 1) +
epsilon(w) * $f(value(z), value(w)) * log(value(z))
ezvw = epsilon(z) * value(w) # for using in ? : statement
du1 = iszero(ezvw) ? zero(T) : ezvw * $f(value(z), value(w) - 1)
ew = epsilon(w) # for using in ? : statement
# the float is for type stability because log promotes to floats
du2 = iszero(ew) ? zero(float(T)) : ew * val * log(value(z))
du = du1 + du2

Dual(val, du)
end
end

Base.mod(z::Dual, n::Number) = Dual(mod(value(z), n), epsilon(z))

# these two definitions are needed to fix ambiguity warnings
Base.:^(z::Dual, n::Unsigned) = z^Signed(n)
Base.:^(z::Dual, n::Integer) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
Base.:^(z::Dual, n::Rational) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
# introduce a boolean !iszero(n) for hard zero behaviour to combat NaNs
function pow(z::Dual, n::AbstractFloat)
return Dual(value(z)^n, !iszero(n) * (epsilon(z) * n * value(z)^(n - 1)))
end
function pow(z::Dual{T}, n::Integer) where T
iszero(n) && return Dual(one(T), zero(T)) # avoid DomainError Int^(negative Int)
isone(z) && return Dual(one(T), epsilon(z) * n)
return Dual(value(z)^n, epsilon(z) * n * value(z)^(n - 1))
end
# these first two definitions are needed to fix ambiguity warnings
for T1 ∈ (:Integer, :Rational, :Number)
@eval Base.:^(z::Dual{T}, n::$T1) where T = pow(z, n)
end


Base.:^(z::Dual, n::Number) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
NaNMath.pow(z::Dual, n::Number) = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
NaNMath.pow(z::Number, w::Dual) = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))
NaNMath.pow(z::Dual{T}, n::Number) where T = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
NaNMath.pow(z::Number, w::Dual{T}) where T = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))

Base.inv(z::Dual) = dual(inv(value(z)),-epsilon(z)/value(z)^2)

Expand Down
44 changes: 44 additions & 0 deletions test/automatic_differentiation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,52 @@ y = x^3.0
@test value(y) ≈ 2.0^3
@test epsilon(y) ≈ 3.0*2^2

# taking care with divides by zero where there shouldn't be any on paper
for (y, n) ∈ Iterators.product((float(x), Dual(0.0, 1)), (0, 0.0))
z = y^n
@test value(z) == 1
@test !isnan(epsilon(z))
@test epsilon(z) == 0
end

# acting on floats works as expected
for (y, n) ∈ ((float(x), Dual(0.0, 1)), -1:1)
@test float(y)^n == float(y)^float(n)
end

@test !isnan(epsilon(Dual(0, 1)^1))
@test Dual(0, 1)^1 == Dual(0, 1)

# power_by_squaring error for integers
# needs to be wrapped to make n a literal
powwrap(z, n, epspart=0) = Dual(z, epspart)^n
@test_throws DomainError powwrap(0, -1)
@test_throws DomainError powwrap(2, -1)
@test_throws DomainError powwrap(123, -1) # etc
# these ones don't DomainError
@test powwrap(0, 0, 0) == Dual(1, 0) # special case is handled
@test powwrap(0, 0, 1) == Dual(1, 0) # special case is handled
@test powwrap(1, -1) == powwrap(1.0, -1) # special case is handled
@test powwrap(1, -2) == powwrap(1.0, -2) # special case is handled
@test powwrap(1, -123) == powwrap(1.0, -123) # special case is handled
@test powwrap(1, 0) == Dual(1, 1)
@test powwrap(123, 0) == Dual(1, 1)
for i ∈ -3:3
@test powwrap(1, i) == Dual(1, i)
end

# this no longer throws 1/0 DomainError
@test powwrap(0, Dual(0, 1)) == Dual(1, 0)
# this never did DomainError because it starts off with a float
@test 0.0^Dual(0, 1) == Dual(1.0, NaN)
# and Dual^Dual uses a log and is now type stable
# because the log promotes ints to floats for all values
@test typeof(value(powwrap(0, Dual(0, 1)))) == Float64
@test Dual(0, 1)^Dual(0, 1) == Dual(1, 0)

y = Dual(2.0, 1)^UInt64(0)
@test !isnan(epsilon(y))
@test epsilon(y) == 0

y = sin(x)+exp(x)
@test value(y) ≈ sin(2)+exp(2)
Expand Down