Skip to content

Commit

Permalink
Override SpecialFunctions (#65)
Browse files Browse the repository at this point in the history
* Add flipsign

* Fix bug in inv(::Dual)

* Fix deprecations

* REQUIRE Julia v0.7
  • Loading branch information
dlfivefifty authored and mlubin committed Sep 4, 2018
1 parent f9747e3 commit c77ee18
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ os:
- linux
- osx
julia:
- 0.6
- 0.7
- 1.0
- nightly
notifications:
email: false
Expand Down
3 changes: 2 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
julia 0.6
julia 0.7
Compat 0.49.0
Calculus
NaNMath
SpecialFunctions 0.7
4 changes: 1 addition & 3 deletions src/DualNumbers.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
__precompile__()

module DualNumbers

using Compat
using Compat, SpecialFunctions
import NaNMath
import Calculus

Expand Down
31 changes: 22 additions & 9 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,6 @@ end

Base.show(io::IO, z::Dual) = dual_show(io, z, get(IOContext(io), :compact, false))

@static if VERSION < v"0.7.0-DEV.4524"
Base.showcompact(io::IO, z::Dual) = dual_show(io, z, true)
end

function Base.read(s::IO, ::Type{Dual{T}}) where T<:ReComp
x = read(s, T)
y = read(s, T)
Expand Down Expand Up @@ -212,6 +208,10 @@ function Base.angle(z::Dual{Complex{T}}) where T<:Real
end
end

Base.flipsign(x::Dual,y::Dual) = y == 0 ? flipsign(x, epsilon(y)) : flipsign(x, value(y))
Base.flipsign(x, y::Dual) = y == 0 ? flipsign(x, epsilon(y)) : flipsign(x, value(y))
Base.flipsign(x::Dual, y) = dual(flipsign(value(x), y), flipsign(epsilon(x), y))

# algebraic definitions
conjdual(z::Dual) = Dual(value(z),-epsilon(z))
absdual(z::Dual) = abs(value(z))
Expand Down Expand Up @@ -264,6 +264,8 @@ Base.:^(z::Dual, n::Number) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
NaNMath.pow(z::Dual, n::Number) = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
NaNMath.pow(z::Number, w::Dual) = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))

inv(z::Dual) = dual(inv(value(z)),-epsilon(z)/value(z)^2)

# force use of NaNMath functions in derivative calculations
function to_nanmath(x::Expr)
if x.head == :call
Expand All @@ -275,14 +277,25 @@ function to_nanmath(x::Expr)
end
to_nanmath(x) = x




for (funsym, exp) in Calculus.symbolic_derivatives_1arg()
funsym == :exp && continue
funsym == :abs2 && continue
isdefined(Base, funsym) || continue
@eval function Base.$(funsym)(z::Dual)
x = value(z)
xp = epsilon(z)
Dual($(funsym)(x),xp*$exp)
funsym == :inv && continue
if isdefined(SpecialFunctions, funsym)
@eval function SpecialFunctions.$(funsym)(z::Dual)
x = value(z)
xp = epsilon(z)
Dual($(funsym)(x),xp*$exp)
end
elseif isdefined(Base, funsym)
@eval function Base.$(funsym)(z::Dual)
x = value(z)
xp = epsilon(z)
Dual($(funsym)(x),xp*$exp)
end
end
# extend corresponding NaNMath methods
if funsym in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
Expand Down
21 changes: 19 additions & 2 deletions test/automatic_differentiation_test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using DualNumbers
using DualNumbers, SpecialFunctions
using Compat
using Compat.Test
using Test
using Compat.LinearAlgebra
import DualNumbers: value
import NaNMath
Expand Down Expand Up @@ -106,6 +106,10 @@ a = angle(z)

@test angle(Dual(0.0+im,0.0+im)) == π/2


# check bug in inv
@test inv(dual(1.0+1.0im,1.0)) == 1/dual(1.0+1.0im,1.0) == dual(1.0+1.0im,1.0)^(-1)

#
# Tests limit definition. Let z = a + b ɛ, where a and b ∈ C.
#
Expand Down Expand Up @@ -145,3 +149,16 @@ test(x, y) = x^2 + y

@test epsilon(Dual(-2.0,1.0)^2.0) == -4
@test epsilon(Dual(-2.0,1.0)^Dual(2.0,0.0)) == -4


# test for flipsign
flipsign(Dual(1.0,1.0),2.0) == Dual(1.0,1.0)
flipsign(Dual(1.0,1.0),-2.0) == Dual(-1.0,-1.0)
flipsign(Dual(1.0,1.0),Dual(1.0,1.0)) == Dual(1.0,1.0)
flipsign(Dual(1.0,1.0),Dual(0.0,-1.0)) == Dual(-1.0,-1.0)
flipsign(-1.0,Dual(1.0,1.0)) == -1.0


# test SpecialFunctions
@test erf(dual(1.0,1.0)) == dual(erf(1.0), 2exp(-1.0^2)/sqrt(π))
@test gamma(dual(1.,1)) == dual(gamma(1.0),polygamma(0,1.0))
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using DualNumbers
using Compat
using Compat.Test
using Test

@test checkindex(Bool, 1:3, dual(2))

Expand Down

0 comments on commit c77ee18

Please sign in to comment.