Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@ include("symbols.jl")
# Automatic Differentiation
export AbstractADType
export AutoChainRules,
AutoDiffractor,
AutoEnzyme,
AutoFastDifferentiation,
AutoFiniteDiff,
AutoFiniteDifferences,
AutoForwardDiff,
AutoGTPSA,
AutoModelingToolkit,
AutoMooncake,
AutoMooncakeForward,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
AutoTapir,
AutoTaylorDiff,
AutoTracker,
AutoZygote,
NoAutoDiff,
NoAutoDiffSelectedError,
AutoReactant
AutoDiffractor,
AutoEnzyme,
AutoFastDifferentiation,
AutoFiniteDiff,
AutoFiniteDifferences,
AutoForwardDiff,
AutoGTPSA,
AutoModelingToolkit,
AutoMooncake,
AutoMooncakeForward,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
AutoTapir,
AutoTaylorDiff,
AutoTracker,
AutoZygote,
NoAutoDiff,
NoAutoDiffSelectedError,
AutoReactant
@public AbstractMode
@public ForwardMode, ReverseMode, ForwardOrReverseMode, SymbolicMode
@public mode
Expand All @@ -58,6 +58,10 @@ export AutoChainRules,
export AutoSparse
@public dense_ad

# DI Automatic Differentiation
export AutoDI
@public inner_ad

# Sparsity detection
export AbstractSparsityDetector
export jacobian_sparsity, hessian_sparsity
Expand Down
56 changes: 54 additions & 2 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ struct AutoEnzyme{M, A} <: AbstractADType
end

function AutoEnzyme(;
mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A}
mode::M = nothing, function_annotation::Type{A} = Nothing
) where {M, A}
return AutoEnzyme{M, A}(mode)
end

Expand Down Expand Up @@ -106,7 +107,8 @@ struct AutoReactant{M<:AutoEnzyme} <: AbstractADType
end

function AutoReactant(;
mode::Union{AutoEnzyme,Nothing} = nothing)
mode::Union{AutoEnzyme, Nothing} = nothing
)
if mode === nothing
mode = AutoEnzyme()
end
Expand Down Expand Up @@ -586,3 +588,53 @@ NoAutoDiffSelectedError() = NoAutoDiffSelectedError("Automatic differentiation c
function mode(::NoAutoDiff)
throw(NoAutoDiffSelectedError())
end

"""
AutoDI{I<:AbstractADType}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call it AutoDifferentiationInterface? I know it is longer but it is more explicit, and coherent with other types in here. People can always define a shortcut if needed


Wraps an AD type to signify that the DifferentiationInterface wrapper should be used instead of calling the backend directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this formulation, because it seems to suggest that anything not wrapped in AutoDI will not use DI. However we agreed that the current behavior of packages who call, say, AutoEnzyme through DI should be preserved. Maybe add a word of caution specifying that "using an ADType directly instead of wrapping it inside AutoDI does not forbid the use of DI"?


This allows packages to distinguish between an intention to directly call a corresponding AD tool vs. the DI wrapper for said tool, enabling the ability to use, test, and validate both approaches.

# Fields

- `inner_ad::I`: the underlying AD package, subtyping [`AbstractADType`](@ref)

# Constructors

AutoDI(inner_ad)

# Example

```jldoctest
julia> using ADTypes

julia> ad = AutoDI(AutoForwardDiff())
AutoDI(AutoForwardDiff())

julia> inner_ad(ad)
AutoForwardDiff()
```
"""
struct AutoDI{I <: AbstractADType} <: AbstractADType
inner_ad::I
end

function Base.show(io::IO, backend::AutoDI)
print(io, AutoDI, "(", repr(backend.inner_ad, context = io), ")")
end

"""
inner_ad(ad::AutoDI)::AbstractADType
inner_ad(ad::AbstractADType)::AbstractADType

Return the underlying AD package for a DI AD choice, acts as the identity on a non-DI AD choice.

# See also

- [`AutoDI`](@ref)
"""
inner_ad(ad::AutoDI) = ad.inner_ad
inner_ad(ad::AbstractADType) = ad

mode(di_ad::AutoDI) = mode(inner_ad(di_ad))
65 changes: 65 additions & 0 deletions test/di.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
@testset "AutoDI" begin
@testset "Subtyping and wrapping $ad_name" for (ad_name, ad) in [
("AutoForwardDiff", AutoForwardDiff()),
("AutoZygote", AutoZygote()),
("AutoEnzyme", AutoEnzyme()),
("AutoReverseDiff", AutoReverseDiff()),
("AutoChainRules", AutoChainRules(; ruleconfig = ForwardOrReverseRuleConfig())),
]
di_ad = AutoDI(ad)
@test di_ad isa AbstractADType
@test di_ad isa AutoDI

# Test mode propagation
if mode(ad) isa ForwardMode
@test mode(di_ad) isa ForwardMode
elseif mode(ad) isa ForwardOrReverseMode
@test mode(di_ad) isa ForwardOrReverseMode
elseif mode(ad) isa ReverseMode
@test mode(di_ad) isa ReverseMode
elseif mode(ad) isa SymbolicMode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is never hit

@test mode(di_ad) isa SymbolicMode
end

# Test inner_ad accessor
@test inner_ad(ad) == ad
@test inner_ad(di_ad) == ad
end

@testset "All AD backends" begin
for ad in every_ad()
di_ad = AutoDI(ad)
@test di_ad isa AbstractADType
@test inner_ad(di_ad) == ad
@test mode(di_ad) == mode(ad)
end
end

@testset "Nested wrapping" begin
# Test that we can wrap AutoDI with AutoSparse and vice versa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we want to support arbitrary wrappings like that. Shouldn't there be a proper way to wrap and an improper one?

ad = AutoForwardDiff()
di_ad = AutoDI(ad)
sparse_di_ad = AutoSparse(di_ad)

@test sparse_di_ad isa AutoSparse
@test dense_ad(sparse_di_ad) isa AutoDI
@test inner_ad(dense_ad(sparse_di_ad)) == ad

# Test AutoDI wrapping AutoSparse
sparse_ad = AutoSparse(ad)
di_sparse_ad = AutoDI(sparse_ad)

@test di_sparse_ad isa AutoDI
@test inner_ad(di_sparse_ad) isa AutoSparse
@test dense_ad(inner_ad(di_sparse_ad)) == ad
end

@testset "Display" begin
ad = AutoForwardDiff(chunksize = 5)
di_ad = AutoDI(ad)

str = sprint(show, di_ad)
@test occursin("AutoDI", str)
@test occursin("AutoForwardDiff", str)
end
end
40 changes: 22 additions & 18 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
using ADTypes
using ADTypes: AbstractADType,
mode,
ForwardMode,
ForwardOrReverseMode,
ReverseMode,
SymbolicMode
mode,
ForwardMode,
ForwardOrReverseMode,
ReverseMode,
SymbolicMode
using ADTypes: dense_ad,
NoSparsityDetector,
KnownJacobianSparsityDetector,
KnownHessianSparsityDetector,
sparsity_detector,
jacobian_sparsity,
hessian_sparsity,
NoColoringAlgorithm,
coloring_algorithm,
column_coloring,
row_coloring,
symmetric_coloring
inner_ad,
NoSparsityDetector,
KnownJacobianSparsityDetector,
KnownHessianSparsityDetector,
sparsity_detector,
jacobian_sparsity,
hessian_sparsity,
NoColoringAlgorithm,
coloring_algorithm,
column_coloring,
row_coloring,
symmetric_coloring
using Aqua: Aqua
using ChainRulesCore: ChainRulesCore, RuleConfig,
HasForwardsMode, HasReverseMode,
NoForwardsMode, NoReverseMode
HasForwardsMode, HasReverseMode,
NoForwardsMode, NoReverseMode
using EnzymeCore: EnzymeCore
using JET: JET
using Test
Expand Down Expand Up @@ -101,6 +102,9 @@ end
@testset "Sparse" begin
include("sparse.jl")
end
@testset "DI" begin
include("di.jl")
end
@testset "Symbols" begin
include("symbols.jl")
end
Expand Down