diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index c5faad5346..7fa3b924da 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -650,10 +650,16 @@ def __init__(self, nu, V, *args, **kwargs): self.p = p = tt.as_tensor_variable(V.shape[0]) self.V = V = tt.as_tensor_variable(V) self.mean = nu * V - self.mode = tt.switch(1 * (nu >= p + 1), + self.mode = tt.switch(tt.ge(nu, p + 1), (nu - p - 1) * V, np.nan) + def random(self, point=None, size=None): + nu, V = draw_values([self.nu, self.V], point=point) + size= 1 if size is None else size + return generate_samples(stats.wishart.rvs, np.asscalar(nu), V, + broadcast_shape=(size,)) + def logp(self, X): nu = self.nu p = self.p