Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP distributions with particle parameters #78

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/MonteCarloMeasurements.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ export Normal, MvNormal, Cauchy, Beta, Exponential, Gamma, Laplace, Uniform, fit

export unsafe_comparisons, @unsafe, set_comparison_function

export bymap, bypmap, @bymap, @bypmap, @prob, Workspace, with_workspace, has_particles, mean_object
export bymap, bypmap, @bymap, @bypmap, @prob, Workspace, with_workspace, has_particles, mean_object, change_representation

export ParticleDistribution

include("types.jl")
include("register_primitive.jl")
Expand All @@ -99,6 +101,7 @@ include("deconstruct.jl")
include("diff.jl")
include("plotting.jl")
include("optimize.jl")
include("particle_distributions.jl")

# This is defined here so that @bymap is loaded
LinearAlgebra.norm2(p::AbstractVector{<:AbstractParticles}) = bymap(LinearAlgebra.norm2,p)
Expand Down
44 changes: 44 additions & 0 deletions src/deconstruct.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,47 @@
"""
change_representation(F, p::T) where T

Convert from type `Particles{F}` to `F{Particles}`

Example:
```julia
@unsafe d = Normal(Particles(),10+Particles())
d2 = change_representation(Normal, d);
d3 = change_representation(Normal, d2)
d3.μ == d.μ
d3.σ == d.σ
```
"""
function change_representation(F, p::AbstractParticles{T,N}) where {T,N}
fields = map(fieldnames(T)) do fn
getfield.(p.particles, fn)
end
F(Particles.(fields)...)
end

"""
change_representation(F, p::T) where T

Convert from type `F{Particles}` to `Particles{F}`

Example:
```julia
@unsafe d = Normal(Particles(),10+Particles())
d2 = change_representation(Normal, d);
d3 = change_representation(Normal, d2)
d3.μ == d.μ
d3.σ == d.σ
```
"""
function change_representation(F, p::T) where T
fields = map(fieldnames(T)) do fn
getfield(p, fn)
end
N = nparticles(fields[1])
Fs = [F(getindex.(fields,i)...) for i in 1:N]
Particles(Fs)
end

"""
has_particles(P)
Determine whether or no the object `P` has some kind of particles inside it. This function examins fields of `P` recursively and looks inside arrays etc.
Expand Down
57 changes: 57 additions & 0 deletions src/particle_distributions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
struct ParticleDistribution{D,N,P}
d::Vector{D}
constructor::P
end


"""
ParticleDistribution(constructor::Type{<:Distribution}, p...)

A `ParticleDistribution` represents a hierarchical distribution where the parameters of the distribution are `Particles`. The internal representation is as a `Vector{Distribution{FloatType}}` for efficient drawing of random numbers etc. But construction and printing is done as if it was a type `Distribution{Particles}`.

# Example
```julia
julia> pd = ParticleDistribution(Normal, 1±0.1, 1±0.1)
ParticleNormal{Float64}(
μ: 1.0 ± 0.1
σ: 1.0 ± 0.1
)

julia> rand(pd)
Part10000(1.012 ± 1.01)
"""
function ParticleDistribution(constructor::Type{<:Distribution}, p...)
N = nparticles(p[1])
dists = [constructor(getindex.(p, i)...) for i in 1:N]
ParticleDistribution{eltype(dists), N, typeof(constructor)}(dists, constructor)
end

Base.length(d::ParticleDistribution) = length(d.d[1])
Base.eltype(d::ParticleDistribution{D,N}) where {D,N} = Particles{eltype(D),N}

Particles(a::BitArray) = Particles(Vector(a))

Check warning on line 32 in src/particle_distributions.jl

View check run for this annotation

Codecov / codecov/patch

src/particle_distributions.jl#L32

Added line #L32 was not covered by tests

function Base.rand(rng::AbstractRNG, d::ParticleDistribution{D,N}) where {D,N}
eltype(d)(rand.(rng, d.d))
end

Base.rand(d::ParticleDistribution) = rand(Random.GLOBAL_RNG, d)

function Base.show(io::IO, d::ParticleDistribution{D}) where D
fields = map(fieldnames(D)) do fn
getfield.(d.d, fn)
end
println(io, "Particle", D, "(")
for (i,fn) in enumerate(fieldnames(D))
println(io, " ", string(fn), ": ", Particles(fields[i]))
end
print(io, ")")
end

Base.getindex(d::ParticleDistribution, i...) = getindex(d.d, i...)


function Distributions.logpdf(pd::ParticleDistribution{D,N}, x) where {D,N}
T = float(eltype(D))
Particles{T,N}(logpdf.(pd.d, x))
end
24 changes: 24 additions & 0 deletions src/register_primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
register_primitive(f, eval=eval)

Register both single and multi-argument function so that it works with particles. If you want to register functions from within a module, you must pass the modules `eval` function.

Example:
```julia
module MyMod
using MonteCarloMeasurements
register_primitive(floor, MyMod.eval)
end
```
"""
function register_primitive(ff, eval=eval)
register_primitive_multi(ff, eval)
Expand All @@ -15,6 +23,14 @@ end
register_primitive_multi(ff, eval=eval)

Register a multi-argument function so that it works with particles. If you want to register functions from within a module, you must pass the modules `eval` function.

Example:
```julia
module MyMod
using MonteCarloMeasurements
register_primitive(floor, MyMod.eval)
end
```
"""
function register_primitive_multi(ff, eval=eval)
f = nameof(ff)
Expand Down Expand Up @@ -85,6 +101,14 @@ end
register_primitive_single(ff, eval=eval)

Register a single-argument function so that it works with particles. If you want to register functions from within a module, you must pass the modules `eval` function.

Example:
```julia
module MyMod
using MonteCarloMeasurements
register_primitive(floor, MyMod.eval)
end
```
"""
function register_primitive_single(ff, eval=eval)
f = nameof(ff)
Expand Down
4 changes: 1 addition & 3 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ function systematic_sample(rng::AbstractRNG, N, d=Normal(0,1); permute=true)
T = eltype(d)
e = T(0.5/N) # rand()/N
y = e:1/N:1
o = map(y) do y
quantile(d,y)
end
o = quantile.(d,y)
permute && permute!(o, randperm(rng, N))
return eltype(o) == T ? o : T.(o)
end
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ Random.seed!(0)
include("test_deconstruct.jl")
include("test_sleefpirates.jl")
include("test_measurements.jl")
include("test_particle_distributions.jl")

end

Expand Down
47 changes: 47 additions & 0 deletions test/test_particle_distributions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using MonteCarloMeasurements, Distributions
@testset "Particle Distributions" begin
@info "Testing Particle Distributions"



pd = ParticleDistribution(Bernoulli, Particles(1000, Beta(2, 3)))
# @btime rand($pd) # 23.304 ns (0 allocations: 0 bytes)
# @btime rand(Bernoulli(0.3)) # 10.050 ns (0 allocations: 0 bytes)

@test pd[1] isa Bernoulli
@test length(pd) == 1
@test rand(pd) isa eltype(pd)
@test length(pd.d) == 1000
@test_nowarn display(pd)
@test logpdf(pd, 1).particles == [logpdf(d,1) for d in pd.d]

pd = ParticleDistribution(
Normal,
Particles(1000, Normal(10, 3)),
Particles(1000, Normal(2, 0.1)),
)

@test pd[1] isa Normal
@test length(pd) == 1
@test rand(pd) isa eltype(pd)
@test length(pd.d) == 1000
@test_nowarn display(pd)
@test logpdf(pd, 1).particles == [logpdf(d,1) for d in pd.d]



# @btime rand($pd) # 27.726 ns (0 allocations: 0 bytes)
# @btime rand(Normal(10,2)) # 12.788 ns (0 allocations: 0 bytes)


@unsafe d = Normal(Particles(),10+Particles())
@test d isa Normal{<:Particles{Float64}}
d2 = change_representation(Normal, d);
@test d2 isa Particles{Normal{Float64}}
d3 = change_representation(Normal, d2)
@test d3.μ == d.μ
@test d3.σ == d.σ



end