Skip to content

Commit

Permalink
fix number of chains
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Aug 24, 2021
1 parent d1982dc commit c1c3a0b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pymc3/tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_model():
sigma = pm.HalfNormal("sigma", 1)
mu = pm.BART("mu", X, Y, m=50)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample()
idata = pm.sample(chains=4)
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")

np.testing.assert_allclose(mean, Y, 0.5)
Expand All @@ -43,7 +43,7 @@ def test_model():
mu_ = pm.BART("mu_", X, Y, m=50)
mu = pm.Deterministic("mu", pm.math.invlogit(mu_))
y = pm.Bernoulli("y", mu, observed=Y)
idata = pm.sample()
idata = pm.sample(chains=4)
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")

np.testing.assert_allclose(mean, Y, atol=0.5)
Expand Down

0 comments on commit c1c3a0b

Please sign in to comment.