Skip to content

Commit 4e03b95

Browse files
authored
Merge branch 'master' into gls-chisq
2 parents 2e3b450 + c6b2c82 commit 4e03b95

10 files changed

+178
-37
lines changed

AUTHORS.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Active developers are indicated by (*). Authors of the PINT paper are indicated
2020
* Anne Archibald (#*)
2121
* Matteo Bachetti (#)
2222
* Bastian Beischer
23-
* Deven Bhakta
23+
* Deven Bhakta (*)
2424
* Chloe Champagne (#)
2525
* Jonathan Colen (#)
2626
* Thankful Cromartie

CHANGELOG-unreleased.md

+2
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ the released changes.
1414
- Updated `CONTRIBUTING.rst` with the latest information.
1515
- Moved design matrix normalization code from `pint.fitter` to the new `pint.utils.normalize_designmatrix()` function.
1616
- Made `Residuals` independent of `GLSFitter` (GLS chi2 is now computed using the new function `Residuals._calc_gls_chi2()`).
17+
- Made `TimingModel.params` and `TimingModel.ordered_params` identical. Deprecated `TimingModel.ordered_params`.
1718
### Added
1819
- Third-order Roemer delay terms to ELL1 model
1920
- Options to add a TZR TOA (`AbsPhase`) during the creation of a `TimingModel` using `ModelBuilder.__call__`, `get_model`, and `get_model_and_toas`
2021
- `pint.print_info()` function for bug reporting
22+
- Added an autocorrelation function to check for chain convergence in `event_optimize`
2123
### Fixed
2224
- Deleting JUMP1 from flag tables will not prevent fitting
2325
- Simulating TOAs from tim file when PLANET_SHAPIRO is true now works

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ the unreleased changes. This file should only be changed while tagging a new ver
1414
- Unreleased CHANGELOG entries should now be entered in `CHANGELOG-unreleased.md` instead of `CHANGELOG.md`. Updated documentation accordingly.
1515
- Changed tests to remove `unittest` and use pure pytest format
1616
- Changed deprecated `sampler.chain` usage
17+
- Download data automatically in the profiling script `high_level_benchmark.py` instead of silently giving wrong results.
1718
### Added
1819
- `SpindownBase` as the abstract base class for `Spindown` and `PeriodSpindown` in the `How_to_build_a_timing_model_component.py` example.
1920
- `SolarWindDispersionBase` as the abstract base class for solar wind dispersion components.

profiling/.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
J0740+6620.cfr+19.tim
2+
bench_*_summary

profiling/high_level_benchmark.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
import os
1818
import platform
19+
import urllib.request
1920
from prfparser import parse_file
2021

2122

@@ -67,11 +68,19 @@ def get_results(script, outfile):
6768
parser = argparse.ArgumentParser(
6869
description="High-level summary of python file timing."
6970
)
71+
72+
if not os.path.isfile("J0740+6620.cfr+19.tim"):
73+
print("Downloading data file J0740+6620.cfr+19.tim ...")
74+
urllib.request.urlretrieve(
75+
"https://data.nanograv.org/static/data/J0740+6620.cfr+19.tim",
76+
"J0740+6620.cfr+19.tim",
77+
)
78+
79+
script1 = "bench_load_TOAs.py"
7080
script2 = "bench_chisq_grid.py"
7181
script3 = "bench_chisq_grid_WLSFitter.py"
7282
script4 = "bench_MCMC.py"
7383

74-
script1 = "bench_load_TOAs.py"
7584
# time scripts
7685
output1 = bench_file(script1)
7786
output2 = bench_file(script2)

src/pint/fitter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def get_summary(self, nodmx=False):
363363

364364
# to handle all parameter names, determine the longest length for the first column
365365
longestName = 0 # optionally specify the minimum length here instead of 0
366-
for pn in self.model.params_ordered:
366+
for pn in self.model.params:
367367
if nodmx and pn.startswith("DMX"):
368368
continue
369369
if len(pn) > longestName:
@@ -378,7 +378,7 @@ def get_summary(self, nodmx=False):
378378
s += ("{:<" + spacingName + "s} {:>20s} {:>28s} {}\n").format(
379379
"=" * longestName, "=" * 20, "=" * 28, "=" * 5
380380
)
381-
for pn in self.model.params_ordered:
381+
for pn in self.model.params:
382382
if nodmx and pn.startswith("DMX"):
383383
continue
384384
prefitpar = getattr(self.model_init, pn)

src/pint/models/timing_model.py

+32-26
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,11 @@ class TimingModel:
183183
removed with methods on this object, and for many of them additional
184184
parameters in families (``DMXEP_1234``) can be added.
185185
186-
Parameters in a TimingModel object are listed in the ``model.params`` and
187-
``model.params_ordered`` objects. Each Parameter can be set as free or
188-
frozen using its ``.frozen`` attribute, and a list of the free parameters
189-
is available through the ``model.free_params`` property; this can also
190-
be used to set which parameters are free. Several methods are available
191-
to get and set some or all parameters in the forms of dictionaries.
186+
Parameters in a TimingModel object are listed in the ``model.params`` object.
187+
Each Parameter can be set as free or frozen using its ``.frozen`` attribute,
188+
and a list of the free parameters is available through the ``model.free_params``
189+
property; this can also be used to set which parameters are free. Several methods
190+
are available to get and set some or all parameters in the forms of dictionaries.
192191
193192
TimingModel objects also support a number of functions for computing
194193
various things like orbital phase, and barycentric versions of TOAs,
@@ -500,20 +499,30 @@ def __getattr__(self, name):
500499
)
501500

502501
@property_exists
503-
def params(self):
504-
"""List of all parameter names in this model and all its components (order is arbitrary)."""
505-
# FIXME: any reason not to just use params_ordered here?
506-
p = self.top_level_params
507-
for cp in self.components.values():
508-
p = p + cp.params
509-
return p
502+
def params_ordered(self):
503+
"""List of all parameter names in this model and all its components.
504+
This is the same as `params`."""
505+
506+
# Historically, this was different from `params` because Python
507+
# dictionaries were unordered until Python 3.7. Now there is no reason for
508+
# them to be different.
509+
510+
warn(
511+
"`TimingModel.params_ordered` is now deprecated and may be removed in the future. "
512+
"Use `TimingModel.params` instead. It gives the same output as `TimingModel.params_ordered`.",
513+
DeprecationWarning,
514+
)
515+
516+
return self.params
510517

511518
@property_exists
512-
def params_ordered(self):
519+
def params(self):
513520
"""List of all parameter names in this model and all its components, in a sensible order."""
521+
514522
# Define the order of components in the list
515523
# Any not included will be printed between the first and last set.
516524
# FIXME: make order completely canonical (sort components by name?)
525+
517526
start_order = ["astrometry", "spindown", "dispersion"]
518527
last_order = ["jump_delay"]
519528
compdict = self.get_components_by_category()
@@ -551,15 +560,15 @@ def params_ordered(self):
551560
def free_params(self):
552561
"""List of all the free parameters in the timing model. Can be set to change which are free.
553562
554-
These are ordered as ``self.params_ordered`` does.
563+
These are ordered as ``self.params`` does.
555564
556565
Upon setting, order does not matter, and aliases are accepted.
557566
ValueError is raised if a parameter is not recognized.
558567
559568
On setting, parameter aliases are converted with
560569
:func:`pint.models.timing_model.TimingModel.match_param_aliases`.
561570
"""
562-
return [p for p in self.params_ordered if not getattr(self, p).frozen]
571+
return [p for p in self.params if not getattr(self, p).frozen]
563572

564573
@free_params.setter
565574
def free_params(self, params):
@@ -620,7 +629,7 @@ def get_params_dict(self, which="free", kind="quantity"):
620629
if which == "free":
621630
ps = self.free_params
622631
elif which == "all":
623-
ps = self.params_ordered
632+
ps = self.params
624633
else:
625634
raise ValueError("get_params_dict expects which to be 'all' or 'free'")
626635
c = OrderedDict()
@@ -2014,10 +2023,7 @@ def compare(
20142023
log.debug("Check verbosity - only warnings/info will be displayed")
20152024
othermodel = copy.deepcopy(othermodel)
20162025

2017-
if (
2018-
"POSEPOCH" in self.params_ordered
2019-
and "POSEPOCH" in othermodel.params_ordered
2020-
):
2026+
if "POSEPOCH" in self.params and "POSEPOCH" in othermodel.params:
20212027
if (
20222028
self.POSEPOCH.value is not None
20232029
and othermodel.POSEPOCH.value is not None
@@ -2028,7 +2034,7 @@ def compare(
20282034
% (other_model_name, model_name)
20292035
)
20302036
othermodel.change_posepoch(self.POSEPOCH.value)
2031-
if "PEPOCH" in self.params_ordered and "PEPOCH" in othermodel.params_ordered:
2037+
if "PEPOCH" in self.params and "PEPOCH" in othermodel.params:
20322038
if (
20332039
self.PEPOCH.value is not None
20342040
and self.PEPOCH.value != othermodel.PEPOCH.value
@@ -2037,7 +2043,7 @@ def compare(
20372043
"Updating PEPOCH in %s to match %s" % (other_model_name, model_name)
20382044
)
20392045
othermodel.change_pepoch(self.PEPOCH.value)
2040-
if "DMEPOCH" in self.params_ordered and "DMEPOCH" in othermodel.params_ordered:
2046+
if "DMEPOCH" in self.params and "DMEPOCH" in othermodel.params:
20412047
if (
20422048
self.DMEPOCH.value is not None
20432049
and self.DMEPOCH.value != othermodel.DMEPOCH.value
@@ -2072,7 +2078,7 @@ def compare(
20722078
f"{model_name} is in ECL({self.ECL.value}) coordinates but {other_model_name} is in ICRS coordinates and convertcoordinates=False"
20732079
)
20742080

2075-
for pn in self.params_ordered:
2081+
for pn in self.params:
20762082
par = getattr(self, pn)
20772083
if par.value is None:
20782084
continue
@@ -2299,8 +2305,8 @@ def compare(
22992305
)
23002306

23012307
# Now print any parameters in othermodel that were missing in self.
2302-
mypn = self.params_ordered
2303-
for opn in othermodel.params_ordered:
2308+
mypn = self.params
2309+
for opn in othermodel.params:
23042310
if opn in mypn and getattr(self, opn).value is not None:
23052311
continue
23062312
if nodmx and opn.startswith("DMX"):

src/pint/scripts/event_optimize.py

+95-4
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,82 @@ def get_fit_keyvals(model, phs=0.0, phserr=0.1):
247247
return fitkeys, np.asarray(fitvals), np.asarray(fiterrs)
248248

249249

250+
def run_sampler_autocorr(sampler, pos, nsteps, burnin, csteps=100, crit1=10):
251+
"""Runs the sampler and checks for chain convergence. Return the converged sampler and the mean autocorrelation time per 100 steps
252+
Parameters
253+
----------
254+
Sampler
255+
The Emcee Ensemble Sampler
256+
pos
257+
The Initial positions of the walkers
258+
nsteps : int
259+
The number of integration steps
260+
csteps : int
261+
The interval at which the autocorrelation time is computed.
262+
crit1 : int
263+
The ratio of chain length to autocorrelation time to satisfy convergence
264+
Returns
265+
-------
266+
The sampler and the mean autocorrelation times
267+
Note
268+
----
269+
The function checks for convergence of the chains every specified number of steps.
270+
The criteria to check for convergence is:
271+
1. the chain has to be longer than the specified ratio times the estimated autocorrelation time
272+
2. the change in the estimated autocorrelation time is less than 1%
273+
"""
274+
autocorr = []
275+
old_tau = np.inf
276+
converged1 = False
277+
converged2 = False
278+
for sample in sampler.sample(pos, iterations=nsteps, progress=True):
279+
if not converged1:
280+
# Checks if the iteration is past the burnin and checks for convergence at 10% tau change
281+
if sampler.iteration >= burnin and sampler.iteration % csteps == 0:
282+
tau = sampler.get_autocorr_time(tol=0, quiet=True)
283+
if np.any(np.isnan(tau)):
284+
continue
285+
else:
286+
x = np.mean(tau)
287+
autocorr.append(x)
288+
converged1 = np.all(tau * crit1 < sampler.iteration)
289+
converged1 &= np.all(np.abs(old_tau - tau) / tau < 0.1)
290+
# log.info("The mean estimated integrated autocorrelation step is: " + str(x))
291+
old_tau = tau
292+
if converged1:
293+
log.info(
294+
"10 % convergence reached with a mean estimated integrated step: "
295+
+ str(x)
296+
)
297+
else:
298+
continue
299+
else:
300+
continue
301+
else:
302+
if not converged2:
303+
# Checks for convergence at every 25 steps instead of 100 and tau change is 1%
304+
if sampler.iteration % int(csteps / 4) == 0:
305+
tau = sampler.get_autocorr_time(tol=0, quiet=True)
306+
if np.any(np.isnan(tau)):
307+
continue
308+
else:
309+
x = np.mean(tau)
310+
autocorr.append(x)
311+
converged2 = np.all(tau * crit1 < sampler.iteration)
312+
converged2 &= np.all(np.abs(old_tau - tau) / tau < 0.01)
313+
# log.info("The mean estimated integrated autocorrelation step is: " + str(x))
314+
old_tau = tau
315+
converge_step = sampler.iteration
316+
else:
317+
continue
318+
if converged2 and (sampler.iteration - burnin) >= 1000:
319+
log.info(f"Convergence reached at {converge_step}")
320+
break
321+
else:
322+
continue
323+
return autocorr
324+
325+
250326
class emcee_fitter(Fitter):
251327
def __init__(
252328
self, toas=None, model=None, template=None, weights=None, phs=0.5, phserr=0.03
@@ -545,6 +621,13 @@ def main(argv=None):
545621
default=False,
546622
action="store_true",
547623
)
624+
parser.add_argument(
625+
"--no-autocorr",
626+
help="Turn the autocorrelation check function off",
627+
default=False,
628+
action="store_true",
629+
dest="noautocorr",
630+
)
548631

549632
args = parser.parse_args(argv)
550633
pint.logging.setup(
@@ -820,21 +903,29 @@ def unwrapped_lnpost(theta):
820903
pool=pool,
821904
backend=backend,
822905
)
823-
sampler.run_mcmc(pos, nsteps)
906+
if args.noautocorr:
907+
sampler.run_mcmc(pos, nsteps, progress=True)
908+
else:
909+
autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)
824910
pool.close()
825911
pool.join()
826912
except ImportError:
827913
log.info("Pathos module not available, using single core")
828914
sampler = emcee.EnsembleSampler(
829915
nwalkers, ndim, ftr.lnposterior, blobs_dtype=dtype, backend=backend
830916
)
831-
sampler.run_mcmc(pos, nsteps)
917+
if args.noautocorr:
918+
sampler.run_mcmc(pos, nsteps, progress=True)
919+
else:
920+
autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)
832921
else:
833922
sampler = emcee.EnsembleSampler(
834923
nwalkers, ndim, ftr.lnposterior, blobs_dtype=dtype, backend=backend
835924
)
836-
# The number is the number of points in the chain
837-
sampler.run_mcmc(pos, nsteps)
925+
if args.noautocorr:
926+
sampler.run_mcmc(pos, nsteps, progress=True)
927+
else:
928+
autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)
838929

839930
def chains_to_dict(names, sampler):
840931
samples = np.transpose(sampler.get_chain(), (1, 0, 2))

tests/test_design_matrix.py

+5
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,8 @@ def test_combine_designmatrix_all(self):
110110
]
111111
== 0.0
112112
)
113+
114+
def test_param_order(self):
115+
params_dm = self.model.designmatrix(self.toas, incoffset=False)[1]
116+
params_free = self.model.free_params
117+
assert params_dm == params_free

0 commit comments

Comments
 (0)