-
-
Notifications
You must be signed in to change notification settings - Fork 15
Introduce AutoDI wrapper type for DifferentiationInterface dispatch #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b0c18cf
eebb071
da9c12c
5180120
540c6d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -68,7 +68,8 @@ struct AutoEnzyme{M, A} <: AbstractADType | |
| end | ||
|
|
||
| function AutoEnzyme(; | ||
| mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A} | ||
| mode::M = nothing, function_annotation::Type{A} = Nothing | ||
| ) where {M, A} | ||
| return AutoEnzyme{M, A}(mode) | ||
| end | ||
|
|
||
|
|
@@ -106,7 +107,8 @@ struct AutoReactant{M<:AutoEnzyme} <: AbstractADType | |
| end | ||
|
|
||
| function AutoReactant(; | ||
| mode::Union{AutoEnzyme,Nothing} = nothing) | ||
| mode::Union{AutoEnzyme, Nothing} = nothing | ||
| ) | ||
| if mode === nothing | ||
| mode = AutoEnzyme() | ||
| end | ||
|
|
@@ -586,3 +588,53 @@ NoAutoDiffSelectedError() = NoAutoDiffSelectedError("Automatic differentiation c | |
| function mode(::NoAutoDiff) | ||
| throw(NoAutoDiffSelectedError()) | ||
| end | ||
|
|
||
| """ | ||
| AutoDI{I<:AbstractADType} | ||
|
|
||
| Wraps an AD type to signify that the DifferentiationInterface wrapper should be used instead of calling the backend directly. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about this formulation, because it seems to suggest that anything not wrapped in |
||
|
|
||
| This allows packages to distinguish between an intention to directly call a corresponding AD tool vs. the DI wrapper for said tool, enabling the ability to use, test, and validate both approaches. | ||
|
|
||
| # Fields | ||
|
|
||
| - `inner_ad::I`: the underlying AD package, subtyping [`AbstractADType`](@ref) | ||
|
|
||
| # Constructors | ||
|
|
||
| AutoDI(inner_ad) | ||
|
|
||
| # Example | ||
|
|
||
| ```jldoctest | ||
| julia> using ADTypes | ||
|
|
||
| julia> ad = AutoDI(AutoForwardDiff()) | ||
| AutoDI(AutoForwardDiff()) | ||
|
|
||
| julia> inner_ad(ad) | ||
| AutoForwardDiff() | ||
| ``` | ||
| """ | ||
| struct AutoDI{I <: AbstractADType} <: AbstractADType | ||
| inner_ad::I | ||
| end | ||
|
|
||
| function Base.show(io::IO, backend::AutoDI) | ||
| print(io, AutoDI, "(", repr(backend.inner_ad, context = io), ")") | ||
| end | ||
|
|
||
| """ | ||
| inner_ad(ad::AutoDI)::AbstractADType | ||
| inner_ad(ad::AbstractADType)::AbstractADType | ||
|
|
||
| Return the underlying AD package for a DI AD choice, acts as the identity on a non-DI AD choice. | ||
|
|
||
| # See also | ||
|
|
||
| - [`AutoDI`](@ref) | ||
| """ | ||
| inner_ad(ad::AutoDI) = ad.inner_ad | ||
| inner_ad(ad::AbstractADType) = ad | ||
|
|
||
| mode(di_ad::AutoDI) = mode(inner_ad(di_ad)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| @testset "AutoDI" begin | ||
| @testset "Subtyping and wrapping $ad_name" for (ad_name, ad) in [ | ||
| ("AutoForwardDiff", AutoForwardDiff()), | ||
| ("AutoZygote", AutoZygote()), | ||
| ("AutoEnzyme", AutoEnzyme()), | ||
| ("AutoReverseDiff", AutoReverseDiff()), | ||
| ("AutoChainRules", AutoChainRules(; ruleconfig = ForwardOrReverseRuleConfig())), | ||
| ] | ||
| di_ad = AutoDI(ad) | ||
| @test di_ad isa AbstractADType | ||
| @test di_ad isa AutoDI | ||
|
|
||
| # Test mode propagation | ||
| if mode(ad) isa ForwardMode | ||
| @test mode(di_ad) isa ForwardMode | ||
| elseif mode(ad) isa ForwardOrReverseMode | ||
| @test mode(di_ad) isa ForwardOrReverseMode | ||
| elseif mode(ad) isa ReverseMode | ||
| @test mode(di_ad) isa ReverseMode | ||
| elseif mode(ad) isa SymbolicMode | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line is never hit |
||
| @test mode(di_ad) isa SymbolicMode | ||
| end | ||
|
|
||
| # Test inner_ad accessor | ||
| @test inner_ad(ad) == ad | ||
| @test inner_ad(di_ad) == ad | ||
| end | ||
|
|
||
| @testset "All AD backends" begin | ||
| for ad in every_ad() | ||
| di_ad = AutoDI(ad) | ||
| @test di_ad isa AbstractADType | ||
| @test inner_ad(di_ad) == ad | ||
| @test mode(di_ad) == mode(ad) | ||
| end | ||
| end | ||
|
|
||
| @testset "Nested wrapping" begin | ||
| # Test that we can wrap AutoDI with AutoSparse and vice versa | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure we want to support arbitrary wrappings like that. Shouldn't there be a proper way to wrap and an improper one? |
||
| ad = AutoForwardDiff() | ||
| di_ad = AutoDI(ad) | ||
| sparse_di_ad = AutoSparse(di_ad) | ||
|
|
||
| @test sparse_di_ad isa AutoSparse | ||
| @test dense_ad(sparse_di_ad) isa AutoDI | ||
| @test inner_ad(dense_ad(sparse_di_ad)) == ad | ||
|
|
||
| # Test AutoDI wrapping AutoSparse | ||
| sparse_ad = AutoSparse(ad) | ||
| di_sparse_ad = AutoDI(sparse_ad) | ||
|
|
||
| @test di_sparse_ad isa AutoDI | ||
| @test inner_ad(di_sparse_ad) isa AutoSparse | ||
| @test dense_ad(inner_ad(di_sparse_ad)) == ad | ||
| end | ||
|
|
||
| @testset "Display" begin | ||
| ad = AutoForwardDiff(chunksize = 5) | ||
| di_ad = AutoDI(ad) | ||
|
|
||
| str = sprint(show, di_ad) | ||
| @test occursin("AutoDI", str) | ||
| @test occursin("AutoForwardDiff", str) | ||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we call it AutoDifferentiationInterface? I know it is longer but it is more explicit, and coherent with other types in here. People can always define a shortcut if needed