Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,13 @@ def _random_corr_matrix(cls, rng, n, eta, flat_size):
lkjcorr = LKJCorrRV()


class MultivariateIntervalTransform(Interval):
name = "interval"

def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)


class LKJCorr(BoundedContinuous):
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
Expand Down Expand Up @@ -1623,7 +1630,7 @@ def logp(value, n, eta):

@_default_transform.register(LKJCorr)
def lkjcorr_default_transform(op, rv):
return Interval(floatX(-1.0), floatX(1.0))
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))


class MatrixNormalRV(RandomVariable):
Expand Down
13 changes: 6 additions & 7 deletions pymc/logprob/transform_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,12 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs)
raise NotImplementedError(
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
)
else:
# Check there is no broadcasting between logp and jacobian
if logp.type.broadcastable != log_jac_det.type.broadcastable:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)
# Check there is no broadcasting between logp and jacobian
if logp.type.broadcastable != log_jac_det.type.broadcastable:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)

if use_jacobian:
if value.name:
Expand Down
22 changes: 21 additions & 1 deletion tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import functools as ft
import re
import warnings

import numpy as np
Expand All @@ -33,6 +32,7 @@
import pymc as pm

from pymc.distributions.multivariate import (
MultivariateIntervalTransform,
_LKJCholeskyCov,
_OrderedMultinomial,
posdef,
Expand Down Expand Up @@ -2121,6 +2121,26 @@ def ref_rand(size, n, eta):
size=1000,
)

@pytest.mark.parametrize(
argnames="shape, expected_shape",
argvalues=[
((2,), ()),
pytest.param(
(3, 2),
(3,),
marks=pytest.mark.xfail(
raises=NotImplementedError,
reason="We do not support batch dimensions for pm.LKJCorr yet.",
),
),
],
)
def test_default_transform(self, shape, expected_shape):
with pm.Model() as m:
x = pm.LKJCorr("x", n=2, eta=1, shape=shape)
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
assert m.logp(sum=False)[0].type.shape == expected_shape


class TestLKJCholeskyCov(BaseTestDistributionRandom):
pymc_dist = _LKJCholeskyCov
Expand Down