Skip to content

Commit 29364cd

Browse files
committed
Replace ReverseDiff with Flux done (#457)
1 parent db1ac66 commit 29364cd

File tree

6 files changed

+4
-34
lines changed

6 files changed

+4
-34
lines changed

src/Turing.jl

-6
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,9 @@ using Markdown
2020
using Stan
2121
import Stan: Adapt, Hmc
2222
end
23-
@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
24-
using ReverseDiff: GradientTape, GradientConfig, gradient!, compile, TrackedArray
25-
import ReverseDiff: gradient
26-
end
27-
2823
import Base: ~, convert, promote_rule, rand, getindex, setindex!
2924
import Distributions: sample
3025
import ForwardDiff: gradient
31-
import Flux: gradient
3226
using Flux: Tracker
3327
import MCMCChain: AbstractChains, Chains
3428

src/core/ad.jl

+1-9
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ gradient2(_vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
108108
g(vi[spl])
109109
end
110110

111-
# @init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
112-
113111
gradient_r(theta::Vector{Float64}, vi::VarInfo, model::Function) = gradient_r(theta, vi, model, nothing)
114112
gradient_r(theta::Vector{Float64}, vi::Turing.VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
115113
f_r(ipts) = begin
@@ -119,12 +117,6 @@ gradient_r(theta::Vector{Float64}, vi::Turing.VarInfo, model::Function, spl::Uni
119117

120118
grad = Tracker.gradient(f_r, theta)
121119

122-
# grad = ReverseDiff.gradient(x -> (vi[spl] = x; -runmodel(model, vi, spl).logp), inputs)
123-
124-
# vi[spl] = realpart(vi[spl])
125-
# vi.logp = 0
126-
127-
map(x -> isa(x, Tracker.TrackedReal) ? x.data : x, grad)
120+
first(grad).data
128121
end
129122

130-
# end

src/core/util.jl

-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44

55
@inline invlogit(x::Union{T,Vector{T},Matrix{T}}) where T<:Real = one(T) ./ (one(T) .+ exp.(-x))
66
@inline logit(x::Union{T,Vector{T},Matrix{T}}) where T<:Real = log.(x ./ (one(T) - x))
7-
@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
8-
@inline invlogit(x::TrackedArray) = one(Real) ./ (one(Real) + exp.(-x))
9-
@inline logit(x::TrackedArray) = log.(x ./ (one(Real) - x))
10-
end
117

128
# More stable, faster version of rand(Categorical)
139
function randcat(p::Vector{Float64})

src/helper.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
@inline realpart(ds::Matrix{Any}) = [realpart(col) for col in ds]
1111
@inline realpart(ds::Array) = map(d -> realpart(d), ds) # NOTE: this function is not optimized
1212
# @inline realpart(ds::TArray) = realpart(Array(ds)) # TODO: is it disabled temporary
13-
@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
14-
@inline realpart(ta::ReverseDiff.TrackedReal) = ta.value
15-
end
13+
@inline realpart(ta::Tracker.TrackedReal) = ta.data
1614

1715
@inline dualpart(d::ForwardDiff.Dual) = d.partials.values
1816
@inline dualpart(ds::Union{Array,SubArray}) = map(d -> dualpart(d), ds)

src/samplers/hmcda.jl

+1-7
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
112112
push!(spl.info[:accept_his], false)
113113

114114
# Reset Θ
115-
# NOTE: ForwardDiff and ReverseDiff need different implementation
116-
# due to struct Dual vs mutable TrackedReal
117115
if ADBACKEND == :forward_diff
118116

119117
vi[spl] = θ
@@ -122,11 +120,7 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
122120

123121
vi_spl = vi[spl]
124122
for i = 1:length(θ)
125-
if isa(vi_spl[i], ReverseDiff.TrackedReal)
126-
vi_spl[i].value = θ[i]
127-
else
128-
vi_spl[i] = θ[i]
129-
end
123+
vi_spl[i] = θ[i]
130124
end
131125

132126
end

src/samplers/support/hmc_core.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,7 @@ function gen_rev_func(vi, spl)
5757
elseif ADBACKEND == :reverse_diff
5858
vi_spl = vi[spl]
5959
for i = 1:length(θ_old)
60-
if isa(vi_spl[i], ReverseDiff.TrackedReal)
61-
vi_spl[i].value = θ_old[i]
62-
else
63-
vi_spl[i] = θ_old[i]
64-
end
60+
vi_spl[i] = θ_old[i]
6561
end
6662
end
6763
setlogp!(vi, old_logp)

0 commit comments

Comments
 (0)