Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Aug 24, 2021
1 parent 14d2128 commit d1982dc
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 19 deletions.
69 changes: 69 additions & 0 deletions pymc3/tests/test_bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np

import pymc3 as pm


def test_split_node():
split_node = pm.distributions.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
assert split_node.index == 5
assert split_node.idx_split_variable == 2
assert split_node.split_value == 3.0
assert split_node.depth == 2
assert split_node.get_idx_parent_node() == 2
assert split_node.get_idx_left_child() == 11
assert split_node.get_idx_right_child() == 12


def test_leaf_node():
leaf_node = pm.distributions.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3])
assert leaf_node.index == 5
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
assert leaf_node.value == 3.14
assert leaf_node.get_idx_parent_node() == 2
assert leaf_node.get_idx_left_child() == 11
assert leaf_node.get_idx_right_child() == 12


def test_model():
X = np.linspace(7, 15, 100)
Y = np.sin(np.random.normal(X, 0.2)) + 3
X = X[:, None]

with pm.Model() as model:
sigma = pm.HalfNormal("sigma", 1)
mu = pm.BART("mu", X, Y, m=50)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample()
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")

np.testing.assert_allclose(mean, Y, 0.5)

Y = np.repeat([0, 1], 50)
with pm.Model() as model:
mu_ = pm.BART("mu_", X, Y, m=50)
mu = pm.Deterministic("mu", pm.math.invlogit(mu_))
y = pm.Bernoulli("y", mu, observed=Y)
idata = pm.sample()
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")

np.testing.assert_allclose(mean, Y, atol=0.5)


def test_bart_vi():
X = np.random.normal(0, 1, size=(3, 250)).T
Y = np.random.normal(0, 1, size=250)
X[:, 0] = np.random.normal(Y, 0.1)

with pm.Model() as model:
mu = pm.BART("mu", X, Y, m=10)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(random_seed=3415, chains=1)
var_imp = (
idata.sample_stats["variable_inclusion"]
.stack(samples=("chain", "draw"))
.mean("samples")
)
var_imp /= var_imp.sum()
assert var_imp[0] > var_imp[1:].sum()
np.testing.assert_almost_equal(var_imp.sum(), 1)
19 changes: 0 additions & 19 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,25 +174,6 @@ def test_trace_report(self, step_cls, discard):
assert trace.report.n_draws == 100
assert isinstance(trace.report.t_sampling, float)

def test_bart_vi(self):
X = np.random.normal(0, 1, size=(3, 250)).T
Y = np.random.normal(0, 1, size=250)
X[:, 0] = np.random.normal(Y, 0.1)

with pm.Model() as model:
mu = pm.BART("mu", X, Y, m=10)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(random_seed=3415, chains=1)
var_imp = (
idata.sample_stats["variable_inclusion"]
.stack(samples=("chain", "draw"))
.mean("samples")
)
var_imp /= var_imp.sum()
assert var_imp[0] > var_imp[1:].sum()
npt.assert_almost_equal(var_imp.sum(), 1)

def test_return_inferencedata(self):
with self.model:
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())
Expand Down

0 comments on commit d1982dc

Please sign in to comment.