Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Dirichlet-multinomial distribution. #3639

Closed
wants to merge 14 commits into from
Closed
2 changes: 2 additions & 0 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from .multivariate import MvStudentT
from .multivariate import Dirichlet
from .multivariate import Multinomial
from .multivariate import DirichletMultinomial
from .multivariate import Wishart
from .multivariate import WishartBartlett
from .multivariate import LKJCholeskyCov
Expand Down Expand Up @@ -150,6 +151,7 @@
'MvStudentT',
'Dirichlet',
'Multinomial',
'DirichletMultinomial',
'Wishart',
'WishartBartlett',
'LKJCholeskyCov',
Expand Down
82 changes: 81 additions & 1 deletion pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@


__all__ = ['MvNormal', 'MvStudentT', 'Dirichlet',
'Multinomial', 'Wishart', 'WishartBartlett',
'Multinomial', 'DirichletMultinomial',
'Wishart', 'WishartBartlett',
'LKJCorr', 'LKJCholeskyCov', 'MatrixNormal',
'KroneckerNormal']

Expand Down Expand Up @@ -709,6 +710,85 @@ def logp(self, x):
)


class DirichletMultinomial(Discrete):
R"""Dirichlet Multinomial log-likelihood.

Dirichlet mixture of multinomials distribution, with a marginalized PMF.

.. math::

f(x \mid n, \alpha) = \frac{\Gamma(n + 1)\Gamma(\sum\alpha_k)}
{\Gamma(\n + \sum\alpha_k)}
\prod_{k=1}^K
\frac{\Gamma(x_k + \alpha_k)}
{\Gamma(x_k + 1)\Gamma(alpha_k)}

========== ===========================================
Support :math:`x \in \{0, 1, \ldots, n\}` such that
:math:`\sum x_i = n`
Mean :math:`n \frac{\alpha_i}{\sum{\alpha_k}}`
========== ===========================================

Parameters
----------
alpha : two-dimensional array
Dirichlet parameter. Elements must be non-negative.
Dimension of each element of the distribution is the length
of the second dimension of alpha.
n : one-dimensional array
Total counts in each replicate.

"""

def __init__(self, n, alpha, *args, **kwargs):
super().__init__(*args, **kwargs)
self.alpha = tt.as_tensor_variable(alpha)
self.n = tt.as_tensor_variable(n)

p = self.alpha / self.alpha.sum(-1, keepdims=True)
self.mean = tt.shape_padright(self.n) * p

mode = tt.cast(tt.round(self.mean), 'int32')
diff = tt.shape_padright(self.n) - tt.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = tt.abs_(diff) > 0
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()],
diff[inc_bool_arr.nonzero()])
self.mode = mode

def logp(self, x):
alpha = self.alpha
n = self.n
sum_alpha = alpha.sum(axis=-1)

const = (gammaln(n + 1) + gammaln(sum_alpha)) - gammaln(n + sum_alpha)
series = gammaln(x + alpha) - (gammaln(x + 1) + gammaln(alpha))
result = const + series.sum(axis=-1)
return bound(result,
tt.all(tt.ge(x, 0)),
tt.all(tt.gt(alpha, 0)),
tt.all(tt.ge(n, 0)),
tt.all(tt.eq(x.sum(axis=-1), n)),
broadcast_conditions=False)

def random(self, point=None, size=None, repeat=None):
alpha, n = draw_values([self.alpha, self.n], point=point, size=size)
out = np.empty_like(alpha)
for i in range(len(n)):
p = np.random.dirichlet(alpha[i, :])
x = np.random.multinomial(n[i], p)
out[i, :] = x
return out

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
n = dist.n
alpha = dist.alpha
return (r'${} \sim \text{{DirichletMultinomial}}('
r'\matit{{n}}={} \mathit{{\alpha}}={})$'
).format(name, get_variable_name(n), get_variable_name(alpha))


def posdef(AA):
try:
linalg.cholesky(AA)
Expand Down
49 changes: 49 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Multinomial,
VonMises,
Dirichlet,
DirichletMultinomial,
MvStudentT,
MvNormal,
MatrixNormal,
Expand Down Expand Up @@ -1421,6 +1422,54 @@ def test_multinomial_vec_2d_p(self):
decimal=4,
)

@pytest.mark.parametrize('alpha,n', [
[[[.25, .25, .25, .25]], [1]],
[[[.3, .6, .05, .05]], [2]],
[[[.3, .6, .05, .05]], [10]],
[[[.3, .6, .05, .05],
[.25, .25, .25, .25]],
[10, 2]],
])
def test_dirichlet_multinomial_mode(self, alpha, n):
alpha = np.array(alpha)
n = np.array(n)
with Model() as model:
m = DirichletMultinomial('m', n, alpha,
shape=alpha.shape)
assert_allclose(m.distribution.mode.eval().sum(axis=-1), n)

@pytest.mark.parametrize('alpha,n,enum', [
[[[.25, .25, .25, .25]], [1], [[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]]]
])
def test_dirichlet_multinomial_pmf(self, alpha, n, enum):
alpha = np.array(alpha)
n = np.array(n)
with Model() as model:
m = DirichletMultinomial('m', n=n, alpha=alpha,
shape=alpha.shape)
logp = lambda x: m.distribution.logp(np.array([x])).eval()
p_all_poss = [np.exp(logp(x)) for x in enum]
assert_almost_equal(np.sum(p_all_poss), 1)

@pytest.mark.parametrize('alpha,n', [
[[[.25, .25, .25, .25]], [1]],
[[[.3, .6, .05, .05]], [2]],
[[[.3, .6, .05, .05]], [10]],
[[[.3, .6, .05, .05],
[.25, .25, .25, .25]],
[10, 2]],
])
def test_dirichlet_multinomial_random(self, alpha, n):
alpha = np.array(alpha)
n = np.array(n)
with Model() as model:
m = DirichletMultinomial('m', n=n, alpha=alpha,
shape=alpha.shape)
m.random()

def test_categorical_bounds(self):
with Model():
x = Categorical("x", p=np.array([0.2, 0.3, 0.5]))
Expand Down
5 changes: 5 additions & 0 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,11 @@ class TestBetaBinomial(BaseTestCases.BaseTestCase):
params = {"n": 5, "alpha": 1.0, "beta": 1.0}


class TestDirichletMultinomial(BaseTestCases.BaseTestCase):
distribution = pm.DirichletMultinomial
params = {'n': [5], 'alpha': [[1., 1., 1., 1.]]}


class TestBernoulli(BaseTestCases.BaseTestCase):
distribution = pm.Bernoulli
params = {"p": 0.5}
Expand Down