Skip to content

Commit a307f82

Browse files
authored
Merge pull request #485 from TuringLang/flux-support
Flux support (#457)
2 parents 8dc84e1 + a73904b commit a307f82

File tree

11 files changed

+41
-67
lines changed

11 files changed

+41
-67
lines changed

REQUIRE

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Markdown
55
Distributions 0.11.0
66
ForwardDiff
77
MCMCChain 0.1.0
8+
Flux
9+
Stan
810

911
ProgressMeter
1012

src/Turing.jl

+4-8
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,21 @@ using LinearAlgebra
1616
using ProgressMeter
1717
using Markdown
1818

19-
@init @require Stan="682df890-35be-576f-97d0-3d8c8b33a550" begin
19+
# @init @require Stan="682df890-35be-576f-97d0-3d8c8b33a550" begin
2020
using Stan
2121
import Stan: Adapt, Hmc
22-
end
23-
@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
24-
using ReverseDiff: GradientTape, GradientConfig, gradient!, compile, TrackedArray
25-
import ReverseDiff: gradient
26-
end
27-
22+
# end
2823
import Base: ~, convert, promote_rule, rand, getindex, setindex!
2924
import Distributions: sample
3025
import ForwardDiff: gradient
26+
using Flux: Tracker
3127
import MCMCChain: AbstractChains, Chains
3228

3329
##############################
3430
# Global variables/constants #
3531
##############################
3632

37-
global ADBACKEND = :forward_diff
33+
global ADBACKEND = :reverse_diff
3834
setadbackend(backend_sym) = begin
3935
@assert backend_sym == :forward_diff || backend_sym == :reverse_diff
4036
global ADBACKEND = backend_sym

src/core/ad.jl

+17-35
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ gradient(vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
7979
spl.info[:grad_cache][θ_hash] = grad
8080
end
8181

82+
vi.logp = realpart(vi.logp)
83+
8284
grad
8385
end
8486

@@ -93,8 +95,7 @@ verifygrad(grad::Vector{Float64}) = begin
9395
end
9496

9597
# Direct call of ForwardDiff.gradient; this is slow
96-
97-
gradient2(_vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
98+
gradient_slow(_vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
9899

99100
vi = deepcopy(_vi)
100101

@@ -108,38 +109,19 @@ gradient2(_vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
108109
g(vi[spl])
109110
end
110111

111-
@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
112-
113-
gradient_r(theta::Vector{Float64}, vi::VarInfo, model::Function) = gradient_r(theta, vi, model, nothing)
114-
gradient_r(theta::Vector{Float64}, vi::Turing.VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
115-
inputs = (theta)
116-
117-
if Turing.ADSAFE || (spl == nothing || length(spl.info[:reverse_diff_cache]) == 0)
118-
f_r(ipts) = begin
119-
vi[spl][:] = ipts[:]
120-
-runmodel(model, vi, spl).logp
121-
end
122-
gtape = GradientTape(f_r, inputs)
123-
ctape = compile(gtape)
124-
res = (similar(theta))
125-
126-
if spl != nothing
127-
spl.info[:reverse_diff_cache][:ctape] = ctape
128-
spl.info[:reverse_diff_cache][:res] = res
129-
end
130-
else
131-
ctape = spl.info[:reverse_diff_cache][:ctape]
132-
res = spl.info[:reverse_diff_cache][:res]
133-
end
134-
135-
grad = ReverseDiff.gradient!(res, ctape, inputs)
136-
137-
# grad = ReverseDiff.gradient(x -> (vi[spl] = x; -runmodel(model, vi, spl).logp), inputs)
138-
139-
# vi[spl] = realpart(vi[spl])
140-
# vi.logp = 0
141-
142-
grad
112+
gradient_r(theta::AbstractVector{<:Real}, vi::VarInfo, model::Function) =
113+
gradient_r(theta, vi, model, nothing)
114+
gradient_r(theta::AbstractVector{<:Real}, vi::Turing.VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
115+
# Use Flux.Tracker to get gradient
116+
grad = Tracker.gradient(x -> (vi[spl] = x; -runmodel(model, vi, spl).logp), theta)
117+
# Clean tracked numbers
118+
# Numbers do not need to be tracked between two gradient calls
119+
vi.logp = vi.logp.data
120+
vi_spl = vi[spl]
121+
for i = 1:length(theta)
122+
vi_spl[i] = vi_spl[i].data
123+
end
124+
# Return non-tracked graident value
125+
return first(grad).data
143126
end
144127

145-
end

src/core/util.jl

-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44

55
@inline invlogit(x::Union{T,Vector{T},Matrix{T}}) where T<:Real = one(T) ./ (one(T) .+ exp.(-x))
66
@inline logit(x::Union{T,Vector{T},Matrix{T}}) where T<:Real = log.(x ./ (one(T) - x))
7-
@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
8-
@inline invlogit(x::TrackedArray) = one(Real) ./ (one(Real) + exp.(-x))
9-
@inline logit(x::TrackedArray) = log.(x ./ (one(Real) - x))
10-
end
117

128
# More stable, faster version of rand(Categorical)
139
function randcat(p::Vector{Float64})

src/helper.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
@inline realpart(ds::Matrix{Any}) = [realpart(col) for col in ds]
1111
@inline realpart(ds::Array) = map(d -> realpart(d), ds) # NOTE: this function is not optimized
1212
# @inline realpart(ds::TArray) = realpart(Array(ds)) # TODO: is it disabled temporary
13-
@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
14-
@inline realpart(ta::ReverseDiff.TrackedReal) = ta.value
15-
end
13+
@inline realpart(ta::Tracker.TrackedReal) = ta.data
1614

1715
@inline dualpart(d::ForwardDiff.Dual) = d.partials.values
1816
@inline dualpart(ds::Union{Array,SubArray}) = map(d -> dualpart(d), ds)
1917

2018
# Base.promote_rule(D1::Type{Real}, D2::Type{ForwardDiff.Dual}) = D2
19+
import Base: <=
20+
<=(a::Tracker.TrackedReal, b::Tracker.TrackedReal) = a.data <= b.data
2121

2222
#####################################################
2323
# Helper functions for vectorize/reconstruct values #

src/samplers/hmc.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ end
4141
# Please see https://github.com/TuringLang/Turing.jl/pull/459 for explanations
4242
DEFAULT_ADAPT_CONF_TYPE = Nothing
4343
STAN_DEFAULT_ADAPT_CONF = nothing
44-
@init @require Stan="682df890-35be-576f-97d0-3d8c8b33a550" begin
44+
# @init @require Stan="682df890-35be-576f-97d0-3d8c8b33a550" begin
4545
DEFAULT_ADAPT_CONF_TYPE = Union{DEFAULT_ADAPT_CONF_TYPE,Stan.Adapt}
4646
STAN_DEFAULT_ADAPT_CONF = Stan.Adapt()
47-
end
47+
# end
4848

4949
# NOTE: the implementation of HMC is removed,
5050
# it now reuses the one of HMCDA
@@ -181,6 +181,9 @@ assume(spl::Sampler{T}, dist::Distribution, vn::VarName, vi::VarInfo) where T<:H
181181
r = vi[vn]
182182
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
183183
# r
184+
@debug "dist = $dist"
185+
@debug "vn = $vn"
186+
@debug "r = $r" "typeof(r)=$(typeof(r))"
184187
r, logpdf_with_trans(dist, r, istrans(vi, vn))
185188
end
186189

src/samplers/hmcda.jl

+1-7
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
112112
push!(spl.info[:accept_his], false)
113113

114114
# Reset Θ
115-
# NOTE: ForwardDiff and ReverseDiff need different implementation
116-
# due to struct Dual vs mutable TrackedReal
117115
if ADBACKEND == :forward_diff
118116

119117
vi[spl] = θ
@@ -122,11 +120,7 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
122120

123121
vi_spl = vi[spl]
124122
for i = 1:length(θ)
125-
if isa(vi_spl[i], ReverseDiff.TrackedReal)
126-
vi_spl[i].value = θ[i]
127-
else
128-
vi_spl[i] = θ[i]
129-
end
123+
vi_spl[i] = θ[i]
130124
end
131125

132126
end

src/samplers/sampler.jl

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ assume(spl::Nothing, dist::Distribution, vn::VarName, vi::VarInfo) = begin
4444
# NOTE: The importance weight is not correctly computed here because
4545
# r is genereated from some uniform distribution which is different from the prior
4646
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
47+
4748
r, logpdf_with_trans(dist, r, istrans(vi, vn))
4849
end
4950

@@ -90,6 +91,9 @@ end
9091

9192
observe(spl::Nothing, dist::Distribution, value::Any, vi::VarInfo) = begin
9293
vi.num_produce += 1
94+
@debug "dist = $dist"
95+
@debug "value = $value"
96+
9397
# acclogp!(vi, logpdf(dist, value))
9498
logpdf(dist, value)
9599
end

src/samplers/support/hmc_core.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,7 @@ function gen_rev_func(vi, spl)
5757
elseif ADBACKEND == :reverse_diff
5858
vi_spl = vi[spl]
5959
for i = 1:length(θ_old)
60-
if isa(vi_spl[i], ReverseDiff.TrackedReal)
61-
vi_spl[i].value = θ_old[i]
62-
else
63-
vi_spl[i] = θ_old[i]
64-
end
60+
vi_spl[i] = θ_old[i]
6561
end
6662
end
6763
setlogp!(vi, old_logp)

src/transform.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ invlink(d::PDMatDistribution, Z::Vector{Matrix{T}}) where {T<:Real} = begin
348348
Z
349349
end
350350

351-
logpdf_with_trans(d::PDMatDistribution, x::Array{T,2}, transform::Bool) where {T<:Real} = begin
351+
logpdf_with_trans(d::PDMatDistribution, x::Array{T0,2}, transform::Bool) where {T0<:Union{T,Tracker.TrackedReal{T}}} where {T<:Real} = begin
352352
lp = logpdf(d, x)
353353
if transform && isfinite(lp)
354354
U = cholesky(x).U
@@ -361,11 +361,11 @@ logpdf_with_trans(d::PDMatDistribution, x::Array{T,2}, transform::Bool) where {T
361361
lp
362362
end
363363

364-
logpdf_with_trans(d::PDMatDistribution, X::Vector{Matrix{T}}, transform::Bool) where {T<:Real} = begin
364+
logpdf_with_trans(d::PDMatDistribution, X::Vector{Matrix{T0}}, transform::Bool) where {T0<:Union{T,Tracker.TrackedReal{T}}} where {T<:Real} = begin
365365
lp = logpdf(d, X)
366366
if transform && all(isfinite.(lp))
367367
n = length(X)
368-
U = Vector{Matrix{T}}(undef, n)
368+
U = Vector{Matrix{T0}}(undef, n)
369369
for i = 1:n
370370
U[i] = cholesky(X[i]).U'
371371
end

test/hmcda.jl/hmcda.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Turing
22
using Test
33
using Random
4+
using Distributions
45

56
Random.seed!(128)
67

0 commit comments

Comments
 (0)