Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 69 additions & 68 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,27 @@ jobs:
matrix:
version:
- "1.10"
- "1"
# - "1"
group:
- Misc/Internals
- Misc/DifferentiateWith
- Misc/FromPrimitive
- Misc/SparsityDetector
- Misc/ZeroBackends
- Back/ChainRulesBackends
- Back/Enzyme
- Back/FiniteDiff
- Back/FiniteDifferences
- Back/ForwardDiff
- Back/Mooncake
- Back/PolyesterForwardDiff
- Back/ReverseDiff
- Back/SymbolicBackends
- Back/Tracker
- Back/Zygote
- Down/Flux
- Down/Lux
# - Misc/Internals
# - Misc/DifferentiateWith
# - Misc/FromPrimitive
# - Misc/SparsityDetector
# - Misc/ZeroBackends
# - Back/ChainRulesBackends
# - Back/Enzyme
# - Back/FiniteDiff
# - Back/FiniteDifferences
# - Back/ForwardDiff
# - Back/Mooncake
# - Back/PolyesterForwardDiff
# - Back/ReverseDiff
# - Back/SymbolicBackends
# - Back/Tracker
# - Back/Zygote
# - Down/Flux
# - Down/Lux
- Down/Reactant
skip_lts:
- ${{ github.event.pull_request.draft }}
exclude:
Expand Down Expand Up @@ -86,52 +87,52 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true

test-DIT:
name: ${{ matrix.version }} - DIT (${{ matrix.group }})
runs-on: ubuntu-latest
if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
timeout-minutes: 60
permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
actions: write
contents: read
strategy:
fail-fast: true
matrix:
version:
- "1.10"
- "1"
group:
- Formalities
- Zero
- Standard
- Weird
skip_lts:
- ${{ github.event.pull_request.draft }}
# exclude:
# - skip_lts: true
# version: "1.10"
env:
JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: x64
- uses: julia-actions/cache@v2
- name: Install dependencies & run tests
run: julia --project=./DifferentiationInterfaceTest -e '
using Pkg;
Pkg.Registry.update();
Pkg.develop(path="./DifferentiationInterface");
Pkg.test("DifferentiationInterfaceTest"; coverage=true);'
- uses: julia-actions/julia-processcoverage@v1
with:
directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test
- uses: codecov/codecov-action@v4
with:
files: lcov.info
flags: DIT
name: ${{ matrix.version }} - DIT (${{ matrix.group }})
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
# test-DIT:
# name: ${{ matrix.version }} - DIT (${{ matrix.group }})
# runs-on: ubuntu-latest
# if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
# timeout-minutes: 60
# permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
# actions: write
# contents: read
# strategy:
# fail-fast: true
# matrix:
# version:
# - "1.10"
# - "1"
# group:
# - Formalities
# # - Zero
# # - Standard
# # - Weird
# skip_lts:
# - ${{ github.event.pull_request.draft }}
# # exclude:
# # - skip_lts: true
# # version: "1.10"
# env:
# JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
# steps:
# - uses: actions/checkout@v4
# - uses: julia-actions/setup-julia@v2
# with:
# version: ${{ matrix.version }}
# arch: x64
# - uses: julia-actions/cache@v2
# - name: Install dependencies & run tests
# run: julia --project=./DifferentiationInterfaceTest -e '
# using Pkg;
# Pkg.Registry.update();
# Pkg.develop(path="./DifferentiationInterface");
# Pkg.test("DifferentiationInterfaceTest"; coverage=true);'
# - uses: julia-actions/julia-processcoverage@v1
# with:
# directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test
# - uses: codecov/codecov-action@v4
# with:
# files: lcov.info
# flags: DIT
# name: ${{ matrix.version }} - DIT (${{ matrix.group }})
# token: ${{ secrets.CODECOV_TOKEN }}
# fail_ci_if_error: true
3 changes: 3 additions & 0 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Expand All @@ -35,6 +36,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReactantExt = "Reactant"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
Expand All @@ -56,6 +58,7 @@ ForwardDiff = "0.10.36"
LinearAlgebra = "<0.0.1,1"
Mooncake = "0.4.0"
PolyesterForwardDiff = "0.1.2"
Reactant = "0.2.1"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.5.0,0.6"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module DifferentiationInterfaceReactantExt

using ADTypes: ADTypes
import DifferentiationInterface as DI
using DifferentiationInterface:
ReactantBackend,
DerivativePrep,
GradientPrep,
HessianPrep,
HVPPrep,
JacobianPrep,
PullbackPrep,
PushforwardPrep,
SecondDerivativePrep
using Reactant: @compile, to_rarray

ADTypes.mode(rebackend::ReactantBackend) = ADTypes.mode(rebackend.backend)
DI.check_available(rebackend::ReactantBackend) = DI.check_available(rebackend.backend)
DI.inplace_support(rebackend::ReactantBackend) = DI.inplace_support(rebackend.backend)

include("onearg.jl")

end # module
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
struct ReactantGradientPrep{XR,GR,CG,CG!,CVG,CVG!} <: GradientPrep
xr::XR
gr::GR
compiled_gradient::CG
compiled_gradient!::CG!
compiled_value_and_gradient::CVG
compiled_value_and_gradient!::CVG!
end

function DI.prepare_gradient(f, rebackend::ReactantBackend, x)
(; backend) = rebackend
xr = to_rarray(x)
gr = to_rarray(similar(x))
_gradient(_xr) = DI.gradient(f, backend, _xr)
_gradient!(_gr, _xr) = DI.gradient!(f, _gr, backend, _xr)
_value_and_gradient(_xr) = DI.value_and_gradient(f, backend, _xr)
_value_and_gradient!(_gr, _xr) = DI.value_and_gradient!(f, _gr, backend, _xr)
compiled_gradient = @compile _gradient(xr)
compiled_gradient! = @compile _gradient!(gr, xr)
compiled_value_and_gradient = @compile _value_and_gradient(xr)
compiled_value_and_gradient! = @compile _value_and_gradient!(gr, xr)
return ReactantGradientPrep(
xr,
gr,
compiled_gradient,
compiled_gradient!,
compiled_value_and_gradient,
compiled_value_and_gradient!,
)
end

function DI.gradient(f, prep::ReactantGradientPrep, ::ReactantBackend, x)
(; xr, compiled_gradient) = prep
copyto!(xr, x)
return compiled_gradient(xr)
end

function DI.value_and_gradient(f, prep::ReactantGradientPrep, ::ReactantBackend, x)
(; xr, compiled_value_and_gradient) = prep
copyto!(xr, x)
return compiled_value_and_gradient(xr)
end

function DI.gradient!(f, grad, prep::ReactantGradientPrep, ::ReactantBackend, x)
(; xr, gr, compiled_gradient!) = prep
copyto!(xr, x)
prep.compiled_gradient!(gr, xr)
return copyto!(grad, gr)
end

function DI.value_and_gradient!(f, grad, prep::ReactantGradientPrep, ::ReactantBackend, x)
(; xr, gr, compiled_value_and_gradient!) = prep
copyto!(xr, x)
y, gr = compiled_value_and_gradient!(gr, xr)
return y, copyto!(grad, gr)
end
4 changes: 4 additions & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ include("misc/from_primitive.jl")
include("misc/sparsity_detector.jl")
include("misc/zero_backends.jl")

struct ReactantBackend{B} <: ADTypes.AbstractADType
backend::B
end

## Exported

export Context, Constant
Expand Down
24 changes: 24 additions & 0 deletions DifferentiationInterface/test/Down/Reactant/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Pkg
Pkg.add("Enzyme")
Pkg.add(; url="https://github.com/EnzymeAD/Reactant.jl")

using DifferentiationInterface
using DifferentiationInterface: ReactantBackend
using DifferentiationInterfaceTest
using Enzyme: Enzyme
using LinearAlgebra
using Reactant: Reactant
using Test

@assert !isnothing(
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceReactantExt)
)

LOGGING = get(ENV, "CI", "false") == "false"

scenarios = [
Scenario{:gradient,:out}(sum, [1.0, 2.0]; res1=ones(2)),
Scenario{:gradient,:in}(sum, [1.0, 2.0]; res1=ones(2)),
]

test_differentiation(ReactantBackend(AutoEnzyme()), scenarios; logging=LOGGING)