Skip to content

Commit

Permalink
Fix bayes.update
Browse files Browse the repository at this point in the history
  • Loading branch information
peterhurford committed Oct 10, 2022
1 parent 42cd148 commit f577f85
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 109 deletions.
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

* Distributions are now implemented as classes (rather than lists).

#### Bayesian library updates
* **[Breaking change]** `bayes.update` now updates normal distributions from the distribution rather than from samples.
* **[Breaking change]** `bayes.update` no longer takes a `type` parameter but can now infer the type from the passed distribution.
* **[Breaking change]** Corrected a bug in how `bayes.update` implemented `evidence_weight` when updating normal distributions.


## v0.7

Expand Down
22 changes: 12 additions & 10 deletions squigglepy/bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,27 @@ def bayesnet(event_fn, n=1, find=None, conditional_on=None,
return reduce_fn(events)


def update(prior, evidence, evidence_weight=1, type='normal'):
if type == 'normal': # TODO: Infer
prior_mean = np.mean(prior) # TODO: Get from class, not samples
prior_var = np.std(prior) ** 2
evidence_mean = np.mean(evidence)
evidence_var = np.std(evidence) ** 2
def update(prior, evidence, evidence_weight=1):
if prior.type == 'norm' and evidence.type == 'norm':
prior_mean = prior.mean
prior_var = prior.sd ** 2
evidence_mean = evidence.mean
evidence_var = evidence.sd ** 2
return norm(mean=((evidence_var * prior_mean +
evidence_weight * (prior_var * evidence_mean)) /
(evidence_var + prior_var)),
(evidence_weight * prior_var + evidence_var)),
sd=math.sqrt((evidence_var * prior_var) /
(evidence_weight * evidence_var + prior_var)))
elif type == 'beta':
(evidence_weight * prior_var + evidence_var)))
elif prior.type == 'beta' and evidence.type == 'beta':
prior_a = prior.a
prior_b = prior.b
evidence_a = evidence.a
evidence_b = evidence.b
return beta(prior_a + evidence_a, prior_b + evidence_b)
elif prior.type != evidence.type:
raise ValueError('can only update distributions of the same type.')
else:
raise ValueError('type `{}` not supported.'.format(type))
raise ValueError('type `{}` not supported.'.format(prior.type))


def average(prior, evidence, weights=[0.5, 0.5]):
Expand Down
23 changes: 23 additions & 0 deletions squigglepy/distributions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np
from scipy import stats

from .utils import _process_weights_values


Expand Down Expand Up @@ -65,13 +68,23 @@ def __init__(self, x=None, y=None, mean=None, sd=None,
self.lclip = lclip
self.rclip = rclip
self.type = 'norm'

if self.x is not None and self.y is not None and self.x > self.y:
raise ValueError('`high value` cannot be lower than `low value`')

if (self.x is None or self.y is None) and self.sd is None:
raise ValueError('must define either x/y or mean/sd')
elif (self.x is not None or self.y is not None) and self.sd is not None:
raise ValueError('must define either x/y or mean/sd -- cannot define both')
elif self.sd is not None and self.mean is None:
self.mean = 0

if self.mean is None and self.sd is None:
self.mean = (self.x + self.y) / 2
cdf_value = 0.5 + 0.5 * self.credibility
normed_sigma = stats.norm.ppf(cdf_value)
self.sd = (self.y - self.mean) / normed_sigma


def norm(x=None, y=None, credibility=0.9, mean=None, sd=None,
lclip=None, rclip=None):
Expand All @@ -91,13 +104,23 @@ def __init__(self, x=None, y=None, mean=None, sd=None,
self.lclip = lclip
self.rclip = rclip
self.type = 'lognorm'

if self.x is not None and self.y is not None and self.x > self.y:
raise ValueError('`high value` cannot be lower than `low value`')

if (self.x is None or self.y is None) and self.sd is None:
raise ValueError('must define either x/y or mean/sd')
elif (self.x is not None or self.y is not None) and self.sd is not None:
raise ValueError('must define either x/y or mean/sd -- cannot define both')
elif self.sd is not None and self.mean is None:
self.mean = 0

if self.mean is None and self.sd is None:
self.mean = (np.log(self.x) + np.log(self.y)) / 2
cdf_value = 0.5 + 0.5 * self.credibility
normed_sigma = stats.norm.ppf(cdf_value)
self.sd = (np.log(self.y) - self.mean) / normed_sigma


def lognorm(x=None, y=None, credibility=0.9, mean=None, sd=None,
lclip=None, rclip=None):
Expand Down
52 changes: 8 additions & 44 deletions squigglepy/samplers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

from scipy import stats
from tqdm import tqdm

from .distributions import const, BaseDistribution
Expand All @@ -12,41 +11,12 @@ def _get_rng():
return _squigglepy_internal_rng


def normal_sample(low=None, high=None, mean=None, sd=None, credibility=0.9):
if mean is None:
if low > high:
raise ValueError('`high value` cannot be lower than `low value`')
elif low == high:
return low
mu = (high + low) / 2
cdf_value = 0.5 + 0.5 * credibility
normed_sigma = stats.norm.ppf(cdf_value)
sigma = (high - mu) / normed_sigma
else:
mu = mean
sigma = sd
return _get_rng().normal(mu, sigma)


def lognormal_sample(low=None, high=None, mean=None, sd=None,
credibility=0.9):
if (low is not None and low < 0) or (mean is not None and mean < 0):
raise ValueError('lognormal_sample cannot handle negative values')
if mean is None:
if low > high:
raise ValueError('`high value` cannot be lower than `low value`')
elif low == high:
return low
log_low = np.log(low)
log_high = np.log(high)
mu = (log_high + log_low) / 2
cdf_value = 0.5 + 0.5 * credibility
normed_sigma = stats.norm.ppf(cdf_value)
sigma = (log_high - mu) / normed_sigma
else:
mu = mean
sigma = sd
return _get_rng().lognormal(mu, sigma)
def normal_sample(mean, sd):
return _get_rng().normal(mean, sd)


def lognormal_sample(mean, sd):
return _get_rng().lognormal(mean, sd)


def t_sample(low, high, t, credibility=0.9):
Expand Down Expand Up @@ -173,16 +143,10 @@ def sample(var, n=1, lclip=None, rclip=None, verbose=False):
out = discrete_sample(var.items)

elif var.type == 'norm':
if var.x is not None and var.y is not None:
out = normal_sample(var.x, var.y, credibility=var.credibility)
else:
out = normal_sample(mean=var.mean, sd=var.sd)
out = normal_sample(mean=var.mean, sd=var.sd)

elif var.type == 'lognorm':
if var.x is not None and var.y is not None:
out = lognormal_sample(var.x, var.y, credibility=var.credibility)
else:
out = lognormal_sample(mean=var.mean, sd=var.sd)
out = lognormal_sample(mean=var.mean, sd=var.sd)

elif var.type == 'binomial':
out = binomial_sample(n=var.n, p=var.p)
Expand Down
14 changes: 5 additions & 9 deletions tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,9 @@ def define_event():

print('Test 6...')
prior = sq.norm(1, 5)
prior_samples = sq.sample(prior, n=K)
evidence = sq.norm(2, 3)
evidence_samples = sq.sample(evidence, n=K)
posterior = bayes.update(prior_samples, evidence_samples)
posterior_samples = sq.sample(posterior, n=K)
out = (np.mean(posterior_samples), np.std(posterior_samples))
if round(out[0], 2) != 2.53 and round(out[1], 2) != 0.3:
posterior = bayes.update(prior, evidence)
if round(posterior.mean, 2) != 2.53 and round(posterior.sd, 2) != 0.3:
print('ERROR 6')
import pdb
pdb.set_trace()
Expand All @@ -182,7 +178,7 @@ def define_event():
average = bayes.average(prior, evidence)
average_samples = sq.sample(average, n=K)
out = (np.mean(average_samples), np.std(average_samples))
if round(out[0], 2) != 2.73 and round(out[1], 2) != 0.97:
if round(out[0], 2) != 2.74 and round(out[1], 2) != 0.93:
print('ERROR 7')
import pdb
pdb.set_trace()
Expand Down Expand Up @@ -229,7 +225,7 @@ def define_event():
n=10*K,
find=lambda e: (e['mary_calls'] and e['john_calls']),
conditional_on=lambda e: e['earthquake'])
if round(out, 2) != 0.1:
if round(out, 2) != 0.29:
print('ERROR 8')
import pdb
pdb.set_trace()
Expand All @@ -240,7 +236,7 @@ def define_event():
n=10*K,
find=lambda e: e['burglary'],
conditional_on=lambda e: (e['mary_calls'] and e['john_calls']))
if round(out, 2) != 0.32:
if round(out, 2) != 0.29:
print('ERROR 9')
import pdb
pdb.set_trace()
Expand Down
25 changes: 15 additions & 10 deletions tests/test_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ..squigglepy.bayes import simple_bayes, bayesnet, update, average
from ..squigglepy.samplers import sample
from ..squigglepy.distributions import discrete, norm, beta
from ..squigglepy.distributions import discrete, norm, beta, gamma
from ..squigglepy.rng import set_seed


Expand Down Expand Up @@ -212,31 +212,36 @@ def test_bayesnet_insufficent_samples_error():


def test_update_normal():
out = update(list(range(10)), list(range(5, 15)))
out = update(norm(1, 10), norm(5, 15))
assert out.type == 'norm'
assert out.mean == 7
assert round(out.mean, 2) == 7.51
assert round(out.sd, 2) == 2.03


def test_update_normal_evidence_weight():
out = update(list(range(10)), list(range(5, 15)), evidence_weight=3)
out = update(norm(1, 10), norm(5, 15), evidence_weight=3)
assert out.type == 'norm'
# TODO: This seems wrong?
assert out.mean == 16.5
assert round(out.sd, 2) == 1.44
assert round(out.mean, 2) == 8.69
assert round(out.sd, 2) == 1.48


def test_update_beta():
out = update(beta(1, 1), beta(2, 2), type='beta')
out = update(beta(1, 1), beta(2, 2))
assert out.type == 'beta'
assert out.a == 3
assert out.b == 3


def test_update_not_implemented():
with pytest.raises(ValueError) as excinfo:
update(1, 2, type='error')
assert 'type `error` not supported' in str(excinfo.value)
update(gamma(1), gamma(2))
assert 'type `gamma` not supported' in str(excinfo.value)


def test_update_not_matching():
with pytest.raises(ValueError) as excinfo:
update(norm(1, 2), beta(1, 2))
assert 'can only update distributions of the same type' in str(excinfo.value)


def test_average():
Expand Down
20 changes: 16 additions & 4 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_norm():
assert norm(1, 2).type == 'norm'
assert norm(1, 2).x == 1
assert norm(1, 2).y == 2
assert norm(1, 2).mean is None
assert norm(1, 2).sd is None
assert norm(1, 2).mean == 1.5
assert round(norm(1, 2).sd, 2) == 0.3
assert norm(1, 2).credibility == 0.9
assert norm(1, 2).lclip is None
assert norm(1, 2).rclip is None
Expand Down Expand Up @@ -79,6 +79,12 @@ def test_norm_overdefinition_value_error():
assert 'cannot define both' in str(execinfo.value)


def test_norm_low_gt_high():
with pytest.raises(ValueError) as execinfo:
norm(10, 5)
assert '`high value` cannot be lower than `low value`' in str(execinfo.value)


def test_norm_passes_lclip_rclip():
obj = norm(1, 2, lclip=0, rclip=3)
assert obj.type == 'norm'
Expand All @@ -104,8 +110,8 @@ def test_lognorm():
assert lognorm(1, 2).type == 'lognorm'
assert lognorm(1, 2).x == 1
assert lognorm(1, 2).y == 2
assert lognorm(1, 2).mean is None
assert lognorm(1, 2).sd is None
assert round(lognorm(1, 2).mean, 2) == 0.35
assert round(lognorm(1, 2).sd, 2) == 0.21
assert lognorm(1, 2).credibility == 0.9
assert lognorm(1, 2).lclip is None
assert lognorm(1, 2).rclip is None
Expand Down Expand Up @@ -146,6 +152,12 @@ def test_lognorm_overdefinition_value_error():
assert 'cannot define both' in str(execinfo.value)


def test_lognorm_low_gt_high():
with pytest.raises(ValueError) as execinfo:
lognorm(10, 5)
assert '`high value` cannot be lower than `low value`' in str(execinfo.value)


def test_lognorm_passes_lclip_rclip():
obj = lognorm(1, 2, lclip=0, rclip=3)
assert obj.type == 'lognorm'
Expand Down
42 changes: 10 additions & 32 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,17 @@ def standard_t(self, t):

@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
def test_norm(mocker):
assert normal_sample(1, 2) == (1.5, 0.3)
assert normal_sample(1, 2) == (1, 2)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
def test_norm_with_mean_sd(mocker):
assert normal_sample(mean=1, sd=2) == (1, 2)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
def test_norm_with_credibility(mocker):
assert normal_sample(1, 2, credibility=0.7) == (1.5, 0.48)


def test_norm_low_gt_high():
with pytest.raises(ValueError) as execinfo:
normal_sample(10, 5)
assert '`high value` cannot be lower than `low value`' in str(execinfo.value)
def test_sample_norm(mocker):
assert sample(norm(1, 2)) == (1.5, 0.3)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
def test_sample_norm(mocker):
assert sample(norm(1, 2)) == (1.5, 0.3)
def test_sample_norm_with_credibility(mocker):
assert sample(norm(1, 2, credibility=0.7)) == (1.5, 0.48)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
Expand All @@ -86,28 +75,17 @@ def test_sample_norm_passes_lclip_rclip():

@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
def test_lognorm(mocker):
assert lognormal_sample(1, 2) == (0.35, 0.21)
assert lognormal_sample(1, 2) == (1, 2)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
def test_lognorm_with_mean_sd(mocker):
assert lognormal_sample(mean=1, sd=2) == (1, 2)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
def test_lognorm_with_credibility(mocker):
assert lognormal_sample(1, 2, credibility=0.7) == (0.35, 0.33)


def test_lognorm_low_gt_high():
with pytest.raises(ValueError) as execinfo:
lognormal_sample(10, 5)
assert '`high value` cannot be lower than `low value`' in str(execinfo.value)
def test_sample_lognorm(mocker):
assert sample(lognorm(1, 2)) == (0.35, 0.21)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
def test_sample_lognorm(mocker):
assert sample(lognorm(1, 2)) == (0.35, 0.21)
def test_sample_lognorm_with_credibility(mocker):
assert sample(lognorm(1, 2, credibility=0.7)) == (0.35, 0.33)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
Expand Down

0 comments on commit f577f85

Please sign in to comment.