From 7e37432e6a459c0bcb64f714d4f4def6ee46fabe Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 18:45:54 +0100 Subject: [PATCH 01/14] Reimplement pointwise_logdensities (almost completely) --- ext/DynamicPPLMCMCChainsExt.jl | 133 +++++++++++++++++ src/pointwise_logdensities.jl | 254 ++------------------------------- test/pointwise_logdensities.jl | 29 ++-- 3 files changed, 164 insertions(+), 252 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7886ad468..600af4177 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -292,4 +292,137 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha end end +""" + pointwise_logdensities( + model::Model, + chain::Chains, + ::Val{whichlogprob}=Val(:both), + ) + +Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where +the log-density of each variable at each sample is stored (rather than its value). + +`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or +`:likelihood`. + +See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_prior_logdensities`](@ref). + +# Examples + +```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) +julia> using MCMCChains + +julia> @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end +demo (generic function with 2 methods) + +julia> # Example observations. + model = demo([1.0, 2.0, 3.0], [4.0]); + +julia> # A chain with 3 iterations. + chain = Chains( + reshape(1.:6., 3, 2), + [:s, :m]; + info=(varname_to_symbol=Dict( + @varname(s) => :s, + @varname(m) => :m, + ),), + ); + +julia> plds = pointwise_logdensities(model, chain) +Chains MCMC chain (3×6×1 Array{Float64, 3}): + +Iterations = 1:1:3 +Number of chains = 1 +Samples per chain = 3 +parameters = s, m, xs[1], xs[2], xs[3], y +[...] + +julia> plds[:s] +2-dimensional AxisArray{Float64,2,...} with axes: + :iter, 1:1:3 + :chain, 1:1 +And data, a 3×1 Matrix{Float64}: + -0.8027754226637804 + -1.3822169643436162 + -2.0986122886681096 + +julia> # The above is the same as: + logpdf.(InverseGamma(2, 3), chain[:s]) +3×1 Matrix{Float64}: + -0.8027754226637804 + -1.3822169643436162 + -2.0986122886681096 +``` +""" +function DynamicPPL.pointwise_logdensities( + model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Val{whichlogprob}=Val(:both) +) where {whichlogprob} + vi = DynamicPPL.VarInfo(model) + acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() + accname = DynamicPPL.accumulator_name(acc) + vi = DynamicPPL.setaccs!!(vi, (acc,)) + + parameter_only_chain = MCMCChains.get_sections(chain, :parameters) + + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + pointwise_logps = map(iters) do (sample_idx, chain_idx) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Re-evaluate the model + _, vi = DynamicPPL.init!!( + model, + vi, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) + DynamicPPL.getacc(vi, Val(accname)).logps + end + + # pointwise_logps is a matrix of OrderedDicts -- we just need to convert to a Chains + all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() + for d in pointwise_logps + union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d))) + end + new_data = [ + get(pointwise_logps[iter, chain], k, missing) for + iter in 1:size(pointwise_logps, 1), k in all_keys, + chain in 1:size(pointwise_logps, 2) + ] + return MCMCChains.Chains(new_data, Symbol.(collect(all_keys))) +end + +""" + pointwise_loglikelihoods(model, chain, ::Val{whichlogprob}=Val(:both)) + +Compute the pointwise log-likelihoods of the model given the chain. This is the same as +`pointwise_logdensities(model, chain)`, but only including the likelihood terms. + +See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). +""" +function DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, chain::MCMCChains.Chains +) + return DynamicPPL.pointwise_logdensities(model, chain, Val(:likelihood)) +end + +""" + pointwise_prior_logdensities(model, chain, ::Val{whichlogprob}=Val(:both)) + +Compute the pointwise log-prior-densities of the model given the chain. This is the same as +`pointwise_logdensities(model, chain)`, but only including the prior terms. + +See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). +""" +function DynamicPPL.pointwise_prior_logdensities( + model::DynamicPPL.Model, chain::MCMCChains.Chains +) + return DynamicPPL.pointwise_logdensities(model, chain, Val(:prior)) +end + end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 47ca62530..579fe703e 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,35 +1,21 @@ """ - PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator + PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator An accumulator that stores the log-probabilities of each variable in a model. -Internally this accumulator stores the log-probabilities in a dictionary, where -the keys are the variable names and the values are vectors of -log-probabilities. Each element in a vector corresponds to one execution of the -model. +Internally this accumulator stores the log-probabilities in a dictionary, where the keys are +the variable names and the values are vectors of log-probabilities. Each element in a vector +corresponds to one execution of the model. `whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies -which log-probabilities to store in the accumulator. `KeyType` is the type by which variable -names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary -used internally to store the log-probabilities, by default -`OrderedDict{KeyType, Vector{LogProbType}}`. +which log-probabilities to store in the accumulator. """ -struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: - AbstractAccumulator - logps::D -end - -function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps) -end +struct PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator + logps::OrderedDict{VarName,LogProbType} -function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob,VarName}() -end - -function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType} - logps = OrderedDict{KeyType,Vector{LogProbType}}() - return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) + function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} + return new{whichlogprob}(OrderedDict{VarName,LogProbType}()) + end end function Base.:(==)( @@ -42,28 +28,14 @@ function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichl return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps)) end -function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) - logps = acc.logps - # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. - T = last(fieldtypes(eltype(logps))) - logpvec = get!(logps, vn, T()) - return push!(logpvec, logp) -end - -function Base.push!( - acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp -) where {whichlogprob} - return push!(acc, string(vn), logp) -end - function accumulator_name( ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} ) where {whichlogprob} return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function _zero(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) +function _zero(::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}() end reset(acc::PointwiseLogProbAccumulator) = _zero(acc) split(acc::PointwiseLogProbAccumulator) = _zero(acc) @@ -71,21 +43,14 @@ function combine( acc::PointwiseLogProbAccumulator{whichlogprob}, acc2::PointwiseLogProbAccumulator{whichlogprob}, ) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps)) + return PointwiseLogProbAccumulator{whichlogprob}(mergewith(+, acc.logps, acc2.logps)) end function accumulate_assume!!( acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right ) where {whichlogprob} if whichlogprob == :both || whichlogprob == :prior - # T is the element type of the vectors that are the values of `acc.logps`. Usually - # it's LogProbType. - T = eltype(last(fieldtypes(eltype(acc.logps)))) - # Note that in only accumulating LogPrior, we effectively ignore logjac - # (since we want to return log densities that don't depend on the - # linking status of the VarInfo). - subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) - push!(acc, vn, subacc.logp) + acc.logps[vn] = loglikelihood(right, val) end return acc end @@ -99,172 +64,11 @@ function accumulate_observe!!( return acc end if whichlogprob == :both || whichlogprob == :likelihood - # T is the element type of the vectors that are the values of `acc.logps`. Usually - # it's LogProbType. - T = eltype(last(fieldtypes(eltype(acc.logps)))) - subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn) - push!(acc, vn, subacc.logp) + acc.logps[vn] = loglikelihood(right, left) end return acc end -""" - pointwise_logdensities( - model::Model, - chain::Chains, - keytype=String, - ::Val{whichlogprob}=Val(:both), - ) - -Runs `model` on each sample in `chain` returning a `OrderedDict{VarName, Matrix{Float64}}` -with keys being model variables and values being matrices of shape -`(num_chains, num_samples)`. - -`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported, with `VarName` being the default. -`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or -`:likelihood`. - -See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). - -# Notes -Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -*observation*) statements can be implemented in three ways: -1. using a `for` loop: -```julia -for i in eachindex(y) - y[i] ~ Normal(μ, σ) -end -``` -2. using `.~`: -```julia -y .~ Normal(μ, σ) -``` -3. using `MvNormal`: -```julia -y ~ MvNormal(fill(μ, n), σ^2 * I) -``` - -In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -while in (3) `y` will be treated as a _single_ n-dimensional observation. - -This is important to keep in mind, in particular if the computation is used -for downstream computations. - -# Examples -## From chain -```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) -julia> using MCMCChains - -julia> @model function demo(xs, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - y ~ Normal(m, √s) - end -demo (generic function with 2 methods) - -julia> # Example observations. - model = demo([1.0, 2.0, 3.0], [4.0]); - -julia> # A chain with 3 iterations. - chain = Chains( - reshape(1.:6., 3, 2), - [:s, :m] - ); - -julia> pointwise_logdensities(model, chain) -OrderedDict{VarName, Matrix{Float64}} with 6 entries: - s => [-0.802775; -1.38222; -2.09861;;] - m => [-8.91894; -7.51551; -7.46824;;] - xs[1] => [-5.41894; -5.26551; -5.63491;;] - xs[2] => [-2.91894; -3.51551; -4.13491;;] - xs[3] => [-1.41894; -2.26551; -2.96824;;] - y => [-0.918939; -1.51551; -2.13491;;] - -julia> pointwise_logdensities(model, chain, String) -OrderedDict{String, Matrix{Float64}} with 6 entries: - "s" => [-0.802775; -1.38222; -2.09861;;] - "m" => [-8.91894; -7.51551; -7.46824;;] - "xs[1]" => [-5.41894; -5.26551; -5.63491;;] - "xs[2]" => [-2.91894; -3.51551; -4.13491;;] - "xs[3]" => [-1.41894; -2.26551; -2.96824;;] - "y" => [-0.918939; -1.51551; -2.13491;;] - -julia> pointwise_logdensities(model, chain, VarName) -OrderedDict{VarName, Matrix{Float64}} with 6 entries: - s => [-0.802775; -1.38222; -2.09861;;] - m => [-8.91894; -7.51551; -7.46824;;] - xs[1] => [-5.41894; -5.26551; -5.63491;;] - xs[2] => [-2.91894; -3.51551; -4.13491;;] - xs[3] => [-1.41894; -2.26551; -2.96824;;] - y => [-0.918939; -1.51551; -2.13491;;] -``` - -## Broadcasting -Note that `x .~ Dist()` will treat `x` as a collection of -_independent_ observations rather than as a single observation. - -```jldoctest; setup = :(using Distributions) -julia> @model function demo(x) - x .~ Normal() - end; - -julia> m = demo([1.0, ]); - -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) --1.4189385332046727 - -julia> m = demo([1.0; 1.0]); - -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -(-1.4189385332046727, -1.4189385332046727) -``` -""" -function pointwise_logdensities( - model::Model, chain, ::Type{KeyType}=VarName, ::Val{whichlogprob}=Val(:both) -) where {KeyType,whichlogprob} - # Get the data by executing the model once - vi = VarInfo(model) - - # This accumulator tracks the pointwise log-probabilities in a single iteration. - AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} - vi = setaccs!!(vi, (AccType(),)) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - - # Maintain a separate accumulator that isn't tied to a VarInfo but rather - # tracks _all_ iterations. - all_logps = AccType() - for (sample_idx, chain_idx) in iters - # Update the values - setval!(vi, chain, sample_idx, chain_idx) - - # Execute model - vi = last(evaluate!!(model, vi)) - - # Get the log-probabilities - this_iter_logps = getacc(vi, Val(accumulator_name(AccType))).logps - - # Merge into main acc - for (varname, this_lp) in this_iter_logps - # Because `this_lp` is obtained from one model execution, it should only - # contain one variable, hence `only()`. - push!(all_logps, varname, only(this_lp)) - end - end - - niters = size(chain, 1) - nchains = size(chain, 3) - logdensities = OrderedDict( - varname => reshape(vals, niters, nchains) for (varname, vals) in all_logps.logps - ) - return logdensities -end - function pointwise_logdensities( model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) ) where {whichlogprob} @@ -274,38 +78,10 @@ function pointwise_logdensities( return getacc(varinfo, Val(accumulator_name(AccType))).logps end -""" - pointwise_loglikelihoods(model, chain[, keytype]) - -Compute the pointwise log-likelihoods of the model given the chain. -This is the same as `pointwise_logdensities(model, chain)`, but only -including the likelihood terms. - -See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). -""" -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=VarName) where {T} - return pointwise_logdensities(model, chain, T, Val(:likelihood)) -end - function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) return pointwise_logdensities(model, varinfo, Val(:likelihood)) end -""" - pointwise_prior_logdensities(model, chain[, keytype]) - -Compute the pointwise log-prior-densities of the model given the chain. -This is the same as `pointwise_logdensities(model, chain)`, but only -including the prior terms. - -See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). -""" -function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=VarName -) where {T} - return pointwise_logdensities(model, chain, T, Val(:prior)) -end - function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) return pointwise_logdensities(model, varinfo, Val(:prior)) end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index aac59380c..be5f20010 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,4 +1,4 @@ -@testset "logdensities_likelihoods.jl" begin +@testset "pointwise_logdensities.jl" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -39,32 +39,35 @@ end @testset "pointwise_logdensities chain" begin - # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, - # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just - # to ensure that we don't accidentally break the version on `Chains`. model = DynamicPPL.TestUtils.demo_assume_index_observe() - # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced - # an impl of this for containers. - # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. vns = DynamicPPL.TestUtils.varnames(model) # Get some random `NamedTuple` samples from the prior. num_iters = 3 vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] # Concatenate the vector representations and create a `Chains` from it. vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) - chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + chain = Chains( + permutedims(vals_arr), + map(Symbol, vns); + info=(varname_to_symbol=Dict(vn => Symbol(vn) for vn in vns),), + ) # Compute the different pointwise logdensities. logjoints_pointwise = pointwise_logdensities(model, chain) logpriors_pointwise = pointwise_prior_logdensities(model, chain) loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) + # Check output type + @test logjoints_pointwise isa MCMCChains.Chains + @test logpriors_pointwise isa MCMCChains.Chains + @test loglikelihoods_pointwise isa MCMCChains.Chains + # Check that they contain the correct variables. - @test all(vn in keys(logjoints_pointwise) for vn in vns) - @test all(vn in keys(logpriors_pointwise) for vn in vns) - @test !any(Base.Fix1(subsumes, @varname(x)), keys(logpriors_pointwise)) - @test !any(vn in keys(loglikelihoods_pointwise) for vn in vns) - @test all(Base.Fix1(subsumes, @varname(x)), keys(loglikelihoods_pointwise)) + @test all(Symbol(vn) in keys(logjoints_pointwise) for vn in vns) + @test all(Symbol(vn) in keys(logpriors_pointwise) for vn in vns) + @test !any(Base.Fix1(startswith, "x"), String.(keys(logpriors_pointwise))) + @test !any(Symbol(vn) in keys(loglikelihoods_pointwise) for vn in vns) + @test all(Base.Fix1(startswith, "x"), String.(keys(loglikelihoods_pointwise))) # Get the sum of the logjoints for each of the iterations. logjoints = [ From 4002b08376fa261c524ffa245defee6d935ca6d8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 18:51:06 +0100 Subject: [PATCH 02/14] Move logjoint, logprior, ... as well --- ext/DynamicPPLMCMCChainsExt.jl | 108 +++++++++++++++++++++++++++++++++ src/model.jl | 108 --------------------------------- src/varinfo.jl | 66 -------------------- 3 files changed, 108 insertions(+), 174 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 600af4177..69d83652a 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -425,4 +425,112 @@ function DynamicPPL.pointwise_prior_logdensities( return DynamicPPL.pointwise_logdensities(model, chain, Val(:prior)) end +""" + logjoint(model::Model, chain::MCMCChains.Chains) + +Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> logjoint(demo_model([1., 2.]), chain); +``` +""" +function logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict{VarName,Any}( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + DynamicPPL.logjoint(model, argvals_dict) + end +end + +""" + loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) + +Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. +n +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> loglikelihood(demo_model([1., 2.]), chain); +``` +""" +function Distributions.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict{DynamicPPL.VarName,Any}( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + DynamicPPL.loglikelihood(model, argvals_dict) + end +end + +""" + logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) + +Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> logprior(demo_model([1., 2.]), chain); +``` +""" +function logprior(model::Model, chain::AbstractMCMC.AbstractChains) + var_info = VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict{VarName,Any}( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + DynamicPPL.logprior(model, argvals_dict) + end +end + end diff --git a/src/model.jl b/src/model.jl index 6c7e8de94..d6682416b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1058,42 +1058,6 @@ function logjoint(model::Model, varinfo::AbstractVarInfo) return getlogjoint(last(evaluate!!(model, varinfo))) end -""" - logjoint(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> logjoint(demo_model([1., 2.]), chain); -``` -""" -function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - logjoint(model, argvals_dict) - end -end - """ logprior(model::Model, varinfo::AbstractVarInfo) @@ -1116,42 +1080,6 @@ function logprior(model::Model, varinfo::AbstractVarInfo) return getlogprior(last(evaluate!!(model, varinfo))) end -""" - logprior(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> logprior(demo_model([1., 2.]), chain); -``` -""" -function logprior(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - logprior(model, argvals_dict) - end -end - """ loglikelihood(model::Model, varinfo::AbstractVarInfo) @@ -1170,42 +1098,6 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getloglikelihood(last(evaluate!!(model, varinfo))) end -""" - loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> loglikelihood(demo_model([1., 2.]), chain); -``` -""" -function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - loglikelihood(model, argvals_dict) - end -end - # Implemented & documented in DynamicPPLMCMCChainsExt function predict end diff --git a/src/varinfo.jl b/src/varinfo.jl index 417766653..734bf3db5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1783,72 +1783,6 @@ function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) return missing_keys end -""" - setval!(vi::VarInfo, x) - setval!(vi::VarInfo, values, keys) - setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - -Set the values in `vi` to the provided values and leave those which are not present in -`x` or `chains` unchanged. - -## Notes -This is rather limited for two reasons: -1. It uses `subsumes_string(string(vn), map(string, keys))` under the hood, - and therefore suffers from the same limitations as [`subsumes_string`](@ref). -2. It will set every `vn` present in `keys`. It will NOT however - set every `k` present in `keys`. This means that if `vn == [m[1], m[2]]`, - representing some variable `m`, calling `setval!(vi, (m = [1.0, 2.0]))` will - be a no-op since it will try to find `m[1]` and `m[2]` in `keys((m = [1.0, 2.0]))`. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1]` - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 -``` -""" -setval!(vi::VarInfo, x) = setval!(vi, values(x), keys(x)) -setval!(vi::VarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) -function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) -end - -function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - set_transformed!!(vi, false, vn) - end - - return indices -end - values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) From b4444198b9a1e9bfddb9a716a256705b93db2ba8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 19:23:20 +0100 Subject: [PATCH 03/14] Fix imports, etc --- ext/DynamicPPLMCMCChainsExt.jl | 26 +++++++++--------- src/simple_varinfo.jl | 15 ++++++----- test/varinfo.jl | 48 +--------------------------------- 3 files changed, 23 insertions(+), 66 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 69d83652a..59b796cec 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -449,13 +449,13 @@ julia> # construct a chain of samples using MCMCChains julia> logjoint(demo_model([1., 2.]), chain); ``` """ -function logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) +function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) var_info = DynamicPPL.VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) ) DynamicPPL.logjoint(model, argvals_dict) end @@ -485,13 +485,13 @@ julia> # construct a chain of samples using MCMCChains julia> loglikelihood(demo_model([1., 2.]), chain); ``` """ -function Distributions.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) +function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) var_info = DynamicPPL.VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) ) DynamicPPL.loglikelihood(model, argvals_dict) end @@ -521,13 +521,13 @@ julia> # construct a chain of samples using MCMCChains julia> logprior(demo_model([1., 2.]), chain); ``` """ -function logprior(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model +function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) ) DynamicPPL.logprior(model, argvals_dict) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d6c0cbcad..2ba25f142 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -510,7 +510,7 @@ function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} end """ - logjoint(model::Model, θ) + logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log joint probability of variables `θ` for the probabilistic `model`. @@ -539,10 +539,11 @@ julia> # Truth. -9902.33787706641 ``` """ -logjoint(model::Model, θ) = logjoint(model, SimpleVarInfo(θ)) +logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) = + logjoint(model, SimpleVarInfo(θ)) """ - logprior(model::Model, θ) + logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log prior probability of variables `θ` for the probabilistic `model`. @@ -571,10 +572,11 @@ julia> # Truth. -5000.918938533205 ``` """ -logprior(model::Model, θ) = logprior(model, SimpleVarInfo(θ)) +logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) = + logprior(model, SimpleVarInfo(θ)) """ - loglikelihood(model::Model, θ) + loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log likelihood of variables `θ` for the probabilistic `model`. @@ -603,7 +605,8 @@ julia> # Truth. -4901.418938533205 ``` """ -Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ)) +Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) = + loglikelihood(model, SimpleVarInfo(θ)) # Allow usage of `NamedBijector` too. function link!!( diff --git a/test/varinfo.jl b/test/varinfo.jl index 5b541e1dd..cd23e9a41 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -396,61 +396,15 @@ end @test vicopy[s_vns] == 42 end end - - # https://github.com/TuringLang/DynamicPPL.jl/issues/250 - @model function demo() - return x ~ filldist(MvNormal([1, 100], I), 2) - end - - vi = VarInfo(demo()) - vals_prev = vi.metadata.x.vals - ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] - DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals end - @testset "setval! on chain" begin - # Define a helper function - """ - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - - Test `setval!` on `model` and `chain`. - - Worth noting that this only supports models containing symbols of the forms - `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. - """ - function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - θ_old = var_info[:] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[:] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end - end - + @testset "returned on MCMCChains.Chains" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS chain = make_chain_from_prior(model, 10) # A simple way of checking that the computation is determinstic: run twice and compare. res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) res2 = returned(model, MCMCChains.get_sections(chain, :parameters)) @test all(res1 .== res2) - test_setval!(model, MCMCChains.get_sections(chain, :parameters)) end end From e1489bad400671e5a04a66fb6b4caf37057fe29f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 19:28:08 +0100 Subject: [PATCH 04/14] Remove tests that are failing (yes I learnt this from Claude) --- test/varinfo.jl | 81 ------------------------------------------------- 1 file changed, 81 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index cd23e9a41..6b31fbe91 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -317,87 +317,6 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval!" begin - @model function testmodel(x) - n = length(x) - s ~ truncated(Normal(); lower=0) - m ~ MvNormal(zeros(n), I) - return x ~ MvNormal(m, s^2 * I) - end - - @model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV} - n = length(x) - s ~ truncated(Normal(); lower=0) - - m = TV(undef, n) - for i in eachindex(m) - m[i] ~ Normal() - end - - for i in eachindex(x) - x[i] ~ Normal(m[i], s) - end - end - - x = randn(5) - model_mv = testmodel(x) - model_uv = testmodel_univariate(x) - - for model in [model_uv, model_mv] - m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) - s_vns = @varname(s) - - vi_typed = DynamicPPL.typed_varinfo(model) - vi_untyped = DynamicPPL.untyped_varinfo(model) - vi_vnv = DynamicPPL.untyped_vector_varinfo(model) - vi_vnv_typed = DynamicPPL.typed_vector_varinfo(model) - - model_name = model == model_uv ? "univariate" : "multivariate" - @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ - vi_untyped, vi_typed, vi_vnv, vi_vnv_typed - ] - Random.seed!(23) - vicopy = deepcopy(vi) - - ### `setval` ### - # TODO(mhauru) The interface here seems inconsistent between Metadata and - # VarNamedVector. I'm lazy to fix it though, because I think we need to - # rework it soon anyway. - if vi in [vi_vnv, vi_vnv_typed] - DynamicPPL.setval!(vicopy, zeros(5), m_vns) - else - DynamicPPL.setval!(vicopy, (m=zeros(5),)) - end - # Setting `m` fails for univariate due to limitations of `setval!`. - # See docstring of `setval!` for more info. - if model == model_uv && vi in [vi_untyped, vi_typed] - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] == vi[s_vns] - - # Ordering is NOT preserved => fails for multivariate model. - DynamicPPL.setval!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] == vi[s_vns] - - DynamicPPL.setval!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - DynamicPPL.setval!(vicopy, (s=42,)) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] == 42 - end - end - end - @testset "returned on MCMCChains.Chains" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS chain = make_chain_from_prior(model, 10) From 2e69e0b6d7dbb72f655fc8eef75e45f89b02149d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 19:32:21 +0100 Subject: [PATCH 05/14] Changelog --- HISTORY.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 90864508b..15ed84dd4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -61,9 +61,11 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead. `loadstate` is exported from DynamicPPL. -### Change of default keytype of `pointwise_logdensities` +### Change of output type for `pointwise_logdensities` -The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`. +The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects, instead of dictionaries of matrices. +This also means that you can no longer specify the output type. +If you want to extract the matrices, you can do so by indexing into the returned `Chains` object. **Other changes** From fc393bc80cc87b41b2d8eb2272a21e5d67e8226f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 19:36:08 +0100 Subject: [PATCH 06/14] logpdf --- src/pointwise_logdensities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 579fe703e..4de330c0e 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -50,7 +50,7 @@ function accumulate_assume!!( acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right ) where {whichlogprob} if whichlogprob == :both || whichlogprob == :prior - acc.logps[vn] = loglikelihood(right, val) + acc.logps[vn] = logpdf(right, val) end return acc end From f2a83b3917d13865db7d99b0c2dc8429ab42cedd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 19:44:59 +0100 Subject: [PATCH 07/14] fix docstrings --- ext/DynamicPPLMCMCChainsExt.jl | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 59b796cec..d000c92ce 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -293,9 +293,9 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha end """ - pointwise_logdensities( - model::Model, - chain::Chains, + DynamicPPL.pointwise_logdensities( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, ::Val{whichlogprob}=Val(:both), ) @@ -305,7 +305,7 @@ the log-density of each variable at each sample is stored (rather than its value `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or `:likelihood`. -See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_prior_logdensities`](@ref). +See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). # Examples @@ -398,12 +398,15 @@ function DynamicPPL.pointwise_logdensities( end """ - pointwise_loglikelihoods(model, chain, ::Val{whichlogprob}=Val(:both)) + DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, + chain::MCMCChains.Chains + ) Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain)`, but only including the likelihood terms. -See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). +See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). """ function DynamicPPL.pointwise_loglikelihoods( model::DynamicPPL.Model, chain::MCMCChains.Chains @@ -412,12 +415,15 @@ function DynamicPPL.pointwise_loglikelihoods( end """ - pointwise_prior_logdensities(model, chain, ::Val{whichlogprob}=Val(:both)) + DynamicPPL.pointwise_prior_logdensities( + model::DynamicPPL.Model, + chain::MCMCChains.Chains + ) Compute the pointwise log-prior-densities of the model given the chain. This is the same as `pointwise_logdensities(model, chain)`, but only including the prior terms. -See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). +See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref). """ function DynamicPPL.pointwise_prior_logdensities( model::DynamicPPL.Model, chain::MCMCChains.Chains From 0dec616587612fdb18b25e1ea05aa940c8c57c2d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 19:53:34 +0100 Subject: [PATCH 08/14] allow dict output --- ext/DynamicPPLMCMCChainsExt.jl | 50 ++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d000c92ce..e20eeb96b 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -296,6 +296,7 @@ end DynamicPPL.pointwise_logdensities( model::DynamicPPL.Model, chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains ::Val{whichlogprob}=Val(:both), ) @@ -305,7 +306,11 @@ the log-density of each variable at each sample is stored (rather than its value `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or `:likelihood`. -See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). +You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName, +Matrix{Float64}}` instead. + +See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref), +[`DynamicPPL.pointwise_prior_logdensities`](@ref). # Examples @@ -360,10 +365,23 @@ julia> # The above is the same as: -1.3822169643436162 -2.0986122886681096 ``` + +julia> # Alternatively: + plds_dict = pointwise_logdensities(model, chain, OrderedDict) +OrderedDict{VarName, Matrix{Float64}} with 6 entries: + s => [-0.802775; -1.38222; -2.09861;;] + m => [-8.91894; -7.51551; -7.46824;;] + xs[1] => [-5.41894; -5.26551; -5.63491;;] + xs[2] => [-2.91894; -3.51551; -4.13491;;] + xs[3] => [-1.41894; -2.26551; -2.96824;;] + y => [-0.918939; -1.51551; -2.13491;;] """ function DynamicPPL.pointwise_logdensities( - model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Val{whichlogprob}=Val(:both) -) where {whichlogprob} + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains, + ::Val{whichlogprob}=Val(:both), +) where {whichlogprob,Tout} vi = DynamicPPL.VarInfo(model) acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() accname = DynamicPPL.accumulator_name(acc) @@ -384,7 +402,6 @@ function DynamicPPL.pointwise_logdensities( DynamicPPL.getacc(vi, Val(accname)).logps end - # pointwise_logps is a matrix of OrderedDicts -- we just need to convert to a Chains all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() for d in pointwise_logps union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d))) @@ -394,13 +411,22 @@ function DynamicPPL.pointwise_logdensities( iter in 1:size(pointwise_logps, 1), k in all_keys, chain in 1:size(pointwise_logps, 2) ] - return MCMCChains.Chains(new_data, Symbol.(collect(all_keys))) + + if Tout == MCMCChains.Chains + # pointwise_logps is a matrix of OrderedDicts -- we just need to convert to a Chains + return MCMCChains.Chains(new_data, Symbol.(collect(all_keys))) + elseif Tout <: AbstractDict + return Tout{DynamicPPL.VarName,Matrix{Float64}}( + k => new_data[:, i, :] for (i, k) in enumerate(all_keys) + ) + end end """ DynamicPPL.pointwise_loglikelihoods( model::DynamicPPL.Model, - chain::MCMCChains.Chains + chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains ) Compute the pointwise log-likelihoods of the model given the chain. This is the same as @@ -409,9 +435,9 @@ Compute the pointwise log-likelihoods of the model given the chain. This is the See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). """ function DynamicPPL.pointwise_loglikelihoods( - model::DynamicPPL.Model, chain::MCMCChains.Chains -) - return DynamicPPL.pointwise_logdensities(model, chain, Val(:likelihood)) + model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains +) where {Tout} + return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:likelihood)) end """ @@ -426,9 +452,9 @@ Compute the pointwise log-prior-densities of the model given the chain. This is See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref). """ function DynamicPPL.pointwise_prior_logdensities( - model::DynamicPPL.Model, chain::MCMCChains.Chains -) - return DynamicPPL.pointwise_logdensities(model, chain, Val(:prior)) + model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains +) where {Tout} + return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:prior)) end """ From f41106ed10156773bdd152e6246d3c002c958404 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 19:54:50 +0100 Subject: [PATCH 09/14] changelog --- HISTORY.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 15ed84dd4..abefc1e36 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -63,9 +63,9 @@ The `resume_from=chn` keyword argument to `sample` has been removed; please use ### Change of output type for `pointwise_logdensities` -The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects, instead of dictionaries of matrices. -This also means that you can no longer specify the output type. -If you want to extract the matrices, you can do so by indexing into the returned `Chains` object. +The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects by default, instead of dictionaries of matrices. + +If you want the old behaviour, you can pass `OrderedDict` as the third argument, i.e., `pointwise_logdensities(model, chain, OrderedDict)`. **Other changes** From 0b1d6a6a5b24fda870a6ce445f1bcc931361a39b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 19:59:00 +0100 Subject: [PATCH 10/14] fix some comments --- ext/DynamicPPLMCMCChainsExt.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index e20eeb96b..9442b4edd 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -386,9 +386,7 @@ function DynamicPPL.pointwise_logdensities( acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() accname = DynamicPPL.accumulator_name(acc) vi = DynamicPPL.setaccs!!(vi, (acc,)) - parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) pointwise_logps = map(iters) do (sample_idx, chain_idx) # Extract values from the chain @@ -402,10 +400,12 @@ function DynamicPPL.pointwise_logdensities( DynamicPPL.getacc(vi, Val(accname)).logps end + # pointwise_logps is a matrix of OrderedDicts all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() for d in pointwise_logps union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d))) end + # this is a 3D array: (iterations, variables, chains) new_data = [ get(pointwise_logps[iter, chain], k, missing) for iter in 1:size(pointwise_logps, 1), k in all_keys, @@ -413,7 +413,6 @@ function DynamicPPL.pointwise_logdensities( ] if Tout == MCMCChains.Chains - # pointwise_logps is a matrix of OrderedDicts -- we just need to convert to a Chains return MCMCChains.Chains(new_data, Symbol.(collect(all_keys))) elseif Tout <: AbstractDict return Tout{DynamicPPL.VarName,Matrix{Float64}}( From 94432ebeaa5c101d39461eb03b087783c3f1a16f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 16 Oct 2025 15:25:07 +0100 Subject: [PATCH 11/14] fix tests --- ext/DynamicPPLMCMCChainsExt.jl | 6 +++--- src/pointwise_logdensities.jl | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 9442b4edd..3ef6ff312 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -483,7 +483,7 @@ julia> logjoint(demo_model([1., 2.]), chain); function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) var_info = DynamicPPL.VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{VarName,Any}( vn_parent => DynamicPPL.values_from_chain( var_info, vn_parent, chain, chain_idx, iteration_idx ) for vn_parent in keys(var_info) @@ -519,7 +519,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain); function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) var_info = DynamicPPL.VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{DynamicPPL.VarName,Any}( + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( vn_parent => DynamicPPL.values_from_chain( var_info, vn_parent, chain, chain_idx, iteration_idx ) for vn_parent in keys(var_info) @@ -555,7 +555,7 @@ julia> logprior(demo_model([1., 2.]), chain); function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) var_info = DynamicPPL.VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{VarName,Any}( vn_parent => DynamicPPL.values_from_chain( var_info, vn_parent, chain, chain_idx, iteration_idx ) for vn_parent in keys(var_info) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 4de330c0e..2346dd396 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -13,8 +13,10 @@ which log-probabilities to store in the accumulator. struct PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator logps::OrderedDict{VarName,LogProbType} - function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} - return new{whichlogprob}(OrderedDict{VarName,LogProbType}()) + function PointwiseLogProbAccumulator{whichlogprob}( + d::OrderedDict{VarName,LogProbType}=OrderedDict{VarName,LogProbType}() + ) where {whichlogprob} + return new{whichlogprob}(d) end end From 57e099a257d9b0773b70caa902e403f053e4d3f0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 16 Oct 2025 15:42:16 +0100 Subject: [PATCH 12/14] Fix more imports --- ext/DynamicPPLMCMCChainsExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 3ef6ff312..f49037520 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -483,7 +483,7 @@ julia> logjoint(demo_model([1., 2.]), chain); function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) var_info = DynamicPPL.VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{VarName,Any}( + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( vn_parent => DynamicPPL.values_from_chain( var_info, vn_parent, chain, chain_idx, iteration_idx ) for vn_parent in keys(var_info) @@ -555,7 +555,7 @@ julia> logprior(demo_model([1., 2.]), chain); function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) var_info = DynamicPPL.VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{VarName,Any}( + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( vn_parent => DynamicPPL.values_from_chain( var_info, vn_parent, chain, chain_idx, iteration_idx ) for vn_parent in keys(var_info) From 802e38f057cdff42047a1a9dca448d2d04363fd4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 17 Oct 2025 11:19:00 +0100 Subject: [PATCH 13/14] Remove stray n Co-authored-by: Markus Hauru --- ext/DynamicPPLMCMCChainsExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index f49037520..c3e87269e 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -496,7 +496,6 @@ end loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. -n # Examples ```jldoctest From 76880654e40c6b46ac3b929ca8c162e9c061ed13 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 17 Oct 2025 11:32:02 +0100 Subject: [PATCH 14/14] Expand `logprior`, `loglikelihood`, and `logjoint` docstrings --- ext/DynamicPPLMCMCChainsExt.jl | 33 ++++++++++++++++++++++++--------- src/pointwise_logdensities.jl | 3 +-- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index c3e87269e..771dd664f 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -474,10 +474,15 @@ julia> @model function demo_model(x) end end; -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); -julia> logjoint(demo_model([1., 2.]), chain); +julia> logjoint(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -5.440428709758045 + -5.440428709758045 + -5.440428709758045 ``` """ function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) @@ -509,10 +514,15 @@ julia> @model function demo_model(x) end end; -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); -julia> loglikelihood(demo_model([1., 2.]), chain); +julia> loglikelihood(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -2.1447298858494 + -2.1447298858494 + -2.1447298858494 ``` """ function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) @@ -545,10 +555,15 @@ julia> @model function demo_model(x) end end; -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); -julia> logprior(demo_model([1., 2.]), chain); +julia> logprior(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -3.2956988239086447 + -3.2956988239086447 + -3.2956988239086447 ``` """ function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 2346dd396..848ecb1f0 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -4,8 +4,7 @@ An accumulator that stores the log-probabilities of each variable in a model. Internally this accumulator stores the log-probabilities in a dictionary, where the keys are -the variable names and the values are vectors of log-probabilities. Each element in a vector -corresponds to one execution of the model. +the variable names and the values are log-probabilities. `whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies which log-probabilities to store in the accumulator.