Skip to content

Commit

Permalink
Add pipeline interface and remove threading
Browse files Browse the repository at this point in the history
A pipeline objects holds references to the channels and tasks involved
in the execution. It further holds the configurations of each stage of
the pipeline.

Addresses #18
Opens #1
See https://github.com/JuliaLang/julia/issues/37706
  • Loading branch information
jonas-schulze committed Sep 25, 2020
1 parent 911cdec commit 9799f38
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 148 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.1.0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[extras]
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Expand Down
28 changes: 23 additions & 5 deletions src/ParaReal.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
module ParaReal

import DiffEqBase
import Base.Threads,
DiffEqBase,
Distributed

using Base.Iterators: countfrom, repeat
using Base.Threads: nthreads, @threads
using Distributed: Future,
RemoteChannel,
procs,
remotecall,
workers,
@fetchfrom,
@spawnat
using LinearAlgebra: norm
using UnPack: @unpack

export ParaRealAlgorithm

const T = Base.Threads
const D = Distributed

include("types.jl")
include("stages.jl")
include("pipeline.jl")

include("solution.jl")
include("problem.jl")

include("solve.jl")
include("worker.jl")

include("utils.jl")
include("compat.jl")

export ParaRealAlgorithm

end # module
84 changes: 84 additions & 0 deletions src/pipeline.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
function init_pipeline(workers::Vector{Int})
conns = map(RemoteChannel, workers)
nsteps = length(workers)
results = RemoteChannel(() -> Channel(nsteps))
configs = Vector{StageConfig}(undef, nsteps)

# Initialize first stages:
for i in 1:nsteps-1
prev = conns[i]
next = conns[i+1]
configs[i] = StageConfig(step=i,
nsteps=nsteps,
prev=prev,
next=next,
results=results)
end

# Initialize final stage:
prev = next = conns[nsteps]
# Pass a `::ValueChannel` instead of `nothing` as to not trigger another
# compilation. The value of `next` will never be accessed anyway.
configs[nsteps] = StageConfig(step=nsteps,
nsteps=nsteps,
prev=prev,
next=next,
results=results)

Pipeline(conns=conns,
results=results,
workers=workers,
configs=configs)
end

function start_pipeline!(pipeline::Pipeline, prob, alg; kwargs...)
is_pipeline_started(pipeline) && error("Pipeline already started")
@unpack workers, configs = pipeline
tasks = map(workers, configs) do w, c
D.@spawnat w execute_stage(prob, alg, c; kwargs...)
end
pipeline.tasks = tasks
nothing
end

function send_initial_value(pipeline::Pipeline, prob)
u0 = initialvalue(prob)
c = first(pipeline.conns)
put!(c, u0)
close(c)
nothing
end

"""
is_pipeline_started(pl::Pipeline) -> Bool
Determine whether the stages of a pipeline have been started executing.
"""
is_pipeline_started(pl::Pipeline) = pl.tasks !== nothing

"""
is_pipeline_done(pl::Pipeline) -> Bool
Determine whether all the stages of a pipeline have exited.
"""
is_pipeline_done(pl::Pipeline) = is_pipeline_started(pl) && all(isready, pl.tasks)

"""
is_pipeline_failed(pl::Pipeline) -> Bool)
Determine whether some stage of a pipeline has exited because an exception was thrown.
"""
function is_pipeline_failed(pl::Pipeline)
is_pipeline_started(pl) || return false
@unpack tasks = pl
for t in tasks
isready(t) || continue
try
# should return nothing:
fetch(t)
catch
return true
end
end
return false
end
13 changes: 12 additions & 1 deletion src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ struct GlobalSolution{S}
end
end

function collect_solutions(results, nsteps)
function collect_solutions(pipeline::Pipeline)
@unpack results, workers = pipeline
nsteps = length(workers)

# Collect local solutions. Sorting them shouldn't be necessary,
# but as there is networking involved, we're rather safe than sorry:
step, sol = take!(results)
Expand Down Expand Up @@ -60,3 +63,11 @@ function assemble_solution(

DiffEqBase.build_solution(prob, alg, ts, us, retcode=gsol.retcode)
end

"""
nextvalue(sol)
Extract the initial value for the next ParaReal iteration.
Defaults to `sol[end]`.
"""
nextvalue(sol) = sol[end]
79 changes: 14 additions & 65 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,85 +1,34 @@
using Distributed: workers, @spawnat, RemoteChannel, procs
using Base.Threads: nthreads, @threads
using Base.Iterators: countfrom, repeat

DiffEqBase.solve(prob::DiffEqBase.DEProblem, alg::ParaRealAlgorithm; kwargs...) = solve(prob, alg; kwargs...)

function solve(
prob,
alg::ParaRealAlgorithm;
ws = workers(),
nt = Base.VERSION >= v"1.3" ? nthreads() : 1,
workers = workers(),
maxiters = 10,
)

issubset(ws, procs()) || error("Unknown worker ids in `$workers`, no subset of `$(procs())`")
Base.VERSION >= v"1.3" || nt == 1 || error("Multiple threads/tasks per worker require Julia v1.3")

u0 = initialvalue(prob)
uType = typeof(u0)
uChannel = Channel{uType}
uRemoteChannel = RemoteChannel{uChannel}
createchan = () -> uChannel(1)

@debug "Setting up connections"
# Create connections between the pipeline stages:
nsteps = length(ws) * nt
conns = Vector{uRemoteChannel}(undef, nsteps+1)
i = 1
for w in ws
for _ in 1:nt
conns[i] = RemoteChannel(createchan, w)
i += 1
end
end
@assert i == nsteps+1
conns[i] = conns[i-1]
# conns[end] will never be accessed anyway
D.myid() in workers &&
error("Cannot use the managing process as a worker process (FIXME)")
issubset(workers, D.procs()) ||
error("Unknown worker ids in `$workers`, no subset of `$(D.procs())`")
allunique(workers) ||
@warn "Multiple tasks per worker won't run in parallel. Use for debugging only."

# Create a connection back home for the local solutions:
results = RemoteChannel(() -> Channel(nsteps))
@debug "Initializing global cache"
pipeline = init_pipeline(workers)

@debug "Starting worker tasks"
# Wire up the pipeline:
if Base.VERSION >= v"1.3"
# TODO: use `@spawn` instead of `@threads for` for better composability.
# For as long as tasks can't jump between threads, `@spawn` is quiet unreliable
# in populating different threads. Therefore we stick to `@threads` for now.
tasks = asyncmap(ws, countfrom(1, nt)) do w, i
_conns = @view conns[i:i+nt]
@spawnat w begin
@threads for j in 1:nt
in = _conns[j]
out = _conns[j+1]
_solve(prob, alg,
j+i-1, nsteps,
in, out,
results;
maxiters=maxiters)
end
end
end
else
tasks = asyncmap(enumerate(ws)) do (i, w)
@spawnat w _solve(prob, alg,
i, nsteps,
conns[i], conns[i+1],
results;
maxiters=maxiters)
end
end
start_pipeline!(pipeline, prob, alg, maxiters=maxiters)

@debug "Sending initial value"
# Kick off the pipeline:
firstchan = first(conns)
put!(firstchan, u0)
close(firstchan)
send_initial_value(pipeline, prob)

# Make sure there were no errors:
wait.(tasks)
@debug "Waiting for completion"
wait.(pipeline.tasks)

@debug "Collecting local solutions"
sol = collect_solutions(results, nsteps)
sol = collect_solutions(pipeline)

@debug "Reassembling global solution"
return assemble_solution(prob, alg, sol)
Expand Down
67 changes: 19 additions & 48 deletions src/worker.jl → src/stages.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,15 @@
using LinearAlgebra: norm
using DiffEqBase: solution_new_retcode

# TODO: put into worker setup / config:
# step, n,
# prev, next,
# tol
#
# Maybe also integrate worker state (mutable struct)

"""
_solve(step, n, prob, prev, next)
# Arguments
* `prob::ODEProblem` global problem to be solved
* `alg::ParaRealAlgorithm`
* `step::Integer` current step in the pipeline
* `n::Integer` total number of steps in the pipeline
* `prev::AbstractChannel` where to get new `u0`-values from
* `next::AbstractChannel` where to put `u0`-values for the next pipeline step
"""
function _solve(prob,
function execute_stage(prob,
alg::ParaRealAlgorithm,
step::Integer,
n::Integer,
prev::RemoteChannel{<:AbstractChannel{uType}},
next::RemoteChannel{<:AbstractChannel{uType}},
result::RemoteChannel;
config::StageConfig;
maxiters = 10,
tol = 1e-5,
maxiters = 100,
) where uType
)

@unpack step, nsteps, prev, next, results = config
finalstage = step == nsteps

# Initialize local problem instance
tspan = local_tspan(step, n, prob.tspan)
tspan = local_tspan(step, nsteps, prob.tspan)

# Allocate buffers
u = initialvalue(prob)
Expand All @@ -46,8 +23,10 @@ function _solve(prob,

converged = false
niters = 0
@debug "Waiting for data" step pid=D.myid() tid=T.threadid()
for u0 in prev
niters += 1
@debug "Received new initial value" step niters
prob = remake(prob, u0=u0, tspan=tspan) # copies :-(

# Abort if maximum number of iterations is reached.
Expand All @@ -63,16 +42,16 @@ function _solve(prob,
# Hand correction of coarse solution on to the next workers.
# Note that there is no correction to be done in the first iteration.
if niters == 1
step == n || put!(next, coarse_u)
finalstage || put!(next, coarse_u)
else
alg.update!(correction, coarse_u, fine_u, coarse_u_old)
diff = norm(correction - fine_u, 1) / norm(correction, 1)
if diff < tol
@debug "Worker $step/$n converged after $niters/$maxiters iterations"
@debug "Converged successfully" step niters
converged = true
break
else
step == n || put!(next, correction)
finalstage || put!(next, correction)
end
end

Expand All @@ -82,21 +61,21 @@ function _solve(prob,
end

if niters > maxiters
@warn "Worker $step/$n reached maximum number of iterations: $maxiters"
@warn "Reached reached maximum number of iterations: $maxiters" step
end

# If this worker converged, there is no need to pass on the
# next/same solution again. If, instead, the previous worker
# converged, closing `prev`, send the last fine solution to `next`
# as the (eventually) converged solution of this worker.
converged || step == n || put!(next, fine_u)
step == n || close(next)
converged || finalstage || put!(next, fine_u)
finalstage || close(next)

retcode = niters > maxiters ? :MaxIters : :Success
sol = LocalSolution(fine_sol, retcode)
@debug "Worker $step/$n sending results"
put!(result, (step, sol)) # Redo? return via `return` instead of channel
@debug "Worker $step/$n finished"
@debug "Sending results" step
put!(results, (step, sol)) # Redo? return via `return` instead of channel
@debug "Finished" step niters
nothing
end

Expand All @@ -116,11 +95,3 @@ function fsolve end

csolve(prob, alg::ParaRealAlgorithm) = alg.coarse(prob)
fsolve(prob, alg::ParaRealAlgorithm) = alg.fine(prob)

"""
nextvalue(sol)
Extract the initial value for the next ParaReal iteration.
Defaults to `sol[end]`.
"""
nextvalue(sol) = sol[end]
20 changes: 20 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,23 @@ function default_update!(y_new, y_coarse, y_fine, y_coarse_old)
@. y_new = y_coarse+y_fine-y_coarse_old
nothing
end

const ValueChannel = RemoteChannel{Channel{Any}}
const RemoteTask = Future

Base.@kwdef struct StageConfig
step::Int # corresponding step in the pipeline
nsteps::Int # total number of steps in the pipeline
prev::ValueChannel # where to get new `u0`-values from
next::ValueChannel # where to put `u0`-values for the next pipeline step
results::RemoteChannel # where to put the solution objects after convergence
end

Base.@kwdef mutable struct Pipeline
conns::Vector{ValueChannel}
results::ValueChannel

workers::Vector{Int}
configs::Vector{StageConfig}
tasks::Union{Vector{RemoteTask}, Nothing} = nothing
end
Loading

0 comments on commit 9799f38

Please sign in to comment.