From da5c91f1fb67ea2c4d0d5dc5aedb9125bbc741ec Mon Sep 17 00:00:00 2001 From: bkraske <71412733+bkraske@users.noreply.github.com> Date: Wed, 17 Jan 2024 14:27:05 -0700 Subject: [PATCH] Use observation list instead of ordered dicts --- src/default_policy_sim.jl | 8 +++++--- src/planner.jl | 4 ++-- src/tree.jl | 11 ++++++++--- test/baby_sanity_check.jl | 2 +- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/default_policy_sim.jl b/src/default_policy_sim.jl index 2706f1c..46c1adc 100644 --- a/src/default_policy_sim.jl +++ b/src/default_policy_sim.jl @@ -1,7 +1,8 @@ function branching_sim(pomdp::POMDP, policy::Policy, b::ScenarioBelief, steps::Integer, fval) S = statetype(pomdp) O = obstype(pomdp) - odict = OrderedDict{O, Vector{Pair{Int, S}}}() + odict = Dict{O, Vector{Pair{Int, S}}}() + o_list = O[] if steps <= 0 return length(b.scenarios)*fval(pomdp, b) @@ -19,14 +20,15 @@ function branching_sim(pomdp::POMDP, policy::Policy, b::ScenarioBelief, steps::I push!(odict[o], k=>sp) else odict[o] = [k=>sp] + push!(o_list,o) end r_sum += r end end - next_r = 0.0 - for (o, scenarios) in odict + for o in o_list + scenarios = odict[o] bp = ScenarioBelief(scenarios, b.random_source, b.depth+1, o) if length(scenarios) == 1 next_r += rollout(pomdp, policy, bp, steps-1, fval) diff --git a/src/planner.jl b/src/planner.jl index c911c8d..ef772d9 100644 --- a/src/planner.jl +++ b/src/planner.jl @@ -2,10 +2,10 @@ function build_despot(p::DESPOTPlanner, b_0) D = DESPOT(p, b_0) b = 1 trial = 1 - start = CPUtime_us() + start = time() while D.mu[1]-D.l[1] > p.sol.epsilon_0 && - CPUtime_us()-start < p.sol.T_max*1e6 && + time()-start < p.sol.T_max && trial <= p.sol.max_trials b = explore!(D, 1, p) backup!(D, b, p) diff --git a/src/tree.jl b/src/tree.jl index 2e5257e..2ca490c 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -56,11 +56,13 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner) S = statetype(p.pomdp) A = actiontype(p.pomdp) O = obstype(p.pomdp) - odict = OrderedDict{O, Int}() + odict = Dict{O, Int}() + olist = O[] belief = get_belief(D, b, p.rs) for a in actions(p.pomdp, belief) empty!(odict) + empty!(olist) rsum = 0.0 for scen in D.scenarios[b] @@ -74,12 +76,13 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner) push!(D.scenarios, Vector{Pair{Int, S}}()) bp = length(D.scenarios) odict[o] = bp + push!(olist,o) end push!(D.scenarios[bp], first(scen)=>sp) end end - push!(D.ba_children, collect(values(odict))) + push!(D.ba_children, [odict[o] for o in olist]) ba = length(D.ba_children) push!(D.ba_action, a) push!(D.children[b], ba) @@ -89,7 +92,9 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner) nbps = length(odict) resize!(D, length(D.children) + nbps) - for (o, bp) in odict + for o in olist + bp = odict[o] + D.obs[bp] = o D.children[bp] = Int[] D.parent_b[bp] = b diff --git a/test/baby_sanity_check.jl b/test/baby_sanity_check.jl index ba2779c..b4d364d 100644 --- a/test/baby_sanity_check.jl +++ b/test/baby_sanity_check.jl @@ -1,6 +1,6 @@ using POMDPs using ARDESPOT -using POMDPToolbox +using POMDPTools using POMDPModels using ProgressMeter