From 4e6203aa2d95e378e4029e01aeb6ead0f33dba95 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 12 Oct 2023 23:00:59 +0200 Subject: [PATCH] ForwardDiff support (#17) * Add ForwardDiff extension * Better tests * Simpler test * ForwardDiff only for 1.9 and above * Fix 1.6 tests * Fix version check for Aqua formatting --- Project.toml | 10 +++++- README.md | 9 +++++ ext/LogarithmicNumbersForwardDiffExt.jl | 44 +++++++++++++++++++++++++ test/forwarddiff.jl | 20 +++++++++++ test/runtests.jl | 13 ++++++-- 5 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 ext/LogarithmicNumbersForwardDiffExt.jl create mode 100644 test/forwarddiff.jl diff --git a/Project.toml b/Project.toml index 2985b88..2f92855 100644 --- a/Project.toml +++ b/Project.toml @@ -6,12 +6,20 @@ version = "1.2.1" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +LogarithmicNumbersForwardDiffExt = "ForwardDiff" + [compat] +ForwardDiff = "0.10" julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test"] +test = ["Aqua", "ForwardDiff", "Test"] diff --git a/README.md b/README.md index 8c06376..d523e62 100644 --- a/README.md +++ b/README.md @@ -83,3 +83,12 @@ possibilities for `func`: `pdf`, `cdf`, `ccdf`. - [SpecialFunctions.jl](https://github.com/JuliaMath/SpecialFunctions.jl): `gamma`, `factorial`, `beta`, `erfc`, `erfcx`. + +## Autodiff + +On Julia >= 1.9, if you load [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), you should be allowed to compute + +- derivatives of functions involving `exp(Logarithmic, x)` +- derivatives of functions evaluated at `Logarithmic(x)` + +This functionality is experimental, please report any bug or unexpected behavior. diff --git a/ext/LogarithmicNumbersForwardDiffExt.jl b/ext/LogarithmicNumbersForwardDiffExt.jl new file mode 100644 index 0000000..75692eb --- /dev/null +++ b/ext/LogarithmicNumbersForwardDiffExt.jl @@ -0,0 +1,44 @@ +module LogarithmicNumbersForwardDiffExt + +using ForwardDiff: ForwardDiff, Dual, partials +using LogarithmicNumbers: LogarithmicNumbers, AnyLogarithmic, Logarithmic, ULogarithmic + +## Promotion rules + +function Base.promote_rule(::Type{Logarithmic{R}}, ::Type{Dual{T,V,N}}) where {R<:Real,T,V,N} + return Dual{T,promote_rule(Logarithmic{R}, V),N} +end + +function Base.promote_rule(::Type{ULogarithmic{R}}, ::Type{Dual{T,V,N}}) where {R<:Real,T,V,N} + return Dual{T,promote_rule(ULogarithmic{R}, V),N} +end + +## Constructors + +# Based on the unary_definition macro in ForwardDiff.jl (https://github.com/JuliaDiff/ForwardDiff.jl/blob/6a6443b754b0fcfb4d671c9a3d01776df801f498/src/dual.jl#L230-L244) + +function Base.exp(::Type{ULogarithmic{R}}, d::Dual{T,V,N}) where {R<:Real,T,V,N} + x = ForwardDiff.value(d) + val = exp(ULogarithmic{R}, x) + deriv = exp(ULogarithmic{R}, x) + return ForwardDiff.dual_definition_retval(Val{T}(), val, deriv, partials(d)) +end + +function Base.exp(::Type{Logarithmic{R}}, d::Dual{T,V,N}) where {R<:Real,T,V,N} + x = ForwardDiff.value(d) + val = exp(Logarithmic{R}, x) + deriv = exp(Logarithmic{R}, x) + return ForwardDiff.dual_definition_retval(Val{T}(), val, deriv, partials(d)) +end + +function Base.exp(::Type{ULogarithmic}, d::Dual{T,V,N}) where {T,V,N} + return exp(ULogarithmic{V}, d) +end + +function Base.exp(::Type{Logarithmic}, d::Dual{T,V,N}) where {T,V,N} + return exp(Logarithmic{V}, d) +end + +# TODO: do we need more constructors? + +end diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl new file mode 100644 index 0000000..7b845bf --- /dev/null +++ b/test/forwarddiff.jl @@ -0,0 +1,20 @@ +using ForwardDiff: derivative, gradient +using LogarithmicNumbers +using Test + +f(x) = log(exp(x) * x) +g1(x) = log(exp(ULogarithmic, x) * x) +g2(x) = log(exp(ULogFloat64, x) * x) +h1(x) = log(exp(Logarithmic, x) * x) +h2(x) = log(exp(LogFloat64, x) * x) + +x = 1000 +d = 1 + inv(x) + +@test isnan(derivative(f, x)) +@test derivative(f, LogFloat64(x)) ≈ d +@test derivative(f, ULogFloat64(x)) ≈ d +@test derivative(g1, x) ≈ d +@test derivative(g2, x) ≈ d +@test derivative(h1, x) ≈ d +@test derivative(h2, x) ≈ d diff --git a/test/runtests.jl b/test/runtests.jl index ee9561e..6567590 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,5 @@ using LogarithmicNumbers, Test, Aqua -Aqua.test_all(LogarithmicNumbers) - function _approx(x,y) ans = isapprox(x, y, atol=1e-3) || (isnan(x) && isnan(y)) ans || @show x y @@ -37,7 +35,11 @@ Int[x for x in vals if x isa Int && x ≥ 0], atypes = (ULogarithmic, Logarithmic) atypes2 = (ULogarithmic, ULogFloat32, Logarithmic, LogFloat32) -@testset "LogarithmicNumbers" begin +@testset verbose=true "LogarithmicNumbers" begin + + @testset verbose=true "Aqua" begin + Aqua.test_all(LogarithmicNumbers, project_toml_formatting=(VERSION >= v"1.7")) + end @testset "types" begin @test @isdefined ULogarithmic @@ -569,4 +571,9 @@ atypes2 = (ULogarithmic, ULogFloat32, Logarithmic, LogFloat32) end + @testset "ForwardDiff" begin + if VERSION >= v"1.9" + include("forwarddiff.jl") + end + end end