diff --git a/pymc3/step_methods/metropolis.py b/pymc3/step_methods/metropolis.py index 118412437ef..0dd37bf3d67 100644 --- a/pymc3/step_methods/metropolis.py +++ b/pymc3/step_methods/metropolis.py @@ -146,12 +146,25 @@ def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1., self.any_discrete = self.discrete.any() self.all_discrete = self.discrete.all() + # remember initial settings before tuning so they can be reset + self._untuned_settings = dict( + scaling=self.scaling, + steps_until_tune=tune_interval, + accepted=self.accepted + ) + self.mode = mode shared = pm.make_shared_replacements(vars, model) self.delta_logp = delta_logp(model.logpt, vars, shared) super().__init__(vars, shared) + def reset_tuning(self): + """Resets the tuned sampler parameters to their initial values.""" + for attr, initial_value in self._untuned_settings.items(): + setattr(self, attr, initial_value) + return + def astep(self, q0): if not self.steps_until_tune and self.tune: # Tune scaling parameter