@@ -124,75 +124,112 @@ end
124124# #####################
125125# Default Transition #
126126# #####################
127- # Default
128- getstats (t) = nothing
127+ getstats (:: Any ) = NamedTuple ()
129128
129+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
130+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130131abstract type AbstractTransition end
131132
132- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
133+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
133134 θ:: T
134- lp:: F # TODO : merge `lp` with `stat`
135- stat:: S
136- end
135+ logprior:: F
136+ loglikelihood:: F
137+ stat:: N
138+
139+ """
140+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141+
142+ Construct a new `Turing.Inference.Transition` object using the outputs of a
143+ sampler step.
144+
145+ Here, `vi` represents a VarInfo _for which the appropriate parameters have
146+ already been set_. However, the accumulators (e.g. logp) may in general
147+ have junk contents. The role of this method is to re-evaluate `model` and
148+ thus set the accumulators to the correct values.
149+
150+ `sampler_transition` is the transition object returned by the sampler
151+ itself and is only used to extract statistics of interest.
152+ """
153+ function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
154+ vi = DynamicPPL. setaccs!! (
155+ vi,
156+ (
157+ DynamicPPL. ValuesAsInModelAccumulator (true ),
158+ DynamicPPL. LogPriorAccumulator (),
159+ DynamicPPL. LogLikelihoodAccumulator (),
160+ ),
161+ )
162+ _, vi = DynamicPPL. evaluate!! (model, vi)
163+
164+ # Extract all the information we need
165+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
166+ logprior = DynamicPPL. getlogprior (vi)
167+ loglikelihood = DynamicPPL. getloglikelihood (vi)
168+
169+ # # Convert values to the format needed (i.e. a Vector of (varname,
170+ # # value) tuples, where value isa Real: all vector-valued varnames must
171+ # # be split up.)
172+ # # TODO (penelopeysm): This wouldn't be necessary if not for MCMCChains's
173+ # # poor representation...
174+ # values_split = if isempty(vals_as_in_model)
175+ # # If there are no values, we return an empty vector.
176+ # # This is the case for models with no parameters.
177+ # Vector{Tuple{VarName,Any}}()
178+ # else
179+ # iters = map(
180+ # DynamicPPL.varname_and_value_leaves,
181+ # keys(vals_as_in_model),
182+ # values(vals_as_in_model),
183+ # )
184+ # mapreduce(collect, vcat, iters)
185+ # end
186+
187+ # Get additional statistics
188+ stats = getstats (sampler_transition)
189+ return new {typeof(vals_as_in_model),typeof(logprior),typeof(stats)} (
190+ vals_as_in_model, logprior, loglikelihood, stats
191+ )
192+ end
137193
138- Transition (θ, lp) = Transition (θ, lp, nothing )
139- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
140- # TODO (DPPL0.37/penelopeysm): Fix this
141- θ = getparams (model, vi)
142- lp = getlogjoint_internal (vi)
143- return Transition (θ, lp, getstats (t))
194+ function Transition (
195+ model:: DynamicPPL.Model ,
196+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
197+ sampler_transition,
198+ )
199+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
200+ # much faster to convert it to a typed varinfo first, hence this method.
201+ # https://github.com/TuringLang/Turing.jl/issues/2604
202+ return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
203+ end
144204end
145205
146- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147206function metadata (t:: Transition )
148- stat = t. stat
149- if stat === nothing
150- return (lp= t. lp,)
151- else
152- return merge ((lp= t. lp,), stat)
153- end
207+ return merge (
208+ t. stat,
209+ (
210+ lp= t. logprior + t. loglikelihood,
211+ logprior= t. logprior,
212+ loglikelihood= t. loglikelihood,
213+ ),
214+ )
215+ end
216+ function metadata (vi:: AbstractVarInfo )
217+ return (
218+ lp= DynamicPPL. getlogjoint (vi),
219+ logprior= DynamicPPL. getlogp (vi),
220+ loglikelihood= DynamicPPL. getloglikelihood (vi),
221+ )
154222end
155-
156- # TODO (DPPL0.37/penelopeysm): Fix this
157- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
158-
159- # Metadata of VarInfo object
160- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
162223
163224# #########################
164225# Chain making utilities #
165226# #########################
166227
167- """
168- getparams(model, t)
169-
170- Return a named tuple of parameters.
171- """
172- getparams (model, t) = t. θ
173- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
174- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176- # of the parameters change depending on the realizations. Hence we have to use
177- # `values_as_in_model`, which re-runs the model and extracts the parameters
178- # as they are seen in the model, i.e. in the constrained space. Moreover,
179- # this means that the code below will work both of linked and invlinked `vi`.
180- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182- return DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
228+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
229+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
230+ t = Transition (model, vi, nothing )
231+ return getparams (model, t)
183232end
184- function getparams (
185- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
186- )
187- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
188- # much faster to convert it to a typed varinfo before calling getparams.
189- # https://github.com/TuringLang/Turing.jl/issues/2604
190- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
191- end
192- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
193- return Dict {VarName,Any} ()
194- end
195-
196233function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
197234 names_set = OrderedSet {VarName} ()
198235 # Extract the parameter names and values from each transition.
@@ -208,7 +245,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
208245 iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
209246 mapreduce (collect, vcat, iters)
210247 end
211-
212248 nms = map (first, nms_and_vs)
213249 vs = map (last, nms_and_vs)
214250 for nm in nms
@@ -224,7 +260,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
224260end
225261
226262function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
227- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
263+ valmat = reshape ([DynamicPPL . getlogjoint (t) for t in ts], :, 1 )
228264 return [:lp ], valmat
229265end
230266
@@ -466,16 +502,17 @@ function transitions_from_chain(
466502 chain:: MCMCChains.Chains ;
467503 sampler= DynamicPPL. SampleFromPrior (),
468504)
469- vi = Turing . VarInfo (model)
505+ vi = VarInfo (model)
470506
471507 iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
472508 transitions = map (iters) do (sample_idx, chain_idx)
473509 # Set variables present in `chain` and mark those NOT present in chain to be resampled.
510+ # TODO (DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
474511 DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
475512 model (rng, vi, sampler)
476513
477514 # Convert `VarInfo` into `NamedTuple` and save.
478- Transition (model, vi)
515+ Transition (model, vi, nothing )
479516 end
480517
481518 return transitions
0 commit comments