Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface Changeover #793

Merged
merged 200 commits into from
Sep 12, 2019
Merged

Interface Changeover #793

merged 200 commits into from
Sep 12, 2019

Conversation

cpfiffer
Copy link
Member

@cpfiffer cpfiffer commented May 25, 2019

To do

Supported samplers (src/inference/)

  • IS
  • MH
  • Gibbs
  • HMC
    • HMCDA
    • NUTS
  • SMC
    • PG
    • CSMC

Contributed samplers (src/contrib/inference)

  • DynamicNUTS
  • [ ] SGLD
  • [ ] SGHMC
  • [ ] PIMH
  • [ ] PMMH
  • [ ] IPMCMC

Summary

This will be the comprehensive PR containing all the code to change us over to using the new interface. Since the changes must be done for all the samplers pretty much all at once (we don't want half to use the new interface and the other half to use the old version), this will probably be a large PR. There is an hmc-interface.jl file which is not really ready for consumption yet, so you can ignore it.

Currently, I have a (very) rough proof-of-concept for SMC to get myself on stable ground for the rest of the samplers. The following interface is supported:

using Turing

@model test() = begin
    x ~ Normal(0, 1)
end

chn = sample(test(), SMC(50), 500)

I've introduced a couple of new things since the discussion in #746.

First, I'm attempting to unify all the samplers under the existing Sampler type we have. The info field is not going to be removed in this PR because of all the VarInfo stuff -- that will be a separate undertaking. The Sampler type now has an additional field called state, which contains a mutable struct that replaces the sampler-specific information like logevidence or eval_num. The ultimate goal is to have this go away at some point and move to a more general sampler state, but changing over to the front-end interface and changing the entire backend would be too much for one PR.

Second, this isn't yet benchmarked for performance. One issue I have is how to package up the TransitionTypes into something that has name & value information together, without too much overhead. At the moment, each step bundles up all the names from vi and all the vi.vals, but I think there's probably a better design.

More to come on this.

@yebai
Copy link
Member

yebai commented Jun 8, 2019

Hey @cpfiffer, the following issues have been raised for a while. It seems they are mostly interface related features. Can you take a look at them in this PR if possible?

@cpfiffer
Copy link
Member Author

cpfiffer commented Sep 8, 2019

docs/site/ must have been left over from all my weird messing around with the website, sorry it made it in. The line change is now back to a much more reasonable ~2k lines.

@xukai92
Copy link
Member

xukai92 commented Sep 8, 2019

Seems that everything is done. Great PR!

I will merge it in 24 hours if no rejection.

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really great work @cpfiffer - this has clearly been a lot of work. I've left a number of nit-picky style related comments, and also some requests for additional inline documentation.

One thing in particular: the docstrings for sample_init!, step!, and sample_end! could probably do with being improved a bit. sample_init! and sample_end! are a little bit terse, and step! produces three separate docstrings, where one would suffice. Would also be helpful to document whether or not it's expected that the user will return the modified sampler or not.

docs/src/interface-manual.md Outdated Show resolved Hide resolved
docs/src/interface-manual.md Outdated Show resolved Hide resolved
docs/src/interface-manual.md Show resolved Hide resolved
docs/src/using-turing/autodiff.md Show resolved Hide resolved
docs/src/using-turing/sampler-viz.md Show resolved Hide resolved
"""
islinked(vi::VT, spl::Sampler) where VT<:VarInfo

Returns `true` if a `VarInfo` is linked for a particular sampler `spl`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A slightly more informative docstring would be appreciated. What does linked mean in this context?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to say "in the transformed space" rather than "linked", is that helpful?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's helpful, but maybe add a reference to some other documentation for Bijectors or something? I mean, it's clear what's happening here if you know anything about the relationship between Turing and Bijectors etc, but for a newbie it might still be a bit unclear?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to read:

islinked(vi::VT, spl::Sampler) where VT<:VarInfo

Returns true if a VarInfo is in the transformed space for a particular sampler spl.

Turing's Hamiltonian samplers use the link and invlink functions from
Bijectors.jl to map a constrained variable
(for example, one bounded to the space [0, 1]) from its constrained space to the set of
real numbers. islinked checks if the number is in the constrained space or the real space.

src/inference/gibbs.jl Outdated Show resolved Hide resolved
src/inference/gibbs.jl Outdated Show resolved Hide resolved
src/inference/gibbs.jl Outdated Show resolved Hide resolved
src/interface/Interface.jl Outdated Show resolved Hide resolved
@xukai92
Copy link
Member

xukai92 commented Sep 9, 2019

I meet a bug in the following example with thie branch

using Turing

@model testmissing() = begin
    n ~ Categorical(10)
    l = tzeros(Int, n)
    for i in 1:n
        l[i] ~ Categorical(2)
    end
    3 ~ Normal(sum(l), 1)
end

chain = sample(testmissing(), PG(10), 1_000)

gives

Progress: 100%|█████████████████████████████████████████| Time: 0:00:02
ArgumentError: number of columns of each array must match (got (7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9))

Stacktrace:
 [1] _typed_vcat(::Type{Float64}, ::NTuple{1000,Array{Float64,2}}) at ./abstractarray.jl:1305
 [2] typed_vcat(::Type{Float64}, ::Array{Float64,2}, ::Array{Float64,2}, ::Array{Float64,2}, ::Vararg{Array{Float64,2},N} where N) at ./abstractarray.jl:1319
 [3] vcat(::Array{Float64,2}, ::Array{Float64,2}, ::Array{Float64,2}, ::Array{Float64,2}, ::Vararg{Array{Float64,2},N} where N) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/SparseArrays/src/sparsevector.jl:1069
 [4] #Chains#10(::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Type, ::Random.MersenneTwister, ::Turing.Model{Tuple{:n,:l},Tuple{},getfield(Main, Symbol("###inner_function#190586#61")),NamedTuple{(),Tuple{}},NamedTuple{(),Tuple{}}}, ::Turing.Sampler{PG{()},Turing.Inference.PGState{Turing.Core.RandomVariables.VarInfo{NamedTuple{(:n, :l),Tuple{Turing.Core.RandomVariables.Metadata{Dict{Turing.Core.RandomVariables.VarName{:n},Int64},Array{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}},1},Array{Turing.Core.RandomVariables.VarName{:n},1},Array{Int64,1},Array{Set{Turing.Selector},1}},Turing.Core.RandomVariables.Metadata{Dict{Turing.Core.RandomVariables.VarName{:l},Int64},Array{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}},1},Array{Turing.Core.RandomVariables.VarName{:l},1},Array{Int64,1},Array{Set{Turing.Selector},1}}}},Float64},Float64}}, ::Int64, ::Array{Turing.Inference.ParticleTransition,1}) at /Users/kai/projects/TuringLang/Turing.jl/src/inference/Inference.jl:321
 [5] Type at /Users/kai/projects/TuringLang/Turing.jl/src/inference/Inference.jl:308 [inlined]
 [6] #sample#3(::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Random.MersenneTwister, ::Turing.Model{Tuple{:n,:l},Tuple{},getfield(Main, Symbol("###inner_function#190586#61")),NamedTuple{(),Tuple{}},NamedTuple{(),Tuple{}}}, ::Turing.Sampler{PG{()},Turing.Inference.PGState{Turing.Core.RandomVariables.VarInfo{NamedTuple{(:n, :l),Tuple{Turing.Core.RandomVariables.Metadata{Dict{Turing.Core.RandomVariables.VarName{:n},Int64},Array{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}},1},Array{Turing.Core.RandomVariables.VarName{:n},1},Array{Int64,1},Array{Set{Turing.Selector},1}},Turing.Core.RandomVariables.Metadata{Dict{Turing.Core.RandomVariables.VarName{:l},Int64},Array{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}},1},Array{Turing.Core.RandomVariables.VarName{:l},1},Array{Int64,1},Array{Set{Turing.Selector},1}}}},Float64},Float64}}, ::Int64) at /Users/kai/projects/TuringLang/Turing.jl/src/interface/Interface.jl:173
 [7] #sample at ./none:0 [inlined]
 [8] #sample#2 at /Users/kai/projects/TuringLang/Turing.jl/src/interface/Interface.jl:126 [inlined]
 [9] (::getfield(StatsBase, Symbol("#kw##sample")))(::NamedTuple{(:progress,),Tuple{Bool}}, ::typeof(sample), ::Turing.Model{Tuple{:n,:l},Tuple{},getfield(Main, Symbol("###inner_function#190586#61")),NamedTuple{(),Tuple{}},NamedTuple{(),Tuple{}}}, ::Turing.Sampler{PG{()},Turing.Inference.PGState{Turing.Core.RandomVariables.VarInfo{NamedTuple{(:n, :l),Tuple{Turing.Core.RandomVariables.Metadata{Dict{Turing.Core.RandomVariables.VarName{:n},Int64},Array{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}},1},Array{Turing.Core.RandomVariables.VarName{:n},1},Array{Int64,1},Array{Set{Turing.Selector},1}},Turing.Core.RandomVariables.Metadata{Dict{Turing.Core.RandomVariables.VarName{:l},Int64},Array{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}},1},Array{Turing.Core.RandomVariables.VarName{:l},1},Array{Int64,1},Array{Set{Turing.Selector},1}}}},Float64},Float64}}, ::Int64) at ./none:0
 [10] #sample#2(::Nothing, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Turing.Model{Tuple{:n,:l},Tuple{},getfield(Main, Symbol("###inner_function#190586#61")),NamedTuple{(),Tuple{}},NamedTuple{(),Tuple{}}}, ::PG{()}, ::Int64) at /Users/kai/projects/TuringLang/Turing.jl/src/inference/Inference.jl:148
 [11] sample(::Turing.Model{Tuple{:n,:l},Tuple{},getfield(Main, Symbol("###inner_function#190586#61")),NamedTuple{(),Tuple{}},NamedTuple{(),Tuple{}}}, ::PG{()}, ::Int64) at /Users/kai/projects/TuringLang/Turing.jl/src/inference/Inference.jl:147
 [12] top-level scope at In[14]:9

This example works in master but the returned chain cannot be displayed in at least Jupyter notebook.

@yebai yebai mentioned this pull request Sep 9, 2019
@cpfiffer
Copy link
Member Author

That's a very strange bug. Seems to only happen sometimes. I'm investigating.

@xukai92
Copy link
Member

xukai92 commented Sep 11, 2019

Another 24-hour merge notifcation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants