Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7ff2fa2
Bump patch version and add Tapir ext
willtebbutt Jun 26, 2024
47d8d31
Make Tapir available at test time
willtebbutt Jun 26, 2024
5069e82
Add Tapir runs to AD testing
willtebbutt Jun 26, 2024
de15572
Add single rule to Tapir to handle bisection
willtebbutt Jun 26, 2024
ee78772
using Tapir
willtebbutt Jun 26, 2024
65092b9
Run on 1.6 only and add Tapir to AD tests
willtebbutt Jun 26, 2024
02f6fff
Disable more tests
willtebbutt Jun 26, 2024
f421606
Update ext/BijectorsTapirExt.jl
willtebbutt Jun 26, 2024
a1773cb
Fix formatting
willtebbutt Jun 26, 2024
1743ae0
Restrict version
willtebbutt Jun 27, 2024
13b7fb5
Remove Tapir from Project
willtebbutt Jun 27, 2024
fe722a9
Do not run Tapir CI on 1.6
willtebbutt Jun 27, 2024
f1a069d
Enable 1.6 tests in general
willtebbutt Jun 27, 2024
f5888fa
Enable 1.6 on interface tests
willtebbutt Jun 27, 2024
c94e1ef
Tweak versioning
willtebbutt Jun 27, 2024
ece43a1
Cancel when multiple things are pushed
willtebbutt Jun 27, 2024
e386fd2
Add Tapir to extras
willtebbutt Jun 27, 2024
23fc7f6
Comment out tapir usage
willtebbutt Jun 27, 2024
787da57
Try allowing more versions of Tapir
willtebbutt Jun 27, 2024
4383e9f
Allow more versions of Tapir
willtebbutt Jun 27, 2024
b5d3725
More tweaks
willtebbutt Jun 27, 2024
a0994e3
Add Pkg to test deps
willtebbutt Jun 27, 2024
79dc479
Refine CI
willtebbutt Jun 27, 2024
9c20891
Use Tapir on 1.10
willtebbutt Jun 27, 2024
6a2803a
Remove CI modifications
willtebbutt Jun 27, 2024
7737223
Formatting
willtebbutt Jun 27, 2024
422aaa7
add comment to Tapir installation
willtebbutt Jun 27, 2024
45564b1
Support a range of types
willtebbutt Jul 2, 2024
ef64b62
Merge in main
willtebbutt Jul 2, 2024
d367fe1
Fix Project.toml
willtebbutt Jul 2, 2024
b71fc75
Fix formatting
willtebbutt Jul 2, 2024
1360676
Fix formatting
willtebbutt Jul 2, 2024
ccde34e
Fix formatting
willtebbutt Jul 2, 2024
a72af0a
Apply suggestions from code review
willtebbutt Jul 3, 2024
883b924
Sort out formatting
willtebbutt Jul 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ jobs:
- x64
AD:
- ForwardDiff
- Tapir
- Tracker
- ReverseDiff
- Zygote
exclude:
- version: 1.6
AD: Tapir
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -37,6 +38,7 @@ BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsTrackerExt = "Tracker"
BijectorsTapirExt = "Tapir"
BijectorsZygoteExt = "Zygote"

[compat]
Expand All @@ -60,6 +62,7 @@ Requires = "0.5, 1"
ReverseDiff = "1"
Roots = "1.3.4, 2"
Statistics = "1"
Tapir = "0.2.23"
Tracker = "0.2"
Zygote = "0.6.63"
julia = "1.6"
Expand All @@ -69,5 +72,6 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
40 changes: 40 additions & 0 deletions ext/BijectorsTapirExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module BijectorsTapirExt

if isdefined(Base, :get_extension)
using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule
using Bijectors: find_alpha
using ChainRulesCore: rrule
else
using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule
using ..Bijectors: find_alpha, rrule
using ..ChainRulesCore: rrule
end

for P in [Float16, Float32, Float64]
@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),P,P,P})
end

# The final argument could be an Integer of some kind. This should be fine provided that
# it has tangent type equal to `NoTangent`, which means that it's non-differentiable and
# can be safely dropped. We verify that the concrete type of the Integer satisfies this
# constraint, and error if (for some reason) it does not. This should be fine unless a very
# unusual Integer type is encountered.
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})

function Tapir.rrule!!(
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I}
) where {P<:Base.IEEEFloat,I<:Integer}
# Require that the integer is non-differentiable.
if tangent_type(I) != Tapir.NoTangent
msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent."
throw(ArgumentError(msg))
end
out, pb = rrule(find_alpha, primal(x), primal(y), primal(z))
function find_alpha_pb(dout::P)
_, dx, dy, _ = pb(dout)
return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData()
end
return Tapir.zero_fcodual(out), find_alpha_pb
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
15 changes: 15 additions & 0 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ end
test_frule(Bijectors.find_alpha, x, y, z)
test_rrule(Bijectors.find_alpha, x, y, z)

if @isdefined Tapir
Tapir.TestUtils.test_rrule!!(
Xoshiro(123), Bijectors.find_alpha, x, y, z; is_primitive=true, perf_flag=:none
)
Tapir.TestUtils.test_rrule!!(
Xoshiro(123), Bijectors.find_alpha, x, y, 3; is_primitive=true, perf_flag=:none
)
#! format: off
Tapir.TestUtils.test_rrule!!(
Xoshiro(123), Bijectors.find_alpha, x, y, UInt32(3);
is_primitive=true, perf_flag=:none,
)
#! format: on
end

test_rrule(
Bijectors.combine,
Bijectors.PartitionMask(3, [1], [2]) ⊢ ChainRulesTestUtils.NoTangent(),
Expand Down
23 changes: 23 additions & 0 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,28 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
end
end

if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10"
rule = Tapir.build_rrule(f, x; safety_on=false)
if :tapir in broken
@test_broken(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
else
@test(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
end
end

return nothing
end
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ if VERSION < v"1.9"
using Compat: stack
end

# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing
# on at least version 1.10.
if VERSION >= v"1.10"
using Pkg
Pkg.add("Tapir")
using Tapir
end

const GROUP = get(ENV, "GROUP", "All")

# Always include this since it can be useful for other tests.
Expand Down