diff --git a/Project.toml b/Project.toml index c0179083b..08542f9f2 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] @@ -16,6 +17,7 @@ AbstractMCMC = "1.0" Bijectors = "0.5.2, 0.6, 0.7" Distributions = "0.22, 0.23" MacroTools = "0.5.1" +Requires = "0.5, 1.0" ZygoteRules = "0.2" julia = "1" @@ -37,7 +39,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"] +test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "ReverseDiff", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 281e427fd..95bfb27b5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -1,6 +1,7 @@ module DynamicPPL using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel +using Requires using Distributions using Bijectors using MacroTools @@ -36,13 +37,14 @@ export AbstractVarInfo, set_num_produce!, reset_num_produce!, increment_num_produce!, + getmode, set_retained_vns_del_by_spl!, is_flagged, unset_flag!, setgid!, updategid!, setorder!, - istrans, + islinked_and_trans, link!, invlink!, tonamedtuple, diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 88a28f840..1813d5ebf 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -9,3 +9,6 @@ ZygoteRules.@adjoint function push!( ) return push!(vi, vn, r, dist, gidset), _ -> nothing end +ZygoteRules.@adjoint function zygote_setval!(vi, val, vn) + return zygote_setval!(vi, val, vn), _ -> nothing +end \ No newline at end of file diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ebba1e088..53978696d 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -23,16 +23,16 @@ function tilde(ctx::DefaultContext, sampler, right, vn::VarName, _, vi) return _tilde(sampler, right, vn, vi) end function tilde(ctx::PriorContext, sampler, right, vn::VarName, inds, vi) + @assert !islinked(vi) if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds) end return _tilde(sampler, right, vn, vi) end function tilde(ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) + @assert !islinked(vi) if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds) end return _tilde(sampler, NoDist(right), vn, vi) end @@ -125,20 +125,20 @@ function assume( if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vn, "del") + @assert !islinked(vi) unset_flag!(vi, vn, "del") r = init(dist, spl) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) + vi[vn, dist] = r setorder!(vi, vn, get_num_produce(vi)) else - r = vi[vn] + r = vi[vn, dist] end else + @assert !islinked(vi) r = init(dist, spl) push!(vi, vn, r, dist, spl) - settrans!(vi, false, vn) end - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function observe( @@ -167,11 +167,11 @@ function dot_tilde( inds, vi, ) + @assert !islinked(vi) if ctx.vars !== nothing var = _getindex(getfield(ctx.vars, getsym(vn)), inds) vns, dist = get_vns_and_dist(right, var, vn) set_val!(vi, vns, dist, var) - settrans!.(Ref(vi), false, vns) else vns, dist = get_vns_and_dist(right, left, vn) end @@ -189,11 +189,11 @@ function dot_tilde( inds, vi, ) + @assert !islinked(vi) if ctx.vars !== nothing var = _getindex(getfield(ctx.vars, getsym(vn)), inds) vns, dist = get_vns_and_dist(right, var, vn) set_val!(vi, vns, dist, var) - settrans!.(Ref(vi), false, vns) else vns, dist = get_vns_and_dist(right, left, vn) end @@ -214,14 +214,12 @@ function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) return value end - function get_vns_and_dist(dist::NamedDist, var, vn::VarName) return get_vns_and_dist(dist.dist, var, dist.name) end function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName) getvn = i -> VarName(vn, (vn.indexing..., (Colon(), i))) return getvn.(1:size(var, 2)), dist - end function get_vns_and_dist( dist::Union{Distribution, AbstractArray{<:Distribution}}, @@ -256,7 +254,7 @@ function dot_assume( ) @assert length(dist) == size(var, 1) r = get_and_set_val!(vi, vns, dist, spl) - lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + lp = sum(Bijectors.logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) var .= r return var, lp end @@ -269,7 +267,9 @@ function dot_assume( ) r = get_and_set_val!(vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + mode = getmode(vi) + trans = istrans(vi, vns[1]) && (mode isa LinkMode || mode isa InitLinkMode) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, trans)) var .= r return var, lp end @@ -293,23 +293,23 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") + @assert !islinked(vi) unset_flag!(vi, vns[1], "del") r = init(dist, spl, n) for i in 1:n vn = vns[i] - vi[vn] = vectorize(dist, r[:, i]) - settrans!(vi, false, vn) + vi[vn, dist] = r[:, i] setorder!(vi, vn, get_num_produce(vi)) end else - r = vi[vns] + r = vi[vns, dist] end else + @assert !islinked(vi) r = init(dist, spl, n) for i in 1:n vn = vns[i] push!(vi, vn, r[:,i], dist, spl) - settrans!(vi, false, vn) end end return r @@ -324,24 +324,24 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") + @assert !islinked(vi) unset_flag!(vi, vns[1], "del") f = (vn, dist) -> init(dist, spl) r = f.(vns, dists) for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists - vi[vn] = vectorize(dist, r[i]) - settrans!(vi, false, vn) + vi[vn, dist] = r[i] setorder!(vi, vn, get_num_produce(vi)) end else - r = reshape(vi[vec(vns)], size(vns)) + r = vi[vns, dists] end else + @assert !islinked(vi) f = (vn, dist) -> init(dist, spl) r = f.(vns, dists) push!.(Ref(vi), vns, r, dists, Ref(spl)) - settrans!.(Ref(vi), false, vns) end return r end @@ -354,7 +354,7 @@ function set_val!( ) @assert size(val, 2) == length(vns) foreach(enumerate(vns)) do (i, vn) - vi[vn] = val[:,i] + vi[vn, dist] = val[:,i] end return val end @@ -367,7 +367,7 @@ function set_val!( @assert size(val) == size(vns) foreach(CartesianIndices(val)) do ind dist = dists isa AbstractArray ? dists[ind] : dists - vi[vns[ind]] = vectorize(dist, val[ind]) + vi[vns[ind], dist] = val[ind] end return val end diff --git a/src/prob_macro.jl b/src/prob_macro.jl index b047f711f..452a2863b 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -235,7 +235,6 @@ _setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c) for vn in md.$n.vns val = copy.(vec(c[Symbol(string(vn))].value)) setval!(vi, val, vn) - settrans!(vi, false, vn) end end end...) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 996934c53..cb1b084c5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -39,6 +39,11 @@ function setlogp!(vi::ThreadSafeVarInfo, logp) return setlogp!(vi.varinfo, logp) end +Bijectors.link(vi::ThreadSafeVarInfo) = ThreadSafeVarInfo(link(vi.varinfo), vi.logps) +Bijectors.invlink(vi::ThreadSafeVarInfo) = ThreadSafeVarInfo(invlink(vi.varinfo), vi.logps) +initlink(vi::ThreadSafeVarInfo) = ThreadSafeVarInfo(initlink(vi.varinfo), vi.logps) + +getrange(vi::ThreadSafeVarInfo, vn::VarName) = getrange(vi.varinfo, vn) get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo) reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) @@ -50,20 +55,27 @@ function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) setgid!(vi.varinfo, gid, vn) end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) -setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) +getmode(vi::ThreadSafeVarInfo) = getmode(vi.varinfo) +issynced(vi::ThreadSafeVarInfo) = issynced(vi.varinfo) +function setsynced!(vi::ThreadSafeVarInfo, b::Bool) + setsynced!(vi.varinfo, b) + return vi +end +getmetadata(vi::ThreadSafeVarInfo, vn::VarName) = getmetadata(vi.varinfo, vn) -link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) -invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) +init_dist_link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = init_dist_link!(vi.varinfo, spl) +init_dist_invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = init_dist_invlink!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) +getinitdist(vi::ThreadSafeVarInfo, vn::VarName) = getinitdist(vi.varinfo, vn) +has_fixed_support(vi::ThreadSafeVarInfo) = has_fixed_support(vi.varinfo) +set_fixed_support!(vi::ThreadSafeVarInfo, b::Bool) = set_fixed_support!(vi.varinfo, b) getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = getindex(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, spl) -getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) -getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) function setindex!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) setindex!(vi.varinfo, val, spl) @@ -85,6 +97,10 @@ function empty!(vi::ThreadSafeVarInfo) fill!(vi.logps, zero(getlogp(vi))) return vi end +function empty!(vi::ThreadSafeVarInfo, spl::AbstractSampler) + empty!(vi.varinfo, spl) + return vi +end function push!( vi::ThreadSafeVarInfo, diff --git a/src/varinfo/ad.jl b/src/varinfo/ad.jl new file mode 100644 index 000000000..71fa3a6b7 --- /dev/null +++ b/src/varinfo/ad.jl @@ -0,0 +1,16 @@ +function __init__() + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin + value(x::ForwardDiff.Dual) = ForwardDiff.value(x) + value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x) + value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) + value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) + end + @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + value(x::Tracker.TrackedReal) = Tracker.data(x) + value(x::Tracker.TrackedArray) = Tracker.data(x) + value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x) + end +end \ No newline at end of file diff --git a/src/varinfo/indexing.jl b/src/varinfo/indexing.jl index f1760e685..69909463e 100644 --- a/src/varinfo/indexing.jl +++ b/src/varinfo/indexing.jl @@ -1,3 +1,5 @@ +## Vectorized value getters and setters ## + const VarView = Union{Int, UnitRange, Vector{Int}} """ @@ -5,17 +7,24 @@ const VarView = Union{Int, UnitRange, Vector{Int}} Return a view `vi.vals[vview]`. """ -getval(vi::UntypedVarInfo, vview::VarView) = view(vi.metadata.vals, vview) +function getval(vi::UntypedVarInfo, vview::VarView) + vals = getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals + return view(vals, vview) +end """ setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) Set the value of `vi.vals[vview]` to `val`. """ -setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val +function setval!(vi::UntypedVarInfo, val, vview::VarView) + vals = getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals + return vals[vview] = val +end function setval!(vi::UntypedVarInfo, val, vview::Vector{UnitRange}) + vals = getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals if length(vview) > 0 - vi.metadata.vals[[i for arr in vview for i in arr]] = val + vals[[i for arr in vview for i in arr]] = val end return val end @@ -27,7 +36,11 @@ Return the value(s) of `vn`. The values may or may not be transformed to Euclidean space. """ -getval(vi::VarInfo, vn::VarName) = view(getmetadata(vi, vn).vals, getrange(vi, vn)) +function getval(vi::AbstractVarInfo, vn::VarName) + metadata = getmetadata(vi, vn) + vals = getmode(vi) isa LinkMode ? metadata.trans_vals : metadata.vals + return view(vals, getrange(vi, vn)) +end """ setval!(vi::VarInfo, val, vn::VarName) @@ -36,7 +49,11 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. The values may or may not be transformed to Euclidean space. """ -setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = val +function setval!(vi::AbstractVarInfo, val, vn::VarName) + metadata = getmetadata(vi, vn) + vals = getmode(vi) isa LinkMode ? metadata.trans_vals : metadata.vals + return vals[getrange(vi, vn)] = val +end """ getval(vi::VarInfo, vns::Vector{<:VarName}) @@ -56,12 +73,20 @@ Return the values of all the variables in `vi`. The values may or may not be transformed to Euclidean space. """ -getall(vi::UntypedVarInfo) = vi.metadata.vals -getall(vi::TypedVarInfo) = vcat(_getall(vi.metadata)...) -@generated function _getall(metadata::NamedTuple{names}) where {names} +function getall(vi::UntypedVarInfo) + return getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals +end +function getall(vi::TypedVarInfo) + return vcat(_getall(vi.metadata, Val(getmode(vi) isa LinkMode))...) +end +@generated function _getall(metadata::NamedTuple{names}, ::Val{linked}) where {names, linked} exprs = [] for f in names - push!(exprs, :(metadata.$f.vals)) + if linked + push!(exprs, :(metadata.$f.trans_vals)) + else + push!(exprs, :(metadata.$f.vals)) + end end return :($(exprs...),) end @@ -73,9 +98,23 @@ Set the values of all the variables in `vi` to `val`. The values may or may not be transformed to Euclidean space. """ -setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val -setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) -@generated function _setall!(metadata::NamedTuple{names}, val, start = 0) where {names} +function setall!(vi::UntypedVarInfo, val) + vals = getmode(vi) isa LinkMode ? vi.metadata.trans_vals : vi.metadata.vals + return vals .= val +end +setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val, Val(getmode(vi) isa LinkMode)) +@generated function _setall!(metadata::NamedTuple{names}, val, ::Val{true}, start = 0) where {names} + expr = Expr(:block) + start = :(1) + for f in names + length = :(length(metadata.$f.trans_vals)) + finish = :($start + $length - 1) + push!(expr.args, :(metadata.$f.trans_vals .= val[$start:$finish])) + start = :($start + $length) + end + return expr +end +@generated function _setall!(metadata::NamedTuple{names}, val, ::Val{false}, start = 0) where {names} expr = Expr(:block) start = :(1) for f in names @@ -87,32 +126,111 @@ setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) return expr end -# The default getindex & setindex!() for get & set values -# NOTE: vi[vn] will always transform the variable to its original space and Julia type +## VarName getindex and setindex! ## + +function zygote_setval!(vi, val, vn) + return setval!(vi, val, vn) +end + """ - getindex(vi::VarInfo, vn::VarName) - getindex(vi::VarInfo, vns::Vector{<:VarName}) + getindex(vi::VarInfo, vn::VarName, dist::Distribution) + getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) -distribution(s). +distribution(s) `dist`. If the value(s) is (are) transformed to the Euclidean space, it is (they are) transformed back. """ -function getindex(vi::AbstractVarInfo, vn::VarName) +function Base.getindex( + vi::AbstractVarInfo, + vn::VarName, + dist::Distribution, +) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vn) - return istrans(vi, vn) ? - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn))) : - reconstruct(dist, getval(vi, vn)) + trans = istrans(vi, vn) + if has_fixed_support(vi) + set_fixed_support!(vi, bijector(dist) == bijector(getinitdist(vi, vn))) + end + if getmode(vi) isa LinkMode && trans + trans_val = reconstruct(dist, getval(vi, vn)) + val = Bijectors.invlink(dist, trans_val) + zygote_setval!(invlink(vi), value(vectorize(dist, val)), vn) + elseif getmode(vi) isa InitLinkMode && trans + val = reconstruct(dist, getval(vi, vn)) + trans_val = Bijectors.link(dist, val) + zygote_setval!(link(vi), vectorize(dist, trans_val), vn) + else + val = reconstruct(dist, getval(vi, vn)) + end + return val end -function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) - @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - dist = getdist(vi, vns[1]) - return istrans(vi, vns[1]) ? - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) : - reconstruct(dist, getval(vi, vns), length(vns)) +function Base.getindex( + vi::AbstractVarInfo, + vn::VarName, +) + @assert getmode(vi) isa StandardMode + @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" + return reconstruct(getinitdist(vi, vn), getval(vi, vn)) +end +function Base.getindex( + vi::AbstractVarInfo, + vns::AbstractVector{<:VarName}, + dist::MultivariateDistribution, +) + return mapreduce(hcat, vns) do vn + vi[vn, dist] + end +end +function Base.getindex( + vi::AbstractVarInfo, + vns::AbstractArray{<:VarName}, + dists::Union{Distribution, AbstractArray{<:Distribution}}, +) + return broadcast(vns, dists) do vn, dist + vi[vn, dist] + end end +function Base.getindex( + vi::AbstractVarInfo, + vns::Vector{<:VarName}, +) + return map(vns) do vn + vi[vn] + end +end + +""" + setindex!(vi::VarInfo, val, vn::VarName) + +Set the current value(s) of the random variable `vn` in `vi` to `val`. + +The value(s) may or may not be transformed to Euclidean space. +""" +function setindex!(vi::AbstractVarInfo, val, vn::VarName, dist::Distribution) + @assert haskey(vi, vn) "[DynamicPPL] variable not found in VarInfo." + trans = istrans(vi, vn) + if getmode(vi) isa LinkMode && trans + trans_val = Bijectors.link(dist, val) + setval!(vi, vectorize(dist, trans_val), vn) + setval!(invlink(vi), vectorize(dist, val), vn) + elseif getmode(vi) isa InitLinkMode && trans + trans_val = Bijectors.link(dist, val) + setval!(vi, vectorize(dist, val), vn) + setval!(link(vi), vectorize(dist, trans_val), vn) + else + setval!(vi, vectorize(dist, val), vn) + end + return vi +end +function setindex!(vi::AbstractVarInfo, val, vn::VarName) + @assert getmode(vi) isa StandardMode + @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" + setval!(vi, vectorize(getinitdist(vi, vn), val), vn) + return vi +end + +## Sampler getindex and setindex! ## """ getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) @@ -121,32 +239,41 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`. The value(s) may or may not be transformed to Euclidean space. """ -getindex(vi::AbstractVarInfo, spl::SampleFromPrior) = copy(getall(vi)) -getindex(vi::AbstractVarInfo, spl::SampleFromUniform) = copy(getall(vi)) -getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) +function getindex(vi::AbstractVarInfo, spl::Union{SampleFromPrior, SampleFromUniform}) + return copy(getall(vi)) +end +function getindex(vi::UntypedVarInfo, spl::Sampler) + return copy(getval(vi, getranges(vi, spl))) +end function getindex(vi::TypedVarInfo, spl::Sampler) # Gets the ranges as a NamedTuple - ranges = _getranges(vi, spl) + ranges = getranges(vi, spl) # Calling getfield(ranges, f) gives all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi` - return vcat(_getindex(vi.metadata, ranges)...) + return vcat(_getindex(vi.metadata, ranges, Val(getmode(vi) isa LinkMode))...) end # Recursively builds a tuple of the `vals` of all the symbols -@generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} +@generated function _getindex( + metadata, + ranges::NamedTuple{names}, + ::Val{false}, +) where {names} expr = Expr(:tuple) for f in names push!(expr.args, :(metadata.$f.vals[ranges.$f])) end return expr end - -""" - setindex!(vi::VarInfo, val, vn::VarName) - -Set the current value(s) of the random variable `vn` in `vi` to `val`. - -The value(s) may or may not be transformed to Euclidean space. -""" -setindex!(vi::AbstractVarInfo, val, vn::VarName) = setval!(vi, val, vn) +@generated function _getindex( + metadata, + ranges::NamedTuple{names}, + ::Val{true}, +) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :(metadata.$f.trans_vals[ranges.$f])) + end + return expr +end """ setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) @@ -155,20 +282,34 @@ Set the current value(s) of the random variables sampled by `spl` in `vi` to `va The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!(vi, val) -setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) +function setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) + setall!(vi, val) + setsynced!(vi, false) + return vi +end +function setindex!(vi::UntypedVarInfo, val, spl::Sampler) + setval!(vi, val, getranges(vi, spl)) + setsynced!(vi, false) + return vi +end function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` - ranges = _getranges(vi, spl) - _setindex!(vi.metadata, val, ranges) - return val + ranges = getranges(vi, spl) + _setindex!(vi.metadata, val, ranges, Val(getmode(vi) isa LinkMode)) + setsynced!(vi, false) + return vi end # Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. -@generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} +@generated function _setindex!( + metadata, + val, + ranges::NamedTuple{names}, + ::Val{linked}, +) where {names, linked} expr = Expr(:block) offset = :(0) for f in names - f_vals = :(metadata.$f.vals) + f_vals = linked ? :(metadata.$f.trans_vals) : :(metadata.$f.vals) f_range = :(ranges.$f) start = :($offset + 1) len = :(length($f_range)) @@ -215,27 +356,34 @@ function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Sel return push!(vi, vn, r, dist, Set([gid])) end function push!( - vi::VarInfo, - vn::VarName, - r, - dist::Distribution, - gidset::Set{Selector} - ) - + vi::VarInfo, + vn::VarName, + val, + dist::Distribution, + gidset::Set{Selector}, +) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo @assert ~(haskey(vi, vn)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" end - val = vectorize(dist, r) - meta = getmetadata(vi, vn) meta.idcs[vn] = length(meta.idcs) + 1 push!(meta.vns, vn) - l = length(meta.vals); n = length(val) + + vectorized_val = vectorize(dist, val) + l = length(meta.vals); n = length(vectorized_val) push!(meta.ranges, l+1:l+n) - append!(meta.vals, val) + if getmode(vi) isa LinkMode || getmode(vi) isa InitLinkMode + append!(meta.vals, vectorized_val) + trans_val = Bijectors.link(dist, val) + append!(meta.trans_vals, vectorize(dist, trans_val)) + else + append!(meta.vals, vectorized_val) + append!(meta.trans_vals, vectorized_val) + setsynced!(vi, false) + end push!(meta.dists, dist) push!(meta.gids, gidset) push!(meta.orders, get_num_produce(vi)) diff --git a/src/varinfo/linking.jl b/src/varinfo/linking.jl index a55876df2..497bb7d2b 100644 --- a/src/varinfo/linking.jl +++ b/src/varinfo/linking.jl @@ -1,44 +1,95 @@ -# X -> R for all variables associated with given sampler """ - link!(vi::VarInfo, spl::Sampler) + islinked(vi::VarInfo, spl::Sampler) + +Check whether `vi` is in the transformed space for a particular sampler `spl`. +Turing's Hamiltonian samplers use the `link` and `invlink` functions from +[Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable +(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of +real numbers. `islinked` checks if the number is in the constrained space or the real space. +""" +function islinked(vi::UntypedVarInfo, spl::Sampler) + vns = getvns(vi, spl) + return islinked(vi) && istrans(vi, vns[1]) +end +function islinked(vi::TypedVarInfo, spl::Sampler) + vns = getvns(vi, spl) + return islinked(vi) && _islinked(vi, vns) +end +@generated function _islinked(vi, vns::NamedTuple{names}) where {names} + out = [] + for f in names + push!(out, :(length(vns.$f) == 0 ? false : istrans(vi, vns.$f[1]))) + end + return Expr(:||, false, out...) +end +function islinked_and_trans(vi::AbstractVarInfo, vn::VarName) + return islinked(vi) && istrans(vi, vn) +end + +function Bijectors.link(vi::VarInfo) + return VarInfo( + vi.metadata, + vi.logp, + vi.num_produce, + LinkMode(), + vi.fixed_support, + vi.synced, + ) +end +function initlink(vi::VarInfo) + return VarInfo( + vi.metadata, + vi.logp, + vi.num_produce, + InitLinkMode(), + vi.fixed_support, + vi.synced, + ) +end +function Bijectors.invlink(vi::VarInfo) + return VarInfo( + vi.metadata, + vi.logp, + vi.num_produce, + StandardMode(), + vi.fixed_support, + vi.synced, + ) +end +islinked(vi::AbstractVarInfo) = getmode(vi) isa LinkMode || getmode(vi) isa InitLinkMode + +# X -> R for all variables associated with given sampler +""" + init_dist_link!(vi::VarInfo, spl::Sampler) Transform the values of the random variables sampled by `spl` in `vi` from the support of their distributions to the Euclidean space and set their corresponding `"trans"` flag values to `true`. """ -function link!(vi::UntypedVarInfo, spl::Sampler) +function init_dist_link!(vi::UntypedVarInfo, spl::Sampler) # TODO: Change to a lazy iterator over `vns` - vns = _getvns(vi, spl) - if ~istrans(vi, vns[1]) - for vn in vns - dist = getdist(vi, vn) - # TODO: Use inplace versions to avoid allocations - setval!(vi, vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") + vns = getvns(vi, spl) + for vn in vns + dist = getinitdist(vi, vn) + initlink(vi)[vn, dist] end + return vi end -function link!(vi::TypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) - return _link!(vi.metadata, vi, vns, Val(getspace(spl))) +function init_dist_link!(vi::TypedVarInfo, spl::AbstractSampler) + vns = getvns(vi, spl) + _init_dist_link!(vi.metadata, vi, vns, Val(getspace(spl))) + return vi end -@generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} +@generated function _init_dist_link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} expr = Expr(:block) for f in names if inspace(f, space) || length(space) == 0 push!(expr.args, quote f_vns = vi.metadata.$f.vns - if ~istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") + # Iterate over all `f_vns` and transform + for vn in f_vns + dist = getinitdist(vi, vn) + initlink(vi)[vn, dist] end end) end @@ -46,45 +97,61 @@ end return expr end +function invlink!(vi::AbstractVarInfo, spl::AbstractSampler, model) + settrans!(vi, spl) + if !issynced(vi) + if has_fixed_support(vi) + init_dist_invlink!(vi, spl) + else + model(link(vi), spl) + end + setsynced!(vi, true) + end + return vi +end +function link!(vi::AbstractVarInfo, spl::AbstractSampler, model) + settrans!(vi, spl) + if !issynced(vi) + if has_fixed_support(vi) + init_dist_link!(vi, spl) + else + model(initlink(vi), spl) + end + setsynced!(vi, true) + end + return vi +end + # R -> X for all variables associated with given sampler """ - invlink!(vi::VarInfo, spl::AbstractSampler) - + init_dist_invlink!(vi::VarInfo, spl::AbstractSampler) Transform the values of the random variables sampled by `spl` in `vi` from the Euclidean space back to the support of their distributions and sets their corresponding `"trans"` flag values to `false`. """ -function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) - if istrans(vi, vns[1]) - for vn in vns - dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") +function init_dist_invlink!(vi::UntypedVarInfo, spl::AbstractSampler) + vns = getvns(vi, spl) + for vn in vns + dist = getinitdist(vi, vn) + link(vi)[vn, dist] end + return vi end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) - return _invlink!(vi.metadata, vi, vns, Val(getspace(spl))) +function init_dist_invlink!(vi::TypedVarInfo, spl::AbstractSampler) + vns = getvns(vi, spl) + _init_dist_invlink!(vi.metadata, vi, vns, Val(getspace(spl))) + return vi end -@generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} +@generated function _init_dist_invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} expr = Expr(:block) for f in names if inspace(f, space) || length(space) == 0 push!(expr.args, quote f_vns = vi.metadata.$f.vns - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") + # Iterate over all `f_vns` and transform + for vn in f_vns + dist = getinitdist(vi, vn) + link(vi)[vn, dist] end end) end @@ -92,44 +159,47 @@ end return expr end - +# X -> R for all variables associated with given sampler """ - islinked(vi::VarInfo, spl::Sampler) + settrans!(vi::VarInfo, spl::Sampler) -Check whether `vi` is in the transformed space for a particular sampler `spl`. - -Turing's Hamiltonian samplers use the `link` and `invlink` functions from -[Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable -(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of -real numbers. `islinked` checks if the number is in the constrained space or the real space. +Set the `"trans"` flag to `true` for all the vaiables in the space of `spl`. """ -function islinked(vi::UntypedVarInfo, spl::Sampler) - vns = _getvns(vi, spl) - return istrans(vi, vns[1]) +function settrans!(vi::UntypedVarInfo, spl::Sampler) + # TODO: Change to a lazy iterator over `vns` + vns = getvns(vi, spl) + if ~istrans(vi, vns[1]) + for vn in vns + settrans!(vi, true, vn) + end + end + return vi end -function islinked(vi::TypedVarInfo, spl::Sampler) - vns = _getvns(vi, spl) - return _islinked(vi, vns) +function settrans!(vi::TypedVarInfo, spl::AbstractSampler) + vns = getvns(vi, spl) + _settrans!(vi.metadata, vi, vns, Val(getspace(spl))) + return vi end -@generated function _islinked(vi, vns::NamedTuple{names}) where {names} - out = [] +@generated function _settrans!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} + expr = Expr(:block) for f in names - push!(out, :(length(vns.$f) == 0 ? false : istrans(vi, vns.$f[1]))) + if inspace(f, space) || length(space) == 0 + push!(expr.args, quote + f_vns = vi.metadata.$f.vns + if ~istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + settrans!(vi, true, vn) + end + end + end) + end end - return Expr(:||, false, out...) + return expr end -""" - istrans(vi::VarInfo, vn::VarName) - -Return true if `vn`'s values in `vi` are transformed to Euclidean space, and false if -they are in the support of `vn`'s distribution. -""" -istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") - """ settrans!(vi::VarInfo, trans::Bool, vn::VarName) - Set the `trans` flag value of `vn` in `vi`. """ function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) diff --git a/src/varinfo/types.jl b/src/varinfo/types.jl index 0ad82de02..23064ddf9 100644 --- a/src/varinfo/types.jl +++ b/src/varinfo/types.jl @@ -32,7 +32,7 @@ When sampling, the first iteration uses a type unstable `Metadata` for all the variables then a specialized `Metadata` is used for each symbol along with a function barrier to make the rest of the sampling type stable. """ -struct Metadata{TIdcs <: Dict{<:VarName,Int}, TDists <: AbstractVector{<:Distribution}, TVN <: AbstractVector{<:VarName}, TVal <: AbstractVector{<:Real}, TGIds <: AbstractVector{Set{Selector}}} +struct Metadata{TIdcs <: Dict{<:VarName,Int}, TDists <: AbstractVector{<:Distribution}, TVN <: AbstractVector{<:VarName}, TVal <: AbstractVector{<:Real}, TTransVal <: AbstractVector{<:Real}, TGIds <: AbstractVector{Set{Selector}}} # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` idcs :: TIdcs # Dict{<:VarName,Int} @@ -47,6 +47,10 @@ struct Metadata{TIdcs <: Dict{<:VarName,Int}, TDists <: AbstractVector{<:Distrib # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` vals :: TVal # AbstractVector{<:Real} + # Vector of the transformed values of all the univariate, multivariate and matrix + # variablse. The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` + trans_vals :: TTransVal # AbstractVector{<:Real} + # Vector of distributions correpsonding to `vns` dists :: TDists # AbstractVector{<:Distribution} @@ -68,6 +72,7 @@ Construct an empty type unstable instance of `Metadata`. """ function Metadata() vals = Vector{Real}() + trans_vals = Vector{Real}() flags = Dict{String, BitVector}() flags["del"] = BitVector() flags["trans"] = BitVector() @@ -77,6 +82,7 @@ function Metadata() Vector{VarName}(), Vector{UnitRange{Int}}(), vals, + trans_vals, Vector{Distribution}(), Vector{Set{Selector}}(), Vector{Int}(), @@ -88,6 +94,47 @@ end # VarInfo # ########### +abstract type VarInfoMode end + +""" + LinkMode + +For any random variable whose `"trans"` flag is set to `true`: +1. The transformed values are used in `getindex` and `setindex!`. +2. The untransformed values are computed and cached, and +3. The `logpdf_with_trans` is computed with `trans` set as `true`. + +For random variables whose `"trans"` flag is set to `false`, this is equivalent to +the `StandardMode`. This model can be used when running HMC or MAP in the +unconstrained space. +""" +struct LinkMode <: VarInfoMode end + +""" + InitLinkMode + +For any random variable whose `"trans"` flag is set to `true`: +1. The untransformed values are used in `getindex` and `setindex!`. +2. The transformed values are computed and cached, and +3. The `logpdf_with_trans` is computed with `trans` set as `true`. + +For random variables whose `"trans"` flag is set to `false`, this is equivalent to +the `StandardMode`. This mode can be used to initialize a `VarInfo` for HMC or MAP. +""" +struct InitLinkMode <: VarInfoMode end + +""" + StandardMode + +For all random variables: +1. The untransformed values are used in `getindex` and `setindex!`. +2. The `logpdf` is computed, ie. `logpdf_with_trans` with `trans` as `false`. + +This mode can be used when running non-HMC samplers or when doing MAP on the +constrained support directly. +""" +struct StandardMode <: VarInfoMode end + """ ``` struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo @@ -109,48 +156,72 @@ Note: It is the user's responsibility to ensure that each "symbol" is visited at once whenever the model is called, regardless of any stochastic branching. Each symbol refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. """ -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo +struct VarInfo{Tmeta, Tlogp, Tmode <: VarInfoMode} <: AbstractVarInfo metadata::Tmeta logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} + mode::Tmode + fixed_support::Base.RefValue{Bool} + synced::Base.RefValue{Bool} end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) function VarInfo(model::Model, ctx = DefaultContext()) vi = VarInfo() model(vi, SampleFromPrior(), ctx) return TypedVarInfo(vi) end - function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) new_vi = deepcopy(old_vi) new_vi[spl] = x return new_vi end function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) - md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) - VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))) + md = newmetadata(old_vi.metadata, Val(getspace(spl)), x, Val(getmode(old_vi) isa LinkMode)) + return VarInfo( + md, + Base.RefValue{eltype(x)}(getlogp(old_vi)), + Ref(get_num_produce(old_vi)), + old_vi.mode, + old_vi.fixed_support, + Ref(false), + ) end -@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space} +@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x, ::Val{islinked}) where {names, space, islinked} exprs = [] offset = :(0) for f in names mdf = :(metadata.$f) if inspace(f, space) || length(space) == 0 len = :(length($mdf.vals)) - push!(exprs, :($f = Metadata($mdf.idcs, - $mdf.vns, - $mdf.ranges, - x[($offset + 1):($offset + $len)], - $mdf.dists, - $mdf.gids, - $mdf.orders, - $mdf.flags - ) - ) - ) + if islinked + push!(exprs, :($f = Metadata($mdf.idcs, + $mdf.vns, + $mdf.ranges, + $mdf.vals, + x[($offset + 1):($offset + $len)], + $mdf.dists, + $mdf.gids, + $mdf.orders, + $mdf.flags + ) + ) + ) + else + push!(exprs, :($f = Metadata($mdf.idcs, + $mdf.vns, + $mdf.ranges, + x[($offset + 1):($offset + $len)], + $mdf.trans_vals, + $mdf.dists, + $mdf.gids, + $mdf.orders, + $mdf.flags + ) + ) + ) + end offset = :($offset + $len) else push!(exprs, :($f = $mdf)) @@ -160,6 +231,8 @@ end return :($(exprs...),) end +VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0), StandardMode(), Ref(true), Ref(false)) + """ TypedVarInfo(vi::UntypedVarInfo) @@ -197,6 +270,7 @@ function TypedVarInfo(vi::UntypedVarInfo) _ranges = getindex.((meta.ranges,), inds) # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] + _trans_vals = [copy.(meta.trans_vals[_ranges[i]]) for i in 1:n] sym_ranges = Vector{eltype(_ranges)}(undef, n) start = 0 for i in 1:n @@ -204,17 +278,28 @@ function TypedVarInfo(vi::UntypedVarInfo) start += length(_vals[i]) end sym_vals = foldl(vcat, _vals) + sym_trans_vals = foldl(vcat, _trans_vals) - push!(new_metas, Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, - sym_dists, sym_gids, sym_orders, sym_flags)) + push!( + new_metas, + Metadata( + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_trans_vals, + sym_dists, sym_gids, sym_orders, sym_flags + ) + ) end logp = getlogp(vi) num_produce = get_num_produce(vi) nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, Ref(logp), Ref(num_produce), vi.mode, vi.fixed_support, vi.synced) end TypedVarInfo(vi::TypedVarInfo) = vi + +#### +#### Printing +#### + function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) vi_str = """ /======================================================================= diff --git a/src/varinfo/utils.jl b/src/varinfo/utils.jl index 2ed244949..c8d308207 100644 --- a/src/varinfo/utils.jl +++ b/src/varinfo/utils.jl @@ -1,3 +1,13 @@ +has_fixed_support(vi::VarInfo) = vi.fixed_support[] +function set_fixed_support!(vi::VarInfo, b::Bool) + return vi.fixed_support[] = vi.fixed_support[] && b +end + +getmode(vi::VarInfo) = vi.mode +issynced(vi::VarInfo) = vi.synced[] +setsynced!(vi::VarInfo, b::Bool) = vi.synced[] = b +value(x) = x + """ empty!(meta::Metadata) @@ -10,6 +20,7 @@ function empty!(meta::Metadata) empty!(meta.vns) empty!(meta.ranges) empty!(meta.vals) + empty!(meta.trans_vals) empty!(meta.dists) empty!(meta.gids) empty!(meta.orders) @@ -20,6 +31,30 @@ function empty!(meta::Metadata) return meta end +""" + empty!(vi::VarInfo) + +Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to +zeros. + +This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. +""" +function empty!(vi::VarInfo) + _empty!(vi.metadata) + resetlogp!(vi) + reset_num_produce!(vi) + setsynced!(vi, false) + return vi +end +@inline _empty!(metadata::Metadata) = empty!(metadata) +@generated function _empty!(metadata::NamedTuple{names}) where {names} + expr = Expr(:block) + for f in names + push!(expr.args, :(empty!(metadata.$f))) + end + return expr +end + # Removes the first element of a NamedTuple. The pairs in a NamedTuple are ordered, so this is well-defined. if VERSION < v"1.1" _tail(nt::NamedTuple{names}) where names = NamedTuple{Base.tail(names)}(nt) @@ -59,11 +94,11 @@ function getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) end """ - getdist(vi::VarInfo, vn::VarName) + getinitdist(vi::VarInfo, vn::VarName) Return the distribution from which `vn` was sampled in `vi`. """ -getdist(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).dists[getidx(vi, vn)] +getinitdist(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).dists[getidx(vi, vn)] """ getgid(vi::VarInfo, vn::VarName) @@ -83,14 +118,14 @@ syms(vi::TypedVarInfo) = keys(vi.metadata) # Get all indices of variables belonging to SampleFromPrior: # if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to # the SampleFromPrior sampler -@inline function _getidcs(vi::UntypedVarInfo, ::SampleFromPrior) +@inline function getidcs(vi::UntypedVarInfo, ::SampleFromPrior) return filter(i -> isempty(vi.metadata.gids[i]) , 1:length(vi.metadata.gids)) end # Get a NamedTuple of all the indices belonging to SampleFromPrior, one for each symbol -@inline function _getidcs(vi::TypedVarInfo, ::SampleFromPrior) - return _getidcs(vi.metadata) +@inline function getidcs(vi::TypedVarInfo, ::SampleFromPrior) + return getidcs(vi.metadata) end -@generated function _getidcs(metadata::NamedTuple{names}) where {names} +@generated function getidcs(metadata::NamedTuple{names}) where {names} exprs = [] for f in names push!(exprs, :($f = findinds(metadata.$f))) @@ -100,26 +135,13 @@ end end # Get all indices of variables belonging to a given sampler -@inline function _getidcs(vi::AbstractVarInfo, spl::Sampler) - # NOTE: 0b00 is the sanity flag for - # |\____ getidcs (mask = 0b10) - # \_____ getranges (mask = 0b01) - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - # Checks if cache is valid, i.e. no new pushes were made, to return the cached idcs - # Otherwise, it recomputes the idcs and caches it - #if haskey(spl.info, :idcs) && (spl.info[:cache_updated] & CACHEIDCS) > 0 - # spl.info[:idcs] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS - idcs = _getidcs(vi, spl.selector, Val(getspace(spl))) - #spl.info[:idcs] = idcs - #end - return idcs +@inline function getidcs(vi::AbstractVarInfo, spl::Sampler) + return getidcs(vi, spl.selector, Val(getspace(spl))) end -@inline _getidcs(vi::UntypedVarInfo, s::Selector, space) = findinds(vi.metadata, s, space) -@inline _getidcs(vi::TypedVarInfo, s::Selector, space) = _getidcs(vi.metadata, s, space) +@inline getidcs(vi::UntypedVarInfo, s::Selector, space) = findinds(vi.metadata, s, space) +@inline getidcs(vi::TypedVarInfo, s::Selector, space) = getidcs(vi.metadata, s, space) # Get a NamedTuple for all the indices belonging to a given selector for each symbol -@generated function _getidcs(metadata::NamedTuple{names}, s::Selector, ::Val{space}) where {names, space} +@generated function getidcs(metadata::NamedTuple{names}, s::Selector, ::Val{space}) where {names, space} exprs = [] # Iterate through each varname in metadata. for f in names @@ -145,14 +167,14 @@ end end # Get all vns of variables belonging to spl -_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -_getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior, SampleFromUniform}) = _getvns(vi, Selector(), Val(())) -_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, _getidcs(vi, s, space)) -function _getvns(vi::TypedVarInfo, s::Selector, space) - return _getvns(vi.metadata, _getidcs(vi, s, space)) +getvns(vi::AbstractVarInfo, spl::Sampler) = getvns(vi, spl.selector, Val(getspace(spl))) +getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior, SampleFromUniform}) = getvns(vi, Selector(), Val(())) +getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, getidcs(vi, s, space)) +function getvns(vi::TypedVarInfo, s::Selector, space) + return getvns(vi.metadata, getidcs(vi, s, space)) end # Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol -@generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} +@generated function getvns(metadata, idcs::NamedTuple{names}) where {names} exprs = [] for f in names push!(exprs, :($f = metadata.$f.vns[idcs.$f])) @@ -162,28 +184,28 @@ end end # Get the index (in vals) ranges of all the vns of variables belonging to spl -@inline function _getranges(vi::AbstractVarInfo, spl::Sampler) +@inline function getranges(vi::AbstractVarInfo, spl::Sampler) ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 # spl.info[:ranges] #else #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES - ranges = _getranges(vi, spl.selector, Val(getspace(spl))) + ranges = getranges(vi, spl.selector, Val(getspace(spl))) #spl.info[:ranges] = ranges return ranges #end end # Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::AbstractVarInfo, s::Selector, space) - return _getranges(vi, _getidcs(vi, s, space)) +@inline function getranges(vi::AbstractVarInfo, s::Selector, space) + return getranges(vi, getidcs(vi, s, space)) end -@inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int}) +@inline function getranges(vi::UntypedVarInfo, idcs::Vector{Int}) mapreduce(i -> vi.metadata.ranges[i], vcat, idcs, init=Int[]) end -@inline _getranges(vi::TypedVarInfo, idcs::NamedTuple) = _getranges(vi.metadata, idcs) +@inline getranges(vi::TypedVarInfo, idcs::NamedTuple) = getranges(vi.metadata, idcs) -@generated function _getranges(metadata::NamedTuple, idcs::NamedTuple{names}) where {names} +@generated function getranges(metadata::NamedTuple, idcs::NamedTuple{names}) where {names} exprs = [] for f in names push!(exprs, :($f = findranges(metadata.$f.ranges, idcs.$f))) @@ -204,29 +226,6 @@ function set_flag!(vi::VarInfo, vn::VarName, flag::String) return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true end -""" - empty!(vi::VarInfo) - -Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to -zeros. - -This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. -""" -function empty!(vi::VarInfo) - _empty!(vi.metadata) - resetlogp!(vi) - reset_num_produce!(vi) - return vi -end -@inline _empty!(metadata::Metadata) = empty!(metadata) -@generated function _empty!(metadata::NamedTuple{names}) where {names} - expr = Expr(:block) - for f in names - push!(expr.args, :(empty!(metadata.$f))) - end - return expr -end - # Functions defined only for UntypedVarInfo """ keys(vi::UntypedVarInfo) @@ -240,7 +239,20 @@ keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -setgid!(vi::VarInfo, gid::Selector, vn::VarName) = push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) +function setgid!(vi::VarInfo, gid::Selector, vn::VarName) + push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) + return vi +end + +""" + istrans(vi::VarInfo, vn::VarName) + +Return true if `vn`'s values in `vi` are transformed to Euclidean space, and false if +they are in the support of `vn`'s distribution. +""" +function istrans(vi::AbstractVarInfo, vn::VarName) + return is_flagged(vi, vn, "trans") +end """ getlogp(vi::VarInfo) @@ -349,15 +361,14 @@ end return expr end -@inline function findvns(vi, f_vns) - if length(f_vns) == 0 - throw("Unidentified error, please report this error in an issue.") - end - return map(vn -> vi[vn], f_vns) -end - function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler, SampleFromPrior}) - return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi), typeof(spl)})) + T = eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi), typeof(spl)})) + if T === Union{} + # To throw a meaningful error + return eltype(vi[spl]) + else + return T + end end """ @@ -386,6 +397,10 @@ function setorder!(vi::VarInfo, vn::VarName, index::Int) return vi end +####################################### +# Rand & replaying method for VarInfo # +####################################### + """ is_flagged(vi::VarInfo, vn::VarName, flag::String) @@ -411,7 +426,7 @@ Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `tr """ function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) # Get the indices of `vns` that belong to `spl` as a vector - gidcs = _getidcs(vi, spl) + gidcs = getidcs(vi, spl) if get_num_produce(vi) == 0 for i = length(gidcs):-1:1 vi.metadata.flags["del"][gidcs[i]] = true @@ -427,7 +442,7 @@ function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) end function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler) # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol - gidcs = _getidcs(vi, spl) + gidcs = getidcs(vi, spl) return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi)) end @generated function _set_retained_vns_del_by_spl!(metadata, gidcs::NamedTuple{names}, num_produce) where {names} @@ -465,3 +480,61 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) setgid!(vi, spl.selector, vn) end end + +""" + set_namedtuple!(vi::VarInfo, nt::NamedTuple) + +Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. +""" +function set_namedtuple!(vi::VarInfo, nt::NamedTuple) + @assert !islinked(vi) + for (n, vals) in pairs(nt) + vns = vi.metadata[n].vns + + n_vns = length(vns) + n_vals = length(vals) + v_isarr = vals isa AbstractArray + + if v_isarr && n_vals == 1 && n_vns > 1 + for (vn, val) in zip(vns, vals[1]) + vi[vn] = val + end + elseif v_isarr && n_vals > 1 && n_vns == 1 + vi[vns[1]] = vals + elseif v_isarr && n_vals == n_vns > 1 + for (vn, val) in zip(vns, vals) + vi[vn] = val + end + elseif v_isarr && n_vals == 1 && n_vns == 1 + vi[vns[1]] = vals[1] + elseif !(v_isarr) + vi[vns[1]] = vals + else + error("Cannot assign `NamedTuple` to `VarInfo`") + end + end +end + +function updategid!(vi::UntypedVarInfo, spl::Sampler) + vns = getvns(vi, spl) + for vn in vns + updategid!(vi, vn, spl) + end + return vi +end +function updategid!(vi::TypedVarInfo, spl::Sampler) + vns = getvns(vi, spl) + _updategid!(vi, spl, vns) + return vi +end +@generated function _updategid!(vi, spl, vns::NamedTuple{names}) where names + expr = Expr(:block) + for n in names + push!(expr.args, quote + for vn in vi.metadata.$n.vns + updategid!(vi, vn, spl) + end + end) + end + return expr +end diff --git a/src/varinfo/varinfo.jl b/src/varinfo/varinfo.jl index 8fa7de2ad..83c870435 100644 --- a/src/varinfo/varinfo.jl +++ b/src/varinfo/varinfo.jl @@ -7,3 +7,4 @@ include("types.jl") include("utils.jl") include("indexing.jl") include("linking.jl") +include("ad.jl") diff --git a/test/Turing/contrib/inference/dynamichmc.jl b/test/Turing/contrib/inference/dynamichmc.jl index 17d1221d9..66829b01b 100644 --- a/test/Turing/contrib/inference/dynamichmc.jl +++ b/test/Turing/contrib/inference/dynamichmc.jl @@ -53,29 +53,16 @@ function AbstractMCMC.sample_init!( model::Model, spl::Sampler{<:DynamicNUTS}, N::Integer; - kwargs... + kwargs..., ) # Set up lp function. function _lp(x) - gradient_logp(x, spl.state.vi, model, spl) + gradient_logp(x, link(spl.state.vi), model, spl) end # Set the parameters to a starting value. initialize_parameters!(spl; kwargs...) - - model(spl.state.vi, SampleFromUniform()) - link!(spl.state.vi, spl) - l, dl = _lp(spl.state.vi[spl]) - while !isfinite(l) || !isfinite(dl) - model(spl.state.vi, SampleFromUniform()) - link!(spl.state.vi, spl) - l, dl = _lp(spl.state.vi[spl]) - end - - if spl.selector.tag == :default && !islinked(spl.state.vi, spl) - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - end + link!(spl.state.vi, spl, model) results = mcmc_with_warmup( rng, @@ -99,7 +86,8 @@ function AbstractMCMC.step!( ) # Pop the next draw off the vector. draw = popfirst!(spl.state.draws) - spl.state.vi[spl] = draw + link(spl.state.vi)[spl] = draw + invlink!(spl.state.vi, spl, model) return Transition(spl) end @@ -118,7 +106,7 @@ end # Disable the progress logging for DynamicHMC, since it has its own progress meter. function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, + model::DynamicPPL.AbstractModel, alg::DynamicNUTS, N::Integer; chain_type=MCMCChains.Chains, @@ -127,7 +115,7 @@ end kwargs... ) if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + @warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end if resume_from === nothing return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; @@ -139,7 +127,7 @@ end function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, + model::DynamicPPL.AbstractModel, alg::DynamicNUTS, parallel::AbstractMCMC.AbstractMCMCParallel, N::Integer, @@ -149,7 +137,7 @@ function AbstractMCMC.sample( kwargs... ) if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + @warn "[DynamicNUTS] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" end return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; chain_type=chain_type, progress=false, kwargs...) diff --git a/test/Turing/contrib/inference/sghmc.jl b/test/Turing/contrib/inference/sghmc.jl index 83c488613..b1f9dda05 100644 --- a/test/Turing/contrib/inference/sghmc.jl +++ b/test/Turing/contrib/inference/sghmc.jl @@ -61,13 +61,9 @@ function step( is_first::Val{true}; kwargs... ) - spl.selector.tag != :default && link!(vi, spl) - # Initialize velocity v = zeros(Float64, size(vi[spl])) spl.info[:v] = v - - spl.selector.tag != :default && invlink!(vi, spl) return vi, true end @@ -84,13 +80,12 @@ function step( Turing.DEBUG && @debug "X-> R..." if spl.selector.tag != :default - link!(vi, spl) - model(vi, spl) + model(initlink(vi), spl) end Turing.DEBUG && @debug "recording old variables..." θ, v = vi[spl], spl.info[:v] - _, grad = gradient_logp(θ, vi, model, spl) + _, grad = gradient_logp(θ, link(vi), model, spl) verifygrad(grad) # Implements the update equations from (15) of Chen et al. (2014). @@ -197,7 +192,7 @@ function step( Turing.DEBUG && @debug "X-> R..." if spl.selector.tag != :default - link!(vi, spl) + link!(vi, spl, model) model(vi, spl) end diff --git a/test/Turing/inference/AdvancedSMC.jl b/test/Turing/inference/AdvancedSMC.jl index 77a4c7090..5acfeb5dd 100644 --- a/test/Turing/inference/AdvancedSMC.jl +++ b/test/Turing/inference/AdvancedSMC.jl @@ -321,21 +321,21 @@ function DynamicPPL.assume( elseif is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = rand(dist) - vi[vn] = vectorize(dist, r) + vi[vn, dist] = r setgid!(vi, spl.selector, vn) setorder!(vi, vn, get_num_produce(vi)) else updategid!(vi, vn, spl) - r = vi[vn] + r = vi[vn, dist] end else # vn belongs to other sampler <=> conditionning on vn if haskey(vi, vn) - r = vi[vn] + r = vi[vn, dist] else r = rand(dist) push!(vi, vn, r, dist, Selector(:invalid)) end - lp = logpdf_with_trans(dist, r, istrans(vi, vn)) + lp = logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) acclogp!(vi, lp) end return r, 0 diff --git a/test/Turing/inference/Inference.jl b/test/Turing/inference/Inference.jl index b3db2982d..c4c8e86c0 100644 --- a/test/Turing/inference/Inference.jl +++ b/test/Turing/inference/Inference.jl @@ -3,9 +3,9 @@ module Inference using ..Core using ..Core: logZ using ..Utilities -using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, +using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, set_namedtuple!, islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, - settrans!, _getvns, getdist, CACHERESET, AbstractSampler, + settrans!, getvns, getinitdist, CACHERESET, AbstractSampler, Model, Sampler, SampleFromPrior, SampleFromUniform, Selector, AbstractSamplerState, DefaultContext, PriorContext, LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist, @@ -26,7 +26,7 @@ import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import ..Core: getchunksize, getADbackend import DynamicPPL: get_matching_type, - VarName, _getranges, _getindex, getval, _getvns + VarName, getval, getvns import EllipticalSliceSampling import Random import MCMCChains @@ -272,7 +272,6 @@ function initialize_parameters!( verbose::Bool=false, kwargs... ) - islinked(spl.state.vi, spl) && invlink!(spl.state.vi, spl) # Get `init_theta` if init_theta !== nothing verbose && @info "Using passed-in initial variable values" init_theta @@ -566,7 +565,7 @@ function get_matching_type( end function get_matching_type( spl::AbstractSampler, - vi, + vi, ::Type{<:AbstractFloat}, ) return floatof(eltype(vi, spl)) diff --git a/test/Turing/inference/ess.jl b/test/Turing/inference/ess.jl index 27a3b3f54..c649123e6 100644 --- a/test/Turing/inference/ess.jl +++ b/test/Turing/inference/ess.jl @@ -33,11 +33,11 @@ function Sampler(alg::ESS, model::Model, s::Selector) # sanity check vi = VarInfo(model) space = getspace(alg) - vns = _getvns(vi, s, Val(space)) + vns = getvns(vi, s, Val(space)) length(vns) == 1 || error("[ESS] does only support one variable ($(length(vns)) variables specified)") for vn in vns[1] - dist = getdist(vi, vn) + dist = getinitdist(vi, vn) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("[ESS] only supports Gaussian prior distributions") end @@ -102,9 +102,9 @@ end function ESSModel(model::Model, spl::Sampler{<:ESS}) vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) μ = mapreduce(vcat, vns[1]) do vn - dist = getdist(vi, vn) + dist = getinitdist(vi, vn) vectorize(dist, mean(dist)) end @@ -115,7 +115,7 @@ end function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ESSModel) spl = model.spl vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) set_flag!(vi, vns[1][1], "del") model.model(vi, spl) return vi[spl] diff --git a/test/Turing/inference/hmc.jl b/test/Turing/inference/hmc.jl index 3b8b95c95..d3b35d6a4 100644 --- a/test/Turing/inference/hmc.jl +++ b/test/Turing/inference/hmc.jl @@ -102,8 +102,8 @@ end function update_hamiltonian!(spl, model, n) metric = gen_metric(n, spl) - ℓπ = gen_logπ(spl.state.vi, spl, model) - ∂ℓπ∂θ = gen_∂logπ∂θ(spl.state.vi, spl, model) + ℓπ = gen_logπ(link(spl.state.vi), spl, model) + ∂ℓπ∂θ = gen_∂logπ∂θ(link(spl.state.vi), spl, model) spl.state.h = AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) return spl end @@ -125,24 +125,24 @@ function AbstractMCMC.sample_init!( initialize_parameters!(spl; verbose=verbose, kwargs...) if init_theta !== nothing # Doesn't support dynamic models - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) - else + elseif resume_from === nothing # Samples new values and sets trans to true, then computes the logp model(empty!(spl.state.vi), SampleFromUniform()) - link!(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) while !isfinite(spl.state.z.ℓπ.value) || !isfinite(spl.state.z.ℓπ.gradient) model(empty!(spl.state.vi), SampleFromUniform()) - link!(spl.state.vi, spl) - theta = spl.state.vi[spl] + link!(spl.state.vi, spl, model) + theta = link(spl.state.vi)[spl] update_hamiltonian!(spl, model, length(theta)) # Refresh the internal cache phase point z's hamiltonian energy. spl.state.z = AHMC.phasepoint(rng, theta, spl.state.h) @@ -166,14 +166,14 @@ function AbstractMCMC.sample_init!( spl.alg.n_adapts = 0 end end - + # Convert to transformed space if we're using # non-Gibbs sampling. - if !islinked(spl.state.vi, spl) && spl.selector.tag == :default - link!(spl.state.vi, spl) - model(spl.state.vi, spl) - elseif islinked(spl.state.vi, spl) && spl.selector.tag != :default - invlink!(spl.state.vi, spl) + if spl.selector.tag == :default + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) + else + invlink!(spl.state.vi, spl, model) model(spl.state.vi, spl) end end @@ -411,20 +411,22 @@ function AbstractMCMC.step!( spl.state.eval_num = 0 Turing.DEBUG && @debug "current ϵ: $ϵ" + updategid!(spl.state.vi, spl) # When a Gibbs component if spl.selector.tag != :default # Transform the space Turing.DEBUG && @debug "X-> R..." - link!(spl.state.vi, spl) - model(spl.state.vi, spl) + link!(spl.state.vi, spl, model) + model(link(spl.state.vi), spl) end # Get position and log density before transition - θ_old, log_density_old = spl.state.vi[spl], getlogp(spl.state.vi) + θ_old, θ_old_trans = spl.state.vi[spl], link(spl.state.vi)[spl] + log_density_old = getlogp(spl.state.vi) if spl.selector.tag != :default - update_hamiltonian!(spl, model, length(θ_old)) + update_hamiltonian!(spl, model, length(θ_old_trans)) resize!(spl.state.z.θ, length(θ_old)) - spl.state.z.θ .= θ_old + spl.state.z.θ .= θ_old_trans end # Transition @@ -443,17 +445,19 @@ function AbstractMCMC.step!( # Update `vi` based on acceptance if t.stat.is_accept - spl.state.vi[spl] = t.z.θ + link(spl.state.vi)[spl] = t.z.θ + invlink!(spl.state.vi, spl, model) setlogp!(spl.state.vi, t.stat.log_density) else spl.state.vi[spl] = θ_old + link(spl.state.vi)[spl] = θ_old_trans setlogp!(spl.state.vi, log_density_old) + DynamicPPL.setsynced!(spl.state.vi, true) end # Gibbs component specified cares # Transform the space back Turing.DEBUG && @debug "R -> X..." - spl.selector.tag != :default && invlink!(spl.state.vi, spl) return HamiltonianTransition(spl, t) end @@ -514,13 +518,13 @@ function DynamicPPL.assume( ) Turing.DEBUG && _debug("assuming...") updategid!(vi, vn, spl) - r = vi[vn] - # acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn))) + r = vi[vn, dist] + # acclogp!(vi, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn))) # r Turing.DEBUG && _debug("dist = $dist") Turing.DEBUG && _debug("vn = $vn") Turing.DEBUG && _debug("r = $r, typeof(r)=$(typeof(r))") - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function DynamicPPL.dot_assume( @@ -532,9 +536,9 @@ function DynamicPPL.dot_assume( ) @assert length(dist) == size(var, 1) updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] + r = vi[vns, dist] var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.dot_assume( spl::Sampler{<:Hamiltonian}, @@ -544,9 +548,9 @@ function DynamicPPL.dot_assume( vi, ) updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) + r = vi[vns, dists] var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans.(dists, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.observe( @@ -606,11 +610,11 @@ function HMCState( vi = spl.state.vi # Link everything if needed. - !islinked(vi, spl) && link!(vi, spl) + link!(vi, spl, model) # Get the initial log pdf and gradient functions. - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) - logπ = gen_logπ(vi, spl, model) + ∂logπ∂θ = gen_∂logπ∂θ(link(vi), spl, model) + logπ = gen_logπ(link(vi), spl, model) # Get the metric type. metricT = getmetricT(spl.alg) @@ -635,7 +639,7 @@ function HMCState( h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ. # Unlink everything. - invlink!(vi, spl) + invlink!(vi, spl, model) return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) end diff --git a/test/Turing/inference/mh.jl b/test/Turing/inference/mh.jl index 764c8959b..393816678 100644 --- a/test/Turing/inference/mh.jl +++ b/test/Turing/inference/mh.jl @@ -54,43 +54,6 @@ alg_str(::Sampler{<:MH}) = "MH" # Utility functions # ##################### -""" - set_namedtuple!(vi::VarInfo, nt::NamedTuple) - -Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. -""" -function set_namedtuple!(vi::VarInfo, nt::NamedTuple) - for (n, vals) in pairs(nt) - vns = vi.metadata[n].vns - - n_vns = length(vns) - n_vals = length(vals) - v_isarr = vals isa AbstractArray - - if v_isarr && n_vals == 1 && n_vns > 1 - for (vn, val) in zip(vns, vals[1]) - vi[vn] = val isa AbstractArray ? val : [val] - end - elseif v_isarr && n_vals > 1 && n_vns == 1 - vi[vns[1]] = vals - elseif v_isarr && n_vals == n_vns > 1 - for (vn, val) in zip(vns, vals) - vi[vn] = [val] - end - elseif v_isarr && n_vals == 1 && n_vns == 1 - if vals[1] isa AbstractArray - vi[vns[1]] = vals[1] - else - vi[vns[1]] = [vals[1]] - end - elseif !(v_isarr) - vi[vns[1]] = [vals] - else - error("Cannot assign `NamedTuple` to `VarInfo`") - end - end -end - """ MHLogDensityFunction @@ -136,7 +99,7 @@ The second `NamedTuple` has model symbols as keys and their stored values as val """ function dist_val_tuple(spl::Sampler{<:MH}) vi = spl.state.vi - vns = _getvns(vi, spl) + vns = getvns(vi, spl) dt = _dist_tuple(spl.alg.proposals, vi, vns) vt = _val_tuple(vi, vns) return dt, vt @@ -149,7 +112,7 @@ end isempty(names) === 0 && return :(NamedTuple()) expr = Expr(:tuple) expr.args = Any[ - :($name = reconstruct(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)), + :($name = reconstruct(unvectorize(DynamicPPL.getinitdist.(Ref(vi), vns.$name)), DynamicPPL.getval(vi, vns.$name))) for name in names] return expr @@ -168,7 +131,7 @@ end :($name = props.$name) else # Otherwise, use the default proposal. - :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)))) + :($name = AMH.StaticProposal(unvectorize(DynamicPPL.getinitdist.(Ref(vi), vns.$name)))) end for name in names] return expr end @@ -229,8 +192,8 @@ function DynamicPPL.assume( vi, ) updategid!(vi, vn, spl) - r = vi[vn] - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + r = vi[vn, dist] + return r, logpdf_with_trans(dist, r, islinked_and_trans(vi, vn)) end function DynamicPPL.dot_assume( @@ -244,9 +207,9 @@ function DynamicPPL.dot_assume( getvn = i -> VarName(vn, vn.indexing * "[:,$i]") vns = getvn.(1:size(var, 2)) updategid!.(Ref(vi), vns, Ref(spl)) - r = vi[vns] + r = vi[vns, dist] var .= r - return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.dot_assume( spl::Sampler{<:MH}, @@ -258,9 +221,9 @@ function DynamicPPL.dot_assume( getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]") vns = getvn.(CartesianIndices(var)) updategid!.(Ref(vi), vns, Ref(spl)) - r = reshape(vi[vec(vns)], size(var)) + r = vi[vns, dists] var .= r - return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return var, sum(logpdf_with_trans.(dists, r, islinked_and_trans(vi, vns[1]))) end function DynamicPPL.observe( diff --git a/test/runtests.jl b/test/runtests.jl index d029c84b5..46b33e81e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using DynamicPPL using Distributions using ForwardDiff +using ReverseDiff using Tracker using Zygote @@ -10,6 +11,7 @@ using Test dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1] include(dir*"/test/Turing/Turing.jl") using .Turing +using DynamicPPL: link, invlink, islinked turnprogress(false) diff --git a/test/test_utils/testing_functions.jl b/test/test_utils/testing_functions.jl index dc14ed125..5a6fee086 100644 --- a/test/test_utils/testing_functions.jl +++ b/test/test_utils/testing_functions.jl @@ -19,7 +19,7 @@ function randr(vi::Turing.VarInfo, else if count Turing.checkindex(vn, vi, spl) end Turing.updategid!(vi, vn, spl) - return vi[vn] + return vi[vn, dist] end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 8fddb313f..67015ee97 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -1,3 +1,6 @@ +dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1] +include(dir*"/test/test_utils/AllUtils.jl") + @testset "threadsafe.jl" begin @testset "constructor" begin vi = VarInfo(gdemo_default) diff --git a/test/varinfo.jl b/test/varinfo.jl index 5ec939807..28713f159 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,8 +1,8 @@ using .Turing, Random using AbstractMCMC: step! using DynamicPPL: Selector, reconstruct, invlink, CACHERESET, - SampleFromPrior, Sampler, SampleFromUniform, - _getidcs, set_retained_vns_del_by_spl!, is_flagged, + SampleFromPrior, Sampler, SampleFromUniform, getinitdist, + getidcs, set_retained_vns_del_by_spl!, is_flagged, set_flag!, unset_flag!, VarInfo, TypedVarInfo, getlogp, setlogp!, resetlogp!, acclogp!, vectorize, setorder!, updategid! @@ -32,7 +32,9 @@ include(dir*"/test/test_utils/AllUtils.jl") for f in fieldnames(typeof(tvi.metadata)) fmeta = getfield(tvi.metadata, f) for vn in fmeta.vns - @test tvi[vn] == vi[vn] + dist1 = getinitdist(tvi, vn) + dist2 = getinitdist(vi, vn) + @test tvi[vn, dist1] == vi[vn, dist2] ind = meta.idcs[vn] tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] @@ -77,16 +79,16 @@ include(dir*"/test/test_utils/AllUtils.jl") @test ~isempty(vi) @test haskey(vi, vn) - @test length(vi[vn]) == 1 + @test length(vi[vn, dist]) == 1 @test length(vi[SampleFromPrior()]) == 1 - @test vi[vn] == r + @test vi[vn, dist] == r @test vi[SampleFromPrior()][1] == r - vi[vn] = [2*r] - @test vi[vn] == 2*r + vi[vn, dist] = 2*r + @test vi[vn, dist] == 2*r @test vi[SampleFromPrior()][1] == 2*r vi[SampleFromPrior()] = [3*r] - @test vi[vn] == 3*r + @test vi[vn, dist] == 3*r @test vi[SampleFromPrior()][1] == 3*r empty!(vi) @@ -152,9 +154,9 @@ include(dir*"/test/test_utils/AllUtils.jl") test_varinfo!(vi) test_varinfo!(empty!(TypedVarInfo(vi))) end - @testset "link!" begin + @testset "link" begin # Test linking spl and vi: - # link!, invlink!, istrans + # link, invlink @model gdemo(x, y) = begin s ~ InverseGamma(2,3) m ~ Uniform(0, 2) @@ -167,33 +169,26 @@ include(dir*"/test/test_utils/AllUtils.jl") meta = vi.metadata model(vi, SampleFromUniform()) - @test all(x -> !istrans(vi, x), meta.vns) + @test all(x -> !DynamicPPL.istrans(vi, x), meta.vns) alg = HMC(0.1, 5) spl = Sampler(alg, model) v = copy(meta.vals) - link!(vi, spl) - @test all(x -> istrans(vi, x), meta.vns) - invlink!(vi, spl) - @test all(x -> !istrans(vi, x), meta.vns) - @test meta.vals == v + @test DynamicPPL.islinked(DynamicPPL.link(vi)) + @test !DynamicPPL.islinked(DynamicPPL.invlink(DynamicPPL.link(vi))) + @test vi[SampleFromPrior()] == v + @test DynamicPPL.invlink(DynamicPPL.link(vi))[SampleFromPrior()] == v vi = TypedVarInfo(vi) meta = vi.metadata alg = HMC(0.1, 5) spl = Sampler(alg, model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) + @test all(x -> !DynamicPPL.istrans(vi, x), meta.s.vns) + @test all(x -> !DynamicPPL.istrans(vi, x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) - link!(vi, spl) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) - invlink!(vi, spl) - @test all(x -> ~istrans(vi, x), meta.s.vns) - @test all(x -> ~istrans(vi, x), meta.m.vns) - @test meta.s.vals == v_s - @test meta.m.vals == v_m + @test vi[SampleFromPrior()] == [v_s; v_m] + @test DynamicPPL.invlink(DynamicPPL.link(vi))[SampleFromPrior()] == [v_s; v_m] end @testset "setgid!" begin vi = VarInfo() @@ -225,12 +220,12 @@ include(dir*"/test/test_utils/AllUtils.jl") elseif is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = rand(dist) - vi[vn] = vectorize(dist, r) + vi[vn, dist] = r setorder!(vi, vn, get_num_produce(vi)) r else updategid!(vi, vn, spl) - vi[vn] + vi[vn, dist] end end @@ -412,11 +407,11 @@ include(dir*"/test/test_utils/AllUtils.jl") spl1 = Sampler(PG(5, :x, :y, :z), empty_model()) for i = 1:3 r = randr(vi, vns[i], dists[i], spl1, false) - val = vi[vns[i]] + val = vi[vns[i], dists[i]] @test sum(val - r) <= 1e-9 end - idcs = _getidcs(vi, spl1) + idcs = getidcs(vi, spl1) if idcs isa NamedTuple @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 else @@ -424,7 +419,7 @@ include(dir*"/test/test_utils/AllUtils.jl") end @test length(vi[spl1]) == 7 - idcs = _getidcs(vi, spl2) + idcs = getidcs(vi, spl2) if idcs isa NamedTuple @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 else @@ -435,7 +430,7 @@ include(dir*"/test/test_utils/AllUtils.jl") vn_u = @varname u randr(vi, vn_u, dists[1], spl2, true) - idcs = _getidcs(vi, spl2) + idcs = getidcs(vi, spl2) if idcs isa NamedTuple @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 else @@ -489,4 +484,19 @@ include(dir*"/test/test_utils/AllUtils.jl") @test vi.metadata.w.gids[1] == Set([hmc.selector]) @test vi.metadata.u.gids[1] == Set([hmc.selector]) end + + @testset "random support distributions" begin + @model function grm(::Type{TV}=Vector{Float64}) where {TV} + b = TV(undef, 5) + b[1] ~ Normal(0,1) + for i in 2:5 + b[i] ~ truncated(Normal(0,1), b[i-1], Inf) + @test b[i] >= b[i-1] + end + end + chain = sample(grm(), HMC(0.05, 1), 100) + for s in 1:100, i in 2:5 + @test chain["b[$i]"].value[s] >= chain["b[$(i-1)]"].value[s] + end + end end