@@ -292,4 +292,290 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
292292 end
293293end
294294
295+ """
296+ DynamicPPL.pointwise_logdensities(
297+ model::DynamicPPL.Model,
298+ chain::MCMCChains.Chains,
299+ ::Type{Tout}=MCMCChains.Chains
300+ ::Val{whichlogprob}=Val(:both),
301+ )
302+
303+ Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where
304+ the log-density of each variable at each sample is stored (rather than its value).
305+
306+ `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or
307+ `:likelihood`.
308+
309+ You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName,
310+ Matrix{Float64}}` instead.
311+
312+ See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref),
313+ [`DynamicPPL.pointwise_prior_logdensities`](@ref).
314+
315+ # Examples
316+
317+ ```jldoctest pointwise-logdensities-chains; setup=:(using Distributions)
318+ julia> using MCMCChains
319+
320+ julia> @model function demo(xs, y)
321+ s ~ InverseGamma(2, 3)
322+ m ~ Normal(0, √s)
323+ for i in eachindex(xs)
324+ xs[i] ~ Normal(m, √s)
325+ end
326+ y ~ Normal(m, √s)
327+ end
328+ demo (generic function with 2 methods)
329+
330+ julia> # Example observations.
331+ model = demo([1.0, 2.0, 3.0], [4.0]);
332+
333+ julia> # A chain with 3 iterations.
334+ chain = Chains(
335+ reshape(1.:6., 3, 2),
336+ [:s, :m];
337+ info=(varname_to_symbol=Dict(
338+ @varname(s) => :s,
339+ @varname(m) => :m,
340+ ),),
341+ );
342+
343+ julia> plds = pointwise_logdensities(model, chain)
344+ Chains MCMC chain (3×6×1 Array{Float64, 3}):
345+
346+ Iterations = 1:1:3
347+ Number of chains = 1
348+ Samples per chain = 3
349+ parameters = s, m, xs[1], xs[2], xs[3], y
350+ [...]
351+
352+ julia> plds[:s]
353+ 2-dimensional AxisArray{Float64,2,...} with axes:
354+ :iter, 1:1:3
355+ :chain, 1:1
356+ And data, a 3×1 Matrix{Float64}:
357+ -0.8027754226637804
358+ -1.3822169643436162
359+ -2.0986122886681096
360+
361+ julia> # The above is the same as:
362+ logpdf.(InverseGamma(2, 3), chain[:s])
363+ 3×1 Matrix{Float64}:
364+ -0.8027754226637804
365+ -1.3822169643436162
366+ -2.0986122886681096
367+ ```
368+
369+ julia> # Alternatively:
370+ plds_dict = pointwise_logdensities(model, chain, OrderedDict)
371+ OrderedDict{VarName, Matrix{Float64}} with 6 entries:
372+ s => [-0.802775; -1.38222; -2.09861;;]
373+ m => [-8.91894; -7.51551; -7.46824;;]
374+ xs[1] => [-5.41894; -5.26551; -5.63491;;]
375+ xs[2] => [-2.91894; -3.51551; -4.13491;;]
376+ xs[3] => [-1.41894; -2.26551; -2.96824;;]
377+ y => [-0.918939; -1.51551; -2.13491;;]
378+ """
379+ function DynamicPPL. pointwise_logdensities (
380+ model:: DynamicPPL.Model ,
381+ chain:: MCMCChains.Chains ,
382+ :: Type{Tout} = MCMCChains. Chains,
383+ :: Val{whichlogprob} = Val (:both ),
384+ ) where {whichlogprob,Tout}
385+ vi = DynamicPPL. VarInfo (model)
386+ acc = DynamicPPL. PointwiseLogProbAccumulator {whichlogprob} ()
387+ accname = DynamicPPL. accumulator_name (acc)
388+ vi = DynamicPPL. setaccs!! (vi, (acc,))
389+ parameter_only_chain = MCMCChains. get_sections (chain, :parameters )
390+ iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
391+ pointwise_logps = map (iters) do (sample_idx, chain_idx)
392+ # Extract values from the chain
393+ values_dict = chain_sample_to_varname_dict (parameter_only_chain, sample_idx, chain_idx)
394+ # Re-evaluate the model
395+ _, vi = DynamicPPL. init!! (
396+ model,
397+ vi,
398+ DynamicPPL. InitFromParams (values_dict, DynamicPPL. InitFromPrior ()),
399+ )
400+ DynamicPPL. getacc (vi, Val (accname)). logps
401+ end
402+
403+ # pointwise_logps is a matrix of OrderedDicts
404+ all_keys = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
405+ for d in pointwise_logps
406+ union! (all_keys, DynamicPPL. OrderedCollections. OrderedSet (keys (d)))
407+ end
408+ # this is a 3D array: (iterations, variables, chains)
409+ new_data = [
410+ get (pointwise_logps[iter, chain], k, missing ) for
411+ iter in 1 : size (pointwise_logps, 1 ), k in all_keys,
412+ chain in 1 : size (pointwise_logps, 2 )
413+ ]
414+
415+ if Tout == MCMCChains. Chains
416+ return MCMCChains. Chains (new_data, Symbol .(collect (all_keys)))
417+ elseif Tout <: AbstractDict
418+ return Tout {DynamicPPL.VarName,Matrix{Float64}} (
419+ k => new_data[:, i, :] for (i, k) in enumerate (all_keys)
420+ )
421+ end
422+ end
423+
424+ """
425+ DynamicPPL.pointwise_loglikelihoods(
426+ model::DynamicPPL.Model,
427+ chain::MCMCChains.Chains,
428+ ::Type{Tout}=MCMCChains.Chains
429+ )
430+
431+ Compute the pointwise log-likelihoods of the model given the chain. This is the same as
432+ `pointwise_logdensities(model, chain)`, but only including the likelihood terms.
433+
434+ See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref).
435+ """
436+ function DynamicPPL. pointwise_loglikelihoods (
437+ model:: DynamicPPL.Model , chain:: MCMCChains.Chains , :: Type{Tout} = MCMCChains. Chains
438+ ) where {Tout}
439+ return DynamicPPL. pointwise_logdensities (model, chain, Tout, Val (:likelihood ))
440+ end
441+
442+ """
443+ DynamicPPL.pointwise_prior_logdensities(
444+ model::DynamicPPL.Model,
445+ chain::MCMCChains.Chains
446+ )
447+
448+ Compute the pointwise log-prior-densities of the model given the chain. This is the same as
449+ `pointwise_logdensities(model, chain)`, but only including the prior terms.
450+
451+ See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref).
452+ """
453+ function DynamicPPL. pointwise_prior_logdensities (
454+ model:: DynamicPPL.Model , chain:: MCMCChains.Chains , :: Type{Tout} = MCMCChains. Chains
455+ ) where {Tout}
456+ return DynamicPPL. pointwise_logdensities (model, chain, Tout, Val (:prior ))
457+ end
458+
459+ """
460+ logjoint(model::Model, chain::MCMCChains.Chains)
461+
462+ Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.
463+
464+ # Examples
465+
466+ ```jldoctest
467+ julia> using MCMCChains, Distributions
468+
469+ julia> @model function demo_model(x)
470+ s ~ InverseGamma(2, 3)
471+ m ~ Normal(0, sqrt(s))
472+ for i in eachindex(x)
473+ x[i] ~ Normal(m, sqrt(s))
474+ end
475+ end;
476+
477+ julia> # Construct a chain of samples using MCMCChains.
478+ # This sets s = 0.5 and m = 1.0 for all three samples.
479+ chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);
480+
481+ julia> logjoint(demo_model([1., 2.]), chain)
482+ 3×1 Matrix{Float64}:
483+ -5.440428709758045
484+ -5.440428709758045
485+ -5.440428709758045
486+ ```
487+ """
488+ function DynamicPPL. logjoint (model:: DynamicPPL.Model , chain:: MCMCChains.Chains )
489+ var_info = DynamicPPL. VarInfo (model) # extract variables info from the model
490+ map (Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))) do (iteration_idx, chain_idx)
491+ argvals_dict = DynamicPPL. OrderedCollections. OrderedDict {DynamicPPL.VarName,Any} (
492+ vn_parent => DynamicPPL. values_from_chain (
493+ var_info, vn_parent, chain, chain_idx, iteration_idx
494+ ) for vn_parent in keys (var_info)
495+ )
496+ DynamicPPL. logjoint (model, argvals_dict)
497+ end
498+ end
499+
500+ """
501+ loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
502+
503+ Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
504+ # Examples
505+
506+ ```jldoctest
507+ julia> using MCMCChains, Distributions
508+
509+ julia> @model function demo_model(x)
510+ s ~ InverseGamma(2, 3)
511+ m ~ Normal(0, sqrt(s))
512+ for i in eachindex(x)
513+ x[i] ~ Normal(m, sqrt(s))
514+ end
515+ end;
516+
517+ julia> # Construct a chain of samples using MCMCChains.
518+ # This sets s = 0.5 and m = 1.0 for all three samples.
519+ chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);
520+
521+ julia> loglikelihood(demo_model([1., 2.]), chain)
522+ 3×1 Matrix{Float64}:
523+ -2.1447298858494
524+ -2.1447298858494
525+ -2.1447298858494
526+ ```
527+ """
528+ function DynamicPPL. loglikelihood (model:: DynamicPPL.Model , chain:: MCMCChains.Chains )
529+ var_info = DynamicPPL. VarInfo (model) # extract variables info from the model
530+ map (Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))) do (iteration_idx, chain_idx)
531+ argvals_dict = DynamicPPL. OrderedCollections. OrderedDict {DynamicPPL.VarName,Any} (
532+ vn_parent => DynamicPPL. values_from_chain (
533+ var_info, vn_parent, chain, chain_idx, iteration_idx
534+ ) for vn_parent in keys (var_info)
535+ )
536+ DynamicPPL. loglikelihood (model, argvals_dict)
537+ end
538+ end
539+
540+ """
541+ logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
542+
543+ Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.
544+
545+ # Examples
546+
547+ ```jldoctest
548+ julia> using MCMCChains, Distributions
549+
550+ julia> @model function demo_model(x)
551+ s ~ InverseGamma(2, 3)
552+ m ~ Normal(0, sqrt(s))
553+ for i in eachindex(x)
554+ x[i] ~ Normal(m, sqrt(s))
555+ end
556+ end;
557+
558+ julia> # Construct a chain of samples using MCMCChains.
559+ # This sets s = 0.5 and m = 1.0 for all three samples.
560+ chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);
561+
562+ julia> logprior(demo_model([1., 2.]), chain)
563+ 3×1 Matrix{Float64}:
564+ -3.2956988239086447
565+ -3.2956988239086447
566+ -3.2956988239086447
567+ ```
568+ """
569+ function DynamicPPL. logprior (model:: DynamicPPL.Model , chain:: MCMCChains.Chains )
570+ var_info = DynamicPPL. VarInfo (model) # extract variables info from the model
571+ map (Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))) do (iteration_idx, chain_idx)
572+ argvals_dict = DynamicPPL. OrderedCollections. OrderedDict {DynamicPPL.VarName,Any} (
573+ vn_parent => DynamicPPL. values_from_chain (
574+ var_info, vn_parent, chain, chain_idx, iteration_idx
575+ ) for vn_parent in keys (var_info)
576+ )
577+ DynamicPPL. logprior (model, argvals_dict)
578+ end
579+ end
580+
295581end
0 commit comments