diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 645f5a47a6..741bfa7216 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -8,6 +8,7 @@ - Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)). - Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)). - Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129) +- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169) ### Documentation diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index a117211b1e..1244fd5d12 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -597,14 +597,10 @@ def __init__(self, n, p, *args, **kwargs): super().__init__(*args, **kwargs) p = p / tt.sum(p, axis=-1, keepdims=True) - n = np.squeeze(n) # works also if n is a tensor if len(self.shape) > 1: self.n = tt.shape_padright(n) self.p = p if p.ndim > 1 else tt.shape_padleft(p) - elif n.ndim == 1: - self.n = tt.shape_padright(n) - self.p = p if p.ndim > 1 else tt.shape_padleft(p) else: # n is a scalar, p is a 1d array self.n = tt.as_tensor_variable(n) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 595de26a6a..7b19e00049 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1447,6 +1447,28 @@ def test_multinomial_vec_2d_p(self): decimal=4, ) + def test_batch_multinomial(self): + n = 10 + vals = np.zeros((4, 5, 3), dtype="int32") + p = np.zeros_like(vals, dtype=theano.config.floatX) + inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None] + np.put_along_axis(vals, inds, n, axis=-1) + np.put_along_axis(p, inds, 1, axis=-1) + + dist = Multinomial.dist(n=n, p=p, shape=vals.shape) + value = tt.tensor3(dtype="int32") + value.tag.test_value = np.zeros_like(vals, dtype="int32") + logp = tt.exp(dist.logp(value)) + f = theano.function(inputs=[value], outputs=logp) + assert_almost_equal( + f(vals), + np.ones(vals.shape[:-1] + (1,)), + decimal=select_by_precision(float64=6, float32=3), + ) + + sample = dist.random(size=2) + assert_allclose(sample, np.stack([vals, vals], axis=0)) + def test_categorical_bounds(self): with Model(): x = Categorical("x", p=np.array([0.2, 0.3, 0.5]))