-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove LogDensityProblemsAD #2490
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## mhauru/dppl-0.35 #2490 +/- ##
====================================================
+ Coverage 26.93% 29.77% +2.83%
====================================================
Files 21 21
Lines 1407 1387 -20
====================================================
+ Hits 379 413 +34
+ Misses 1028 974 -54 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the entire test/ad.jl file, my reasoning is explained below
@test l ≈ logp | ||
@test sort(∇E) ≈ grad_FWAD atol = 1e-9 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this checks that the gradient obtained with LogDensityProblemsAD is equal to the gradient obtained directly with ForwardDiff - this isn't Turing's job to test hence removed
@test zygoteℓ.ℓ === ℓ | ||
∇E2 = LogDensityProblems.logdensity_and_gradient(zygoteℓ, x)[2] | ||
@test sort(∇E2) ≈ grad_FWAD atol = 1e-9 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is checking that Zygote gives the same gradient as ForwardDiff does.
Firstly Zygote isn't formally supported so I don't think we need to test its correctness on this one particular model, and secondly if we did want to support it we should add this to the DPPL test suite
@test sort(∇E2) ≈ grad_FWAD atol = 1e-9 | ||
end | ||
|
||
@testset "general AD tests" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this whole testset is just testing AD correctness on various models, and should all be moved into DynamicPPL, if it's not already there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's there already, in test/compat/ad.jl
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay what that's so weird. Doubly happy to delete it then
|
||
test_model_ad(wishart_ad(), logp3, [:v]) | ||
end | ||
@testset "Simplex Zygote and ReverseDiff (with and without caching) AD" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same with this one, it's not even testing correctness, it's just testing that AD can run
sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000) | ||
sample(invwishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000) | ||
end | ||
@testset "Hessian test" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same with this one, it's testing correctness of Hessians against an analytic value, it's not really our job imo
@test H_f == H_r | ||
end | ||
|
||
@testset "memoization: issue #1393" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think this testset is relevant anymore because the memoisation code has been long removed #1414
end | ||
end | ||
|
||
@testset "ReverseDiff compiled without linking" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this testset is fully contained in DynamicPPL already
70bb546
to
56997c1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. One question mostly out of curiosity.
@@ -202,7 +202,7 @@ end | |||
All the ADTypes on which we want to run the tests. | |||
""" | |||
adbackends = [ | |||
Turing.AutoForwardDiff(; chunksize=0), | |||
Turing.AutoForwardDiff(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the significance of not setting chunksize
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if chunksize isn't set it will default to nothing
, and both 0 and nothing will get 'optimised' into the same thing by this function in DynamicPPLForwardDiffExt https://github.com/TuringLang/DynamicPPL.jl/blob/90c7b26c852b7d0ae87dee0a5a0010b097a0c1d3/ext/DynamicPPLForwardDiffExt.jl#L10-L40
but also technically chunksize=0 is 'wrong' and not supported in DifferentiationInterface, the fact that it works in Turing is solely because of that extension, which I'd honestly like to get rid of but have mainly kept around for compatibility - see TuringLang/DynamicPPL.jl#806 (comment)
if the extension isn't present, then chunksize=0 actually errors, and I believe it was the source of this error #2369 (it's because ForwardDiff tries to count the number of chunks by dividing by the chunksize). I suspect that that error appeared because Optimisation.jl started using DifferentiationInterface under the hood
this PR removes LogDensityProblemsAD in favour of using the new interface in TuringLang/DynamicPPL.jl#806