Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize Product #1391

Merged
merged 31 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
34cd1ac
Move `src/multivariate/product.jl`
devmotion Aug 31, 2021
2758c6a
Generalize `Product` to `ProductDistribution`
devmotion Aug 31, 2021
541e265
Add implementations for more general product distributions
devmotion Sep 1, 2021
5c7a82a
Merge branch 'master' into dw/product
devmotion Nov 9, 2021
c951804
Unify and generalize `rand!`, `logpdf` and `pdf`
devmotion Nov 14, 2021
f4d9184
Merge branch 'master' into dw/product
devmotion Nov 14, 2021
39ad15e
Revert unrelated changes and fix tests
devmotion Nov 14, 2021
5417bb5
Propagate `@inbounds`
devmotion Nov 14, 2021
0ce8ee3
Remove unneeded implementation
devmotion Nov 14, 2021
b0a369d
Fix typos
devmotion Nov 14, 2021
52bd470
Fix some dispatches
devmotion Nov 14, 2021
bf140cc
More fixes
devmotion Nov 15, 2021
343d266
Support tuple of distributions and mix of discrete + continuous
devmotion Nov 15, 2021
120c3f1
Fix additional test errors
devmotion Nov 15, 2021
867a21e
Fix method ambiguity
devmotion Nov 15, 2021
f86549b
Fix `VonMisesFisherSampler`
devmotion Nov 15, 2021
6847f71
Fix mixture sampler
devmotion Nov 15, 2021
84ee30c
Simplify multinomial sampler
devmotion Nov 15, 2021
bab927b
Fix `loglikelihood` for univariate distributions
devmotion Nov 16, 2021
1414e97
Add ReshapedDistribution
devmotion Nov 16, 2021
7df43ce
Fix typo
devmotion Nov 16, 2021
768ea61
Revert some changes
devmotion Nov 16, 2021
19c7c49
Update product.jl
devmotion Nov 16, 2021
b3c6a9c
Merge branch 'master' into dw/product
devmotion Nov 28, 2021
2e7341b
Remove duplicate `eachvariate`/`EachVariate`
devmotion Nov 28, 2021
97d3a0a
Reintroduce `Product`
devmotion Nov 28, 2021
167f319
Improve type inference
devmotion Nov 28, 2021
b190369
Add explanations of `ValueSupport`
devmotion Nov 28, 2021
e1fa4b5
Fix typo
devmotion Nov 28, 2021
55e6c5b
Remove another breaking change
devmotion Nov 29, 2021
1a574fa
Merge branch 'master' into dw/product
devmotion Jun 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,21 @@ The `VariateForm` sub-types defined in `Distributions.jl` are:

### ValueSupport

```@doc
```@docs
Distributions.ValueSupport
```

The `ValueSupport` sub-types defined in `Distributions.jl` are:

**Type** | **Element type** | **Descriptions**
--- | --- | ---
`Discrete` | `Int` | Samples take discrete values
`Continuous` | `Float64` | Samples take continuous real values
```@docs
Distributions.Discrete
Distributions.Continuous
```

**Type** | **Default element type** | **Description** | **Examples**
--- | --- | --- | ---
`Discrete` | `Int` | Samples take countably many values | $\{0,1,2,3\}$, $\mathbb{N}$
`Continuous` | `Float64` | Samples take uncountably many values | $[0, 1]$, $\mathbb{R}$

Multiple samples are often organized into an array, depending on the variate form.

Expand Down
3 changes: 2 additions & 1 deletion src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ export
Pareto,
PGeneralizedGaussian,
SkewedExponentialPower,
Product,
Product, # deprecated
Poisson,
PoissonBinomial,
QQPair,
Expand Down Expand Up @@ -294,6 +294,7 @@ include("cholesky/lkjcholesky.jl")
include("samplers.jl")

# others
include("product.jl")
include("reshaped.jl")
include("truncate.jl")
include("censored.jl")
Expand Down
38 changes: 31 additions & 7 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,42 @@ const Matrixvariate = ArrayLikeVariate{2}
abstract type CholeskyVariate <: VariateForm end

"""
`S <: ValueSupport` specifies the support of sample elements,
either discrete or continuous.
ValueSupport

Abstract type that specifies the support of elements of samples.

It is either [`Discrete`](@ref) or [`Continuous`](@ref).
"""
abstract type ValueSupport end

"""
Discrete <: ValueSupport

This type represents the support of a discrete random variable.

It is countable. For instance, it can be a finite set or a countably infinite set such as
the natural numbers.

See also: [`Continuous`](@ref), [`ValueSupport`](@ref)
"""
struct Discrete <: ValueSupport end

"""
Continuous <: ValueSupport

This types represents the support of a continuous random variable.

It is uncountably infinite. For instance, it can be an interval on the real line.

See also: [`Discrete`](@ref), [`ValueSupport`](@ref)
"""
struct Continuous <: ValueSupport end

# promotions (e.g., in product distribution):
# combination of discrete support (countable) and continuous support (uncountable) yields
# continuous support (uncountable)
Base.promote_rule(::Type{Continuous}, ::Type{Discrete}) = Continuous

## Sampleable

"""
Expand All @@ -42,7 +71,6 @@ Any `Sampleable` implements the `Base.rand` method.
"""
abstract type Sampleable{F<:VariateForm,S<:ValueSupport} end


variate_form(::Type{<:Sampleable{VF}}) where {VF} = VF
value_support(::Type{<:Sampleable{<:VariateForm,VS}}) where {VS} = VS

Expand Down Expand Up @@ -142,10 +170,6 @@ const ContinuousMultivariateDistribution = Distribution{Multivariate, Continuou
const DiscreteMatrixDistribution = Distribution{Matrixvariate, Discrete}
const ContinuousMatrixDistribution = Distribution{Matrixvariate, Continuous}

variate_form(::Type{<:Distribution{VF}}) where {VF} = VF

value_support(::Type{<:Distribution{VF,VS}}) where {VF,VS} = VS

# allow broadcasting over distribution objects
# to be decided: how to handle multivariate/matrixvariate distributions?
Broadcast.broadcastable(d::UnivariateDistribution) = Ref(d)
Expand Down
36 changes: 12 additions & 24 deletions src/multivariate/product.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Statistics: mean, var, cov
# Deprecated product distribution
# TODO: Remove in next breaking release

"""
Product <: MultivariateDistribution
Expand All @@ -20,6 +21,10 @@ struct Product{
V<:AbstractVector{T} where
T<:UnivariateDistribution{S} where
S<:ValueSupport
Base.depwarn(
"`Product(v)` is deprecated, please use `product_distribution(v)`",
:Product,
)
Comment on lines +24 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm just brainfarting, but it seems like this combined with the @deprecate below means that no matter which constructor I use for a vector of univariate distributions, I'm going to get a deprecation-warning?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at the moment one always get's a depwarn ... #1589

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's unfortunate and fixed by #1590. I thought someone would approve the PR and I would be able to make a bugfix release shortly after the issue was discovered but it seems nobody has approved it within almost two weeks. I'm going to merge it and tag a release now since the issue is quite annoying for downstream packages (and AFAICT even causes time outs in Turing).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @devmotion !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful stuff @devmotion !

return new{S, T, V}(v)
end
end
Expand All @@ -43,26 +48,9 @@ insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x))
minimum(d::Product) = map(minimum, d.v)
maximum(d::Product) = map(maximum, d.v)

"""
product_distribution(dists::AbstractVector{<:UnivariateDistribution})

Creates a multivariate product distribution `P` from a vector of univariate distributions.
Fallback is the `Product constructor`, but specialized methods can be defined
for distributions with a special multivariate product.
"""
function product_distribution(dists::AbstractVector{<:UnivariateDistribution})
return Product(dists)
end

"""
product_distribution(dists::AbstractVector{<:Normal})

Computes the multivariate Normal distribution obtained by stacking the univariate
normal distributions. The result is a multivariate Gaussian with a diagonal
covariance matrix.
"""
function product_distribution(dists::AbstractVector{<:Normal})
µ = mean.(dists)
σ2 = var.(dists)
return MvNormal(µ, Diagonal(σ2))
end
# TODO: remove deprecation when `Product` is removed
# it will return a `ProductDistribution` then which is already the default for
# higher-dimensional arrays and distributions
Base.@deprecate product_distribution(
dists::AbstractVector{<:UnivariateDistribution}
) Product(dists)
2 changes: 1 addition & 1 deletion src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ for fname in ["dirichlet.jl",
"mvnormalcanon.jl",
"mvlognormal.jl",
"mvtdist.jl",
"product.jl",
"product.jl", # deprecated
"vonmisesfisher.jl"]
include(joinpath("multivariate", fname))
end
Loading