Skip to content

Commit

Permalink
ForwardDiff support (#17)
Browse files Browse the repository at this point in the history
* Add ForwardDiff extension

* Better tests

* Simpler test

* ForwardDiff only for 1.9 and above

* Fix 1.6 tests

* Fix version check for Aqua formatting
  • Loading branch information
gdalle authored Oct 12, 2023
1 parent 84151f5 commit 4e6203a
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 4 deletions.
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@ version = "1.2.1"
[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
LogarithmicNumbersForwardDiffExt = "ForwardDiff"

[compat]
ForwardDiff = "0.10"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test"]
test = ["Aqua", "ForwardDiff", "Test"]
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,12 @@ possibilities for `func`:
`pdf`, `cdf`, `ccdf`.
- [SpecialFunctions.jl](https://github.com/JuliaMath/SpecialFunctions.jl):
`gamma`, `factorial`, `beta`, `erfc`, `erfcx`.

## Autodiff

On Julia >= 1.9, if you load [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), you should be allowed to compute

- derivatives of functions involving `exp(Logarithmic, x)`
- derivatives of functions evaluated at `Logarithmic(x)`

This functionality is experimental, please report any bug or unexpected behavior.
44 changes: 44 additions & 0 deletions ext/LogarithmicNumbersForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
module LogarithmicNumbersForwardDiffExt

using ForwardDiff: ForwardDiff, Dual, partials
using LogarithmicNumbers: LogarithmicNumbers, AnyLogarithmic, Logarithmic, ULogarithmic

## Promotion rules

function Base.promote_rule(::Type{Logarithmic{R}}, ::Type{Dual{T,V,N}}) where {R<:Real,T,V,N}
return Dual{T,promote_rule(Logarithmic{R}, V),N}
end

function Base.promote_rule(::Type{ULogarithmic{R}}, ::Type{Dual{T,V,N}}) where {R<:Real,T,V,N}
return Dual{T,promote_rule(ULogarithmic{R}, V),N}
end

## Constructors

# Based on the unary_definition macro in ForwardDiff.jl (https://github.com/JuliaDiff/ForwardDiff.jl/blob/6a6443b754b0fcfb4d671c9a3d01776df801f498/src/dual.jl#L230-L244)

function Base.exp(::Type{ULogarithmic{R}}, d::Dual{T,V,N}) where {R<:Real,T,V,N}
x = ForwardDiff.value(d)
val = exp(ULogarithmic{R}, x)
deriv = exp(ULogarithmic{R}, x)
return ForwardDiff.dual_definition_retval(Val{T}(), val, deriv, partials(d))
end

function Base.exp(::Type{Logarithmic{R}}, d::Dual{T,V,N}) where {R<:Real,T,V,N}
x = ForwardDiff.value(d)
val = exp(Logarithmic{R}, x)
deriv = exp(Logarithmic{R}, x)
return ForwardDiff.dual_definition_retval(Val{T}(), val, deriv, partials(d))
end

function Base.exp(::Type{ULogarithmic}, d::Dual{T,V,N}) where {T,V,N}
return exp(ULogarithmic{V}, d)
end

function Base.exp(::Type{Logarithmic}, d::Dual{T,V,N}) where {T,V,N}
return exp(Logarithmic{V}, d)
end

# TODO: do we need more constructors?

end
20 changes: 20 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using ForwardDiff: derivative, gradient
using LogarithmicNumbers
using Test

f(x) = log(exp(x) * x)
g1(x) = log(exp(ULogarithmic, x) * x)
g2(x) = log(exp(ULogFloat64, x) * x)
h1(x) = log(exp(Logarithmic, x) * x)
h2(x) = log(exp(LogFloat64, x) * x)

x = 1000
d = 1 + inv(x)

@test isnan(derivative(f, x))
@test derivative(f, LogFloat64(x)) d
@test derivative(f, ULogFloat64(x)) d
@test derivative(g1, x) d
@test derivative(g2, x) d
@test derivative(h1, x) d
@test derivative(h2, x) d
13 changes: 10 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using LogarithmicNumbers, Test, Aqua

Aqua.test_all(LogarithmicNumbers)

function _approx(x,y)
ans = isapprox(x, y, atol=1e-3) || (isnan(x) && isnan(y))
ans || @show x y
Expand Down Expand Up @@ -37,7 +35,11 @@ Int[x for x in vals if x isa Int && x ≥ 0],
atypes = (ULogarithmic, Logarithmic)
atypes2 = (ULogarithmic, ULogFloat32, Logarithmic, LogFloat32)

@testset "LogarithmicNumbers" begin
@testset verbose=true "LogarithmicNumbers" begin

@testset verbose=true "Aqua" begin
Aqua.test_all(LogarithmicNumbers, project_toml_formatting=(VERSION >= v"1.7"))
end

@testset "types" begin
@test @isdefined ULogarithmic
Expand Down Expand Up @@ -569,4 +571,9 @@ atypes2 = (ULogarithmic, ULogFloat32, Logarithmic, LogFloat32)

end

@testset "ForwardDiff" begin
if VERSION >= v"1.9"
include("forwarddiff.jl")
end
end
end

0 comments on commit 4e6203a

Please sign in to comment.