-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring AD Tests #2307
Comments
We should probably implement some functionality to benchmark various AD backends during this refactoring. |
Thinking about this and #2338 a little, it's not super easy to dive into the internals of
I don't know how much this will really speed it up, but it could be worth a shot. Thoughts ... let me use the magic ping @TuringLang/maintainers? Alternatively, I guess we could call |
We only need to test sampling works for one AD backend. The other AD backends can be tested via |
I agree with the general sentiment here. In my view we should
It might be that we want to run these test cases in Turing.jl for each backend that we claim supports Turing.jl, in addition to running the tests in Mooncake.jl, but that's something that we could run later on. Either way, I really think we should decouple "is this sampler correct?" from "does this AD work correctly?". edit: to phrase the second point differently: in order to have robust AD support for Turing.jl, it is necessary that
|
Ok, so it seems we want to construct a list of |
I'll start playing around with this in a new package, we can decide whether to put it in DPPL proper at a later point in time. |
Agreed. We should think about the precise format though. For the sake of both running the tests in downstream AD packages, and for running them in the Turing.jl integration tests, might it make sense to produce things which can be consumed by DifferentiationInterface's testing functionality? See https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterfaceTest/dev/ for context. |
Ooh. I like that. But I don't think LogDensityProblemsAD uses DI for all its backends just yet, right? In which case, I think we'd still need to test |
Hmmm maybe you're right? That being said, LogDensityProblemsAD does now support arbitrary using ADTypes, Mooncake, Turing
NUTS(; adtype=ADTypes.AutoMooncake(; config=nothing)) so maybe it's fine? |
Mostly notes for myself because Julia really lends itself to creating a real maze of function calls, but if anyone's curious. Mooncake goes via DI so that's all fine. (NOTE: actually running However, running
I managed to force AutoForwardDiff to go via DI by modifying both of these functions (plus some other stuff), and it does give the right results, but given that this has already been discussed in tpapp/LogDensityProblemsAD.jl#29 I don't think it's my job to upstream it. tldr we can't really shortcircuit LogDensityProblemsAD unless they decide to adopt DI for all backends, so I'll move forward with directly testing |
also the code is here for now https://github.com/penelopeysm/ModelTests.jl |
Currently, there are many situation in which we run an entire sampler in order to test AD, such as many cases in https://github.com/TuringLang/Turing.jl/blob/master/test/mcmc/abstractmcmc.jl .
This is quite compute intensive. @yebai pointed out that perhaps we should just check that the gradient runs and is correct, rather than thousants of mcmc chains.
The text was updated successfully, but these errors were encountered: