diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 609087e..622c8c2 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -29,26 +29,26 @@ include("symbols.jl") # Automatic Differentiation export AbstractADType export AutoChainRules, - AutoDiffractor, - AutoEnzyme, - AutoFastDifferentiation, - AutoFiniteDiff, - AutoFiniteDifferences, - AutoForwardDiff, - AutoGTPSA, - AutoModelingToolkit, - AutoMooncake, - AutoMooncakeForward, - AutoPolyesterForwardDiff, - AutoReverseDiff, - AutoSymbolics, - AutoTapir, - AutoTaylorDiff, - AutoTracker, - AutoZygote, - NoAutoDiff, - NoAutoDiffSelectedError, - AutoReactant + AutoDiffractor, + AutoEnzyme, + AutoFastDifferentiation, + AutoFiniteDiff, + AutoFiniteDifferences, + AutoForwardDiff, + AutoGTPSA, + AutoModelingToolkit, + AutoMooncake, + AutoMooncakeForward, + AutoPolyesterForwardDiff, + AutoReverseDiff, + AutoSymbolics, + AutoTapir, + AutoTaylorDiff, + AutoTracker, + AutoZygote, + NoAutoDiff, + NoAutoDiffSelectedError, + AutoReactant @public AbstractMode @public ForwardMode, ReverseMode, ForwardOrReverseMode, SymbolicMode @public mode @@ -58,6 +58,10 @@ export AutoChainRules, export AutoSparse @public dense_ad +# DI Automatic Differentiation +export AutoDI +@public inner_ad + # Sparsity detection export AbstractSparsityDetector export jacobian_sparsity, hessian_sparsity diff --git a/src/dense.jl b/src/dense.jl index c75ceb9..5400f6c 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -68,7 +68,8 @@ struct AutoEnzyme{M, A} <: AbstractADType end function AutoEnzyme(; - mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A} + mode::M = nothing, function_annotation::Type{A} = Nothing + ) where {M, A} return AutoEnzyme{M, A}(mode) end @@ -106,7 +107,8 @@ struct AutoReactant{M<:AutoEnzyme} <: AbstractADType end function AutoReactant(; - mode::Union{AutoEnzyme,Nothing} = nothing) + mode::Union{AutoEnzyme, Nothing} = nothing + ) if mode === nothing mode = AutoEnzyme() end @@ -586,3 +588,53 @@ NoAutoDiffSelectedError() = NoAutoDiffSelectedError("Automatic differentiation c function mode(::NoAutoDiff) throw(NoAutoDiffSelectedError()) end + +""" + AutoDI{I<:AbstractADType} + +Wraps an AD type to signify that the DifferentiationInterface wrapper should be used instead of calling the backend directly. + +This allows packages to distinguish between an intention to directly call a corresponding AD tool vs. the DI wrapper for said tool, enabling the ability to use, test, and validate both approaches. + +# Fields + + - `inner_ad::I`: the underlying AD package, subtyping [`AbstractADType`](@ref) + +# Constructors + + AutoDI(inner_ad) + +# Example + +```jldoctest +julia> using ADTypes + +julia> ad = AutoDI(AutoForwardDiff()) +AutoDI(AutoForwardDiff()) + +julia> inner_ad(ad) +AutoForwardDiff() +``` +""" +struct AutoDI{I <: AbstractADType} <: AbstractADType + inner_ad::I +end + +function Base.show(io::IO, backend::AutoDI) + print(io, AutoDI, "(", repr(backend.inner_ad, context = io), ")") +end + +""" + inner_ad(ad::AutoDI)::AbstractADType + inner_ad(ad::AbstractADType)::AbstractADType + +Return the underlying AD package for a DI AD choice, acts as the identity on a non-DI AD choice. + +# See also + + - [`AutoDI`](@ref) +""" +inner_ad(ad::AutoDI) = ad.inner_ad +inner_ad(ad::AbstractADType) = ad + +mode(di_ad::AutoDI) = mode(inner_ad(di_ad)) diff --git a/test/di.jl b/test/di.jl new file mode 100644 index 0000000..cec90d8 --- /dev/null +++ b/test/di.jl @@ -0,0 +1,65 @@ +@testset "AutoDI" begin + @testset "Subtyping and wrapping $ad_name" for (ad_name, ad) in [ + ("AutoForwardDiff", AutoForwardDiff()), + ("AutoZygote", AutoZygote()), + ("AutoEnzyme", AutoEnzyme()), + ("AutoReverseDiff", AutoReverseDiff()), + ("AutoChainRules", AutoChainRules(; ruleconfig = ForwardOrReverseRuleConfig())), + ] + di_ad = AutoDI(ad) + @test di_ad isa AbstractADType + @test di_ad isa AutoDI + + # Test mode propagation + if mode(ad) isa ForwardMode + @test mode(di_ad) isa ForwardMode + elseif mode(ad) isa ForwardOrReverseMode + @test mode(di_ad) isa ForwardOrReverseMode + elseif mode(ad) isa ReverseMode + @test mode(di_ad) isa ReverseMode + elseif mode(ad) isa SymbolicMode + @test mode(di_ad) isa SymbolicMode + end + + # Test inner_ad accessor + @test inner_ad(ad) == ad + @test inner_ad(di_ad) == ad + end + + @testset "All AD backends" begin + for ad in every_ad() + di_ad = AutoDI(ad) + @test di_ad isa AbstractADType + @test inner_ad(di_ad) == ad + @test mode(di_ad) == mode(ad) + end + end + + @testset "Nested wrapping" begin + # Test that we can wrap AutoDI with AutoSparse and vice versa + ad = AutoForwardDiff() + di_ad = AutoDI(ad) + sparse_di_ad = AutoSparse(di_ad) + + @test sparse_di_ad isa AutoSparse + @test dense_ad(sparse_di_ad) isa AutoDI + @test inner_ad(dense_ad(sparse_di_ad)) == ad + + # Test AutoDI wrapping AutoSparse + sparse_ad = AutoSparse(ad) + di_sparse_ad = AutoDI(sparse_ad) + + @test di_sparse_ad isa AutoDI + @test inner_ad(di_sparse_ad) isa AutoSparse + @test dense_ad(inner_ad(di_sparse_ad)) == ad + end + + @testset "Display" begin + ad = AutoForwardDiff(chunksize = 5) + di_ad = AutoDI(ad) + + str = sprint(show, di_ad) + @test occursin("AutoDI", str) + @test occursin("AutoForwardDiff", str) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b4a5f5b..3fc3ab7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,26 +1,27 @@ using ADTypes using ADTypes: AbstractADType, - mode, - ForwardMode, - ForwardOrReverseMode, - ReverseMode, - SymbolicMode + mode, + ForwardMode, + ForwardOrReverseMode, + ReverseMode, + SymbolicMode using ADTypes: dense_ad, - NoSparsityDetector, - KnownJacobianSparsityDetector, - KnownHessianSparsityDetector, - sparsity_detector, - jacobian_sparsity, - hessian_sparsity, - NoColoringAlgorithm, - coloring_algorithm, - column_coloring, - row_coloring, - symmetric_coloring + inner_ad, + NoSparsityDetector, + KnownJacobianSparsityDetector, + KnownHessianSparsityDetector, + sparsity_detector, + jacobian_sparsity, + hessian_sparsity, + NoColoringAlgorithm, + coloring_algorithm, + column_coloring, + row_coloring, + symmetric_coloring using Aqua: Aqua using ChainRulesCore: ChainRulesCore, RuleConfig, - HasForwardsMode, HasReverseMode, - NoForwardsMode, NoReverseMode + HasForwardsMode, HasReverseMode, + NoForwardsMode, NoReverseMode using EnzymeCore: EnzymeCore using JET: JET using Test @@ -101,6 +102,9 @@ end @testset "Sparse" begin include("sparse.jl") end + @testset "DI" begin + include("di.jl") + end @testset "Symbols" begin include("symbols.jl") end