Skip to content

Commit

Permalink
Generalize Product to ProductDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Aug 31, 2021
1 parent 34cd1ac commit 5870272
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 37 deletions.
1 change: 0 additions & 1 deletion src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ export
NormalInverseGaussian,
Pareto,
PGeneralizedGaussian,
Product,
Poisson,
PoissonBinomial,
QQPair,
Expand Down
91 changes: 55 additions & 36 deletions src/product.jl
Original file line number Diff line number Diff line change
@@ -1,66 +1,85 @@
import Statistics: mean, var, cov

"""
Product <: MultivariateDistribution
ProductDistribution <: Distribution{<:ValueSupport,<:ArrayLikeVariate}
An N dimensional `MultivariateDistribution` constructed from a vector of N independent
`UnivariateDistribution`s.
A distribution of `M + N`-dimensional arrays, constructed from an `N`-dimensional array of
independent `M`-dimensional distributions by stacking them.
```julia
Product(Uniform.(rand(10), 1)) # A 10-dimensional Product from 10 independent `Uniform` distributions.
```
Users should use [`product_distribution`](@ref) to construct a product distribution of
independent distributions instead of constructing a `ProductDistribution` directly.
"""
struct Product{
struct ProductDistribution{
N,
S<:ValueSupport,
T<:UnivariateDistribution{S},
V<:AbstractVector{T},
} <: MultivariateDistribution{S}
T<:Distribution{<:ArrayLikeVariate,S},
V<:AbstractArray{T},
} <: Distribution{ArrayLikeVariate{N},S}
v::V
function Product(v::V) where
V<:AbstractVector{T} where
T<:UnivariateDistribution{S} where
S<:ValueSupport
return new{S, T, V}(v)

function ProductDistribution(v::AbstractArray{T,N}) where {S<:ValueSupport, M, T<:Distribution{ArrayLikeVariate{M},S}, N}
return new{M + N, S, T, typeof(v)}(v)
end
end

length(d::Product) = length(d.v)
function Base.eltype(::Type{<:Product{S,T}}) where {S<:ValueSupport,
T<:UnivariateDistribution{S}}
# aliases
const VectorOfUnivariateDistribution{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} =
ProductDistribution{1,S,T,V}
const MatrixOfUnivariateDistribution{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractMatrix{T}} =
ProductDistribution{2,S,T,V}
const VectorOfMultivariateDistribution{S<:ValueSupport,T<:MultivariateDistribution{S},V<:AbstractVector{T}} =
ProductDistribution{2,S,T,V}

## deprecations
# type parameters can't be deprecated it seems: https://github.com/JuliaLang/julia/issues/9830
# so we define an alias and deprecate the corresponding constructor
const Product{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} = ProductDistribution{1,S,T,V}
Base.@deprecate Product(v::AbstractVector{<:UnivariateDistribution}) ProductDistribution(v)

## General definitions
function Base.eltype(::Type{<:ProductDistribution{S,T}}) where {S<:ValueSupport,T<:Distribution{S,<:ArrayLikeVariate}}
return eltype(T)
end

_rand!(rng::AbstractRNG, d::Product, x::AbstractVector{<:Real}) =

## Vector of univariate distributions
length(d::VectorOfUnivariateDistribution) = length(d.v)

_rand!(rng::AbstractRNG, d::VectorOfUnivariateDistribution, x::AbstractVector{<:Real}) =
broadcast!(dn->rand(rng, dn), x, d.v)
_logpdf(d::Product, x::AbstractVector{<:Real}) =
_logpdf(d::VectorOfUnivariateDistribution, x::AbstractVector{<:Real}) =
sum(n->logpdf(d.v[n], x[n]), 1:length(d))

mean(d::Product) = mean.(d.v)
var(d::Product) = var.(d.v)
cov(d::Product) = Diagonal(var(d))
entropy(d::Product) = sum(entropy, d.v)
insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x))
mean(d::VectorOfUnivariateDistribution) = map(mean, d.v)
var(d::VectorOfUnivariateDistribution) = map(var, d.v)
cov(d::VectorOfUnivariateDistribution) = Diagonal(var(d))
entropy(d::VectorOfUnivariateDistribution) = sum(entropy, d.v)
function insupport(d::VectorOfUnivariateDistribution, x::AbstractVector)
length(d) == length(x) && all(insupport(vi, xi) for (vi, xi) in zip(d.v, x))
end

"""
product_distribution(dists::AbstractVector{<:UnivariateDistribution})
product_distribution(dists::AbstractArray{<:Distribution{<:ArrayLikeVariate{M}},N})
Creates a multivariate product distribution `P` from a vector of univariate distributions.
Fallback is the `Product constructor`, but specialized methods can be defined
for distributions with a special multivariate product.
Create a distribution of `M + N`-dimensional arrays as a product distribution of
independent `M`-dimensional distributions by stacking them.
The function falls back to constructing a [`ProductDistribution`](@ref) distribution but
specialized methods can be defined.
"""
function product_distribution(dists::AbstractVector{<:UnivariateDistribution})
return Product(dists)
function product_distribution(dists::AbstractArray{<:Distribution{<:ArrayLikeVariate}})
return ProductDistribution(dists)
end

"""
product_distribution(dists::AbstractVector{<:Normal})
Computes the multivariate Normal distribution obtained by stacking the univariate
normal distributions. The result is a multivariate Gaussian with a diagonal
covariance matrix.
Create a multivariate normal distribution by stacking the univariate normal distributions.
The resulting distribution of type [`MvNormal`](@ref) has a diagonal covariance matrix.
"""
function product_distribution(dists::AbstractVector{<:Normal})
µ = mean.(dists)
σ2 = var.(dists)
µ = map(mean, dists)
σ2 = map(var, dists)
return MvNormal(µ, Diagonal(σ2))
end

0 comments on commit 5870272

Please sign in to comment.