-
Notifications
You must be signed in to change notification settings - Fork 41
Tapir.jl Usage #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Tapir.jl Usage #319
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
7ff2fa2
Bump patch version and add Tapir ext
willtebbutt 47d8d31
Make Tapir available at test time
willtebbutt 5069e82
Add Tapir runs to AD testing
willtebbutt de15572
Add single rule to Tapir to handle bisection
willtebbutt ee78772
using Tapir
willtebbutt 65092b9
Run on 1.6 only and add Tapir to AD tests
willtebbutt 02f6fff
Disable more tests
willtebbutt f421606
Update ext/BijectorsTapirExt.jl
willtebbutt a1773cb
Fix formatting
willtebbutt 1743ae0
Restrict version
willtebbutt 13b7fb5
Remove Tapir from Project
willtebbutt fe722a9
Do not run Tapir CI on 1.6
willtebbutt f1a069d
Enable 1.6 tests in general
willtebbutt f5888fa
Enable 1.6 on interface tests
willtebbutt c94e1ef
Tweak versioning
willtebbutt ece43a1
Cancel when multiple things are pushed
willtebbutt e386fd2
Add Tapir to extras
willtebbutt 23fc7f6
Comment out tapir usage
willtebbutt 787da57
Try allowing more versions of Tapir
willtebbutt 4383e9f
Allow more versions of Tapir
willtebbutt b5d3725
More tweaks
willtebbutt a0994e3
Add Pkg to test deps
willtebbutt 79dc479
Refine CI
willtebbutt 9c20891
Use Tapir on 1.10
willtebbutt 6a2803a
Remove CI modifications
willtebbutt 7737223
Formatting
willtebbutt 422aaa7
add comment to Tapir installation
willtebbutt 45564b1
Support a range of types
willtebbutt ef64b62
Merge in main
willtebbutt d367fe1
Fix Project.toml
willtebbutt b71fc75
Fix formatting
willtebbutt 1360676
Fix formatting
willtebbutt ccde34e
Fix formatting
willtebbutt a72af0a
Apply suggestions from code review
willtebbutt 883b924
Sort out formatting
willtebbutt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| module BijectorsTapirExt | ||
|
|
||
| if isdefined(Base, :get_extension) | ||
| using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule | ||
| using Bijectors: find_alpha | ||
| using ChainRulesCore: rrule | ||
| else | ||
| using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule | ||
| using ..Bijectors: find_alpha, rrule | ||
| using ..ChainRulesCore: rrule | ||
willtebbutt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| end | ||
|
|
||
| for P in [Float16, Float32, Float64] | ||
| @from_rrule(MinimalCtx, Tuple{typeof(find_alpha),P,P,P}) | ||
| end | ||
|
|
||
| # The final argument could be an Integer of some kind. This should be fine provided that | ||
| # it has tangent type equal to `NoTangent`, which means that it's non-differentiable and | ||
| # can be safely dropped. We verify that the concrete type of the Integer satisfies this | ||
| # constraint, and error if (for some reason) it does not. This should be fine unless a very | ||
| # unusual Integer type is encountered. | ||
| @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) | ||
|
|
||
| function Tapir.rrule!!( | ||
| ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} | ||
| ) where {P<:Base.IEEEFloat,I<:Integer} | ||
| # Require that the integer is non-differentiable. | ||
| if tangent_type(I) != Tapir.NoTangent | ||
| msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." | ||
| throw(ArgumentError(msg)) | ||
| end | ||
| out, pb = rrule(find_alpha, primal(x), primal(y), primal(z)) | ||
willtebbutt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| function find_alpha_pb(dout::P) | ||
| _, dx, dy, _ = pb(dout) | ||
| return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData() | ||
| end | ||
| return Tapir.zero_fcodual(out), find_alpha_pb | ||
| end | ||
|
|
||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.