diff --git a/Project.toml b/Project.toml index 74aa3da..e6bdb26 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,12 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +[weakdeps] +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[extensions] +DualNumbersSpecialFunctionsExt = "SpecialFunctions" + [compat] Calculus = "0.5" NaNMath = "0.3, 1" @@ -15,7 +21,8 @@ julia = "0.7, 1.0" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["LinearAlgebra", "Test"] +test = ["LinearAlgebra", "Test", "SpecialFunctions"] diff --git a/ext/DualNumbersSpecialFunctionsExt.jl b/ext/DualNumbersSpecialFunctionsExt.jl new file mode 100644 index 0000000..750393b --- /dev/null +++ b/ext/DualNumbersSpecialFunctionsExt.jl @@ -0,0 +1,29 @@ +module DualNumbersSpecialFunctionsExt + +using DualNumbers +using DualNumbers: value, epsilon, to_nanmath +using NaNMath +using Calculus +using SpecialFunctions + +for (funsym, expr) in Calculus.symbolic_derivatives_1arg() + if isdefined(SpecialFunctions, funsym) && !isdefined(Base, funsym) + @eval function SpecialFunctions.$(funsym)(z::Dual) + x = value(z) + xp = epsilon(z) + Dual($(funsym)(x),xp*$expr) + end + end + # extend corresponding NaNMath methods + if funsym in (:lgamma,) + funsym = Expr(:.,:NaNMath,Base.Meta.quot(funsym)) + @eval function $(funsym)(z::Dual) + x = value(z) + xp = epsilon(z) + Dual($(funsym)(x),xp*$(to_nanmath(expr))) + end + end +end + + +end diff --git a/src/DualNumbers.jl b/src/DualNumbers.jl index 53e8c53..99642b4 100644 --- a/src/DualNumbers.jl +++ b/src/DualNumbers.jl @@ -1,6 +1,5 @@ module DualNumbers -using SpecialFunctions import NaNMath import Calculus @@ -26,4 +25,8 @@ export ɛ, imɛ +if !isdefined(Base, :get_extension) + include("../ext/DualNumbersSpecialFunctionsExt.jl") +end + end # module diff --git a/src/dual.jl b/src/dual.jl index 385c172..b84f531 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -303,31 +303,25 @@ to_nanmath(x) = x -for (funsym, exp) in Calculus.symbolic_derivatives_1arg() +for (funsym, expr) in Calculus.symbolic_derivatives_1arg() funsym == :exp && continue funsym == :abs2 && continue funsym == :inv && continue - if isdefined(SpecialFunctions, funsym) - @eval function SpecialFunctions.$(funsym)(z::Dual) - x = value(z) - xp = epsilon(z) - Dual($(funsym)(x),xp*$exp) - end - elseif isdefined(Base, funsym) + if isdefined(Base, funsym) @eval function Base.$(funsym)(z::Dual) x = value(z) xp = epsilon(z) - Dual($(funsym)(x),xp*$exp) + Dual($(funsym)(x),xp*$expr) end end # extend corresponding NaNMath methods if funsym in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10, - :lgamma, :log1p) + :log1p) funsym = Expr(:.,:NaNMath,Base.Meta.quot(funsym)) @eval function $(funsym)(z::Dual) x = value(z) xp = epsilon(z) - Dual($(funsym)(x),xp*$(to_nanmath(exp))) + Dual($(funsym)(x),xp*$(to_nanmath(expr))) end end end