Skip to content
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,29 @@ jobs:
matrix:
version:
- "1"
- "lts"
- "pre"
# - "lts"
# - "pre"
group:
- Misc/Internals
- Misc/DifferentiateWith
- Misc/FromPrimitive
- Misc/SparsityDetector
- Misc/ZeroBackends
- Back/ChainRulesBackends
- Back/Enzyme
- Back/FiniteDiff
- Back/FiniteDifferences
- Back/ForwardDiff
- Back/Mooncake
- Back/PolyesterForwardDiff
- Back/ReverseDiff
- Back/SecondOrder
- Back/SymbolicBackends
- Back/Tracker
- Back/Zygote
- Down/Flux
- Down/Lux
# - Misc/DifferentiateWith
# - Misc/FromPrimitive
# - Misc/SparsityDetector
# - Misc/ZeroBackends
# - Back/ChainRulesBackends
# - Back/Enzyme
# - Back/FiniteDiff
# - Back/FiniteDifferences
# - Back/ForwardDiff
# - Back/Mooncake
# - Back/PolyesterForwardDiff
# - Back/ReverseDiff
# - Back/SecondOrder
# - Back/SymbolicBackends
# - Back/Tracker
# - Back/Zygote
# - Down/Flux
# - Down/Lux
- Down/Reactant
skip_lts:
- ${{ github.event.pull_request.draft }}
skip_pre:
Expand Down Expand Up @@ -83,6 +84,8 @@ jobs:
group: Down/Flux
- version: "lts"
group: Down/Lux
- version: "lts"
group: Down/Reactant
# pre-release
- version: "pre"
group: Back/ChainRulesBackends
Expand Down Expand Up @@ -135,13 +138,13 @@ jobs:
matrix:
version:
- "1"
- "lts"
- "pre"
# - "lts"
# - "pre"
group:
- Formalities
- Zero
- Standard
- Weird
# - Zero
# - Standard
# - Weird
skip_lts:
- ${{ github.event.pull_request.draft }}
skip_pre:
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Expand All @@ -36,6 +37,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReactantExt = "Reactant"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
Expand All @@ -57,6 +59,7 @@ LinearAlgebra = "<0.0.1,1"
Mooncake = "0.4.0"
PackageExtensionCompat = "1.0.2"
PolyesterForwardDiff = "0.1.1"
Reactant = "0.2.1"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.5.0,0.6"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module DifferentiationInterfaceReactantExt

using ADTypes: ADTypes
using Compat
import DifferentiationInterface as DI
using DifferentiationInterface:
ReactantBackend,
DerivativePrep,
GradientPrep,
HessianPrep,
HVPPrep,
JacobianPrep,
PullbackPrep,
PushforwardPrep,
SecondDerivativePrep
using Reactant: @compile, to_rarray

ADTypes.mode(rebackend::ReactantBackend) = ADTypes.mode(rebackend.backend)
DI.check_available(rebackend::ReactantBackend) = DI.check_available(rebackend.backend)
DI.inplace_support(rebackend::ReactantBackend) = DI.inplace_support(rebackend.backend)

include("onearg.jl")

end # module
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
struct ReactantGradientPrep{F,G} <: GradientPrep
compiled_function::F
compiled_gradient::G
end

function DI.prepare_gradient(f, rebackend::ReactantBackend, x)
xr = to_rarray(x)
gradient_closure(xr) = DI.gradient(f, rebackend.backend, xr)
compiled_function = @compile f(xr)
compiled_gradient = @compile gradient_closure(xr)
return ReactantGradientPrep(compiled_function, compiled_gradient)
end

function DI.gradient(f, prep::ReactantGradientPrep, ::ReactantBackend, x)
@compat (; compiled_gradient) = prep
xr = to_rarray(x)
return compiled_gradient(xr)
end

function DI.value_and_gradient(f, prep::ReactantGradientPrep, ::ReactantBackend, x)
@compat (; compiled_function, compiled_gradient) = prep
xr = to_rarray(x)
return compiled_function(xr), compiled_gradient(xr)
Comment thread
gdalle marked this conversation as resolved.
Outdated
end

function DI.gradient!(f, grad, prep::ReactantGradientPrep, rebackend::ReactantBackend, x)
gradr = DI.gradient(f, prep, rebackend, x)
return copyto!(grad, gradr)
end

function DI.value_and_gradient!(
f, grad, prep::ReactantGradientPrep, rebackend::ReactantBackend, x
)
y, gradr = DI.value_and_gradient(f, prep, rebackend, x)
return y, copyto!(grad, gradr)
end
4 changes: 4 additions & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ include("misc/from_primitive.jl")
include("misc/sparsity_detector.jl")
include("misc/zero_backends.jl")

struct ReactantBackend{B} <: ADTypes.AbstractADType
backend::B
end

function __init__()
@require_extensions
end
Expand Down
24 changes: 24 additions & 0 deletions DifferentiationInterface/test/Down/Reactant/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Pkg
Pkg.add("Enzyme")
Pkg.add(; url="https://github.com/EnzymeAD/Reactant.jl")

using DifferentiationInterface
using DifferentiationInterface: ReactantBackend
using DifferentiationInterfaceTest
using Enzyme: Enzyme
using LinearAlgebra
using Reactant: Reactant
using Test

LOGGING = get(ENV, "CI", "false") == "false"

rebackend = ReactantBackend(AutoEnzyme())

test_differentiation(
ReactantBackend(AutoEnzyme()),
default_scenarios(; linalg=true);
excluded=[
:derivative, :jacobian, :hessian, :hvp, :pullback, :pushforward, :second_derivative
],
logging=LOGGING,
)