Skip to content

Commit

Permalink
WIP: use channels for inter-task communication.
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Oct 24, 2017
1 parent 4f461c0 commit c860616
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/samplers/pgibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ assume{A<:Union{PG,SMC},D<:Distribution}(spl::Sampler{A}, dists::Vector{D}, vn::
error("[Turing] PG and SMC doesn't support vectorizing assume statement")

observe{T<:Union{PG,SMC}}(spl::Sampler{T}, dist::Distribution, value, vi) =
produce(logpdf(dist, value))
put!(current_task().storage[:turing_chnl], logpdf(dist, value))

observe{A<:Union{PG,SMC},D<:Distribution}(spl::Sampler{A}, ds::Vector{D}, value::Any, vi::VarInfo) =
error("[Turing] PG and SMC doesn't support vectorizing observe statement")
8 changes: 8 additions & 0 deletions src/trace/taskcopy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ function Base.copy(t::Task)
newt.state = t.state
newt.result = t.result
newt.parent = t.parent
newt.storage[:turing_chnl] = deepcopy(t.storage[:turing_chnl]) # Channel(0);
newt.storage[:turing_chnl].putters[1] = newt
push!(newt.storage[:turing_chnl].takers, current_task())
bind(newt.storage[:turing_chnl], newt)
if istaskstarted(t)
schedule(newt);
newt.state = :queued
end
if :last in fieldnames(t)
newt.last = nothing
end
Expand Down
11 changes: 8 additions & 3 deletions src/trace/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ end
function (::Type{Trace{T}}){T}(f::Function)
res = Trace{T}();
# Task(()->f());
res.task = Task( () -> begin res=f(); produce(Val{:done}); res; end )
res.task = Task( () -> begin res=f(); put!(current_task().storage[:turing_chnl], Val{:done}); res; end )
if isa(res.task.storage, Void)
res.task.storage = ObjectIdDict()
end
res.task.storage[:turing_chnl] = Channel(0); schedule(res.task);
res.task.storage[:turing_trace] = res # create a backward reference in task_local_storage
res
end
Expand All @@ -61,10 +62,11 @@ function (::Type{Trace{T}}){T}(f::Function, spl::Sampler, vi :: VarInfo)
res.vi = deepcopy(vi)
res.vi.index = 0
res.vi.num_produce = 0
res.task = Task( () -> begin vi_new=f(vi, spl); produce(Val{:done}); vi_new; end )
res.task = Task( () -> begin vi_new=f(vi, spl); put!(current_task().storage[:turing_chnl], Val{:done}); vi_new; end )
if isa(res.task.storage, Void)
res.task.storage = ObjectIdDict()
end
res.task.storage[:turing_chnl] = Channel(0); schedule(res.task);
res.task.storage[:turing_trace] = res # create a backward reference in task_local_storage
res
end
Expand All @@ -73,7 +75,10 @@ const TraceR = Trace{:R} # Task Copy
const TraceC = Trace{:C} # Replay

# step to the next observe statement, return log likelihood
Base.consume(t::Trace) = (t.vi.num_produce += 1; Base.consume(t.task))
Base.consume(t::Trace) = begin
t.vi.num_produce += 1;
Base.take!(t.task.storage[:turing_chnl])
end

# Task copying version of fork for both TraceR and TraceC.
function forkc(trace :: Trace)
Expand Down
2 changes: 1 addition & 1 deletion test/resample.jl/particlecontainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function f()
ct = current_trace()
vn = VarName(gensym(), :x, "[$n]", 1)
Turing.assume(spl, dist, vn, ct.vi); n += 1;
produce(0)
put!(current_task().storage[:turing_chnl], 0)
vn = VarName(gensym(), :x, "[$n]", 1)
Turing.assume(spl, dist, vn, ct.vi); n += 1;
t[1] = 1 + t[1]
Expand Down
10 changes: 5 additions & 5 deletions test/tarray.jl/tarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@ function f_cta()
t = TArray(Int, 1);
t[1] = 0;
while true
produce(t[1])
put!(current_task().storage[:turing_chnl], t[1])
t[1]
t[1] = 1 + t[1]
end
end

t = Task(f_cta)

consume(t); consume(t)
take!(t.storage[:turing_chnl]); take!(t.storage[:turing_chnl])
a = copy(t);
consume(a); consume(a)
take!(a.storage[:turing_chnl]); take!(a.storage[:turing_chnl])

Base.@assert consume(t) == 2
Base.@assert consume(a) == 4
Base.@assert take!(t.storage[:turing_chnl]) == 2
Base.@assert take!(a.storage[:turing_chnl]) == 4

# Base.@assert TArray(Float64, 5)[1] != 0 REVIEW: can we remove this? (Kai)
Base.@assert tzeros(Float64, 5)[1]==0
Expand Down
8 changes: 4 additions & 4 deletions test/tarray.jl/tarray2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ function f()
u = unique(classes[1:i])
Base.@assert maximum(u) == length(u)
# println("[$(current_task())] classes: ", classes[1:i], "; urn:", urn.counts) REVIEW: can we remove this (Kai)
produce(classes[i])
put!(current_task().storage[:turing_chnl], classes[i])
end
end

t = Task(f)

consume(t);
take!(t.storage[:turing_chnl]);
a = [copy(t) for i = 1:10];

for i =1:20
consume(t);
map((x)->consume(x),a)
take!(t.storage[:turing_chnl]);
map((x)->take!(x.storage[:turing_chnl]),a)
end
23 changes: 13 additions & 10 deletions test/taskcopy.jl/clonetask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,36 @@ using Turing
function f_ct()
t = 0;
while true
produce(t)
put!(current_task().storage[:turing_chnl], t)
t = 1 + t
end
end

t = Task(f_ct)
t = Task(f_ct); t.storage = ObjectIdDict();
t.storage[:turing_chnl] = Channel(0);
schedule(t)

consume(t); consume(t)

take!(t.storage[:turing_chnl]); take!(t.storage[:turing_chnl])
a = copy(t);
consume(a); consume(a)
take!(a.storage[:turing_chnl]); take!(a.storage[:turing_chnl])

# Test case 2: heap allocated objects are shallowly copied.

function f_ct2()
t = [0 1 2];
while true
#println(pointer_from_objref(t)); REVIEW: can we remove this comments (Kai)
produce(t[1])
put!(current_task().storage[:turing_chnl], t[1])
t[1] = 1 + t[1]
end
end

t = Task(f_ct2)

consume(t); consume(t)
take!(t.storage[:turing_chnl]); take!(t.storage[:turing_chnl])
a = copy(t);
consume(a); consume(a)
take!(a.storage[:turing_chnl]); take!(a.storage[:turing_chnl])

# REVIEW: comments below need to be updated (Kai)
# more: add code in copy() to handle invalid cases for cloning tasks.
Expand All @@ -41,14 +44,14 @@ function f_ct3()
t = [0];
o = (x) -> x + 1; # not heap allocated?
while true
produce(t[1])
put!(current_task().storage[:turing_chnl], t[1])
t[1] = 1 + t[1]
end
return o
end

t = Task(f_ct3)

consume(t); consume(t);
take!(t.storage[:turing_chnl]); take!(t.storage[:turing_chnl]);
a = copy(t);
consume(a); consume(a)
take!(a.storage[:turing_chnl]); take!(a.storage[:turing_chnl])
2 changes: 1 addition & 1 deletion test/trace.jl/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function f2()
ct = current_trace()
vn = VarName(gensym(), :x, "[$n]", 1)
Turing.assume(spl, dist, vn, ct.vi); n += 1;
produce(t[1]);
put!(current_task().storage[:turing_chnl], t[1]);
vn = VarName(gensym(), :x, "[$n]", 1)
Turing.assume(spl, dist, vn, ct.vi); n += 1;
t[1] = 1 + t[1]
Expand Down

0 comments on commit c860616

Please sign in to comment.