Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,3 @@ julia = "1.10.8"
[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}
10 changes: 5 additions & 5 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
stepsize::S
end

function DynamicPPL.initialstep(
function Turing.Inference.initialstep(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:DynamicNUTS},
spl::DynamicNUTS,
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
Expand All @@ -59,7 +59,7 @@ function DynamicPPL.initialstep(

# Define log-density function.
ℓ = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
)

# Perform initial step.
Expand All @@ -80,14 +80,14 @@ end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:DynamicNUTS},
spl::DynamicNUTS,
state::DynamicNUTSState;
kwargs...,
)
# Compute next sample.
vi = state.vi
ℓ = state.logdensity
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize)
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)

# Create next sample and state.
Expand Down
4 changes: 3 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ export
maximum_a_posteriori,
maximum_likelihood,
MAP,
MLE
MLE,
# Chain save/resume
loadstate

end
37 changes: 15 additions & 22 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ using DynamicPPL:
getsym,
getdist,
Model,
Sampler,
DefaultContext
using Distributions, Libtask, Bijectors
using DistributionsAD: VectorOfMultivariate
Expand Down Expand Up @@ -50,8 +49,7 @@ import Random
import MCMCChains
import StatsBase: predict

export InferenceAlgorithm,
Hamiltonian,
export Hamiltonian,
StaticHamiltonian,
AdaptiveHamiltonian,
MH,
Expand All @@ -71,15 +69,16 @@ export InferenceAlgorithm,
RepeatSampler,
Prior,
predict,
externalsampler
externalsampler,
init_strategy,
loadstate

###############################################
# Abstract interface for inference algorithms #
###############################################

const TURING_CHAIN_TYPE = MCMCChains.Chains
#########################################
# Generic AbstractMCMC methods dispatch #
#########################################

include("algorithm.jl")
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
include("abstractmcmc.jl")

####################
# Sampler wrappers #
Expand Down Expand Up @@ -312,8 +311,8 @@ getlogevidence(transitions, sampler, state) = missing
# Default MCMCChains.Chains constructor.
function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
model::DynamicPPL.Model,
spl::AbstractSampler,
state,
chain_type::Type{MCMCChains.Chains};
save_state=false,
Expand Down Expand Up @@ -374,8 +373,8 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
model::DynamicPPL.Model,
spl::AbstractSampler,
state,
chain_type::Type{Vector{NamedTuple}};
kwargs...,
Expand Down Expand Up @@ -416,7 +415,7 @@ function group_varnames_by_symbol(vns)
return d
end

function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples)
function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples)
nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples))
return setinfo(c, merge(nt, c.info))
end
Expand All @@ -435,18 +434,12 @@ include("sghmc.jl")
include("emcee.jl")
include("prior.jl")

#################################################
# Generic AbstractMCMC methods dispatch #
#################################################

include("abstractmcmc.jl")

################
# Typing tools #
################

function DynamicPPL.get_matching_type(
spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV}
spl::Union{PG,SMC}, vi, ::Type{TV}
) where {T,N,TV<:Array{T,N}}
return Array{T,N}
end
Expand Down
142 changes: 125 additions & 17 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,98 @@ function _check_model(model::DynamicPPL.Model)
new_model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true)
end
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
function _check_model(model::DynamicPPL.Model, ::AbstractSampler)
return _check_model(model)
end

"""
Turing.Inference.init_strategy(spl::AbstractSampler)

Get the default initialization strategy for a given sampler `spl`, i.e. how initial
parameters for sampling are chosen if not specified by the user. By default, this is
`InitFromPrior()`, which samples initial parameters from the prior distribution.
"""
init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior()

"""
_convert_initial_params(initial_params)

Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or
throw a useful error message.
"""
_convert_initial_params(initial_params::DynamicPPL.AbstractInitStrategy) = initial_params
function _convert_initial_params(nt::NamedTuple)
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
return DynamicPPL.InitFromParams(nt)
end
function _convert_initial_params(d::AbstractDict{<:VarName})
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
return DynamicPPL.InitFromParams(d)
end
function _convert_initial_params(::AbstractVector{<:Real})
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
throw(ArgumentError(errmsg))
end
function _convert_initial_params(@nospecialize(_::Any))
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`."
throw(ArgumentError(errmsg))
end

"""
default_varinfo(rng, model, sampler)

Return a default varinfo object for the given `model` and `sampler`.
The default method for this returns a NTVarInfo (i.e. 'typed varinfo').
"""
function default_varinfo(
rng::Random.AbstractRNG, model::DynamicPPL.Model, ::AbstractSampler
)
# Note that in `AbstractMCMC.step`, the values in the varinfo returned here are
# immediately overwritten by a subsequent call to `init!!`. The reason why we
# _do_ create a varinfo with parameters here (as opposed to simply returning
# an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty
# typed VarInfo would fail. This can happen if two VarNames have different types
# but share the same symbol (e.g. `x.a` and `x.b`).
# TODO(mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments
# and return an empty VarInfo instead.
return DynamicPPL.typed_varinfo(VarInfo(rng, model))
end

#########################################
# Default definitions for the interface #
#########################################

const DEFAULT_CHAIN_TYPE = MCMCChains.Chains

function AbstractMCMC.sample(
model::AbstractModel, alg::InferenceAlgorithm, N::Integer; kwargs...
model::DynamicPPL.Model, spl::AbstractSampler, N::Integer; kwargs...
)
return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...)
return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::InferenceAlgorithm,
model::DynamicPPL.Model,
spl::AbstractSampler,
N::Integer;
initial_params=init_strategy(spl),
check_model::Bool=true,
chain_type=DEFAULT_CHAIN_TYPE,
kwargs...,
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg), N; chain_type, kwargs...)
check_model && _check_model(model, spl)
return AbstractMCMC.mcmcsample(
rng,
model,
spl,
N;
initial_params=_convert_initial_params(initial_params),
chain_type,
kwargs...,
)
end

function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
model::DynamicPPL.Model,
alg::AbstractSampler,
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
Expand All @@ -47,18 +107,66 @@ function AbstractMCMC.sample(
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::InferenceAlgorithm,
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::AbstractSampler,
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
chain_type=DEFAULT_CHAIN_TYPE,
check_model::Bool=true,
initial_params=fill(init_strategy(spl), n_chains),
kwargs...,
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(
rng, model, Sampler(alg), ensemble, N, n_chains; chain_type, kwargs...
check_model && _check_model(model, spl)
if !(initial_params isa AbstractVector) || length(initial_params) != n_chains
errmsg = "`initial_params` must be an AbstractVector of length `n_chains`; one element per chain"
throw(ArgumentError(errmsg))
end
return AbstractMCMC.mcmcsample(
rng,
model,
spl,
ensemble,
N,
n_chains;
chain_type,
initial_params=map(_convert_initial_params, initial_params),
kwargs...,
)
end

function loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
throw(
ArgumentError(
"the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`",
),
)
end
return chain.info[:samplerstate]
end

# TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures
function initialstep end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::AbstractSampler;
initial_params,
kwargs...,
)
# Generate the default varinfo. Note that any parameters inside this varinfo
# will be immediately overwritten by the next call to `init!!`.
vi = default_varinfo(rng, model, spl)

# Fill it with initial parameters. Note that, if `InitFromParams` is used, the
# parameters provided must be in unlinked space (when inserted into the
# varinfo, they will be adjusted to match the linking status of the
# varinfo).
_, vi = DynamicPPL.init!!(rng, model, vi, initial_params)

# Call the actual function that does the first step.
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
end
Comment on lines +150 to +172
Copy link
Member Author

@penelopeysm penelopeysm Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method of step is actually a little bit evil. It used to be less bad because it only applied to Sampler{<:InferenceAlgorithm}, but now it applies to all AbstractSampler, which actually does cause some method ambiguities (which I've pointed out in my other comments).

On top of that, this is just generally a bit inflexible when it comes to warmup steps since it's only defined as a method for step and not step_warmup.

I think that in the next version of Turing this method should be removed. However, I've opted to preserve it for now because I don't want to make too many conceptual changes in this PR (the diff is already too large).

16 changes: 0 additions & 16 deletions src/mcmc/algorithm.jl

This file was deleted.

Loading