Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,11 +2121,11 @@ def ref_rand(size, n, eta):
size=1000,
)

@pytest.mark.parametrize(argnames="n", argvalues=[2, 3], ids=["n=2", "n=3"])
def test_default_transform(self, n):
def test_default_transform(self):
with pm.Model() as m:
pm.LKJCorr("x", n=n, eta=1)
m.logp()
x = pm.LKJCorr("x", n=2, eta=1, shape=(3, 2))
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
assert m.logp(sum=False)[0].shape == (3,)


class TestLKJCholeskyCov(BaseTestDistributionRandom):
Expand Down