Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,18 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
vals = np.stack(
[self._get_sampler_stats(stat_name, i, burn, thin) for i in sampler_idxs], axis=-1
)

if vals.shape[-1] == 1:
return vals[..., 0]
else:
return vals
vals = vals[..., 0]

if vals.dtype == np.dtype(object):
try:
vals = np.vstack(vals)
except ValueError:
# Most likely due to non-identical shapes. Just stick with the object-array.
pass

return vals

def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
"""Get sampler statistics."""
Expand Down
14 changes: 4 additions & 10 deletions pymc/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,28 +946,22 @@ def extract_Q_estimate(trace, levels):
MLDA with variance reduction has been used for sampling.
"""

Q_0_raw = trace.get_sampler_stats("Q_0")
# total number of base level samples from all iterations
total_base_level_samples = sum(it.shape[0] for it in Q_0_raw)
Q_0 = np.concatenate(Q_0_raw).reshape((1, total_base_level_samples))
Q_0_raw = trace.get_sampler_stats("Q_0").squeeze()
Q_0 = np.concatenate(Q_0_raw)[None, ::]
ess_Q_0 = az.ess(np.array(Q_0, np.float64))
Q_0_var = Q_0.var() / ess_Q_0

Q_diff_means = []
Q_diff_vars = []
for l in range(1, levels):
Q_diff_raw = trace.get_sampler_stats(f"Q_{l}_{l-1}")
# total number of samples from all iterations
total_level_samples = sum(it.shape[0] for it in Q_diff_raw)
Q_diff = np.concatenate(Q_diff_raw).reshape((1, total_level_samples))
Q_diff_raw = trace.get_sampler_stats(f"Q_{l}_{l-1}").squeeze()
Q_diff = np.hstack(Q_diff_raw)[None, ::]
ess_diff = az.ess(np.array(Q_diff, np.float64))

Q_diff_means.append(Q_diff.mean())
Q_diff_vars.append(Q_diff.var() / ess_diff)

Q_mean = Q_0.mean() + sum(Q_diff_means)
Q_se = np.sqrt(Q_0_var + sum(Q_diff_vars))

return Q_mean, Q_se


Expand Down
8 changes: 4 additions & 4 deletions pymc/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,13 +1750,13 @@ def perform(self, node, inputs, outputs):
Q_mean_vr, Q_se_vr = extract_Q_estimate(trace, 3)

# check that returned values are floats and finite.
assert isinstance(Q_mean_standard, float)
assert isinstance(Q_mean_standard, np.floating)
assert np.isfinite(Q_mean_standard)
assert isinstance(Q_mean_vr, float)
assert isinstance(Q_mean_vr, np.floating)
assert np.isfinite(Q_mean_vr)
assert isinstance(Q_se_standard, float)
assert isinstance(Q_se_standard, np.floating)
assert np.isfinite(Q_se_standard)
assert isinstance(Q_se_vr, float)
assert isinstance(Q_se_vr, np.floating)
assert np.isfinite(Q_se_vr)

# check consistency of QoI across levels.
Expand Down