Skip to content

Commit

Permalink
Merge pull request #1610 from devbhakt/autoconvergence
Browse files Browse the repository at this point in the history
Autocorrelation check for event_optimize
  • Loading branch information
abhisrkckl authored Aug 17, 2023
2 parents c7f97eb + 9380bab commit 7938a93
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG-unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ the released changes.
- Third-order Roemer delay terms to ELL1 model
- Options to add a TZR TOA (`AbsPhase`) during the creation of a `TimingModel` using `ModelBuilder.__call__`, `get_model`, and `get_model_and_toas`
- `pint.print_info()` function for bug reporting
- Added an autocorrelation function to check for chain convergence in `event_optimize`
### Fixed
- Deleting JUMP1 from flag tables will not prevent fitting
- Simulating TOAs from tim file when PLANET_SHAPIRO is true now works
Expand Down
99 changes: 95 additions & 4 deletions src/pint/scripts/event_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,82 @@ def get_fit_keyvals(model, phs=0.0, phserr=0.1):
return fitkeys, np.asarray(fitvals), np.asarray(fiterrs)


def run_sampler_autocorr(sampler, pos, nsteps, burnin, csteps=100, crit1=10):
"""Runs the sampler and checks for chain convergence. Return the converged sampler and the mean autocorrelation time per 100 steps
Parameters
----------
Sampler
The Emcee Ensemble Sampler
pos
The Initial positions of the walkers
nsteps : int
The number of integration steps
csteps : int
The interval at which the autocorrelation time is computed.
crit1 : int
The ratio of chain length to autocorrelation time to satisfy convergence
Returns
-------
The sampler and the mean autocorrelation times
Note
----
The function checks for convergence of the chains every specified number of steps.
The criteria to check for convergence is:
1. the chain has to be longer than the specified ratio times the estimated autocorrelation time
2. the change in the estimated autocorrelation time is less than 1%
"""
autocorr = []
old_tau = np.inf
converged1 = False
converged2 = False
for sample in sampler.sample(pos, iterations=nsteps, progress=True):
if not converged1:
# Checks if the iteration is past the burnin and checks for convergence at 10% tau change
if sampler.iteration >= burnin and sampler.iteration % csteps == 0:
tau = sampler.get_autocorr_time(tol=0, quiet=True)
if np.any(np.isnan(tau)):
continue
else:
x = np.mean(tau)
autocorr.append(x)
converged1 = np.all(tau * crit1 < sampler.iteration)
converged1 &= np.all(np.abs(old_tau - tau) / tau < 0.1)
# log.info("The mean estimated integrated autocorrelation step is: " + str(x))
old_tau = tau
if converged1:
log.info(
"10 % convergence reached with a mean estimated integrated step: "
+ str(x)
)
else:
continue
else:
continue
else:
if not converged2:
# Checks for convergence at every 25 steps instead of 100 and tau change is 1%
if sampler.iteration % int(csteps / 4) == 0:
tau = sampler.get_autocorr_time(tol=0, quiet=True)
if np.any(np.isnan(tau)):
continue
else:
x = np.mean(tau)
autocorr.append(x)
converged2 = np.all(tau * crit1 < sampler.iteration)
converged2 &= np.all(np.abs(old_tau - tau) / tau < 0.01)
# log.info("The mean estimated integrated autocorrelation step is: " + str(x))
old_tau = tau
converge_step = sampler.iteration
else:
continue
if converged2 and (sampler.iteration - burnin) >= 1000:
log.info(f"Convergence reached at {converge_step}")
break
else:
continue
return autocorr


class emcee_fitter(Fitter):
def __init__(
self, toas=None, model=None, template=None, weights=None, phs=0.5, phserr=0.03
Expand Down Expand Up @@ -545,6 +621,13 @@ def main(argv=None):
default=False,
action="store_true",
)
parser.add_argument(
"--no-autocorr",
help="Turn the autocorrelation check function off",
default=False,
action="store_true",
dest="noautocorr",
)

args = parser.parse_args(argv)
pint.logging.setup(
Expand Down Expand Up @@ -820,21 +903,29 @@ def unwrapped_lnpost(theta):
pool=pool,
backend=backend,
)
sampler.run_mcmc(pos, nsteps)
if args.noautocorr:
sampler.run_mcmc(pos, nsteps, progress=True)
else:
autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)
pool.close()
pool.join()
except ImportError:
log.info("Pathos module not available, using single core")
sampler = emcee.EnsembleSampler(
nwalkers, ndim, ftr.lnposterior, blobs_dtype=dtype, backend=backend
)
sampler.run_mcmc(pos, nsteps)
if args.noautocorr:
sampler.run_mcmc(pos, nsteps, progress=True)
else:
autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)
else:
sampler = emcee.EnsembleSampler(
nwalkers, ndim, ftr.lnposterior, blobs_dtype=dtype, backend=backend
)
# The number is the number of points in the chain
sampler.run_mcmc(pos, nsteps)
if args.noautocorr:
sampler.run_mcmc(pos, nsteps, progress=True)
else:
autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)

def chains_to_dict(names, sampler):
samples = np.transpose(sampler.get_chain(), (1, 0, 2))
Expand Down
31 changes: 28 additions & 3 deletions tests/test_event_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_parallel(tmp_path):
for i in range(samples1.shape[1]):
assert stats.ks_2samp(samples1[:, i], samples2[:, i])[1] == 1.0
except ImportError:
pytest.skip
pytest.skip(f"Pathos multiprocessing package not found")
finally:
os.chdir(p)
sys.stdout = saved_stdout
Expand All @@ -104,11 +104,11 @@ def test_backend(tmp_path):
samples = None

# Running with backend
os.chdir(tmp_path)
cmd = f"{eventfile} {parfile} {temfile} --weightcol=PSRJ0030+0451 --minWeight=0.9 --nwalkers=10 --nsteps=50 --burnin=10 --backend --clobber"
event_optimize.maxpost = -9e99
event_optimize.numcalls = 0
event_optimize.main(cmd.split())

reader = emcee.backends.HDFBackend("J0030+0451_chains.h5")
samples = reader.get_chain(discard=10)
assert samples is not None
Expand All @@ -124,7 +124,32 @@ def test_backend(tmp_path):
assert timestamp == os.path.getmtime("J0030+0451_chains.h5")

except ImportError:
pytest.skip
pytest.skip(f"h5py package not found")
finally:
os.chdir(p)
sys.stdout = saved_stdout


def test_autocorr(tmp_path):
# Defining a log posterior function based on the emcee tutorials
def ln_prob(theta):
ln_prior = -0.5 * np.sum((theta - 1.0) ** 2 / 100.0)
ln_prob = -0.5 * np.sum(theta**2) + ln_prior
return ln_prob

# Setting a random starting position for 10 walkers with 5 dimenisions
coords = np.random.randn(10, 5)
nwalkers, ndim = coords.shape

# Establishing the Sampler
nsteps = 500000
sampler = emcee.EnsembleSampler(nwalkers, ndim, ln_prob)

# Running the sampler with the autocorrelation check function from event_optimize
autocorr = event_optimize.run_sampler_autocorr(sampler, coords, nsteps, burnin=10)

# Extracting the samples and asserting that the autocorrelation check
# stopped the sampler once convergence was reached
samples = np.transpose(sampler.get_chain(discard=10), (1, 0, 2)).reshape((-1, ndim))

assert len(samples) < nsteps

0 comments on commit 7938a93

Please sign in to comment.