diff --git a/docs/src/types.md b/docs/src/types.md index 9812cd2f6..daa1ce920 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -33,16 +33,21 @@ The `VariateForm` sub-types defined in `Distributions.jl` are: ### ValueSupport -```@doc +```@docs Distributions.ValueSupport ``` The `ValueSupport` sub-types defined in `Distributions.jl` are: -**Type** | **Element type** | **Descriptions** ---- | --- | --- -`Discrete` | `Int` | Samples take discrete values -`Continuous` | `Float64` | Samples take continuous real values +```@docs +Distributions.Discrete +Distributions.Continuous +``` + +**Type** | **Default element type** | **Description** | **Examples** +--- | --- | --- | --- +`Discrete` | `Int` | Samples take countably many values | $\{0,1,2,3\}$, $\mathbb{N}$ +`Continuous` | `Float64` | Samples take uncountably many values | $[0, 1]$, $\mathbb{R}$ Multiple samples are often organized into an array, depending on the variate form. diff --git a/src/Distributions.jl b/src/Distributions.jl index 2e4ef62e5..af5616d09 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -145,7 +145,7 @@ export Pareto, PGeneralizedGaussian, SkewedExponentialPower, - Product, + Product, # deprecated Poisson, PoissonBinomial, QQPair, @@ -294,6 +294,7 @@ include("cholesky/lkjcholesky.jl") include("samplers.jl") # others +include("product.jl") include("reshaped.jl") include("truncate.jl") include("censored.jl") diff --git a/src/common.jl b/src/common.jl index 3602372a5..aecad0e1a 100644 --- a/src/common.jl +++ b/src/common.jl @@ -23,13 +23,42 @@ const Matrixvariate = ArrayLikeVariate{2} abstract type CholeskyVariate <: VariateForm end """ -`S <: ValueSupport` specifies the support of sample elements, -either discrete or continuous. + ValueSupport + +Abstract type that specifies the support of elements of samples. + +It is either [`Discrete`](@ref) or [`Continuous`](@ref). """ abstract type ValueSupport end + +""" + Discrete <: ValueSupport + +This type represents the support of a discrete random variable. + +It is countable. For instance, it can be a finite set or a countably infinite set such as +the natural numbers. + +See also: [`Continuous`](@ref), [`ValueSupport`](@ref) +""" struct Discrete <: ValueSupport end + +""" + Continuous <: ValueSupport + +This types represents the support of a continuous random variable. + +It is uncountably infinite. For instance, it can be an interval on the real line. + +See also: [`Discrete`](@ref), [`ValueSupport`](@ref) +""" struct Continuous <: ValueSupport end +# promotions (e.g., in product distribution): +# combination of discrete support (countable) and continuous support (uncountable) yields +# continuous support (uncountable) +Base.promote_rule(::Type{Continuous}, ::Type{Discrete}) = Continuous + ## Sampleable """ @@ -42,7 +71,6 @@ Any `Sampleable` implements the `Base.rand` method. """ abstract type Sampleable{F<:VariateForm,S<:ValueSupport} end - variate_form(::Type{<:Sampleable{VF}}) where {VF} = VF value_support(::Type{<:Sampleable{<:VariateForm,VS}}) where {VS} = VS @@ -142,10 +170,6 @@ const ContinuousMultivariateDistribution = Distribution{Multivariate, Continuou const DiscreteMatrixDistribution = Distribution{Matrixvariate, Discrete} const ContinuousMatrixDistribution = Distribution{Matrixvariate, Continuous} -variate_form(::Type{<:Distribution{VF}}) where {VF} = VF - -value_support(::Type{<:Distribution{VF,VS}}) where {VF,VS} = VS - # allow broadcasting over distribution objects # to be decided: how to handle multivariate/matrixvariate distributions? Broadcast.broadcastable(d::UnivariateDistribution) = Ref(d) diff --git a/src/multivariate/product.jl b/src/multivariate/product.jl index 4e304d76b..d2bcd7a95 100644 --- a/src/multivariate/product.jl +++ b/src/multivariate/product.jl @@ -1,4 +1,5 @@ -import Statistics: mean, var, cov +# Deprecated product distribution +# TODO: Remove in next breaking release """ Product <: MultivariateDistribution @@ -20,6 +21,10 @@ struct Product{ V<:AbstractVector{T} where T<:UnivariateDistribution{S} where S<:ValueSupport + Base.depwarn( + "`Product(v)` is deprecated, please use `product_distribution(v)`", + :Product, + ) return new{S, T, V}(v) end end @@ -43,26 +48,9 @@ insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x)) minimum(d::Product) = map(minimum, d.v) maximum(d::Product) = map(maximum, d.v) -""" - product_distribution(dists::AbstractVector{<:UnivariateDistribution}) - -Creates a multivariate product distribution `P` from a vector of univariate distributions. -Fallback is the `Product constructor`, but specialized methods can be defined -for distributions with a special multivariate product. -""" -function product_distribution(dists::AbstractVector{<:UnivariateDistribution}) - return Product(dists) -end - -""" - product_distribution(dists::AbstractVector{<:Normal}) - -Computes the multivariate Normal distribution obtained by stacking the univariate -normal distributions. The result is a multivariate Gaussian with a diagonal -covariance matrix. -""" -function product_distribution(dists::AbstractVector{<:Normal}) - µ = mean.(dists) - σ2 = var.(dists) - return MvNormal(µ, Diagonal(σ2)) -end +# TODO: remove deprecation when `Product` is removed +# it will return a `ProductDistribution` then which is already the default for +# higher-dimensional arrays and distributions +Base.@deprecate product_distribution( + dists::AbstractVector{<:UnivariateDistribution} +) Product(dists) diff --git a/src/multivariates.jl b/src/multivariates.jl index 1a087f1ba..477c78ba5 100644 --- a/src/multivariates.jl +++ b/src/multivariates.jl @@ -116,7 +116,7 @@ for fname in ["dirichlet.jl", "mvnormalcanon.jl", "mvlognormal.jl", "mvtdist.jl", - "product.jl", + "product.jl", # deprecated "vonmisesfisher.jl"] include(joinpath("multivariate", fname)) end diff --git a/src/product.jl b/src/product.jl new file mode 100644 index 000000000..049a5888d --- /dev/null +++ b/src/product.jl @@ -0,0 +1,246 @@ +""" + ProductDistribution <: Distribution{<:ValueSupport,<:ArrayLikeVariate} + +A distribution of `M + N`-dimensional arrays, constructed from an `N`-dimensional array of +independent `M`-dimensional distributions by stacking them. + +Users should use [`product_distribution`](@ref) to construct a product distribution of +independent distributions instead of constructing a `ProductDistribution` directly. +""" +struct ProductDistribution{N,M,D,S<:ValueSupport,T} <: Distribution{ArrayLikeVariate{N},S} + dists::D + size::Dims{N} + + function ProductDistribution{N,M,D}(dists::D) where {N,M,D} + isempty(dists) && error("product distribution must consist of at least one distribution") + return new{N,M,D,_product_valuesupport(dists),_product_eltype(dists)}( + dists, + _product_size(dists), + ) + end +end + +function ProductDistribution(dists::AbstractArray{<:Distribution{ArrayLikeVariate{M}},N}) where {M,N} + return ProductDistribution{M + N,M,typeof(dists)}(dists) +end + +function ProductDistribution(dists::Tuple{Vararg{<:Distribution{ArrayLikeVariate{M}},N}}) where {M,N} + return ProductDistribution{M + 1,M,typeof(dists)}(dists) +end + +# default definitions (type stable e.g. for arrays with concrete `eltype`) +_product_valuesupport(dists) = mapreduce(value_support ∘ typeof, promote_type, dists) +_product_eltype(dists) = mapreduce(eltype, promote_type, dists) + +# type-stable and faster implementations for tuples +function _product_valuesupport(dists::Tuple{Vararg{<:Distribution}}) + return __product_promote_type(value_support, typeof(dists)) +end +function _product_eltype(dists::Tuple{Vararg{<:Distribution}}) + return __product_promote_type(eltype, typeof(dists)) +end + +__product_promote_type(f::F, ::Type{Tuple{D}}) where {F,D<:Distribution} = f(D) +function __product_promote_type(f::F, ::Type{T}) where {F,T} + return promote_type( + f(Base.tuple_type_head(T)), + __product_promote_type(f, Base.tuple_type_tail(T)), + ) +end + +function _product_size(dists::AbstractArray{<:Distribution{<:ArrayLikeVariate{M}},N}) where {M,N} + size_d = size(first(dists)) + all(size(d) == size_d for d in dists) || error("all distributions must be of the same size") + size_dists = size(dists) + return ntuple(i -> i <= M ? size_d[i] : size_dists[i-M], Val(M + N)) +end +function _product_size(dists::Tuple{Vararg{<:Distribution{<:ArrayLikeVariate{M}},N}}) where {M,N} + size_d = size(first(dists)) + all(size(d) == size_d for d in dists) || error("all distributions must be of the same size") + return ntuple(i -> i <= M ? size_d[i] : N, Val(M + 1)) +end + +## aliases +const VectorOfUnivariateDistribution{D,S<:ValueSupport,T} = ProductDistribution{1,0,D,S,T} +const MatrixOfUnivariateDistribution{D,S<:ValueSupport,T} = ProductDistribution{2,0,D,S,T} +const ArrayOfUnivariateDistribution{N,D,S<:ValueSupport,T} = ProductDistribution{N,0,D,S,T} + +const FillArrayOfUnivariateDistribution{N,D<:Fill{<:Any,N},S<:ValueSupport,T} = ProductDistribution{N,0,D,S,T} + +## General definitions +function Base.eltype(::Type{<:ProductDistribution{<:Any,<:Any,<:Any,<:ValueSupport,T}}) where {T} + return T +end + +size(d::ProductDistribution) = d.size + +mean(d::ProductDistribution) = reshape(mapreduce(vec ∘ mean, vcat, d.dists), size(d)) +var(d::ProductDistribution) = reshape(mapreduce(vec ∘ var, vcat, d.dists), size(d)) +cov(d::ProductDistribution) = Diagonal(vec(var(d))) + +## For product distributions of univariate distributions +mean(d::ArrayOfUnivariateDistribution) = map(mean, d.dists) +mean(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(mean, d.dists)) +var(d::ArrayOfUnivariateDistribution) = map(var, d.dists) +var(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(var, d.dists)) + +function insupport(d::ArrayOfUnivariateDistribution{N}, x::AbstractArray{<:Real,N}) where {N} + size(d) == size(x) && all(insupport(vi, xi) for (vi, xi) in zip(d.dists, x)) +end + +minimum(d::ArrayOfUnivariateDistribution) = map(minimum, d.dists) +minimum(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(minimum, d.dists)) +maximum(d::ArrayOfUnivariateDistribution) = map(maximum, d.dists) +maximum(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(maximum, d.dists)) + +function entropy(d::ArrayOfUnivariateDistribution) + # we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020) + return sum(Broadcast.instantiate(Broadcast.broadcasted(entropy, d.dists))) +end +# fix type instability with tuples +entropy(d::VectorOfUnivariateDistribution{<:Tuple}) = sum(entropy, d.dists) + +## Vector of univariate distributions +length(d::VectorOfUnivariateDistribution) = length(d.dists) + +## For matrix distributions +cov(d::ProductDistribution{2}, ::Val{false}) = reshape(cov(d), size(d)..., size(d)...) + +# `_rand!` for arrays of univariate distributions +function _rand!( + rng::AbstractRNG, + d::ArrayOfUnivariateDistribution{N}, + x::AbstractArray{<:Real,N}, +) where {N} + @inbounds for (i, di) in zip(eachindex(x), d.dists) + x[i] = rand(rng, di) + end + return x +end + +# `_logpdf` for arrays of univariate distributions +# we have to fix a method ambiguity +function _logpdf(d::ArrayOfUnivariateDistribution, x::AbstractArray{<:Real,N}) where {N} + return __logpdf(d, x) +end +_logpdf(d::MatrixOfUnivariateDistribution, x::AbstractMatrix{<:Real}) = __logpdf(d, x) +function __logpdf(d::ArrayOfUnivariateDistribution, x::AbstractArray{<:Real,N}) where {N} + # we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020) + # without allocations to compute `sum(logpdf.(d.dists, x))` + broadcasted = Broadcast.broadcasted(logpdf, d.dists, x) + return sum(Broadcast.instantiate(broadcasted)) +end + +# more efficient implementation of `_rand!` for `Fill` array of univariate distributions +function _rand!( + rng::AbstractRNG, + d::FillArrayOfUnivariateDistribution{N}, + x::AbstractArray{<:Real,N}, +) where {N} + return @inbounds rand!(rng, sampler(first(d.dists)), x) +end + +# more efficient implementation of `_logpdf` for `Fill` array of univariate distributions +# we have to fix a method ambiguity +function _logpdf( + d::FillArrayOfUnivariateDistribution{N}, x::AbstractArray{<:Real,N} +) where {N} + return __logpdf(d, x) +end +_logpdf(d::FillArrayOfUnivariateDistribution{2}, x::AbstractMatrix{<:Real}) = __logpdf(d, x) +function __logpdf( + d::FillArrayOfUnivariateDistribution{N}, x::AbstractArray{<:Real,N} +) where {N} + return @inbounds loglikelihood(first(d.dists), x) +end + +# `_rand! for arrays of distributions +function _rand!( + rng::AbstractRNG, + d::ProductDistribution{N,M}, + A::AbstractArray{<:Real,N}, +) where {N,M} + @inbounds for (di, Ai) in zip(d.dists, eachvariate(A, ArrayLikeVariate{M})) + rand!(rng, di, Ai) + end + return A +end + +# `_logpdf` for arrays of distributions +# we have to fix a method ambiguity +_logpdf(d::ProductDistribution{N}, x::AbstractArray{<:Real,N}) where {N} = __logpdf(d, x) +_logpdf(d::ProductDistribution{2}, x::AbstractMatrix{<:Real}) = __logpdf(d, x) +function __logpdf( + d::ProductDistribution{N,M}, + x::AbstractArray{<:Real,N}, +) where {N,M} + # we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020) + # to compute `sum(logpdf.(d.dists, eachvariate))` + @inbounds broadcasted = Broadcast.broadcasted( + logpdf, d.dists, eachvariate(x, ArrayLikeVariate{M}), + ) + return sum(Broadcast.instantiate(broadcasted)) +end + +# more efficient implementation of `_rand!` for `Fill` arrays of distributions +function _rand!( + rng::AbstractRNG, + d::ProductDistribution{N,M,<:Fill}, + A::AbstractArray{<:Real,N}, +) where {N,M} + @inbounds rand!(rng, sampler(first(d.dists)), A) + return A +end + +# more efficient implementation of `_logpdf` for `Fill` arrays of distributions +# we have to fix a method ambiguity +function _logpdf( + d::ProductDistribution{N,M,<:Fill}, + x::AbstractArray{<:Real,N}, +) where {N,M} + return __logpdf(d, x) +end +function _logpdf( + d::ProductDistribution{2,M,<:Fill}, + x::AbstractMatrix{<:Real}, +) where {M} + return __logpdf(d, x) +end +function __logpdf( + d::ProductDistribution{N,M,<:Fill}, + x::AbstractArray{<:Real,N}, +) where {N,M} + return @inbounds loglikelihood(first(d.dists), x) +end + +""" + product_distribution(dists::AbstractArray{<:Distribution{<:ArrayLikeVariate{M}},N}) + +Create a distribution of `M + N`-dimensional arrays as a product distribution of +independent `M`-dimensional distributions by stacking them. + +The function falls back to constructing a [`ProductDistribution`](@ref) distribution but +specialized methods can be defined. +""" +function product_distribution(dists::AbstractArray{<:Distribution{<:ArrayLikeVariate}}) + return ProductDistribution(dists) +end + +function product_distribution( + dist::Distribution{ArrayLikeVariate{N}}, dists::Distribution{ArrayLikeVariate{N}}..., +) where {N} + return ProductDistribution((dist, dists...)) +end + +""" + product_distribution(dists::AbstractVector{<:Normal}) + +Create a multivariate normal distribution by stacking the univariate normal distributions. + +The resulting distribution of type [`MvNormal`](@ref) has a diagonal covariance matrix. +""" +function product_distribution(dists::AbstractVector{<:Normal}) + µ = map(mean, dists) + σ2 = map(var, dists) + return MvNormal(µ, Diagonal(σ2)) +end diff --git a/test/product.jl b/test/product.jl index d7bf0ae78..7d19898db 100644 --- a/test/product.jl +++ b/test/product.jl @@ -1,6 +1,14 @@ -using Distributions, Test, Random, LinearAlgebra, FillArrays +using Distributions +using FillArrays + +using Test +using Random +using LinearAlgebra + using Distributions: Product +# TODO: remove when `Product` is removed +@testset "Deprecated `Product` distribution" begin @testset "Testing normal product distributions" begin Random.seed!(123456) N = 11 @@ -8,8 +16,8 @@ using Distributions: Product μ = randn(N) ds = Normal.(μ, 1.0) x = rand.(ds) - d_product = product_distribution(ds) - @test d_product isa MvNormal + d_product = @test_deprecated(Product(ds)) + @test d_product isa Product # Check that methods for `Product` are consistent. @test length(d_product) == length(ds) @test eltype(d_product) === eltype(ds[1]) @@ -31,7 +39,7 @@ end ubound = rand(N) ds = Uniform.(-ubound, ubound) x = rand.(ds) - d_product = product_distribution(ds) + d_product = @test_deprecated(product_distribution(ds)) @test d_product isa Product # Check that methods for `Product` are consistent. @test length(d_product) == length(ds) @@ -62,7 +70,7 @@ end support = fill(a, N) ds = DiscreteNonParametric.(support, Ref([0.5, 0.5])) x = rand.(ds) - d_product = product_distribution(ds) + d_product = @test_deprecated(product_distribution(ds)) @test d_product isa Product # Check that methods for `Product` are consistent. @test length(d_product) == length(ds) @@ -89,4 +97,252 @@ end @test mean(d) === Fill(0.0, N) @test cov(d) === Diagonal(Fill(var(Laplace(0.0, 2.3)), N)) end +end + +@testset "Testing normal product distributions" begin + Random.seed!(123456) + N = 11 + + # Construct independent distributions and `ProductDistribution` from these. + μ = randn(N) + + ds1 = Normal.(μ, 1.0) + d_product1 = @inferred(product_distribution(ds1)) + @test d_product1 isa Distributions.DiagNormal + + ds2 = Fill(Normal(first(μ), 1.0), N) + d_product2 = @inferred(product_distribution(ds2)) + @test d_product2 isa MvNormal{Float64,Distributions.ScalMat{Float64},<:Fill{Float64,1}} + + # Check that methods for `ProductDistribution` are consistent. + for (ds, d_product) in ((ds1, d_product1), (ds2, d_product2)) + @test length(d_product) == length(ds) + @test eltype(d_product) === eltype(ds[1]) + @test mean(d_product) == mean.(ds) + @test var(d_product) == var.(ds) + @test cov(d_product) == Diagonal(var.(ds)) + @test entropy(d_product) ≈ sum(entropy.(ds)) + + x = rand(d_product) + @test x isa typeof(rand.(collect(ds))) + @test length(x) == N + @test logpdf(d_product, x) ≈ sum(logpdf.(ds, x)) + end +end + +@testset "Testing generic VectorOfUnivariateDistribution" begin + Random.seed!(123456) + N = 11 + + # Construct independent distributions and `ProductDistribution` from these. + ubound = rand(N) + + ds1 = Uniform.(0.0, ubound) + # Replace with + # d_product1 = @inferred(product_distribution(ds1)) + # when `Product` is removed + d_product1 = @inferred(Distributions.ProductDistribution(ds1)) + @test d_product1 isa Distributions.VectorOfUnivariateDistribution{<:Vector,Continuous,Float64} + + d_product2 = @inferred(product_distribution(ntuple(i -> Uniform(0.0, ubound[i]), 11)...)) + @test d_product2 isa Distributions.VectorOfUnivariateDistribution{<:Tuple,Continuous,Float64} + + ds3 = Fill(Uniform(0.0, first(ubound)), N) + # Replace with + # d_product3 = @inferred(product_distribution(ds3)) + # when `Product` is removed + d_product3 = @inferred(Distributions.ProductDistribution(ds3)) + @test d_product3 isa Distributions.VectorOfUnivariateDistribution{<:Fill,Continuous,Float64} + + # Check that methods for `VectorOfUnivariateDistribution` are consistent. + for (ds, d_product) in ((ds1, d_product1), (ds1, d_product2), (ds3, d_product3)) + @test length(d_product) == length(ds) + @test eltype(d_product) === eltype(ds[1]) + @test @inferred(mean(d_product)) == mean.(ds) + @test @inferred(var(d_product)) == var.(ds) + @test @inferred(cov(d_product)) == Diagonal(var.(ds)) + @test @inferred(entropy(d_product)) == sum(entropy.(ds)) + @test insupport(d_product, zeros(N)) + @test insupport(d_product, maximum.(ds)) + @test !insupport(d_product, maximum.(ds) .+ 1) + @test !insupport(d_product, zeros(N + 1)) + + @test minimum(d_product) == map(minimum, ds) + @test maximum(d_product) == map(maximum, ds) + @test extrema(d_product) == (map(minimum, ds), map(maximum, ds)) + + x = @inferred(rand(d_product)) + @test x isa typeof(rand.(collect(ds))) + @test length(x) == length(d_product) + @test insupport(d_product, x) + @test @inferred(logpdf(d_product, x)) ≈ sum(logpdf.(ds, x)) + # ensure that samples are different, in particular if `Fill` is used + @test length(unique(x)) == N + end +end + +@testset "Testing discrete non-parametric VectorOfUnivariateDistribution" begin + Random.seed!(123456) + N = 11 + + for a in ([0, 1], [-0.5, 0.5]) + # Construct independent distributions and `ProductDistribution` from these. + ds1 = DiscreteNonParametric.(fill(a, N), Ref([0.5, 0.5])) + # Replace with + # d_product1 = @inferred(product_distribution(ds1)) + # when `Product` is removed + d_product1 = @inferred(Distributions.ProductDistribution(ds1)) + @test d_product1 isa Distributions.VectorOfUnivariateDistribution{<:Vector{<:DiscreteNonParametric},Discrete,eltype(a)} + + d_product2 = @inferred(product_distribution(ntuple(_ -> DiscreteNonParametric(a, [0.5, 0.5]), 11)...)) + @test d_product2 isa Distributions.VectorOfUnivariateDistribution{<:NTuple{N,<:DiscreteNonParametric},Discrete,eltype(a)} + + ds3 = Fill(DiscreteNonParametric(a, [0.5, 0.5]), N) + # Replace with + # d_product3 = @inferred(product_distribution(ds3)) + # when `Product` is removed + d_product3 = @inferred(Distributions.ProductDistribution(ds3)) + @test d_product3 isa Distributions.VectorOfUnivariateDistribution{<:Fill{<:DiscreteNonParametric,1},Discrete,eltype(a)} + + # Check that methods for `VectorOfUnivariateDistribution` are consistent. + for (ds, d_product) in ((ds1, d_product1), (ds1, d_product3), (ds3, d_product2)) + @test length(d_product) == length(ds) + @test eltype(d_product) === eltype(ds[1]) + @test @inferred(mean(d_product)) == mean.(ds) + @test @inferred(var(d_product)) == var.(ds) + @test @inferred(cov(d_product)) == Diagonal(var.(ds)) + @test @inferred(entropy(d_product)) == sum(entropy.(ds)) + @test insupport(d_product, fill(a[2], N)) + @test !insupport(d_product, fill(a[2] + 1, N)) + @test !insupport(d_product, fill(a[2], N + 1)) + + @test minimum(d_product) == map(minimum, ds) + @test maximum(d_product) == map(maximum, ds) + @test extrema(d_product) == (map(minimum, ds), map(maximum, ds)) + + x = @inferred(rand(d_product)) + @test x isa typeof(rand.(collect(ds))) + @test length(x) == length(d_product) + @test insupport(d_product, x) + @test @inferred(logpdf(d_product, x)) ≈ sum(logpdf.(ds, x)) + # ensure that samples are different, in particular if `Fill` is used + @test length(unique(x)) == 2 + end + end +end +@testset "Testing tuple of continuous and discrete distribution" begin + Random.seed!(123456) + N = 11 + + ds = (Bernoulli(0.3), Uniform(0.0, 0.7), Categorical([0.4, 0.2, 0.4])) + d_product = @inferred(product_distribution(ds...)) + @test d_product isa Distributions.VectorOfUnivariateDistribution{<:Tuple,Continuous,Float64} + + ds_vec = vcat(ds...) + + @test length(d_product) == 3 + @test eltype(d_product) === Float64 + @test @inferred(mean(d_product)) == mean.(ds_vec) + @test @inferred(var(d_product)) == var.(ds_vec) + @test @inferred(cov(d_product)) == Diagonal(var.(ds_vec)) + @test @inferred(entropy(d_product)) == sum(entropy.(ds_vec)) + @test insupport(d_product, [0, 0.2, 3]) + @test !insupport(d_product, [-0.5, 0.2, 3]) + @test !insupport(d_product, [0, -0.5, 3]) + @test !insupport(d_product, [0, 0.2, -0.5]) + + @test @inferred(minimum(d_product)) == map(minimum, ds_vec) + @test @inferred(maximum(d_product)) == map(maximum, ds_vec) + @test @inferred(extrema(d_product)) == (map(minimum, ds_vec), map(maximum, ds_vec)) + + x = @inferred(rand(d_product)) + @test x isa Vector{Float64} + @test length(x) == length(d_product) + @test insupport(d_product, x) + @test @inferred(logpdf(d_product, x)) ≈ sum(logpdf.(ds, x)) +end + +@testset "Testing generic MatrixOfUnivariateDistribution" begin + Random.seed!(123456) + M, N = 11, 16 + + # Construct independent distributions and `ProductDistribution` from these. + ubound = rand(M, N) + + ds1 = Uniform.(0.0, ubound) + d_product1 = @inferred(product_distribution(ds1)) + @test d_product1 isa Distributions.MatrixOfUnivariateDistribution{<:Matrix{<:Uniform},Continuous,Float64} + + ds2 = Fill(Uniform(0.0, first(ubound)), M, N) + d_product2 = @inferred(product_distribution(ds2)) + @test d_product2 isa Distributions.MatrixOfUnivariateDistribution{<:Fill{<:Uniform,2},Continuous,Float64} + + # Check that methods for `MatrixOfUnivariateDistribution` are consistent. + for (ds, d_product) in ((ds1, d_product1), (ds2, d_product2)) + @test size(d_product) == size(ds) + @test eltype(d_product) === eltype(ds[1]) + @test @inferred(mean(d_product)) == mean.(ds) + @test @inferred(var(d_product)) == var.(ds) + @test @inferred(cov(d_product)) == Diagonal(vec(var.(ds))) + @test @inferred(cov(d_product, Val(false))) == reshape(Diagonal(vec(var.(ds))), M, N, M, N) + + @test minimum(d_product) == map(minimum, ds) + @test maximum(d_product) == map(maximum, ds) + @test extrema(d_product) == (map(minimum, ds), map(maximum, ds)) + + x = @inferred(rand(d_product)) + @test size(x) == size(d_product) + @test x isa typeof(rand.(collect(ds))) + @test @inferred(logpdf(d_product, x)) ≈ sum(logpdf.(ds, x)) + # ensure that samples are different, in particular if `Fill` is used + @test length(unique(x)) == length(d_product) + end +end + +@testset "Testing generic array of multivariate distribution" begin + Random.seed!(123456) + M = 3 + + for N in ((11,), (11, 3)) + # Construct independent distributions and `ProductDistribution` from these. + alphas = [normalize!(rand(M), 1) for _ in Iterators.product(map(x -> 1:x, N)...)] + + ds1 = Dirichlet.(alphas) + d_product1 = @inferred(product_distribution(ds1)) + @test d_product1 isa Distributions.ProductDistribution{length(N) + 1,1,<:Array{<:Dirichlet{Float64},length(N)},Continuous,Float64} + + ds2 = Fill(Dirichlet(first(alphas)), N...) + d_product2 = @inferred(product_distribution(ds2)) + @test d_product2 isa Distributions.ProductDistribution{length(N) + 1,1,<:Fill{<:Dirichlet{Float64},length(N)},Continuous,Float64} + + # Check that methods for `VectorOfMultivariateDistribution` are consistent. + for (ds, d_product) in ((ds1, d_product1), (ds2, d_product2)) + @test size(d_product) == (length(ds[1]), size(ds)...) + @test eltype(d_product) === eltype(ds[1]) + @test @inferred(mean(d_product)) == reshape(mapreduce(mean, (x, y) -> cat(x, y; dims=ndims(ds) + 1), ds), size(d_product)) + @test @inferred(var(d_product)) == reshape(mapreduce(var, (x, y) -> cat(x, y; dims=ndims(ds) + 1), ds), size(d_product)) + @test @inferred(cov(d_product)) == Diagonal(mapreduce(var, vcat, ds)) + + if d_product isa MatrixDistribution + @test @inferred(cov(d_product, Val(false))) == reshape( + Diagonal(mapreduce(var, vcat, ds)), M, length(ds), M, length(ds) + ) + end + + x = @inferred(rand(d_product)) + @test size(x) == size(d_product) + @test x isa typeof(mapreduce(rand, (x, y) -> cat(x, y; dims=ndims(ds) + 1), ds)) + + # inference broken for non-Fill arrays + y = reshape(x, Val(2)) + if ds isa Fill + @test @inferred(logpdf(d_product, x)) ≈ sum(logpdf(d, y[:, i]) for (i, d) in enumerate(ds)) + else + @test logpdf(d_product, x) ≈ sum(logpdf(d, y[:, i]) for (i, d) in enumerate(ds)) + end + # ensure that samples are different, in particular if `Fill` is used + @test length(unique(x)) == length(d_product) + end + end +end \ No newline at end of file