Skip to content

Commit

Permalink
Use observation list instead of ordered dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
bkraske committed Jan 17, 2024
1 parent b2bf430 commit da5c91f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
8 changes: 5 additions & 3 deletions src/default_policy_sim.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
function branching_sim(pomdp::POMDP, policy::Policy, b::ScenarioBelief, steps::Integer, fval)
S = statetype(pomdp)
O = obstype(pomdp)
odict = OrderedDict{O, Vector{Pair{Int, S}}}()
odict = Dict{O, Vector{Pair{Int, S}}}()
o_list = O[]

if steps <= 0
return length(b.scenarios)*fval(pomdp, b)
Expand All @@ -19,14 +20,15 @@ function branching_sim(pomdp::POMDP, policy::Policy, b::ScenarioBelief, steps::I
push!(odict[o], k=>sp)
else
odict[o] = [k=>sp]
push!(o_list,o)
end

r_sum += r
end
end

next_r = 0.0
for (o, scenarios) in odict
for o in o_list
scenarios = odict[o]
bp = ScenarioBelief(scenarios, b.random_source, b.depth+1, o)
if length(scenarios) == 1
next_r += rollout(pomdp, policy, bp, steps-1, fval)
Expand Down
4 changes: 2 additions & 2 deletions src/planner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ function build_despot(p::DESPOTPlanner, b_0)
D = DESPOT(p, b_0)
b = 1
trial = 1
start = CPUtime_us()
start = time()

while D.mu[1]-D.l[1] > p.sol.epsilon_0 &&
CPUtime_us()-start < p.sol.T_max*1e6 &&
time()-start < p.sol.T_max &&
trial <= p.sol.max_trials
b = explore!(D, 1, p)
backup!(D, b, p)
Expand Down
11 changes: 8 additions & 3 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner)
S = statetype(p.pomdp)
A = actiontype(p.pomdp)
O = obstype(p.pomdp)
odict = OrderedDict{O, Int}()
odict = Dict{O, Int}()
olist = O[]

belief = get_belief(D, b, p.rs)
for a in actions(p.pomdp, belief)
empty!(odict)
empty!(olist)
rsum = 0.0

for scen in D.scenarios[b]
Expand All @@ -74,12 +76,13 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner)
push!(D.scenarios, Vector{Pair{Int, S}}())
bp = length(D.scenarios)
odict[o] = bp
push!(olist,o)
end
push!(D.scenarios[bp], first(scen)=>sp)
end
end

push!(D.ba_children, collect(values(odict)))
push!(D.ba_children, [odict[o] for o in olist])
ba = length(D.ba_children)
push!(D.ba_action, a)
push!(D.children[b], ba)
Expand All @@ -89,7 +92,9 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner)

nbps = length(odict)
resize!(D, length(D.children) + nbps)
for (o, bp) in odict
for o in olist
bp = odict[o]

D.obs[bp] = o
D.children[bp] = Int[]
D.parent_b[bp] = b
Expand Down
2 changes: 1 addition & 1 deletion test/baby_sanity_check.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using POMDPs
using ARDESPOT
using POMDPToolbox
using POMDPTools
using POMDPModels
using ProgressMeter

Expand Down

0 comments on commit da5c91f

Please sign in to comment.