Skip to content

Commit

Permalink
finish test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
peterhurford committed Oct 9, 2022
1 parent 8abc766 commit 5ae8c79
Show file tree
Hide file tree
Showing 7 changed files with 482 additions and 36 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

* **[Breaking change]** `credibility`, which defines the size of the interval (e.g., `credibility=0.8` for an 80% CI), is now a property of the distribution rather than the sampler. That is, you should now call `sample(norm(1, 3, credibility=0.8))` whereas previously it was `sample(norm(1, 3), credibility=0.8)`. This will allow mixing of distributions with different credibile ranges.
* **[Breaking change]** Numbers have been changed from functions to global variables. Use `thousand` or `K` instead of `thousand()` (old/deprecated).
* Fixed a bug with the implementation of `lclip` and `rclip`.
* `sample` now has a nice progress reporter if `verbose=True`.
* The `exponential` distribution now implements `lclip` and `rclip`.
* The `mixture` distribution can infer equal weights if no weights are given.
Expand All @@ -35,6 +36,7 @@

* Now has tests via pytest.
* The random numbers now come from a numpy generator as opposed to the previous deprecated `np.random` methods.
* The `sample` module (containing the `sample` function) has been renamed `samplers`.


## v0.6
Expand Down
2 changes: 1 addition & 1 deletion squigglepy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .distributions import * # noqa ignore=F405
from .numbers import * # noqa ignore=F405
from .sample import * # noqa ignore=F405
from .samplers import * # noqa ignore=F405
from .utils import * # noqa ignore=F405
from .rng import * # noqa ignore=F405
57 changes: 25 additions & 32 deletions squigglepy/sample.py → squigglepy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from .utils import event_occurs, _process_weights_values


def normal_sample(low=None, high=None, mean=None, sd=None, credibility=None):
def _get_rng():
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng


def normal_sample(low=None, high=None, mean=None, sd=None, credibility=0.9):
if mean is None:
if low > high:
raise ValueError('`high value` cannot be lower than `low value`')
Expand All @@ -20,12 +25,11 @@ def normal_sample(low=None, high=None, mean=None, sd=None, credibility=None):
else:
mu = mean
sigma = sd
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.normal(mu, sigma)
return _get_rng().normal(mu, sigma)


def lognormal_sample(low=None, high=None, mean=None, sd=None,
credibility=None):
credibility=0.9):
if (low is not None and low < 0) or (mean is not None and mean < 0):
raise ValueError('lognormal_sample cannot handle negative values')
if mean is None:
Expand All @@ -42,23 +46,21 @@ def lognormal_sample(low=None, high=None, mean=None, sd=None,
else:
mu = mean
sigma = sd
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.lognormal(mu, sigma)
return _get_rng().lognormal(mu, sigma)


def t_sample(low, high, t, credibility=None):
def t_sample(low, high, t, credibility=0.9):
if low > high:
raise ValueError('`high value` cannot be lower than `low value`')
elif low == high:
return low
else:
mu = (high + low) / 2
rangex = (high - low) / 2
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.standard_t(t) * rangex * 0.6/credibility + mu
return _get_rng().standard_t(t) * rangex * 0.6/credibility + mu


def log_t_sample(low, high, t, credibility=None):
def log_t_sample(low, high, t, credibility=0.9):
if low > high:
raise ValueError('`high value` cannot be lower than `low value`')
elif low < 0:
Expand All @@ -70,47 +72,39 @@ def log_t_sample(low, high, t, credibility=None):
log_high = np.log(high)
mu = (log_high + log_low) / 2
rangex = (log_high - log_low) / 2
from .rng import _squigglepy_internal_rng
return np.exp(_squigglepy_internal_rng.standard_t(t) * rangex * 0.6/credibility + mu)
return np.exp(_get_rng().standard_t(t) * rangex * 0.6/credibility + mu)


def binomial_sample(n, p):
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.binomial(n, p)
return _get_rng().binomial(n, p)


def beta_sample(a, b):
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.beta(a, b)
return _get_rng().beta(a, b)


def bernoulli_sample(p):
return int(event_occurs(p))


def triangular_sample(left, mode, right):
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.triangular(left, mode, right)
return _get_rng().triangular(left, mode, right)


def poisson_sample(lam):
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.poisson(lam)
return _get_rng().poisson(lam)


def exponential_sample(scale):
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.exponential(scale)
return _get_rng().exponential(scale)


def gamma_sample(shape, scale):
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.gamma(shape, scale)
return _get_rng().gamma(shape, scale)


def uniform_sample(low, high):
from .rng import _squigglepy_internal_rng
return _squigglepy_internal_rng.uniform(low, high)
return _get_rng().uniform(low, high)


def discrete_sample(items):
Expand All @@ -122,9 +116,8 @@ def discrete_sample(items):
weights = [i[0] for i in items]
values = [const(i[1]) for i in items]
else:
weights = None
values = [const(i) for i in items]
len_ = len(items)
weights = [1 / len_ for i in range(len_)]
else:
raise ValueError('inputs to discrete_sample must be a dict or list')

Expand Down Expand Up @@ -183,7 +176,7 @@ def sample(var, n=1, lclip=None, rclip=None, verbose=False):
out = normal_sample(var[0], var[1], credibility=var[3])

elif var[2] == 'norm-mean':
out = normal_sample(mean=var[0], sd=var[1])
out = normal_sample(mean=var[0], sd=var[1], credibility=var[3])

elif var[2] == 'log':
out = lognormal_sample(var[0], var[1], credibility=var[3])
Expand Down Expand Up @@ -239,11 +232,11 @@ def sample(var, n=1, lclip=None, rclip=None, verbose=False):

if lclip is None and lclip_ is not None:
lclip = lclip_
elif rclip is None and rclip_ is not None:
if rclip is None and rclip_ is not None:
rclip = rclip_
elif lclip is not None and lclip_ is not None:
if lclip is not None and lclip_ is not None:
lclip = max(lclip, lclip_)
elif rclip is not None and rclip_ is not None:
if rclip is not None and rclip_ is not None:
rclip = min(rclip, rclip_)

if lclip is not None and out < lclip:
Expand Down
4 changes: 2 additions & 2 deletions squigglepy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ def laplace(s, n=None, time_passed=None,


def roll_die(sides, n=1):
from .sample import sample as samp
from .samplers import sample
from .distributions import discrete
if sides < 2:
raise ValueError('cannot roll less than a 2-sided die.')
elif not isinstance(sides, int):
raise ValueError('can only roll an integer number of sides')
else:
return samp(discrete(list(range(1, sides + 1))), n=n) if sides > 0 else None
return sample(discrete(list(range(1, sides + 1))), n=n) if sides > 0 else None


def flip_coin(n=1):
Expand Down
153 changes: 153 additions & 0 deletions tests/test_bayes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import pytest

from ..squigglepy.bayes import simple_bayes, bayesnet, update, average
from ..squigglepy.samplers import sample
from ..squigglepy.distributions import discrete, norm, beta, mixture
from ..squigglepy.rng import set_seed


def test_simple_bayes():
out = simple_bayes(prior=0.01,
likelihood_h=0.8,
likelihood_not_h=0.096)
assert round(out, 2) == 0.08


def test_bayesnet():
set_seed(42)
out = bayesnet(lambda: {'a': 1, 'b': 2},
find=lambda e: e['a'],
conditional_on=lambda e: e['b'],
n=100)
assert out == 1


def test_bayesnet_conditional():
def define_event():
a = sample(discrete([1, 2]))
b = 1 if a == 1 else 2
return {'a': a, 'b': b}

set_seed(42)
out = bayesnet(define_event,
find=lambda e: e['a'] == 1,
n=100)
assert round(out, 1) == 0.5

out = bayesnet(define_event,
find=lambda e: e['a'] == 1,
conditional_on=lambda e: e['b'] == 1,
n=100)
assert round(out, 1) == 1

out = bayesnet(define_event,
find=lambda e: e['a'] == 2,
conditional_on=lambda e: e['b'] == 1,
n=100)
assert round(out, 1) == 0

out = bayesnet(define_event,
find=lambda e: e['a'] == 1,
conditional_on=lambda e: e['b'] == 2,
n=100)
assert round(out, 1) == 0


def test_bayesnet_reduce_fn():
out = bayesnet(lambda: {'a': 1, 'b': 2},
find=lambda e: e['a'],
reduce_fn=sum,
n=100)
assert out == 100


def test_bayesnet_raw():
out = bayesnet(lambda: {'a': 1, 'b': 2},
find=lambda e: e['a'],
raw=True,
n=100)
assert out == [1] * 100


def test_bayesnet_cache():
from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches
n_caches = len(_squigglepy_internal_bayesnet_caches)

def define_event():
return {'a': 1, 'b': 2}
bayesnet(define_event,
find=lambda e: e['a'],
n=100)
from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches
n_caches2 = len(_squigglepy_internal_bayesnet_caches)
assert n_caches < n_caches2

bayesnet(define_event,
find=lambda e: e['a'],
n=100)
from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches
n_caches3 = len(_squigglepy_internal_bayesnet_caches)
assert n_caches2 == n_caches3

bayesnet(define_event,
find=lambda e: e['b'],
n=100)
from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches
n_caches4 = len(_squigglepy_internal_bayesnet_caches)
assert n_caches2 == n_caches4
assert _squigglepy_internal_bayesnet_caches.get(define_event)['metadata']['n'] == 100


def test_bayesnet_cache_n_error():
def define_event():
return {'a': 1, 'b': 2}
bayesnet(define_event,
find=lambda e: e['a'],
n=100)
with pytest.raises(ValueError) as excinfo:
bayesnet(define_event,
find=lambda e: e['a'],
n=1000)
assert '100 results cached but requested 1000' in str(excinfo.value)


def test_bayesnet_insufficent_samples_error():
with pytest.raises(ValueError) as excinfo:
bayesnet(lambda: {'a': 1, 'b': 2},
find=lambda e: e['a'],
conditional_on=lambda e: e['b'] == 3,
n=100)
assert 'insufficient samples' in str(excinfo.value)


def test_update_normal():
out = update(list(range(10)), list(range(5, 15)))
out[1] = round(out[1], 2)
expected = [7.0, 2.03, 'norm-mean', None, None]
assert out == expected


def test_update_normal_evidence_weight():
out = update(list(range(10)), list(range(5, 15)), evidence_weight=3)
out[1] = round(out[1], 2)
# TODO: This seems wrong?
expected = [16.5, 1.44, 'norm-mean', None, None]
assert out == expected


def test_update_beta():
out = update(beta(1, 1), beta(2, 2), type='beta')
expected = beta(3, 3)
assert out == expected


def test_update_not_implemented():
with pytest.raises(ValueError) as excinfo:
update(1, 2, type='error')
assert 'type `error` not supported' in str(excinfo.value)


def test_average():
out = average(norm(1, 2), norm(3, 4))
expected = mixture([norm(1, 2), norm(3, 4)], [0.5, 0.5])
assert out == expected
2 changes: 1 addition & 1 deletion tests/test_rng.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..squigglepy.rng import set_seed
from ..squigglepy.sample import sample
from ..squigglepy.samplers import sample
from ..squigglepy.distributions import norm


Expand Down
Loading

0 comments on commit 5ae8c79

Please sign in to comment.