diff --git a/pymc3/backends/report.py b/pymc3/backends/report.py index 4384b85cbf..42f6b8a976 100644 --- a/pymc3/backends/report.py +++ b/pymc3/backends/report.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple import logging import enum -import typing +from typing import Any, Optional +import dataclasses + from ..util import is_transformed_name, get_untransformed_name import arviz @@ -38,9 +39,17 @@ class WarningType(enum.Enum): BAD_ENERGY = 8 -SamplerWarning = namedtuple( - 'SamplerWarning', - "kind, message, level, step, exec_info, extra") +@dataclasses.dataclass +class SamplerWarning: + kind: WarningType + message: str + level: str + step: Optional[int] = None + exec_info: Optional[Any] = None + extra: Optional[Any] = None + divergence_point_source: Optional[dict] = None + divergence_point_dest: Optional[dict] = None + divergence_info: Optional[Any] = None _LEVELS = { @@ -53,7 +62,8 @@ class WarningType(enum.Enum): class SamplerReport: - """This object bundles warnings, convergence statistics and metadata of a sampling run.""" + """Bundle warnings, convergence stats and metadata of a sampling run.""" + def __init__(self): self._chain_warnings = {} self._global_warnings = [] @@ -75,17 +85,17 @@ def ok(self): for warn in self._warnings) @property - def n_tune(self) -> typing.Optional[int]: + def n_tune(self) -> Optional[int]: """Number of tune iterations - not necessarily kept in trace!""" return self._n_tune @property - def n_draws(self) -> typing.Optional[int]: + def n_draws(self) -> Optional[int]: """Number of draw iterations.""" return self._n_draws @property - def t_sampling(self) -> typing.Optional[float]: + def t_sampling(self) -> Optional[float]: """ Number of seconds that the sampling procedure took. @@ -110,8 +120,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model): if idata.posterior.sizes['chain'] == 1: msg = ("Only one chain was sampled, this makes it impossible to " "run some convergence checks") - warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', - None, None, None) + warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info') self._add_warnings([warn]) return @@ -134,41 +143,42 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model): msg = ("The rhat statistic is larger than 1.4 for some " "parameters. The sampler did not converge.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'error', None, None, rhat) + WarningType.CONVERGENCE, msg, 'error', extra=rhat) warnings.append(warn) elif rhat_max > 1.2: msg = ("The rhat statistic is larger than 1.2 for some " "parameters.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'warn', None, None, rhat) + WarningType.CONVERGENCE, msg, 'warn', extra=rhat) warnings.append(warn) elif rhat_max > 1.05: msg = ("The rhat statistic is larger than 1.05 for some " "parameters. This indicates slight problems during " "sampling.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'info', None, None, rhat) + WarningType.CONVERGENCE, msg, 'info', extra=rhat) warnings.append(warn) eff_min = min(val.min() for val in ess.values()) - n_samples = idata.posterior.sizes['chain'] * idata.posterior.sizes['draw'] + sizes = idata.posterior.sizes + n_samples = sizes['chain'] * sizes['draw'] if eff_min < 200 and n_samples >= 500: msg = ("The estimated number of effective samples is smaller than " "200 for some parameters.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'error', None, None, ess) + WarningType.CONVERGENCE, msg, 'error', extra=ess) warnings.append(warn) elif eff_min / n_samples < 0.1: msg = ("The number of effective samples is smaller than " "10% for some parameters.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'warn', None, None, ess) + WarningType.CONVERGENCE, msg, 'warn', extra=ess) warnings.append(warn) elif eff_min / n_samples < 0.25: msg = ("The number of effective samples is smaller than " "25% for some parameters.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'info', None, None, ess) + WarningType.CONVERGENCE, msg, 'info', extra=ess) warnings.append(warn) self._add_warnings(warnings) @@ -201,7 +211,7 @@ def filter_warns(warnings): filtered.append(warn) elif (start <= warn.step < stop and (warn.step - start) % step == 0): - warn = warn._replace(step=warn.step - start) + warn = dataclasses.replace(warn, step=warn.step - start) filtered.append(warn) return filtered diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index 7bbc722dee..8551e45962 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -29,10 +29,16 @@ logger = logging.getLogger("pymc3") -HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats") +HMCStepData = namedtuple( + "HMCStepData", + "end, accept_stat, divergence_info, stats" +) +DivergenceInfo = namedtuple( + "DivergenceInfo", + "message, exec_info, state, state_div" +) -DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state") class BaseHMC(arraystep.GradientSharedStep): """Superclass to implement Hamiltonian/hybrid monte carlo.""" @@ -148,15 +154,14 @@ def astep(self, q0): self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap) message_energy = ( "Bad initial energy, check any log probabilities that " - "are inf or -inf, nan or very small:\n{}".format(error_logp.to_string()) + "are inf or -inf, nan or very small:\n{}" + .format(error_logp.to_string()) ) warning = SamplerWarning( WarningType.BAD_ENERGY, message_energy, "critical", self.iter_count, - None, - None, ) self._warnings.append(warning) raise SamplingError("Bad initial energy") @@ -177,19 +182,32 @@ def astep(self, q0): self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune) if hmc_step.divergence_info: info = hmc_step.divergence_info + point = None + point_dest = None + info_store = None if self.tune: kind = WarningType.TUNING_DIVERGENCE - point = None else: kind = WarningType.DIVERGENCE self._num_divs_sample += 1 # We don't want to fill up all memory with divergence info - if self._num_divs_sample < 100: + if self._num_divs_sample < 100 and info.state is not None: point = self._logp_dlogp_func.array_to_dict(info.state.q) - else: - point = None + if self._num_divs_sample < 100 and info.state_div is not None: + point_dest = self._logp_dlogp_func.array_to_dict( + info.state_div.q + ) + if self._num_divs_sample < 100: + info_store = info warning = SamplerWarning( - kind, info.message, "debug", self.iter_count, info.exec_info, point + kind, + info.message, + "debug", + self.iter_count, + info.exec_info, + divergence_point_source=point, + divergence_point_dest=point_dest, + divergence_info=info_store, ) self._warnings.append(warning) @@ -243,9 +261,7 @@ def warnings(self): ) if message: - warning = SamplerWarning( - WarningType.DIVERGENCES, message, "error", None, None, None - ) + warning = SamplerWarning(WarningType.DIVERGENCES, message, "error") warnings.append(warning) warnings.extend(self.step_adapt.warnings()) diff --git a/pymc3/step_methods/hmc/hmc.py b/pymc3/step_methods/hmc/hmc.py index 9c7a533461..6b68662fc2 100644 --- a/pymc3/step_methods/hmc/hmc.py +++ b/pymc3/step_methods/hmc/hmc.py @@ -116,23 +116,25 @@ def _hamiltonian_step(self, start, p0, step_size): energy_change = -np.inf state = start + last = state div_info = None try: for _ in range(n_steps): + last = state state = self.integrator.step(step_size, state) except IntegrationError as e: - div_info = DivergenceInfo('Divergence encountered.', e, state) + div_info = DivergenceInfo('Integration failed.', e, last, None) else: if not np.isfinite(state.energy): div_info = DivergenceInfo( - 'Divergence encountered, bad energy.', None, state) + 'Divergence encountered, bad energy.', None, last, state) energy_change = start.energy - state.energy if np.isnan(energy_change): energy_change = -np.inf if np.abs(energy_change) > self.Emax: div_info = DivergenceInfo( 'Divergence encountered, large integration error.', - None, state) + None, last, state) accept_stat = min(1, np.exp(energy_change)) diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index 30b0705cc8..65079d582d 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -210,7 +210,7 @@ def warnings(self): "The chain reached the maximum tree depth. Increase " "max_treedepth, increase target_accept or reparameterize." ) - warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn", None, None, None) + warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn') warnings.append(warn) return warnings @@ -331,6 +331,7 @@ def _single_step(self, left, epsilon): except IntegrationError as err: error_msg = str(err) error = err + right = None else: # h - H0 energy_change = right.energy - self.start_energy @@ -363,7 +364,7 @@ def _single_step(self, left, epsilon): ) error = None tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1) - divergance_info = DivergenceInfo(error_msg, error, left) + divergance_info = DivergenceInfo(error_msg, error, left, right) return tree, divergance_info, False def _build_subtree(self, left, depth, epsilon): diff --git a/pymc3/step_methods/step_sizes.py b/pymc3/step_methods/step_sizes.py index 51262b8239..bdf0683c62 100644 --- a/pymc3/step_methods/step_sizes.py +++ b/pymc3/step_methods/step_sizes.py @@ -77,7 +77,7 @@ def warnings(self): % (mean_accept, target_accept)) info = {'target': target_accept, 'actual': mean_accept} warning = SamplerWarning( - WarningType.BAD_ACCEPTANCE, msg, 'warn', None, None, info) + WarningType.BAD_ACCEPTANCE, msg, 'warn', extra=info) return [warning] else: return [] diff --git a/requirements.txt b/requirements.txt index 4882d65218..c75496b7e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,3 @@ patsy>=0.5.1 fastprogress>=0.2.0 h5py>=2.7.0 typing-extensions>=3.7.4 -contextvars; python_version < '3.7'