Skip to content

Commit

Permalink
Dispatch progress bar method in EnsembleProblems
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio committed Oct 7, 2024
1 parent cec3e08 commit 863d7ec
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 160 deletions.
2 changes: 2 additions & 0 deletions src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import SciMLBase:
ODEProblem,
SDEProblem,
EnsembleProblem,
EnsembleSerial,
EnsembleThreads,
EnsembleDistributed,
FullSpecialize,
CallbackSet,
ContinuousCallback,
Expand Down
196 changes: 120 additions & 76 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,29 @@ function _mcsolve_prob_func(prob, i, repeat)
return remake(prob, p = prm)
end

# Standard output function
function _mcsolve_output_func(sol, i)
resize!(sol.prob.p.jump_times, sol.prob.p.jump_times_which_idx[] - 1)
resize!(sol.prob.p.jump_which, sol.prob.p.jump_times_which_idx[] - 1)
put!(sol.prob.p.progr_channel, true)
return (sol, false)
end

# Output function with progress bar update
function _mcsolve_output_func_progress(sol, i)
next!(sol.prob.p.progr_trajectories)
return _mcsolve_output_func(sol, i)
end

# Output function with distributed channel update for progress bar
function _mcsolve_output_func_distributed(sol, i)
put!(sol.prob.p.progr_channel, true)
return _mcsolve_output_func(sol, i)

Check warning on line 99 in src/time_evolution/mcsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/mcsolve.jl#L97-L99

Added lines #L97 - L99 were not covered by tests
end

_mcsolve_dispatch_output_func() = _mcsolve_output_func

Check warning on line 102 in src/time_evolution/mcsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/mcsolve.jl#L102

Added line #L102 was not covered by tests
_mcsolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _mcsolve_output_func_progress
_mcsolve_dispatch_output_func(::EnsembleDistributed) = _mcsolve_output_func_distributed

Check warning on line 104 in src/time_evolution/mcsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/mcsolve.jl#L104

Added line #L104 was not covered by tests

function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which)
sol_i = sol[:, i]
!isempty(sol_i.prob.kwargs[:saveat]) ?
Expand Down Expand Up @@ -293,9 +309,12 @@ end
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
ntraj::Int=1,
ensemble_method=EnsembleThreads(),
jump_callback::TJC=ContinuousLindbladJumpCallback(),
prob_func::Function=_mcsolve_prob_func,
output_func::Function=_mcsolve_output_func,
progress_bar::Union{Val,Bool}=Val(true),
kwargs...)
Generates the `EnsembleProblem` of `ODEProblem`s for the ensemble of trajectories of the Monte Carlo wave function time evolution of an open quantum system.
Expand Down Expand Up @@ -343,9 +362,12 @@ If the environmental measurements register a quantum jump, the wave function und
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
- `ntraj::Int`: Number of trajectories to use.
- `ensemble_method`: Ensemble method to use.
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
- `prob_func::Function`: Function to use for generating the ODEProblem.
- `output_func::Function`: Function to use for generating the output of a single trajectory.
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
- `kwargs...`: Additional keyword arguments to pass to the solver.
# Notes
Expand All @@ -369,29 +391,51 @@ function mcsolveEnsembleProblem(
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
seeds::Union{Nothing,Vector{Int}} = nothing,
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_output_func,
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
prob_mc = mcsolveProblem(
H,
ψ0,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = params,
seeds = seeds,
jump_callback = jump_callback,
kwargs...,
)
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
if ensemble_method isa EnsembleDistributed
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
@async while take!(progr_channel)
next!(progr)
end
params = merge(params, (progr_channel = progr_channel,))

Check warning on line 409 in src/time_evolution/mcsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/mcsolve.jl#L405-L409

Added lines #L405 - L409 were not covered by tests
else
params = merge(params, (progr_trajectories = progr,))
end

# Stop the async task if an error occurs
try
prob_mc = mcsolveProblem(
H,
ψ0,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = params,
seeds = seeds,
jump_callback = jump_callback,
kwargs...,
)

ensemble_prob = EnsembleProblem(prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false)
ensemble_prob = EnsembleProblem(prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false)

return ensemble_prob
return ensemble_prob
catch e
if ensemble_method isa EnsembleDistributed
put!(progr_channel, false)

Check warning on line 435 in src/time_evolution/mcsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/mcsolve.jl#L435

Added line #L435 was not covered by tests
end
rethrow()
end
end

@doc raw"""
Expand All @@ -408,7 +452,7 @@ end
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_output_func,
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
)
Expand Down Expand Up @@ -493,43 +537,34 @@ function mcsolve(
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_output_func,
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
if !isnothing(seeds) && length(seeds) != ntraj
throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))"))
end

progr = ProgressBar(ntraj, enable = getVal(progress_bar))
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
@async while take!(progr_channel)
next!(progr)
end

# Stop the async task if an error occurs
try
ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = merge(params, (progr_channel = progr_channel,)),
seeds = seeds,
jump_callback = jump_callback,
prob_func = prob_func,
output_func = output_func,
kwargs...,
)
ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = params,
seeds = seeds,
ntraj = ntraj,
ensemble_method = ensemble_method,
jump_callback = jump_callback,
prob_func = prob_func,
output_func = output_func,
progress_bar = progress_bar,
kwargs...,
)

return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
catch e
put!(progr_channel, false)
rethrow()
end
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
end

function mcsolve(
Expand All @@ -538,33 +573,42 @@ function mcsolve(
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
)
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)

put!(sol[:, 1].prob.p.progr_channel, false)

_sol_1 = sol[:, 1]

expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
states =
isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) :
Vector{Vector{QuantumObject}}(undef, length(sol))
jump_times = Vector{Vector{Float64}}(undef, length(sol))
jump_which = Vector{Vector{Int16}}(undef, length(sol))

foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol))
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)

return TimeEvolutionMCSol(
ntraj,
_sol_1.prob.p.times,
states,
expvals,
expvals_all,
jump_times,
jump_which,
sol.converged,
_sol_1.alg,
_sol_1.prob.kwargs[:abstol],
_sol_1.prob.kwargs[:reltol],
)
try
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)

if ensemble_method isa EnsembleDistributed
put!(sol[:, 1].prob.p.progr_channel, false)

Check warning on line 580 in src/time_evolution/mcsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/mcsolve.jl#L580

Added line #L580 was not covered by tests
end

_sol_1 = sol[:, 1]

expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
states =
isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) :
Vector{Vector{QuantumObject}}(undef, length(sol))
jump_times = Vector{Vector{Float64}}(undef, length(sol))
jump_which = Vector{Vector{Int16}}(undef, length(sol))

foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol))
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)

return TimeEvolutionMCSol(
ntraj,
_sol_1.prob.p.times,
states,
expvals,
expvals_all,
jump_times,
jump_which,
sol.converged,
_sol_1.alg,
_sol_1.prob.kwargs[:abstol],
_sol_1.prob.kwargs[:reltol],
)
catch e
if ensemble_method isa EnsembleDistributed
put!(ens_prob_mc.prob.p.progr_channel, false)

Check warning on line 610 in src/time_evolution/mcsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/mcsolve.jl#L609-L610

Added lines #L609 - L610 were not covered by tests
end
rethrow()

Check warning on line 612 in src/time_evolution/mcsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/mcsolve.jl#L612

Added line #L612 was not covered by tests
end
end
Loading

0 comments on commit 863d7ec

Please sign in to comment.