From c1c3a0b3574240b5dc5b8ff2f865b29bd1741342 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 19 Aug 2021 16:53:45 +0300 Subject: [PATCH] fix number of chains --- pymc3/tests/test_bart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/test_bart.py b/pymc3/tests/test_bart.py index 17d0e78f25e..cf697fa2b70 100644 --- a/pymc3/tests/test_bart.py +++ b/pymc3/tests/test_bart.py @@ -33,7 +33,7 @@ def test_model(): sigma = pm.HalfNormal("sigma", 1) mu = pm.BART("mu", X, Y, m=50) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample() + idata = pm.sample(chains=4) mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples") np.testing.assert_allclose(mean, Y, 0.5) @@ -43,7 +43,7 @@ def test_model(): mu_ = pm.BART("mu_", X, Y, m=50) mu = pm.Deterministic("mu", pm.math.invlogit(mu_)) y = pm.Bernoulli("y", mu, observed=Y) - idata = pm.sample() + idata = pm.sample(chains=4) mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples") np.testing.assert_allclose(mean, Y, atol=0.5)