Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use ReverseDiff/ForwardDiff in KL divergence code #445

Merged
merged 11 commits into from
Dec 10, 2016
Merged

Conversation

jrevels
Copy link
Collaborator

@jrevels jrevels commented Dec 4, 2016

This is basically done, but I have to fix whatever tests were broken and resolve any thread-safety issues that might exist with the consts this PR adds. I haven't found any thread issues yet, but I'm wary...luckily, anything that comes up should be easy to fix by making the state thread-local. It also relies on ReverseDiff being registered and a new version of ForwardDiff getting tagged (this PR needs JuliaDiff/ForwardDiff.jl#166). So the current TODO list is:

This PR refactors the KL divergence code into neat objective functions, then uses ReverseDiff to take the gradient and ForwardDiff to take the jacobian of the ReverseDiff gradient (i.e. the hessian is calculated via mixed-mode AD). I used Unicode to make the notation really nice, but I can switch back to pure ASCII if we want to avoid Unicode.

The KL divergence code is now non-allocating where possible, and gradients and hessians are totally non-allocating. I kept the old source here for the sake of convenient benchmarking, but I'll remove it once this is ready to merge. In the meantime, here's my benchmark setup for a single source (EDIT: I've removed the old code from the source tree, so this will no longer entirely work):

begin
    using Celeste, BenchmarkTools
    sfg = Celeste.SensitiveFloats.SensitiveFloat{Float64}(32, 1, true, false)
    sfh = Celeste.SensitiveFloats.SensitiveFloat{Float64}(32, 1, true, true)
    vs = rand(32)
    kl_grad = DiffBase.GradientResult(vs)
    kl_hess = zeros(32, 32)
    new_f! = Celeste.DeterministicVI.subtract_kl_source!
    old_f! = Celeste.DeterministicVI.subtract_kl_source_old!
end

Here's the timings for the old KL divergence code:

# gradient + objective
julia> @benchmark old_f!($sfg, $vs)
BenchmarkTools.Trial:
  memory estimate:  39.50 kb
  allocs estimate:  336
  --------------
  minimum time:     83.125 μs (0.00% GC)
  median time:      85.378 μs (0.00% GC)
  mean time:        90.929 μs (5.52% GC)
  maximum time:     3.187 ms (93.74% GC)

# hessian + gradient + objective
julia> @benchmark old_f!($sfh, $vs)
BenchmarkTools.Trial:
  memory estimate:  161.05 kb
  allocs estimate:  354
  --------------
  minimum time:     129.111 μs (0.00% GC)
  median time:      138.667 μs (0.00% GC)
  mean time:        166.267 μs (13.80% GC)
  maximum time:     3.649 ms (92.72% GC)

...and here's the new code:

# gradient + objective
julia> @benchmark new_f!($sfg, $vs, $(kl_grad), $(kl_hess))
BenchmarkTools.Trial:
  memory estimate:  0.00 bytes
  allocs estimate:  0
  --------------
  minimum time:     22.881 μs (0.00% GC)
  median time:      26.180 μs (0.00% GC)
  mean time:        31.430 μs (0.00% GC)
  maximum time:     451.272 μs (0.00% GC)

# hessian + gradient + objective
julia> @benchmark new_f!($sfh, $vs, $(kl_grad), $(kl_hess))
BenchmarkTools.Trial:
  memory estimate:  0.00 bytes
  allocs estimate:  0
  --------------
  minimum time:     184.854 μs (0.00% GC)
  median time:      187.164 μs (0.00% GC)
  mean time:        187.971 μs (0.00% GC)
  maximum time:     373.564 μs (0.00% GC)

As you can see, there is a ~370% speedup for computing the objective + gradient, while the objective + gradient + hessian computation is slower by about ~40%. Hopefully the improved memory efficiency makes up for that.

With benchmark_infer.jl and benchmark_elbo.jl, I'm seeing slight memory/time improvements (my guess is that elbo_likelihood dwarfs the KL divergence calculations performance-wise).

@jeff-regier
Copy link
Owner

This is great! The code is way more clear than what came before it, and the auto diff makes it easy to add KL divergences for additional distributions in the future. I like the unicode. The run times all look more than good enough. If you can maintain this level of overhead, there's really no draw back to using auto diff in the kl module and the transform module. I'd like to hear more about how your auto diff routine avoids allocations entirely, maybe during a conference call.

@jeff-regier
Copy link
Owner

To compute the Jacobian of the gradient with forward diff, do you have to compute the gradient 32 times per source, once per parameter, each time with a different perturbation? If so, why isn't there more overhead from computing the Jacobian, I.e >32x rather than <10x ?

@mlubin
Copy link

mlubin commented Dec 5, 2016

If so, why isn't there more overhead from computing the Jacobian, I.e >32x rather than <10x ?

I'm guessing that's the magic of partials.

@jrevels
Copy link
Collaborator Author

jrevels commented Dec 5, 2016

If so, why isn't there more overhead from computing the Jacobian, I.e >32x rather than <10x ?

Yup, what @mlubin said. ForwardDiff does some special tricks to partition the input dimension and compute a "chunk" of partial derivatives per function evaluation (in this case, per gradient evaluation). Here, the chunk size is 8 (which divides evenly into the input dimension of 32).

I'd like to hear more about how your auto diff routine avoids allocations entirely, maybe during a conference call.

It's the magic of ReverseDiff's pre-recorded tapes - it often "converts" allocating Julia code into a non-allocating tape 😊

@jeff-regier
Copy link
Owner

Awesome!!!

@jrevels
Copy link
Collaborator Author

jrevels commented Dec 6, 2016

Note that tests pass locally for me, but Travis won't be able to resolve dependencies correctly until JuliaNLSolvers/Optim.jl#313 is merged/tagged.

@jrevels jrevels force-pushed the jr/adkl branch 2 times, most recently from d57a8d7 to 83900b5 Compare December 8, 2016 22:20
@jrevels jrevels force-pushed the jr/adkl branch 3 times, most recently from cc19876 to 4259c6c Compare December 8, 2016 23:36
@jrevels
Copy link
Collaborator Author

jrevels commented Dec 9, 2016

There's no more work to do here, AFAICT, besides getting things to pass on Travis on Julia v0.6 (it passes on v0.5). The most recent test failure on v0.6 isn't even related to Celeste - a bunch of JuliaStats packages are simply failing to precompile, likely due to JuliaLang/julia#15850.

I've restarted the most recent build to see if any fixes have percolated into METADATA by now. EDIT: ref JuliaLang/METADATA.jl#7239

@andreasnoack
Copy link
Collaborator

Distributions.jl is not fixed yet. Will probably happen today.

@jeff-regier
Copy link
Owner

This code looks great! The 0.6 tests are failing on master too currently so we as well just merge this already.

@jrevels jrevels deleted the jr/adkl branch December 10, 2016 14:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants