Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: convert distributions to Distributions.jl distributions #495

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,28 @@ using Random: AbstractRNG

# Should use Module.function directly in the code instead of doing this
import Distributions: support, pdf, mode, mean
using Distributions: DiscreteUnivariateDistribution, Distribution
using Distributions: VariateForm, Multivariate, Matrixvariate, Univariate
using Distributions: ValueSupport, Discrete
import Random: rand

using UnicodePlots: barplot

"""
Try to guess the Distributions.VariateForm for a distribution based on the sample type.
"""
function infer_variate_form(T::Type)
if T <: AbstractVector
return Multivariate
elseif T <: AbstractMatrix
return Matrixvariate
elseif T <: Number
return Univariate
else
return VariateForm
end
end

export
weighted_iterator
include("weighted_iteration.jl")
Expand Down
6 changes: 4 additions & 2 deletions lib/POMDPTools/src/POMDPDistributions/bool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ Create a distribution over Boolean values (`true` or `false`).

`p_true` is the probability of the `true` outcome; the probability of `false` is 1-`p_true`.
"""
struct BoolDistribution
struct BoolDistribution <: DiscreteUnivariateDistribution
p::Float64 # probability of true
end

pdf(d::BoolDistribution, s::Bool) = s ? d.p : 1.0-d.p
pdf(d::BoolDistribution, s::Real) = convert(Bool, s) ? d.p : 1.0-d.p
Distributions.logpdf(d::BoolDistribution, s) = log(pdf(d, s))

rand(rng::AbstractRNG, s::Random.SamplerTrivial{BoolDistribution}) = rand(rng) <= s[].p
rand(rng::AbstractRNG, d::BoolDistribution) = rand(rng) <= d.p

Base.iterate(d::BoolDistribution) = ((d.p, true), true)
function Base.iterate(d::BoolDistribution, state::Bool)
Expand Down
61 changes: 40 additions & 21 deletions lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Create a sparse categorical distribution.

This is optimized for value iteration with a fast implementation of `weighted_iterator`. Both `pdf` and `rand` are order n.
"""
struct SparseCat{V, P}
struct SparseCat{V, P, F} <: Distribution{F, Discrete}
vals::V
probs::P
end
Expand All @@ -23,13 +23,16 @@ function SparseCat(v, p::AbstractArray)
SparseCat(v, cp)
end
# the method above gets all arrays *except* ones that have a numeric eltype, which are handled below
SparseCat(v, p::AbstractArray{<:Number}) = SparseCat{typeof(v), typeof(p)}(v, p)
SparseCat(v, p::AbstractArray{<:Number}) = SparseCat{typeof(v), typeof(p), infer_variate_form(eltype(v))}(v, p)

function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat})
d = s[]
SparseCat(v, p) = SparseCat{typeof(v), typeof(p), infer_variate_form(eltype(v))}(v, p)

rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) = rand(rng, s[])

function rand(rng::AbstractRNG, d::SparseCat)
r = sum(d.probs)*rand(rng)
tot = zero(eltype(d.probs))
for (v, p) in d
for (v, p) in weighted_iterator(d)
tot += p
if r < tot
return v
Expand All @@ -47,17 +50,26 @@ function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat})
error("Error sampling from SparseCat distribution with vals $(d.vals) and probs $(d.probs)") # try to help with type stability
end

Distributions.sampler(d::SparseCat) = Random.SamplerTrivial(d)
Random.Sampler(::AbstractRNG, d::SparseCat, repetition::Union{Val{1}, Val{Inf}}) = Random.SamplerTrivial(d)

# to resolve ambiguity between pdf(::UnivariateDistribution, ::Real) and pdf(::SparseCat, ::Any)
pdf(d::SparseCat, s) = _pdf(d, s)
pdf(d::SparseCat, s::Real) = _pdf(d, s)

Distributions.logpdf(d::SparseCat, x) = log(pdf(d, x))

# slow linear search :(
function pdf(d::SparseCat, s)
for (v, p) in d
function _pdf(d::SparseCat, s)
for (v, p) in weighted_iterator(d)
if v == s
return p
end
end
return zero(eltype(d.probs))
end

function pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray}
function _pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray}
for (i,v) in enumerate(d.vals)
if v == s
return d.probs[i]
Expand All @@ -67,19 +79,25 @@ function pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray}
end



support(d::SparseCat) = d.vals

weighted_iterator(d::SparseCat) = d
struct SparseCatIterator{D<:SparseCat}
d::D
end


weighted_iterator(d::SparseCat) = SparseCatIterator(d)

# iterator for general SparseCat
# this has some type stability problems
function Base.iterate(d::SparseCat)
function Base.iterate(i::SparseCatIterator)
d = i.d
val, vstate = iterate(d.vals)
prob, pstate = iterate(d.probs)
return ((val=>prob), (vstate, pstate))
end
function Base.iterate(d::SparseCat, dstate::Tuple)
function Base.iterate(i::SparseCatIterator, dstate::Tuple)
d = i.d
vstate, pstate = dstate
vnext = iterate(d.vals, vstate)
pnext = iterate(d.probs, pstate)
Expand All @@ -94,21 +112,22 @@ end
# iterator for SparseCat with indexed members
const Indexed = Union{AbstractArray, Tuple, NamedTuple}

function Base.iterate(d::SparseCat{V,P}, state::Integer=1) where {V<:Indexed, P<:Indexed}
if state > length(d)
function Base.iterate(i::SparseCatIterator{<:SparseCat{<:Indexed,<:Indexed}}, state::Integer=1)
if state > length(i)
return nothing
end
return (d.vals[state]=>d.probs[state], state+1)
return (i.d.vals[state]=>i.d.probs[state], state+1)
end

Base.length(d::SparseCat) = min(length(d.vals), length(d.probs))
Base.eltype(D::Type{SparseCat{V,P}}) where {V, P} = Pair{eltype(V), eltype(P)}
sampletype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V)
Random.gentype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V)
Base.length(i::SparseCatIterator) = min(length(i.d.vals), length(i.d.probs))
Base.eltype(D::Type{SparseCatIterator{SparseCat{V,P,F}}}) where {V,P,F} = Pair{eltype(V), eltype(P)}

sampletype(D::Type{SparseCat{V,P,F}}) where {V,P,F} = eltype(V)
Random.gentype(D::Type{SparseCat{V,P,F}}) where {V,P,F} = eltype(V)

function mean(d::SparseCat)
vsum = zero(eltype(d.vals))
for (v, p) in d
for (v, p) in weighted_iterator(d)
vsum += v*p
end
return vsum/sum(d.probs)
Expand All @@ -117,7 +136,7 @@ end
function mode(d::SparseCat)
bestp = zero(eltype(d.probs))
bestv = first(d.vals)
for (v, p) in d
for (v, p) in weighted_iterator(d)
if p >= bestp
bestp = p
bestv = v
Expand Down
8 changes: 5 additions & 3 deletions lib/POMDPTools/src/Policies/playback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ mutable struct PlaybackPolicy{A<:AbstractArray, P<:Policy, V<:AbstractArray{<:Re
end

# Constructor for the PlaybackPolicy
PlaybackPolicy(actions::AbstractArray, backup_policy::Policy; logpdfs::AbstractArray{<:Real} = Float64[]) = PlaybackPolicy(actions, backup_policy, logpdfs, 1)
function PlaybackPolicy(actions::AbstractArray,
backup_policy::Policy = FunctionPolicy(s->error("PlaybackPolicy out of actions."));
logpdfs::AbstractArray{<:Real} = Float64[])
return PlaybackPolicy(actions, backup_policy, logpdfs, 1)
end

# Action selection for the PlaybackPolicy
function POMDPs.action(p::PlaybackPolicy, s)
Expand All @@ -41,5 +45,3 @@ function Distributions.logpdf(p::PlaybackPolicy, h)
return sum(p.logpdfs[1:N])
end
end


Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@test POMDPDistributions.infer_variate_form(typeof([1 2; 3 4])) == Distributions.Matrixvariate
@test POMDPDistributions.infer_variate_form(typeof([1, 2])) == Distributions.Multivariate
@test POMDPDistributions.infer_variate_form(typeof(1)) == Distributions.Univariate
@test POMDPDistributions.infer_variate_form(Any) == Distributions.VariateForm

p = product_distribution([SparseCat([1, 2, 3], [0.5, 0.2, 0.3]), BoolDistribution(1.0)])
@test rand(p) isa AbstractVector
@test pdf(p, [1, 1]) == 0.5

@test_broken p = Product([SparseCat([:a,:b,:c], [0.5, 0.2, 0.3]), BoolDistribution(1.0)])
6 changes: 6 additions & 0 deletions lib/POMDPTools/test/policies/test_playback_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ playback = PlaybackPolicy(collect(action_hist(hist)), RandomPolicy(mdp))
hist2 = simulate(HistoryRecorder(), mdp, playback, GWPos(3,3))
@test hist == hist2

## Test with default error policy
playback = PlaybackPolicy(collect(action_hist(hist)))
@test all(playback.actions .== action_hist(hist))
hist3 = simulate(HistoryRecorder(), mdp, playback, GWPos(3,3))
@test_throws ErrorException action(playback, GWPos(3,3))

## Test log probability
Distributions.logpdf(p::RandomPolicy, h) = length(h)*log(1. / length(actions(p.problem)))
playback = PlaybackPolicy(collect(action_hist(hist)), RandomPolicy(mdp), logpdfs = -ones(length(hist)))
Expand Down
2 changes: 2 additions & 0 deletions lib/POMDPTools/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using SparseArrays: sparse

import CommonRLInterface

using Distributions: Distributions, product_distribution, Product

@testset "POMDPTools.jl" begin
@testset "POMDPDistributions" begin
Expand All @@ -23,6 +24,7 @@ import CommonRLInterface
include("distributions/test_pretty_printing.jl")
include("distributions/test_sparse_cat.jl")
include("distributions/test_uniform.jl")
include("distributions/test_distributions_jl_integration.jl")
end

@testset "ModelTools" begin
Expand Down