-
Notifications
You must be signed in to change notification settings - Fork 416
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generalize
Product
to ProductDistribution
- Loading branch information
Showing
2 changed files
with
55 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,85 @@ | ||
import Statistics: mean, var, cov | ||
|
||
""" | ||
Product <: MultivariateDistribution | ||
ProductDistribution <: Distribution{<:ValueSupport,<:ArrayLikeVariate} | ||
An N dimensional `MultivariateDistribution` constructed from a vector of N independent | ||
`UnivariateDistribution`s. | ||
A distribution of `M + N`-dimensional arrays, constructed from an `N`-dimensional array of | ||
independent `M`-dimensional distributions by stacking them. | ||
```julia | ||
Product(Uniform.(rand(10), 1)) # A 10-dimensional Product from 10 independent `Uniform` distributions. | ||
``` | ||
Users should use [`product_distribution`](@ref) to construct a product distribution of | ||
independent distributions instead of constructing a `ProductDistribution` directly. | ||
""" | ||
struct Product{ | ||
struct ProductDistribution{ | ||
N, | ||
S<:ValueSupport, | ||
T<:UnivariateDistribution{S}, | ||
V<:AbstractVector{T}, | ||
} <: MultivariateDistribution{S} | ||
T<:Distribution{<:ArrayLikeVariate,S}, | ||
V<:AbstractArray{T}, | ||
} <: Distribution{ArrayLikeVariate{N},S} | ||
v::V | ||
function Product(v::V) where | ||
V<:AbstractVector{T} where | ||
T<:UnivariateDistribution{S} where | ||
S<:ValueSupport | ||
return new{S, T, V}(v) | ||
|
||
function ProductDistribution(v::AbstractArray{T,N}) where {S<:ValueSupport, M, T<:Distribution{ArrayLikeVariate{M},S}, N} | ||
return new{M + N, S, T, typeof(v)}(v) | ||
end | ||
end | ||
|
||
length(d::Product) = length(d.v) | ||
function Base.eltype(::Type{<:Product{S,T}}) where {S<:ValueSupport, | ||
T<:UnivariateDistribution{S}} | ||
# aliases | ||
const VectorOfUnivariateDistribution{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} = | ||
ProductDistribution{1,S,T,V} | ||
const MatrixOfUnivariateDistribution{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractMatrix{T}} = | ||
ProductDistribution{2,S,T,V} | ||
const VectorOfMultivariateDistribution{S<:ValueSupport,T<:MultivariateDistribution{S},V<:AbstractVector{T}} = | ||
ProductDistribution{2,S,T,V} | ||
|
||
## deprecations | ||
# type parameters can't be deprecated it seems: https://github.com/JuliaLang/julia/issues/9830 | ||
# so we define an alias and deprecate the corresponding constructor | ||
const Product{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} = ProductDistribution{1,S,T,V} | ||
Base.@deprecate Product(v::AbstractVector{<:UnivariateDistribution}) ProductDistribution(v) | ||
|
||
## General definitions | ||
function Base.eltype(::Type{<:ProductDistribution{S,T}}) where {S<:ValueSupport,T<:Distribution{S,<:ArrayLikeVariate}} | ||
return eltype(T) | ||
end | ||
|
||
_rand!(rng::AbstractRNG, d::Product, x::AbstractVector{<:Real}) = | ||
|
||
## Vector of univariate distributions | ||
length(d::VectorOfUnivariateDistribution) = length(d.v) | ||
|
||
_rand!(rng::AbstractRNG, d::VectorOfUnivariateDistribution, x::AbstractVector{<:Real}) = | ||
broadcast!(dn->rand(rng, dn), x, d.v) | ||
_logpdf(d::Product, x::AbstractVector{<:Real}) = | ||
_logpdf(d::VectorOfUnivariateDistribution, x::AbstractVector{<:Real}) = | ||
sum(n->logpdf(d.v[n], x[n]), 1:length(d)) | ||
|
||
mean(d::Product) = mean.(d.v) | ||
var(d::Product) = var.(d.v) | ||
cov(d::Product) = Diagonal(var(d)) | ||
entropy(d::Product) = sum(entropy, d.v) | ||
insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x)) | ||
mean(d::VectorOfUnivariateDistribution) = map(mean, d.v) | ||
var(d::VectorOfUnivariateDistribution) = map(var, d.v) | ||
cov(d::VectorOfUnivariateDistribution) = Diagonal(var(d)) | ||
entropy(d::VectorOfUnivariateDistribution) = sum(entropy, d.v) | ||
function insupport(d::VectorOfUnivariateDistribution, x::AbstractVector) | ||
length(d) == length(x) && all(insupport(vi, xi) for (vi, xi) in zip(d.v, x)) | ||
end | ||
|
||
""" | ||
product_distribution(dists::AbstractVector{<:UnivariateDistribution}) | ||
product_distribution(dists::AbstractArray{<:Distribution{<:ArrayLikeVariate{M}},N}) | ||
Creates a multivariate product distribution `P` from a vector of univariate distributions. | ||
Fallback is the `Product constructor`, but specialized methods can be defined | ||
for distributions with a special multivariate product. | ||
Create a distribution of `M + N`-dimensional arrays as a product distribution of | ||
independent `M`-dimensional distributions by stacking them. | ||
The function falls back to constructing a [`ProductDistribution`](@ref) distribution but | ||
specialized methods can be defined. | ||
""" | ||
function product_distribution(dists::AbstractVector{<:UnivariateDistribution}) | ||
return Product(dists) | ||
function product_distribution(dists::AbstractArray{<:Distribution{<:ArrayLikeVariate}}) | ||
return ProductDistribution(dists) | ||
end | ||
|
||
""" | ||
product_distribution(dists::AbstractVector{<:Normal}) | ||
Computes the multivariate Normal distribution obtained by stacking the univariate | ||
normal distributions. The result is a multivariate Gaussian with a diagonal | ||
covariance matrix. | ||
Create a multivariate normal distribution by stacking the univariate normal distributions. | ||
The resulting distribution of type [`MvNormal`](@ref) has a diagonal covariance matrix. | ||
""" | ||
function product_distribution(dists::AbstractVector{<:Normal}) | ||
µ = mean.(dists) | ||
σ2 = var.(dists) | ||
µ = map(mean, dists) | ||
σ2 = map(var, dists) | ||
return MvNormal(µ, Diagonal(σ2)) | ||
end |