Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,13 @@ def _random_corr_matrix(cls, rng, n, eta, flat_size):
lkjcorr = LKJCorrRV()


class MultivariateIntervalTransform(Interval):
name = "interval"

def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)


class LKJCorr(BoundedContinuous):
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
Expand Down Expand Up @@ -1623,7 +1630,7 @@ def logp(value, n, eta):

@_default_transform.register(LKJCorr)
def lkjcorr_default_transform(op, rv):
return Interval(floatX(-1.0), floatX(1.0))
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))


class MatrixNormalRV(RandomVariable):
Expand Down
6 changes: 6 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,12 @@ def ref_rand(size, n, eta):
size=1000,
)

@pytest.mark.parametrize(argnames="n", argvalues=[2, 3], ids=["n=2", "n=3"])
def test_default_transform(self, n):
with pm.Model() as m:
pm.LKJCorr("x", n=n, eta=1)
m.logp()


class TestLKJCholeskyCov(BaseTestDistributionRandom):
pymc_dist = _LKJCholeskyCov
Expand Down