Skip to content

Commit

Permalink
Merge pull request #1187 from devmotion/mcmcsample
Browse files Browse the repository at this point in the history
Use default keyword arguments when sampling from Sampler
  • Loading branch information
cpfiffer authored Mar 27, 2020
2 parents edbee1e + 308bf1c commit 15e9528
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[compat]
AbstractMCMC = "0.5.2"
AbstractMCMC = "0.5.5"
AdvancedHMC = "0.2.20"
AdvancedMH = "0.4"
Bijectors = "0.6.4"
Expand Down
39 changes: 33 additions & 6 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,24 @@ function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
N::Integer;
kwargs...
)
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
sampler::Sampler,
N::Integer;
chain_type=MCMCChains.Chains,
resume_from=nothing,
progress=PROGRESS[],
kwargs...
)
if resume_from === nothing
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
chain_type=chain_type, progress=progress, kwargs...)
return AbstractMCMC.mcmcsample(rng, model, sampler, N;
chain_type=chain_type, progress=progress, kwargs...)
else
return resume(resume_from, N; chain_type=chain_type, progress=progress, kwargs...)
end
Expand All @@ -175,12 +185,23 @@ function AbstractMCMC.psample(
alg::InferenceAlgorithm,
N::Integer,
n_chains::Integer;
kwargs...
)
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains; kwargs...)
end

function AbstractMCMC.psample(
rng::AbstractRNG,
model::AbstractModel,
sampler::Sampler,
N::Integer,
n_chains::Integer;
chain_type=MCMCChains.Chains,
progress=PROGRESS[],
kwargs...
)
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains;
chain_type=chain_type, progress=progress, kwargs...)
return AbstractMCMC.mcmcpsample(rng, model, sampler, N, n_chains;
chain_type=chain_type, progress=progress, kwargs...)
end

function AbstractMCMC.sample_init!(
Expand Down Expand Up @@ -406,11 +427,17 @@ function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples)
return setinfo(c, merge(nt, c.info))
end

function resume(c::MCMCChains.Chains, n_iter::Int; chain_type=MCMCChains.Chains, progress=PROGRESS[], kwargs...)
function resume(
c::MCMCChains.Chains,
n_iter::Int;
chain_type=MCMCChains.Chains,
progress=PROGRESS[],
kwargs...
)
@assert !isempty(c.info) "[Turing] cannot resume from a chain without state info"

# Sample a new chain.
newchain = AbstractMCMC.sample(
newchain = AbstractMCMC.mcmcsample(
c.info[:range],
c.info[:model],
c.info[:spl],
Expand Down
6 changes: 6 additions & 0 deletions test/inference/Inference.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Turing, Random, Test
using DynamicPPL: getlogp
import MCMCChains

dir = splitdir(splitdir(pathof(Turing))[1])[1]
include(dir*"/test/test_utils/AllUtils.jl")
Expand All @@ -17,6 +18,11 @@ include(dir*"/test/test_utils/AllUtils.jl")
# Smoke test for default psample call.
chain = psample(gdemo_default, HMC(0.1, 7), 1000, 4)
check_gdemo(chain)

# run sampler: progress logging should be disabled and
# it should return a Chains object
sampler = Sampler(HMC(0.1, 7), gdemo_default)
@test psample(gdemo_default, sampler, 1000, 4) isa MCMCChains.Chains
end
end
@testset "chain save/resume" begin
Expand Down
5 changes: 5 additions & 0 deletions test/inference/gibbs.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Random, Turing, Test
import AbstractMCMC
import MCMCChains
import Turing.Inference

dir = splitdir(splitdir(pathof(Turing))[1])[1]
Expand Down Expand Up @@ -29,6 +30,10 @@ include(dir*"/test/test_utils/AllUtils.jl")
@test g.state.samplers[1].selector != g.selector
@test g.state.samplers[2].selector != g.selector
@test g.state.samplers[1].selector != g.state.samplers[2].selector

# run sampler: progress logging should be disabled and
# it should return a Chains object
@test sample(gdemo_default, g, N) isa MCMCChains.Chains
end
@numerical_testset "gibbs inference" begin
Random.seed!(100)
Expand Down

0 comments on commit 15e9528

Please sign in to comment.