-
Notifications
You must be signed in to change notification settings - Fork 32
Reactant prototype #325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Reactant prototype #325
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
ed381bc
Reactant prototype
gdalle 36ad7b0
Fix compilation
gdalle 36308c1
Format
gdalle 5f5b335
Deactivate noisy workflows
gdalle 815a03c
Deactivate DIT test
gdalle bca66dd
Try proper tests
gdalle 370cd5b
Merge branch 'main' into gd/reactant
gdalle 8e5a2cc
Trigger CI
gdalle 3bb5143
Merge branch 'main' into gd/reactant
gdalle bbbca2a
Merge remote-tracking branch 'origin/main' into gd/reactant
gdalle 78a2927
Fewer tests
gdalle 899cc51
Merge remote-tracking branch 'origin/main' into gd/reactant
gdalle c8149fd
Update
gdalle 26271d4
CI
gdalle 0fd9544
Logging
gdalle 38df870
No linalg
gdalle 095e2ca
Merge remote-tracking branch 'origin/main' into gd/reactant
gdalle f96bfd0
Fixes
gdalle b02b12c
Merge remote-tracking branch 'origin/main' into gd/reactant
gdalle 4e320da
Adapt to new interface, use main branch
gdalle 2cd3531
Compile value and gradient
gdalle 3542827
Merge branch 'main' into gd/reactant
gdalle 90b4dd7
More prep
gdalle a5260d8
Handmade scdnarios
gdalle 38c7374
Merge remote-tracking branch 'origin/main' into gd/reactant
gdalle 428b407
Remove Compat
gdalle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
...nInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| module DifferentiationInterfaceReactantExt | ||
|
|
||
| using ADTypes: ADTypes | ||
| using Compat | ||
| import DifferentiationInterface as DI | ||
| using DifferentiationInterface: | ||
| ReactantBackend, | ||
| DerivativePrep, | ||
| GradientPrep, | ||
| HessianPrep, | ||
| HVPPrep, | ||
| JacobianPrep, | ||
| PullbackPrep, | ||
| PushforwardPrep, | ||
| SecondDerivativePrep | ||
| using Reactant: @compile, to_rarray | ||
|
|
||
| ADTypes.mode(rebackend::ReactantBackend) = ADTypes.mode(rebackend.backend) | ||
| DI.check_available(rebackend::ReactantBackend) = DI.check_available(rebackend.backend) | ||
| DI.inplace_support(rebackend::ReactantBackend) = DI.inplace_support(rebackend.backend) | ||
|
|
||
| include("onearg.jl") | ||
|
|
||
| end # module |
36 changes: 36 additions & 0 deletions
36
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| struct ReactantGradientPrep{F,G} <: GradientPrep | ||
| compiled_function::F | ||
| compiled_gradient::G | ||
| end | ||
|
|
||
| function DI.prepare_gradient(f, rebackend::ReactantBackend, x) | ||
| xr = to_rarray(x) | ||
| gradient_closure(xr) = DI.gradient(f, rebackend.backend, xr) | ||
| compiled_function = @compile f(xr) | ||
| compiled_gradient = @compile gradient_closure(xr) | ||
| return ReactantGradientPrep(compiled_function, compiled_gradient) | ||
| end | ||
|
|
||
| function DI.gradient(f, prep::ReactantGradientPrep, ::ReactantBackend, x) | ||
| @compat (; compiled_gradient) = prep | ||
| xr = to_rarray(x) | ||
| return compiled_gradient(xr) | ||
| end | ||
|
|
||
| function DI.value_and_gradient(f, prep::ReactantGradientPrep, ::ReactantBackend, x) | ||
| @compat (; compiled_function, compiled_gradient) = prep | ||
| xr = to_rarray(x) | ||
| return compiled_function(xr), compiled_gradient(xr) | ||
| end | ||
|
|
||
| function DI.gradient!(f, grad, prep::ReactantGradientPrep, rebackend::ReactantBackend, x) | ||
| gradr = DI.gradient(f, prep, rebackend, x) | ||
| return copyto!(grad, gradr) | ||
| end | ||
|
|
||
| function DI.value_and_gradient!( | ||
| f, grad, prep::ReactantGradientPrep, rebackend::ReactantBackend, x | ||
| ) | ||
| y, gradr = DI.value_and_gradient(f, prep, rebackend, x) | ||
| return y, copyto!(grad, gradr) | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| using Pkg | ||
| Pkg.add("Enzyme") | ||
| Pkg.add(; url="https://github.com/EnzymeAD/Reactant.jl") | ||
|
|
||
| using DifferentiationInterface | ||
| using DifferentiationInterface: ReactantBackend | ||
| using DifferentiationInterfaceTest | ||
| using Enzyme: Enzyme | ||
| using LinearAlgebra | ||
| using Reactant: Reactant | ||
| using Test | ||
|
|
||
| LOGGING = get(ENV, "CI", "false") == "false" | ||
|
|
||
| rebackend = ReactantBackend(AutoEnzyme()) | ||
|
|
||
| test_differentiation( | ||
| ReactantBackend(AutoEnzyme()), | ||
| default_scenarios(; linalg=true); | ||
| excluded=[ | ||
| :derivative, :jacobian, :hessian, :hvp, :pullback, :pushforward, :second_derivative | ||
| ], | ||
| logging=LOGGING, | ||
| ) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.