diff --git a/pymc3/stats.py b/pymc3/stats.py index 2944c7332b..58d3287d5e 100644 --- a/pymc3/stats.py +++ b/pymc3/stats.py @@ -116,12 +116,13 @@ def dic(trace, model=None): `float` representing the deviance information criterion of the model and trace """ model = modelcontext(model) + logp = model.logp - mean_deviance = -2 * np.mean([model.logp(pt) for pt in trace]) + mean_deviance = -2 * np.mean([logp(pt) for pt in trace]) free_rv_means = {rv.name: trace[rv.name].mean( axis=0) for rv in model.free_RVs} - deviance_at_mean = -2 * model.logp(free_rv_means) + deviance_at_mean = -2 * logp(free_rv_means) return 2 * mean_deviance - deviance_at_mean @@ -328,12 +329,13 @@ def bpic(trace, model=None): Optional model. Default None, taken from context. """ model = modelcontext(model) + logp = model.logp - mean_deviance = -2 * np.mean([model.logp(pt) for pt in trace]) + mean_deviance = -2 * np.mean([logp(pt) for pt in trace]) free_rv_means = {rv.name: trace[rv.name].mean( axis=0) for rv in model.free_RVs} - deviance_at_mean = -2 * model.logp(free_rv_means) + deviance_at_mean = -2 * logp(free_rv_means) return 3 * mean_deviance - 2 * deviance_at_mean