Skip to content

Commit

Permalink
Convert base_scaling stats to one stat of type object, modify tests f…
Browse files Browse the repository at this point in the history
…or stats and also test for declining acceptance rate
  • Loading branch information
gmingas committed Jul 16, 2020
1 parent 401631f commit 9d25cd2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 27 deletions.
21 changes: 9 additions & 12 deletions pymc3/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,8 @@ class MLDA(ArrayStepShared):
stats_dtypes = [{
'accept': np.float64,
'accepted': np.bool,
'tune': np.bool
'tune': np.bool,
'base_scaling': object
}]

def __init__(self, coarse_models, vars=None, base_S=None, base_proposal_dist=None,
Expand Down Expand Up @@ -1102,13 +1103,6 @@ def __init__(self, coarse_models, vars=None, base_S=None, base_proposal_dist=Non
self.tune,
self.subsampling_rates[-1])

# Update stats data types dictionary given vars and base_blocked
if self.base_blocked or len(self.vars) == 1:
self.stats_dtypes[0]['base_scaling'] = np.float64
else:
for name in self.var_names:
self.stats_dtypes[0]['base_scaling_' + name] = np.float64

def astep(self, q0):
"""One MLDA step, given current sample q0"""
# Check if the tuning flag has been changed and if yes,
Expand Down Expand Up @@ -1159,12 +1153,15 @@ def astep(self, q0):
# Capture latest base chain scaling stats from next step method
self.base_scaling_stats = {}
if isinstance(self.next_step_method, CompoundStep):
scaling_list = []
for method in self.next_step_method.methods:
self.base_scaling_stats["base_scaling_" + method.vars[0].name] = method.scaling
elif isinstance(self.next_step_method, Metropolis):
self.base_scaling_stats["base_scaling"] = self.next_step_method.scaling
scaling_list.append(method.scaling)
self.base_scaling_stats = {"base_scaling": np.array(scaling_list)}
elif not isinstance(self.next_step_method, MLDA):
# next method is any block sampler
self.base_scaling_stats = {"base_scaling": np.array(self.next_step_method.scaling)}
else:
# next method is MLDA
# next method is MLDA - propagate dict from lower levels
self.base_scaling_stats = self.next_step_method.base_scaling_stats
stats = {**stats, **self.base_scaling_stats}

Expand Down
32 changes: 17 additions & 15 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,10 +1126,10 @@ def test_acceptance_rate_against_coarseness(self):
Normal("x", 5.0, 1.0)

with Model() as coarse_model_1:
Normal("x", 5.5, 1.5)
Normal("x", 6.0, 2.0)

with Model() as coarse_model_2:
Normal("x", 6.0, 2.0)
Normal("x", 20.0, 5.0)

possible_coarse_models = [coarse_model_0,
coarse_model_1,
Expand All @@ -1139,9 +1139,9 @@ def test_acceptance_rate_against_coarseness(self):
with Model():
Normal("x", 5.0, 1.0)
for coarse_model in possible_coarse_models:
step = MLDA(coarse_models=[coarse_model], subsampling_rates=1,
tune=False)
trace = sample(chains=1, draws=500, tune=0, step=step)
step = MLDA(coarse_models=[coarse_model], subsampling_rates=3,
tune=True)
trace = sample(chains=1, draws=500, tune=100, step=step)
acc.append(trace.get_sampler_stats('accepted').mean())
assert acc[0] > acc[1] > acc[2], "Acceptance rate is not " \
"strictly increasing when" \
Expand Down Expand Up @@ -1197,10 +1197,10 @@ def test_tuning_and_scaling_on(self):
assert trace.get_sampler_stats('tune', chains=0)[ts - 1]
assert not trace.get_sampler_stats('tune', chains=0)[ts]
assert not trace.get_sampler_stats('tune', chains=0)[-1]
assert trace.get_sampler_stats('base_scaling_x', chains=0)[0] == 100.
assert trace.get_sampler_stats('base_scaling_y_logodds__', chains=0)[0] == 100.
assert trace.get_sampler_stats('base_scaling_x', chains=0)[-1] < 100.
assert trace.get_sampler_stats('base_scaling_y_logodds__', chains=0)[-1] < 100.
assert trace.get_sampler_stats('base_scaling', chains=0)[0][0] == 100.
assert trace.get_sampler_stats('base_scaling', chains=0)[0][1] == 100.
assert trace.get_sampler_stats('base_scaling', chains=0)[-1][0] < 100.
assert trace.get_sampler_stats('base_scaling', chains=0)[-1][1] < 100.

def test_tuning_and_scaling_off(self):
"""Test that tuning is deactivated when sample()'s tune=0 and that
Expand Down Expand Up @@ -1239,17 +1239,19 @@ def test_tuning_and_scaling_off(self):

assert not trace_0.get_sampler_stats('tune', chains=0)[0]
assert not trace_0.get_sampler_stats('tune', chains=0)[-1]
assert trace_0.get_sampler_stats('base_scaling_x', chains=0)[0] == \
trace_0.get_sampler_stats('base_scaling_x', chains=0)[-1] == 100.
assert trace_0.get_sampler_stats('base_scaling', chains=0)[0][0] == \
trace_0.get_sampler_stats('base_scaling', chains=0)[-1][0] == \
trace_0.get_sampler_stats('base_scaling', chains=0)[0][1] == \
trace_0.get_sampler_stats('base_scaling', chains=0)[-1][1] == 100.

assert trace_1.get_sampler_stats('tune', chains=0)[0]
assert trace_1.get_sampler_stats('tune', chains=0)[ts_1 - 1]
assert not trace_1.get_sampler_stats('tune', chains=0)[ts_1]
assert not trace_1.get_sampler_stats('tune', chains=0)[-1]
assert trace_1.get_sampler_stats('base_scaling_x', chains=0)[0] == 100.
assert trace_1.get_sampler_stats('base_scaling_y_logodds__', chains=0)[0] == 100.
assert trace_1.get_sampler_stats('base_scaling_x', chains=0)[-1] < 100.
assert trace_1.get_sampler_stats('base_scaling_y_logodds__', chains=0)[-1] < 100.
assert trace_1.get_sampler_stats('base_scaling', chains=0)[0][0] == 100.
assert trace_1.get_sampler_stats('base_scaling', chains=0)[0][1] == 100.
assert trace_1.get_sampler_stats('base_scaling', chains=0)[-1][0] < 100.
assert trace_1.get_sampler_stats('base_scaling', chains=0)[-1][1] < 100.

def test_trace_length(self):
"""Check if trace length is as expected."""
Expand Down

0 comments on commit 9d25cd2

Please sign in to comment.