From d30aa9d121a9c6dcf16208a934ce6cde644a7a31 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 2 Aug 2017 01:18:07 +0200 Subject: [PATCH] Fix bug in nuts stats --- pymc3/step_methods/hmc/nuts.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index 93495f3646..5fd99bded2 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -197,12 +197,12 @@ def astep(self, q0): for _ in range(max_treedepth): direction = logbern(np.log(0.5)) * 2 - 1 - diverging, turning = tree.extend(direction) + diverging_info, turning = tree.extend(direction) q, q_grad = tree.proposal.q, tree.proposal.q_grad - if diverging or turning: - if diverging: - self.report._add_divergence(self.tune, *diverging) + if diverging_info or turning: + if diverging_info: + self.report._add_divergence(self.tune, *diverging_info) break w = 1. / (self.m + self.t0) @@ -223,7 +223,7 @@ def astep(self, q0): 'step_size': step_size, 'tune': self.tune, 'step_size_bar': np.exp(self.log_step_size_bar), - 'diverging': diverging, + 'diverging': bool(diverging_info), } stats.update(tree.stats())