Skip to content

Commit

Permalink
Merge pull request #39 from TuringLang/transducer
Browse files Browse the repository at this point in the history
Add a `Sample` transducer
  • Loading branch information
cpfiffer authored May 21, 2020
2 parents 74bac80 + 14831b8 commit 509b5cf
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"

[compat]
BangBang = "0.3.19"
Expand All @@ -23,6 +24,7 @@ LoggingExtras = "0.4"
ProgressLogging = "0.1"
StatsBase = "0.32, 0.33"
TerminalLoggers = "0.1"
Transducers = "0.4.30"
julia = "1"

[extras]
Expand Down
2 changes: 2 additions & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import LoggingExtras
import ProgressLogging
import StatsBase
import TerminalLoggers
import Transducers

import Distributed
import Logging
Expand Down Expand Up @@ -74,5 +75,6 @@ include("logging.jl")
include("interface.jl")
include("sample.jl")
include("stepper.jl")
include("transducer.jl")

end # module AbstractMCMC
39 changes: 39 additions & 0 deletions src/transducer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: Transducers.Transducer
rng::A
model::M
sampler::S
kwargs::K
end

function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...)
return Sample(Random.GLOBAL_RNG, model, sampler; kwargs...)
end

function Sample(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler;
kwargs...
)
sample_init!(rng, model, sampler, 0)
return Sample(rng, model, sampler, kwargs)
end

function Transducers.start(rf::Transducers.R_{<:Sample}, result)
return Transducers.wrap(rf, nothing, Transducers.start(Transducers.inner(rf), result))
end

function Transducers.next(rf::Transducers.R_{<:Sample}, result, input)
t = Transducers.xform(rf)
Transducers.wrapping(rf, result) do state, iresult
transition = step!(t.rng, t.model, t.sampler, 1, state; t.kwargs...)
iinput = transition
iresult = Transducers.next(Transducers.inner(rf), iresult, transition)
return transition, iresult
end
end

function Transducers.complete(rf::Transducers.R_{Sample}, result)
_private_state, inner_result = Transducers.unwrap(rf, result)
return Transducers.complete(Transducers.inner(rf), inner_result)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using ConsoleProgressMonitor: ProgressLogger
using IJulia
using LoggingExtras: TeeLogger, EarlyFilteredLogger
using TerminalLoggers: TerminalLogger
using Transducers

using Distributed
import Logging
Expand Down Expand Up @@ -276,4 +277,6 @@ include("interface.jl")
MySampler(), 10, 10;
chain_type = MyChain)
end

include("transducer.jl")
end
50 changes: 50 additions & 0 deletions test/transducer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
@testset "transducer.jl" begin
Random.seed!(1234)

@testset "Basic sampling" begin
N = 1_000
local chain
Logging.with_logger(TerminalLogger()) do
xf = AbstractMCMC.Sample(MyModel(), MySampler();
sleepy = true, logger = true)
chain = collect(xf, withprogress(1:N; interval=1e-3))
end

# test output type and size
@test chain isa Vector{<:MyTransition}
@test length(chain) == N

# test some statistical properties
tail_chain = @view chain[2:end]
@test mean(x.a for x in tail_chain) 0.5 atol=6e-2
@test var(x.a for x in tail_chain) 1 / 12 atol=5e-3
@test mean(x.b for x in tail_chain) 0.0 atol=5e-2
@test var(x.b for x in tail_chain) 1 atol=6e-2
end

@testset "drop" begin
xf = AbstractMCMC.Sample(MyModel(), MySampler())
chain = collect(xf |> Drop(1), 1:10)
@test chain isa Vector{MyTransition{Float64,Float64}}
@test length(chain) == 9
end

# Reproduce iterator example
@testset "iterator example" begin
# filter missing values and split transitions
xf = AbstractMCMC.Sample(MyModel(), MySampler()) |>
OfType(MyTransition{Float64,Float64}) |> Map(x -> (x.a, x.b))
as, bs = foldl(xf, 1:999; init = (Float64[], Float64[])) do (as, bs), (a, b)
push!(as, a)
push!(bs, b)
as, bs
end

@test length(as) == length(bs) == 998

@test mean(as) 0.5 atol=1e-2
@test var(as) 1 / 12 atol=5e-3
@test mean(bs) 0.0 atol=5e-2
@test var(bs) 1 atol=5e-2
end
end

0 comments on commit 509b5cf

Please sign in to comment.