Skip to content

Commit

Permalink
add requirement for dims to be immutable for the prior as it is requi…
Browse files Browse the repository at this point in the history
…red for the limit case masking
  • Loading branch information
ferrine committed Jun 22, 2023
1 parent a030dda commit ed5d36c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pymc_experimental/distributions/multivariate/r2d2m2cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,9 @@ def R2D2M2CP(
*broadcast_dims, dim = dims
input_sigma = pt.as_tensor(input_sigma)
output_sigma = pt.as_tensor(output_sigma)
with pm.Model(name):
with pm.Model(name) as model:
if not all(isinstance(model.dim_lengths[d], pt.TensorConstant) for d in dims):
raise ValueError(f"{dims!r} should be constant length imutable dims")
if r2_std is not None:
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
phi = _phi(
Expand Down
24 changes: 24 additions & 0 deletions pymc_experimental/tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,27 @@ def test_zero_length_rvs_not_created(self, model: pm.Model):
"b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a"
)
assert not model.free_RVs, model.free_RVs

def test_immutable_dims(self, model: pm.Model):
model.add_coord("a", range(2), mutable=True)
model.add_coord("b", range(2), mutable=False)
with pytest.raises(ValueError, match="should be constant length immutable dims"):
pmx.distributions.R2D2M2CP(
"beta0",
1,
[1, 1],
dims="a",
r2=0.8,
positive_probs=[0.5, 1],
positive_probs_std=[0.3, 0],
)
with pytest.raises(ValueError, match="should be constant length immutable dims"):
pmx.distributions.R2D2M2CP(
"beta0",
1,
[1, 1],
dims=("a", "b"),
r2=0.8,
positive_probs=[0.5, 1],
positive_probs_std=[0.3, 0],
)

0 comments on commit ed5d36c

Please sign in to comment.