Skip to content
6 changes: 3 additions & 3 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.7.4"
version = "0.7.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -56,7 +56,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
ADTypes = "1.13.0"
ADTypes = "1.17.0"
Aqua = "0.8.12"
ChainRulesCore = "1.23.0"
ComponentArrays = "0.15.27"
Expand All @@ -77,7 +77,7 @@ JET = "0.9"
JLArrays = "0.2.0"
JuliaFormatter = "1,2"
LinearAlgebra = "1"
Mooncake = "0.4.122"
Mooncake = "0.4.147"
Pkg = "1"
PolyesterForwardDiff = "0.1.2"
Random = "1"
Expand Down
6 changes: 4 additions & 2 deletions DifferentiationInterface/docs/src/explanation/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences)
- [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff)
- [`AutoGTPSA`](@extref ADTypes.AutoGTPSA)
- [`AutoMooncake`](@extref ADTypes.AutoMooncake)
- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) (the latter is experimental)
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff)
- [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff)
- [`AutoSymbolics`](@extref ADTypes.AutoSymbolics)
Expand Down Expand Up @@ -48,6 +48,7 @@ In practice, many AD backends have custom implementations for high-level operato
| `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| `AutoGTPSA` | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
| `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| `AutoMooncakeForward` | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 |
| `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand All @@ -68,6 +69,7 @@ Moreover, each context type is supported by a specific subset of backends:
| `AutoForwardDiff` | ✅ | ✅ |
| `AutoGTPSA` | ✅ | ❌ |
| `AutoMooncake` | ✅ | ✅ |
| `AutoMooncakeForward` | ✅ | ✅ |
| `AutoPolyesterForwardDiff` | ✅ | ✅ |
| `AutoReverseDiff` | ✅ | ❌ |
| `AutoSymbolics` | ✅ | ✅ |
Expand Down Expand Up @@ -95,7 +97,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).

## Implementations

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
module DifferentiationInterfaceMooncakeExt

using ADTypes: ADTypes, AutoMooncake
using ADTypes: ADTypes, AutoMooncake, AutoMooncakeForward
import DifferentiationInterface as DI
using Mooncake:
Mooncake,
CoDual,
Config,
Dual,
prepare_derivative_cache,
prepare_gradient_cache,
prepare_pullback_cache,
primal,
tangent,
tangent_type,
value_and_derivative!!,
value_and_gradient!!,
value_and_pullback!!,
zero_dual,
zero_tangent,
rdata_type,
fdata,
Expand All @@ -25,17 +31,17 @@ using Mooncake:
_copy_output,
_copy_to_output!!

DI.check_available(::AutoMooncake) = true
const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}}

get_config(::AutoMooncake{Nothing}) = Config()
get_config(backend::AutoMooncake{<:Config}) = backend.config
DI.check_available(::AnyAutoMooncake{C}) where {C} = true

# tangents need to be copied before returning, otherwise they are still aliased in the cache
mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x)
mycopy(x) = deepcopy(x)
get_config(::AnyAutoMooncake{Nothing}) = Config()
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config

include("onearg.jl")
include("twoarg.jl")
include("forward_onearg.jl")
include("forward_twoarg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
## Pushforward

struct MooncakeOneArgPushforwardPrep{SIG,Tcache,DX} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dx_righttype::DX
end

function DI.prepare_pushforward_nokwarg(
strict::Val,
f::F,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
dx_righttype = zero_tangent(x)
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype)
return prep
end

function DI.value_and_pushforward(
f::F,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x::X,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C,X}
DI.check_prep(f, prep, backend, x, tx, contexts...)
ys_and_ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
y_dual = value_and_derivative!!(
prep.cache,
zero_dual(f),
Dual(x, dx_righttype),
map(zero_dual ∘ DI.unwrap, contexts)...,
)
y = primal(y_dual)
dy = _copy_output(tangent(y_dual))
return y, dy
end
y = first(ys_and_ty[1])
ty = last.(ys_and_ty)
return y, ty
end

function DI.pushforward(
f::F,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
end

function DI.value_and_pushforward!(
f::F,
ty::NTuple,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
return y, ty
end

function DI.pushforward!(
f::F,
ty::NTuple,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...)
return ty
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
## Pushforward

struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dx_righttype::DX
dy_righttype::DY
end

function DI.prepare_pushforward_nokwarg(
strict::Val,
f!::F,
y,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
f!,
y,
x,
map(DI.unwrap, contexts)...;
config.debug_mode,
config.silence_debug_messages,
)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype)
return prep
end

function DI.value_and_pushforward(
f!::F,
y,
prep::MooncakeTwoArgPushforwardPrep,
backend::AutoMooncakeForward,
x::X,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C,X}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
y_dual = zero_dual(y)
value_and_derivative!!(
prep.cache,
zero_dual(f!),
y_dual,
Dual(x, dx_righttype),
map(zero_dual ∘ DI.unwrap, contexts)...,
)
dy = _copy_output(tangent(y_dual))
return dy
end
return y, ty
end

function DI.pushforward(
f!::F,
y,
prep::MooncakeTwoArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
return DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2]
end

function DI.value_and_pushforward!(
f!::F,
y::Y,
ty::NTuple,
prep::MooncakeTwoArgPushforwardPrep,
backend::AutoMooncakeForward,
x::X,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C,X,Y}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
foreach(tx, ty) do dx, dy
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
value_and_derivative!!(
prep.cache,
zero_dual(f!),
Dual(y, dy_righttype),
Dual(x, dx_righttype),
map(zero_dual ∘ DI.unwrap, contexts)...,
)
dy === dy_righttype || copyto!(dy, dy_righttype)
end
return y, ty
end

function DI.pushforward!(
f!::F,
y,
ty::NTuple,
prep::MooncakeTwoArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)
return ty
end
2 changes: 2 additions & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using ADTypes:
AutoForwardDiff,
AutoGTPSA,
AutoMooncake,
AutoMooncakeForward,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
Expand Down Expand Up @@ -115,6 +116,7 @@ export AutoFiniteDifferences
export AutoForwardDiff
export AutoGTPSA
export AutoMooncake
export AutoMooncakeForward
export AutoPolyesterForwardDiff
export AutoReverseDiff
export AutoSymbolics
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/misc/differentiate_with.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be

!!! warning
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).

!!! warning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ end;
@testset for scen in filter(differentiatewith_scenarios()) do scen
DIT.operator(scen) == :pullback
end
Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true)
Mooncake.TestUtils.test_rule(
StableRNG(0), scen.f, scen.x; is_primitive=true, mode=Mooncake.ReverseMode
)
end
end;

Expand Down
6 changes: 5 additions & 1 deletion DifferentiationInterface/test/Back/Mooncake/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ check_no_implicit_imports(DifferentiationInterface)

LOGGING = get(ENV, "CI", "false") == "false"

backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]
backends = [
AutoMooncake(; config=nothing),
AutoMooncake(; config=Mooncake.Config()),
AutoMooncakeForward(; config=nothing),
]

for backend in backends
@test check_available(backend)
Expand Down