-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Use ReverseDiff/ForwardDiff in KL divergence code #445
Conversation
This is great! The code is way more clear than what came before it, and the auto diff makes it easy to add KL divergences for additional distributions in the future. I like the unicode. The run times all look more than good enough. If you can maintain this level of overhead, there's really no draw back to using auto diff in the kl module and the transform module. I'd like to hear more about how your auto diff routine avoids allocations entirely, maybe during a conference call. |
To compute the Jacobian of the gradient with forward diff, do you have to compute the gradient 32 times per source, once per parameter, each time with a different perturbation? If so, why isn't there more overhead from computing the Jacobian, I.e >32x rather than <10x ? |
I'm guessing that's the magic of partials. |
Yup, what @mlubin said. ForwardDiff does some special tricks to partition the input dimension and compute a "chunk" of partial derivatives per function evaluation (in this case, per gradient evaluation). Here, the chunk size is 8 (which divides evenly into the input dimension of 32).
It's the magic of ReverseDiff's pre-recorded tapes - it often "converts" allocating Julia code into a non-allocating tape 😊 |
Awesome!!! |
Note that tests pass locally for me, but Travis won't be able to resolve dependencies correctly until JuliaNLSolvers/Optim.jl#313 is merged/tagged. |
d57a8d7
to
83900b5
Compare
cc19876
to
4259c6c
Compare
There's no more work to do here, AFAICT, besides getting things to pass on Travis on Julia v0.6 (it passes on v0.5). The most recent test failure on v0.6 isn't even related to Celeste - a bunch of JuliaStats packages are simply failing to precompile, likely due to JuliaLang/julia#15850. I've restarted the most recent build to see if any fixes have percolated into METADATA by now. EDIT: ref JuliaLang/METADATA.jl#7239 |
|
This code looks great! The 0.6 tests are failing on master too currently so we as well just merge this already. |
This is basically done, but I have to fix whatever tests were broken and resolve any thread-safety issues that might exist with the
const
s this PR adds. I haven't found any thread issues yet, but I'm wary...luckily, anything that comes up should be easy to fix by making the state thread-local. It also relies on ReverseDiff being registered and a new version of ForwardDiff getting tagged (this PR needs JuliaDiff/ForwardDiff.jl#166). So the current TODO list is:This PR refactors the KL divergence code into neat objective functions, then uses ReverseDiff to take the gradient and ForwardDiff to take the jacobian of the ReverseDiff gradient (i.e. the hessian is calculated via mixed-mode AD). I used Unicode to make the notation really nice, but I can switch back to pure ASCII if we want to avoid Unicode.
The KL divergence code is now non-allocating where possible, and gradients and hessians are totally non-allocating. I kept the old source here for the sake of convenient benchmarking, but I'll remove it once this is ready to merge. In the meantime, here's my benchmark setup for a single source (EDIT: I've removed the old code from the source tree, so this will no longer entirely work):
Here's the timings for the old KL divergence code:
...and here's the new code:
As you can see, there is a ~370% speedup for computing the objective + gradient, while the objective + gradient + hessian computation is slower by about ~40%. Hopefully the improved memory efficiency makes up for that.
With
benchmark_infer.jl
andbenchmark_elbo.jl
, I'm seeing slight memory/time improvements (my guess is thatelbo_likelihood
dwarfs the KL divergence calculations performance-wise).