Skip to content

Commit

Permalink
[WIP] Use ReverseDiff/ForwardDiff in KL divergence code (#445)
Browse files Browse the repository at this point in the history
* refactor elbo_kl objective functions for AD

* small touch-up to add_sources_sf

* improve entry point to non-allocating KL divergence code

* small diff code cleanup

* version bump ForwardDiff and ReverseDiff in REQUIRE

* refactor rebase against master

* get tests passing

* add derivative test against manually generated derivatives for KL divergence

* bump REQUIRE versions now that Optim supports ForwardDiff 0.3

* refactor for thread safety

* fix namespace qualification
  • Loading branch information
jrevels authored and jeff-regier committed Dec 10, 2016
1 parent 927fa6a commit 42d869e
Show file tree
Hide file tree
Showing 13 changed files with 251 additions and 476 deletions.
6 changes: 4 additions & 2 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ julia 0.5-
DataFrames
Distributions
FITSIO 0.8.4
ForwardDiff 0.2
DiffBase 0.0.2
ForwardDiff 0.3.3 0.4
ReverseDiff 0.0.1 0.1
JLD
Optim 0.6
Optim 0.7
WCS 0.1.3
StaticArrays
9 changes: 6 additions & 3 deletions benchmark/speed/benchmark_elbo.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env julia

import Celeste.DeterministicVI: elbo, ElboArgs
using DiffBase

include(string(Pkg.dir("Celeste"), "/test/Synthetic.jl"))
include(string(Pkg.dir("Celeste"), "/test/SampleData.jl"))
Expand All @@ -17,19 +18,21 @@ function main()
calculate_gradient=true,
calculate_hessian=false)

param_length = length(Celeste.Model.CanonicalParams)
kl_source = Celeste.SensitiveFloats.SensitiveFloat{Float64}(param_length, 1, true, false)

println("Warm-up / compiling.")
# do a trial run first, so we don't profile/time compling the code
elbo(ea)
elbo(ea, kl_source)
Profile.clear_malloc_data()

println("Calculating ELBO and gradient.")
if isempty(ARGS)
# let's time it without any overhead from profiling
@time elbo(ea)
@time elbo(ea, kl_source)
elseif ARGS[1] == "--profile"
Profile.init(delay=0.001)
@profile elbo(ea)
@profile elbo(ea, kl_source)
Profile.print(format=:flat, sortedby=:count)
end
end
Expand Down
3 changes: 3 additions & 0 deletions src/DeterministicVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Calculate value, gradient, and hessian of the variational ELBO.
"""
module DeterministicVI

using Base.Threads: threadid, nthreads
using ..Model
import ..Model: BivariateNormalDerivatives, BvnComponent, GalaxyCacheComponent,
GalaxySigmaDerivs, SkyPatch,
Expand All @@ -16,6 +17,8 @@ using ..Transform
import DataFrames
import Optim
using StaticArrays
using ForwardDiff
using ReverseDiff

export ElboArgs

Expand Down
6 changes: 3 additions & 3 deletions src/DeterministicVIImagePSF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ image rather than a mixture of Gaussians.

module DeterministicVIImagePSF

using StaticArrays
using StaticArrays, DiffBase

import ..DeterministicVI:
ElboArgs, ElboIntermediateVariables,
StarPosParams, GalaxyPosParams, CanonicalParams, VariationalParams,
SourceBrightness, GalaxyComponent, SkyPatch,
load_source_brightnesses, add_elbo_log_term!,
accumulate_source_pixel_brightness!, subtract_kl!
accumulate_source_pixel_brightness!, subtract_kl_all_sources!, KL_HELPER_POOL

import ..Model:
populate_gal_fsm!, getids, ParamSet, linear_world_to_pix, lidx,
Expand All @@ -23,7 +23,7 @@ import ..Model:
import ..SensitiveFloats:
SensitiveFloat, zero_sensitive_float_array,
multiply_sfs!, add_scaled_sfs!, clear!

import ..Infer: load_active_pixels!

import ..PSF: get_psf_at_point
Expand Down
12 changes: 6 additions & 6 deletions src/SensitiveFloats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,12 @@ function add_sources_sf!{NumType <: Number}(
@assert size(sf_all.d, 1) == size(sf_s.d, 1)

P = sf_all.local_P
P_shifted = P * (s - 1)

if sf_all.has_gradient
@assert size(sf_s.d) == (P, 1)
@inbounds for s_ind1 in 1:sf_all.local_P
s_all_ind1 = P * (s - 1) + s_ind1
@inbounds for s_ind1 in 1:P
s_all_ind1 = P_shifted + s_ind1
sf_all.d[s_all_ind1] = sf_all.d[s_all_ind1] + sf_s.d[s_ind1]
end
end
Expand All @@ -224,12 +225,11 @@ function add_sources_sf!{NumType <: Number}(
@assert Ph >= P * s

@inbounds for s_ind1 in 1:P
s_all_ind1 = P * (s - 1) + s_ind1
s_all_ind1 = P_shifted + s_ind1

@inbounds for s_ind2 in 1:s_ind1
s_all_ind2 = P * (s - 1) + s_ind2
sf_all.h[s_all_ind2, s_all_ind1] =
sf_all.h[s_all_ind2, s_all_ind1] + sf_s.h[s_ind2, s_ind1]
s_all_ind2 = P_shifted + s_ind2
sf_all.h[s_all_ind2, s_all_ind1] += sf_s.h[s_ind2, s_ind1]
# TODO: move outside the loop?
sf_all.h[s_all_ind1, s_all_ind2] = sf_all.h[s_all_ind2, s_all_ind1]
end
Expand Down
Loading

0 comments on commit 42d869e

Please sign in to comment.