@@ -79,6 +79,8 @@ gradient(vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
79
79
spl. info[:grad_cache ][θ_hash] = grad
80
80
end
81
81
82
+ vi. logp = realpart (vi. logp)
83
+
82
84
grad
83
85
end
84
86
@@ -93,8 +95,7 @@ verifygrad(grad::Vector{Float64}) = begin
93
95
end
94
96
95
97
# Direct call of ForwardDiff.gradient; this is slow
96
-
97
- gradient2 (_vi:: VarInfo , model:: Function , spl:: Union{Nothing, Sampler} ) = begin
98
+ gradient_slow (_vi:: VarInfo , model:: Function , spl:: Union{Nothing, Sampler} ) = begin
98
99
99
100
vi = deepcopy (_vi)
100
101
@@ -108,38 +109,19 @@ gradient2(_vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
108
109
g (vi[spl])
109
110
end
110
111
111
- @init @require ReverseDiff= " 37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
112
-
113
- gradient_r (theta:: Vector{Float64} , vi:: VarInfo , model:: Function ) = gradient_r (theta, vi, model, nothing )
114
- gradient_r (theta:: Vector{Float64} , vi:: Turing.VarInfo , model:: Function , spl:: Union{Nothing, Sampler} ) = begin
115
- inputs = (theta)
116
-
117
- if Turing. ADSAFE || (spl == nothing || length (spl. info[:reverse_diff_cache ]) == 0 )
118
- f_r (ipts) = begin
119
- vi[spl][:] = ipts[:]
120
- - runmodel (model, vi, spl). logp
121
- end
122
- gtape = GradientTape (f_r, inputs)
123
- ctape = compile (gtape)
124
- res = (similar (theta))
125
-
126
- if spl != nothing
127
- spl. info[:reverse_diff_cache ][:ctape ] = ctape
128
- spl. info[:reverse_diff_cache ][:res ] = res
129
- end
130
- else
131
- ctape = spl. info[:reverse_diff_cache ][:ctape ]
132
- res = spl. info[:reverse_diff_cache ][:res ]
133
- end
134
-
135
- grad = ReverseDiff. gradient! (res, ctape, inputs)
136
-
137
- # grad = ReverseDiff.gradient(x -> (vi[spl] = x; -runmodel(model, vi, spl).logp), inputs)
138
-
139
- # vi[spl] = realpart(vi[spl])
140
- # vi.logp = 0
141
-
142
- grad
112
+ gradient_r (theta:: AbstractVector{<:Real} , vi:: VarInfo , model:: Function ) =
113
+ gradient_r (theta, vi, model, nothing )
114
+ gradient_r (theta:: AbstractVector{<:Real} , vi:: Turing.VarInfo , model:: Function , spl:: Union{Nothing, Sampler} ) = begin
115
+ # Use Flux.Tracker to get gradient
116
+ grad = Tracker. gradient (x -> (vi[spl] = x; - runmodel (model, vi, spl). logp), theta)
117
+ # Clean tracked numbers
118
+ # Numbers do not need to be tracked between two gradient calls
119
+ vi. logp = vi. logp. data
120
+ vi_spl = vi[spl]
121
+ for i = 1 : length (theta)
122
+ vi_spl[i] = vi_spl[i]. data
123
+ end
124
+ # Return non-tracked graident value
125
+ return first (grad). data
143
126
end
144
127
145
- end
0 commit comments