1- struct PriorExtractorContext{D<: OrderedDict{VarName,Any} ,Ctx<: AbstractContext } < :
2- AbstractContext
1+ struct PriorDistributionAccumulator{D<: OrderedDict{VarName,Any} } <: AbstractAccumulator
32 priors:: D
4- context:: Ctx
53end
64
7- PriorExtractorContext (context ) = PriorExtractorContext (OrderedDict {VarName,Any} (), context )
5+ PriorDistributionAccumulator ( ) = PriorDistributionAccumulator (OrderedDict {VarName,Any} ())
86
9- NodeTrait (:: PriorExtractorContext ) = IsParent ()
10- childcontext (context:: PriorExtractorContext ) = context. context
11- function setchildcontext (parent:: PriorExtractorContext , child)
12- return PriorExtractorContext (parent. priors, child)
7+ accumulator_name (:: PriorDistributionAccumulator ) = :PriorDistributionAccumulator
8+
9+ split (acc:: PriorDistributionAccumulator ) = PriorDistributionAccumulator (empty (acc. priors))
10+ function combine (acc1:: PriorDistributionAccumulator , acc2:: PriorDistributionAccumulator )
11+ return PriorDistributionAccumulator (merge (acc1. priors, acc2. priors))
1312end
1413
15- function setprior! (context:: PriorExtractorContext , vn:: VarName , dist:: Distribution )
16- return context. priors[vn] = dist
14+ function setprior! (acc:: PriorDistributionAccumulator , vn:: VarName , dist:: Distribution )
15+ acc. priors[vn] = dist
16+ return acc
1717end
1818
1919function setprior! (
20- context :: PriorExtractorContext , vns:: AbstractArray{<:VarName} , dist:: Distribution
20+ acc :: PriorDistributionAccumulator , vns:: AbstractArray{<:VarName} , dist:: Distribution
2121)
2222 for vn in vns
23- context . priors[vn] = dist
23+ acc . priors[vn] = dist
2424 end
25+ return acc
2526end
2627
2728function setprior! (
28- context :: PriorExtractorContext ,
29+ acc :: PriorDistributionAccumulator ,
2930 vns:: AbstractArray{<:VarName} ,
3031 dists:: AbstractArray{<:Distribution} ,
3132)
3233 for (vn, dist) in zip (vns, dists)
33- context . priors[vn] = dist
34+ acc . priors[vn] = dist
3435 end
36+ return acc
3537end
3638
37- function DynamicPPL. tilde_assume (context:: PriorExtractorContext , right, vn, vi)
38- setprior! (context, vn, right)
39- return DynamicPPL. tilde_assume (childcontext (context), right, vn, vi)
39+ function accumulate_assume!! (acc:: PriorDistributionAccumulator , val, logjac, vn, right)
40+ return setprior! (acc, vn, right)
4041end
4142
43+ accumulate_observe!! (acc:: PriorDistributionAccumulator , right, left, vn) = acc
44+
4245"""
4346 extract_priors([rng::Random.AbstractRNG, ]model::Model)
4447
@@ -108,9 +111,13 @@ julia> length(extract_priors(rng, model)[@varname(x)])
108111extract_priors (args:: Union{Model,AbstractVarInfo} ...) =
109112 extract_priors (Random. default_rng (), args... )
110113function extract_priors (rng:: Random.AbstractRNG , model:: Model )
111- context = PriorExtractorContext (SamplingContext (rng))
112- evaluate!! (model, VarInfo (), context)
113- return context. priors
114+ varinfo = VarInfo ()
115+ # TODO (mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
116+ # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
117+ # can't push new variables without knowing the num_produce. Remove this when possible.
118+ varinfo = setaccs!! (varinfo, (PriorDistributionAccumulator (), NumProduceAccumulator ()))
119+ varinfo = last (evaluate!! (model, varinfo, SamplingContext (rng)))
120+ return getacc (varinfo, Val (:PriorDistributionAccumulator )). priors
114121end
115122
116123"""
@@ -122,7 +129,12 @@ This is done by evaluating the model at the values present in `varinfo`
122129and recording the distributions that are present at each tilde statement.
123130"""
124131function extract_priors (model:: Model , varinfo:: AbstractVarInfo )
125- context = PriorExtractorContext (DefaultContext ())
126- evaluate!! (model, deepcopy (varinfo), context)
127- return context. priors
132+ # TODO (mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
133+ # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
134+ # can't push new variables without knowing the num_produce. Remove this when possible.
135+ varinfo = setaccs!! (
136+ deepcopy (varinfo), (PriorDistributionAccumulator (), NumProduceAccumulator ())
137+ )
138+ varinfo = last (evaluate!! (model, varinfo, DefaultContext ()))
139+ return getacc (varinfo, Val (:PriorDistributionAccumulator )). priors
128140end
0 commit comments